feat(pinyin): 添加中文拼音字符统计文件以优化模型输入处理
This commit is contained in:
parent
917d5976a9
commit
4bcbbaa4eb
|
|
@ -5,9 +5,23 @@ description = "Add your description here"
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"modelscope>=1.35.1",
|
||||
"onnx>=1.20.1",
|
||||
"torch>=2.10.0",
|
||||
"amd-quark>=0.1.0",
|
||||
"bokeh>=3.8.2",
|
||||
"datasets>=4.5.0",
|
||||
"ipykernel>=7.2.0",
|
||||
"ipython>=9.10.0",
|
||||
"loguru>=0.7.3",
|
||||
"modelscope>=1.34.0",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.4.2",
|
||||
"onnxruntime>=1.24.2",
|
||||
"pandas>=3.0.0",
|
||||
"pypinyin>=0.55.0",
|
||||
"requests>=2.32.5",
|
||||
"rich>=14.3.1",
|
||||
"tensorboard>=2.20.0",
|
||||
"transformers==5.1.0",
|
||||
"typer>=0.21.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,160 @@
|
|||
import random
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from pypinyin.contrib.tone_convert import to_initials
|
||||
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
from modelscope import AutoModel, AutoTokenizer
|
||||
|
||||
# 加载分词器和模型
|
||||
# model = AutoModel.from_pretrained('iic/nlp_structbert_backbone_lite_std')
|
||||
|
||||
# tokenizer = AutoTokenizer.from_pretrained('iic/nlp_structbert_backbone_lite_std')
|
||||
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
max_workes: int = -1,
|
||||
tokenizer_name: str = "iic/nlp_structbert_backbone_lite_std",
|
||||
max_length=128,
|
||||
text_field: str = "text",
|
||||
py_style_weight = (9, 2, 1),
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.data_path = data_path
|
||||
self.max_length = max_length
|
||||
self.text_field = text_field
|
||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
self.max_workers = max_workes
|
||||
self.py_style_weight = py_style_weight
|
||||
|
||||
@staticmethod
|
||||
def smart_multi_segment_encode(texts, tokenizer, max_length=128):
|
||||
"""
|
||||
智能多段落编码
|
||||
|
||||
思路:使用分词器的基础功能,但灵活控制token_type_ids
|
||||
"""
|
||||
# 第一步:使用分词器单独编码每个段落
|
||||
encoded_segments = []
|
||||
for text in texts:
|
||||
# 注意:不添加特殊标记,我们后面统一处理
|
||||
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
encoded_segments.append(encoded)
|
||||
|
||||
# 第二步:构建完整序列
|
||||
tokens = []
|
||||
token_type_ids = []
|
||||
|
||||
# 添加[CLS]
|
||||
tokens.append(tokenizer.cls_token_id)
|
||||
token_type_ids.append(0) # CLS通常为0
|
||||
|
||||
# 添加各个段落
|
||||
for seg_idx, segment in enumerate(encoded_segments):
|
||||
# 当前段落的类型(0-3循环)
|
||||
current_type = seg_idx % 4
|
||||
|
||||
# 添加段落内容
|
||||
tokens.extend(segment)
|
||||
token_type_ids.extend([current_type] * len(segment))
|
||||
|
||||
# 添加[SEP](最后一个段落可以不加)
|
||||
if seg_idx < len(encoded_segments) - 1:
|
||||
tokens.append(tokenizer.sep_token_id)
|
||||
token_type_ids.append(current_type)
|
||||
else:
|
||||
# 最后一个段落加[SEP]
|
||||
tokens.append(tokenizer.sep_token_id)
|
||||
token_type_ids.append(current_type)
|
||||
|
||||
# 第三步:截断和填充
|
||||
if len(tokens) > max_length:
|
||||
tokens = tokens[:max_length]
|
||||
token_type_ids = token_type_ids[:max_length]
|
||||
else:
|
||||
# 填充
|
||||
padding_length = max_length - len(tokens)
|
||||
tokens = tokens + [tokenizer.pad_token_id] * padding_length
|
||||
token_type_ids = token_type_ids + [0] * padding_length # 填充部分用0
|
||||
|
||||
# 第四步:创建attention mask
|
||||
attention_mask = [
|
||||
1 if token != tokenizer.pad_token_id else 0 for token in tokens
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor([tokens]),
|
||||
"token_type_ids": torch.tensor([token_type_ids]),
|
||||
"attention_mask": torch.tensor([attention_mask]),
|
||||
}
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
|
||||
# 生成需要预测汉字对应的拼音,并进行加强
|
||||
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> List[str]:
|
||||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if text[i] == pinyin_list[i]:
|
||||
return i - 1, mask_pinyin
|
||||
else:
|
||||
py = random.choice(
|
||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||
weight=self.py_style_weight,
|
||||
)
|
||||
mask_pinyin.append(py)
|
||||
return len(text) - 1, mask_pinyin
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
worker_id = worker_info.id
|
||||
num_workers = (
|
||||
self.max_workers if self.max_workers > 0 else worker_info.num_workers
|
||||
)
|
||||
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||||
seed = base_seed + worker_id
|
||||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||
|
||||
batch_samples = []
|
||||
for sample in self.dataset:
|
||||
text = sample.get(self.text_field, "")
|
||||
if text:
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
for i in range(len(text)):
|
||||
# 如果text[i]不再字符库中,则跳过
|
||||
# 当i小于44时候,则将part1取text[0:i]
|
||||
# 当i大于44时候,则将part1取text[i-44:i]
|
||||
if text[i] == pinyin_list[i]:
|
||||
continue
|
||||
if i < 44:
|
||||
part1 = text[0:i]
|
||||
else:
|
||||
part1 = text[i - 44 : i]
|
||||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||
# 但是需要注意边界条件
|
||||
pinyin_len = np.random.choice(
|
||||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||
)
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
part2 = self.get_mask_pinyin(text[i: py_end], pinyin_list[i: py_end])
|
||||
# part3为文本,大概率(0.85)为空,不为空则是i+pinyin_len所指向的字符
|
||||
|
||||
# encoded = self.smart_multi_segment_encode([pinyin_text], self.tokenizer, self.max_length)
|
||||
|
|
@ -6,11 +6,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
import torch.optim as optim
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from tqdm.autonotebook import tqdm
|
||||
|
||||
|
||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||
|
|
@ -254,7 +254,7 @@ class InputMethodEngine(nn.Module):
|
|||
self,
|
||||
train_dataloader,
|
||||
eval_dataloader=None,
|
||||
# monitor: Optional[TrainingMonitor] = None,
|
||||
monitor=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
num_epochs=1,
|
||||
|
|
@ -285,7 +285,6 @@ class InputMethodEngine(nn.Module):
|
|||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
return current_lr
|
||||
|
||||
"""训练函数,调整了输入参数"""
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
|
|
@ -297,13 +296,16 @@ class InputMethodEngine(nn.Module):
|
|||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
# 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示
|
||||
if criterion is None:
|
||||
if loss_weight is not None:
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
weight=loss_weight, label_smoothing=label_smoothing
|
||||
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
|
||||
)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
label_smoothing=label_smoothing, ignore_index=-1
|
||||
)
|
||||
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
|
|
@ -313,86 +315,105 @@ class InputMethodEngine(nn.Module):
|
|||
processed_batches = 0
|
||||
batch_loss_sum = 0.0
|
||||
optimizer.zero_grad()
|
||||
|
||||
try:
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=int(stop_batch))
|
||||
):
|
||||
# LR Schedule
|
||||
# 学习率调度
|
||||
current_lr = lr_schedule(
|
||||
lr, stop_batch, processed_batches, warmup_steps
|
||||
)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
|
||||
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
||||
# 从 batch 中获取数据
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
token_type_ids = batch["hint"]["token_type_ids"].to(
|
||||
self.device
|
||||
) # 新增
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
|
||||
|
||||
# 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1)
|
||||
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
|
||||
|
||||
with torch.amp.autocast(
|
||||
device_type=self.device.type, enabled=mixed_precision
|
||||
):
|
||||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||||
loss = criterion(logits, labels)
|
||||
# 调用模型(训练模式)
|
||||
logits = self(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
slot_target_ids=labels,
|
||||
slot_target_mask=slot_target_mask,
|
||||
mode="train",
|
||||
) # logits: [batch, max_slot_steps, output_vocab_size]
|
||||
|
||||
# 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置)
|
||||
loss = criterion(
|
||||
logits.view(-1, self.output_vocab_size), labels.view(-1)
|
||||
)
|
||||
loss = loss / grad_accum_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (processed_batches) % grad_accum_steps == 0:
|
||||
# 梯度累积更新
|
||||
if (processed_batches + 1) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(), clip_grad_norm
|
||||
)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
batch_loss_sum += loss.item() * grad_accum_steps
|
||||
|
||||
# 定期评估
|
||||
if processed_batches % eval_frequency == 0:
|
||||
if eval_dataloader:
|
||||
self.eval()
|
||||
acc, eval_loss = self.model_eval(
|
||||
eval_dataloader, criterion
|
||||
)
|
||||
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||
self.train()
|
||||
if monitor:
|
||||
# 使用 eval_loss 作为监控指标
|
||||
monitor.add_step(
|
||||
processed_batches,
|
||||
{
|
||||
"train_loss": batch_loss_sum
|
||||
/ (
|
||||
eval_frequency
|
||||
if processed_batches > 0
|
||||
else 1
|
||||
),
|
||||
/ (eval_frequency if processed_batches > 0 else 1),
|
||||
"acc": acc,
|
||||
"loss": eval_loss,
|
||||
"lr": current_lr,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
|
||||
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||
f"current_lr: {current_lr}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||
f"current_lr: {current_lr}"
|
||||
)
|
||||
batch_loss_sum = 0.0
|
||||
|
||||
processed_batches += 1
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
processed_batches += 1
|
||||
|
||||
else:
|
||||
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
|
||||
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
|
||||
# 但为了兼容原有逻辑,这里不做额外处理
|
||||
pass
|
||||
|
||||
# 训练结束通知
|
||||
if monitor:
|
||||
monitor.finish()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
# 训练结束发送通知
|
||||
if monitor:
|
||||
monitor.finish()
|
||||
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,486 @@
|
|||
# file name: query_engine.py
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import msgpack
|
||||
|
||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||
|
||||
|
||||
class QueryEngine:
|
||||
"""
|
||||
高效拼音-字符查询引擎
|
||||
|
||||
特性:
|
||||
1. O(1)时间复杂度的ID查询
|
||||
2. O(1)时间复杂度的字符查询
|
||||
3. O(1)时间复杂度的拼音查询
|
||||
4. 内存友好,构建高效索引
|
||||
5. 支持批量查询和前缀搜索
|
||||
"""
|
||||
|
||||
def __init__(self, min_count: int = 109):
|
||||
"""初始化查询引擎"""
|
||||
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||
|
||||
# 核心索引 - 提供O(1)查询
|
||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||
|
||||
# 辅助索引 - 快速获取详细信息
|
||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||
|
||||
# 统计信息
|
||||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
self._load_time = 0.0
|
||||
self._index_time = 0.0
|
||||
|
||||
self.min_count = min_count
|
||||
|
||||
def load(
|
||||
self,
|
||||
file_path: Union[str, Path] = (
|
||||
files(__package__) / "data" / "pinyin_char_statistics.json"
|
||||
),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
加载统计结果文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
||||
|
||||
Returns:
|
||||
元数据字典
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件格式不支持
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 读取并解析文件
|
||||
self._counter_data = self.parse_file(file_path)
|
||||
|
||||
# 构建索引
|
||||
self._build_indices()
|
||||
|
||||
self._load_time = time.time() - start_time
|
||||
self._loaded = True
|
||||
|
||||
return self._counter_data.metadata
|
||||
|
||||
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
||||
"""解析文件,支持多种格式"""
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
# 尝试解压
|
||||
try:
|
||||
data = gzip.decompress(data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 尝试不同格式
|
||||
for parser_name, parser in [
|
||||
("msgpack", self._parse_msgpack),
|
||||
("pickle", self._parse_pickle),
|
||||
("json", self._parse_json),
|
||||
]:
|
||||
try:
|
||||
return parser(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise ValueError("无法解析文件格式")
|
||||
|
||||
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析msgpack格式"""
|
||||
data_dict = msgpack.unpackb(data, raw=False)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析pickle格式"""
|
||||
return pickle.loads(data)
|
||||
|
||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||
"""解析json格式"""
|
||||
data_str = data.decode("utf-8")
|
||||
data_dict = json.loads(data_str)
|
||||
return self._dict_to_counter(data_dict)
|
||||
|
||||
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||||
"""字典转PinyinCharPairsCounter"""
|
||||
# 转换CharInfo字典
|
||||
pairs_dict = {}
|
||||
if "pairs" in data_dict and data_dict["pairs"]:
|
||||
for id_str, info_dict in data_dict["pairs"].items():
|
||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||
data_dict["pairs"] = pairs_dict
|
||||
|
||||
return PinyinCharPairsCounter(**data_dict)
|
||||
|
||||
def _build_indices(self):
|
||||
"""构建所有查询索引"""
|
||||
start_time = time.time()
|
||||
|
||||
# 重置索引
|
||||
self._id_to_info.clear()
|
||||
self._char_to_ids.clear()
|
||||
self._pinyin_to_ids.clear()
|
||||
self._char_pinyin_to_ids.clear()
|
||||
self._char_freq.clear()
|
||||
self._pinyin_freq.clear()
|
||||
self._char_pinyin_map.clear()
|
||||
|
||||
# 复制频率数据
|
||||
if self._counter_data.chars:
|
||||
self._char_freq = self._counter_data.chars.copy()
|
||||
if self._counter_data.pinyins:
|
||||
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||||
|
||||
# 构建核心索引
|
||||
for char_info in self._counter_data.pairs.values():
|
||||
if char_info.count < self.min_count:
|
||||
continue
|
||||
char = char_info.char
|
||||
pinyin = char_info.pinyin
|
||||
char_info_id = char_info.id
|
||||
|
||||
# ID索引
|
||||
self._id_to_info[char_info_id] = char_info
|
||||
|
||||
# 字符索引
|
||||
if char not in self._char_to_ids:
|
||||
self._char_to_ids[char] = []
|
||||
self._char_to_ids[char].append(char_info_id)
|
||||
|
||||
# 拼音索引
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
self._pinyin_to_ids[pinyin] = []
|
||||
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||||
|
||||
# 字符-拼音映射
|
||||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
||||
|
||||
self._total_pairs = len(self._id_to_info)
|
||||
self._index_time = time.time() - start_time
|
||||
|
||||
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||||
"""
|
||||
通过ID查询字符信息 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
|
||||
Returns:
|
||||
CharInfo对象,不存在则返回None
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._id_to_info.get(id)
|
||||
|
||||
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
if char not in self._char_to_ids:
|
||||
return []
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._char_to_ids[char]
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
def query_by_pinyin(
|
||||
self, pinyin: str, limit: int = 0
|
||||
) -> List[Tuple[int, str, int]]:
|
||||
"""
|
||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
limit: 返回结果数量限制,0表示返回所有
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
if pinyin not in self._pinyin_to_ids:
|
||||
return []
|
||||
|
||||
# 获取所有相关ID
|
||||
ids = self._pinyin_to_ids[pinyin]
|
||||
|
||||
# 构建结果并排序
|
||||
results = []
|
||||
for char_info_id in ids:
|
||||
char_info = self._id_to_info[char_info_id]
|
||||
results.append((char_info_id, char_info.char, char_info.count))
|
||||
|
||||
# 按次数降序排序
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
if limit > 0 and len(results) > limit:
|
||||
results = results[:limit]
|
||||
|
||||
return results
|
||||
|
||||
def get_char_frequency(self, char: str) -> int:
|
||||
"""
|
||||
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._char_freq.get(char, 0)
|
||||
|
||||
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||||
"""
|
||||
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
pinyin: 拼音字符串
|
||||
|
||||
Returns:
|
||||
总出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._pinyin_freq.get(pinyin, 0)
|
||||
|
||||
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||||
"""
|
||||
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
pinyin: 拼音字符串
|
||||
|
||||
Returns:
|
||||
出现次数
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||
|
||||
def get_all_weights(self):
|
||||
"""获取所有字符-拼音对出现的次数 - O(n)时间复杂度"""
|
||||
items_sorted = sorted(self._id_to_info.items(), key=lambda x: x[0])
|
||||
return [info.count for _, info in items_sorted]
|
||||
|
||||
def get_char_info_by_char_pinyin(
|
||||
self, char: str, pinyin: str
|
||||
) -> Optional[CharInfo]:
|
||||
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||
|
||||
Args:
|
||||
char: 汉字字符
|
||||
pinyin: 拼音字符串
|
||||
|
||||
Returns:
|
||||
ID和频率
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||
return self.query_by_id(char_info_id)
|
||||
|
||||
def batch_get_char_pinyin_info(
|
||||
self, pairs: List[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], CharInfo]:
|
||||
"""批量获取汉字-拼音信息
|
||||
|
||||
Args:
|
||||
pairs: 汉字-拼音列表
|
||||
|
||||
Returns:
|
||||
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
result = {}
|
||||
for pair in pairs:
|
||||
char_info_id = self._char_pinyin_to_ids.get(pair)
|
||||
if char_info_id is not None:
|
||||
result[pair] = self._id_to_info.get(char_info_id)
|
||||
else:
|
||||
result[pair] = None
|
||||
return result
|
||||
|
||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||
"""
|
||||
批量ID查询 - O(n)时间复杂度
|
||||
|
||||
Args:
|
||||
ids: ID列表
|
||||
|
||||
Returns:
|
||||
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
results = {}
|
||||
for id_value in ids:
|
||||
results[id_value] = self._id_to_info.get(id_value, None)
|
||||
|
||||
return results
|
||||
|
||||
def batch_query_by_chars(
|
||||
self, chars: List[str], limit_per_char: int = 0
|
||||
) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||
"""
|
||||
批量字符查询
|
||||
|
||||
Args:
|
||||
chars: 字符列表
|
||||
limit_per_char: 每个字符的结果数量限制
|
||||
|
||||
Returns:
|
||||
字典,key为字符,value为查询结果列表
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
results = {}
|
||||
for char in chars:
|
||||
results[char] = self.query_by_char(char, limit_per_char)
|
||||
|
||||
return results
|
||||
|
||||
def search_chars_by_prefix(
|
||||
self, prefix: str, limit: int = 20
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||
|
||||
Args:
|
||||
prefix: 字符前缀
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
列表,每个元素为(字符, 总频率),按频率降序排序
|
||||
"""
|
||||
if not self._loaded:
|
||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||
|
||||
matches = []
|
||||
for char, freq in self._char_freq.items():
|
||||
if char.startswith(prefix):
|
||||
matches.append((char, freq))
|
||||
|
||||
# 按频率降序排序
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return matches[:limit] if limit > 0 else matches
|
||||
|
||||
def is_chinese_char(self, char: str) -> bool:
|
||||
"""
|
||||
判断是否是汉字
|
||||
"""
|
||||
if not self.is_loaded():
|
||||
raise ValueError("请先调用 load() 方法加载数据")
|
||||
return char in self._char_to_ids
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if not self._loaded:
|
||||
return {"status": "not_loaded"}
|
||||
|
||||
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
||||
:10
|
||||
]
|
||||
|
||||
top_pinyins = sorted(
|
||||
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
"status": "loaded",
|
||||
"timestamp": self._counter_data.timestamp,
|
||||
"total_pairs": self._total_pairs,
|
||||
"total_characters": len(self._char_freq),
|
||||
"total_pinyins": len(self._pinyin_freq),
|
||||
"valid_input_character_count": self._counter_data.valid_input_character_count,
|
||||
"load_time_seconds": self._load_time,
|
||||
"index_time_seconds": self._index_time,
|
||||
"top_chars": top_chars,
|
||||
"top_pinyins": top_pinyins,
|
||||
"metadata": self._counter_data.metadata,
|
||||
}
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""检查数据是否已加载"""
|
||||
return self._loaded
|
||||
|
||||
def clear(self):
|
||||
"""清除所有数据和索引,释放内存"""
|
||||
self._counter_data = None
|
||||
self._id_to_info.clear()
|
||||
self._char_to_ids.clear()
|
||||
self._pinyin_to_ids.clear()
|
||||
self._char_freq.clear()
|
||||
self._pinyin_freq.clear()
|
||||
self._char_pinyin_map.clear()
|
||||
self._char_pinyin_to_ids.clear()
|
||||
self._loaded = False
|
||||
self._total_pairs = 0
|
||||
self._load_time = 0.0
|
||||
self._index_time = 0.0
|
||||
Loading…
Reference in New Issue