feat(pinyin): 添加中文拼音字符统计文件以优化模型输入处理

This commit is contained in:
songsenand 2026-03-30 00:16:16 +08:00
parent 917d5976a9
commit 4bcbbaa4eb
6 changed files with 145543 additions and 806 deletions

View File

@ -5,9 +5,23 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"modelscope>=1.35.1", "amd-quark>=0.1.0",
"onnx>=1.20.1", "bokeh>=3.8.2",
"torch>=2.10.0", "datasets>=4.5.0",
"ipykernel>=7.2.0",
"ipython>=9.10.0",
"loguru>=0.7.3",
"modelscope>=1.34.0",
"msgpack>=1.1.2",
"numpy>=2.4.2",
"onnxruntime>=1.24.2",
"pandas>=3.0.0",
"pypinyin>=0.55.0",
"requests>=2.32.5",
"rich>=14.3.1",
"tensorboard>=2.20.0",
"transformers==5.1.0",
"typer>=0.21.1",
] ]
[tool.uv] [tool.uv]

File diff suppressed because it is too large Load Diff

160
src/model/dataset.py Normal file
View File

@ -0,0 +1,160 @@
import random
import os
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from pypinyin import lazy_pinyin, Style
from pypinyin.contrib.tone_convert import to_initials
from loguru import logger
import numpy as np
import torch
from torch.utils.data import DataLoader, IterableDataset
from modelscope import AutoModel, AutoTokenizer
# 加载分词器和模型
# model = AutoModel.from_pretrained('iic/nlp_structbert_backbone_lite_std')
# tokenizer = AutoTokenizer.from_pretrained('iic/nlp_structbert_backbone_lite_std')
class PinyinInputDataset(IterableDataset):
def __init__(
self,
data_path: str,
max_workes: int = -1,
tokenizer_name: str = "iic/nlp_structbert_backbone_lite_std",
max_length=128,
text_field: str = "text",
py_style_weight = (9, 2, 1),
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.data_path = data_path
self.max_length = max_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workes
self.py_style_weight = py_style_weight
@staticmethod
def smart_multi_segment_encode(texts, tokenizer, max_length=128):
"""
智能多段落编码
思路使用分词器的基础功能但灵活控制token_type_ids
"""
# 第一步:使用分词器单独编码每个段落
encoded_segments = []
for text in texts:
# 注意:不添加特殊标记,我们后面统一处理
encoded = tokenizer.encode(text, add_special_tokens=False)
encoded_segments.append(encoded)
# 第二步:构建完整序列
tokens = []
token_type_ids = []
# 添加[CLS]
tokens.append(tokenizer.cls_token_id)
token_type_ids.append(0) # CLS通常为0
# 添加各个段落
for seg_idx, segment in enumerate(encoded_segments):
# 当前段落的类型0-3循环
current_type = seg_idx % 4
# 添加段落内容
tokens.extend(segment)
token_type_ids.extend([current_type] * len(segment))
# 添加[SEP](最后一个段落可以不加)
if seg_idx < len(encoded_segments) - 1:
tokens.append(tokenizer.sep_token_id)
token_type_ids.append(current_type)
else:
# 最后一个段落加[SEP]
tokens.append(tokenizer.sep_token_id)
token_type_ids.append(current_type)
# 第三步:截断和填充
if len(tokens) > max_length:
tokens = tokens[:max_length]
token_type_ids = token_type_ids[:max_length]
else:
# 填充
padding_length = max_length - len(tokens)
tokens = tokens + [tokenizer.pad_token_id] * padding_length
token_type_ids = token_type_ids + [0] * padding_length # 填充部分用0
# 第四步创建attention mask
attention_mask = [
1 if token != tokenizer.pad_token_id else 0 for token in tokens
]
return {
"input_ids": torch.tensor([tokens]),
"token_type_ids": torch.tensor([token_type_ids]),
"attention_mask": torch.tensor([attention_mask]),
}
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[List[str]]:
return lazy_pinyin(text, errors=lambda x: [c for c in x])
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> List[str]:
mask_pinyin = []
for i in range(len(text)):
if text[i] == pinyin_list[i]:
return i - 1, mask_pinyin
else:
py = random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
weight=self.py_style_weight,
)
mask_pinyin.append(py)
return len(text) - 1, mask_pinyin
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = (
self.max_workers if self.max_workers > 0 else worker_info.num_workers
)
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
batch_samples = []
for sample in self.dataset:
text = sample.get(self.text_field, "")
if text:
pinyin_list = self.generate_pinyin(text)
for i in range(len(text)):
# 如果text[i]不再字符库中,则跳过
# 当i小于44时候则将part1取text[0:i]
# 当i大于44时候则将part1取text[i-44:i]
if text[i] == pinyin_list[i]:
continue
if i < 44:
part1 = text[0:i]
else:
part1 = text[i - 44 : i]
# 首先取随机值pinyin_len1-8pinyin_len取值呈高斯分布最大概率取3
# 获取text[i + pinyin_len]字符如果无法获取所指向的后如果pinyin_len
# part2的长度为x取pinyin_list[i:i+pinyin_len]为part2
# 但是需要注意边界条件
pinyin_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
)
py_end = min(i + pinyin_len, len(text))
part2 = self.get_mask_pinyin(text[i: py_end], pinyin_list[i: py_end])
# part3为文本大概率0.85为空不为空则是i+pinyin_len所指向的字符
# encoded = self.smart_multi_segment_encode([pinyin_text], self.tokenizer, self.max_length)

View File

@ -6,11 +6,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp import torch.amp as amp
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger from loguru import logger
from modelscope import AutoTokenizer from modelscope import AutoTokenizer
from tqdm.autonotebook import tqdm
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现 from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
@ -254,7 +254,7 @@ class InputMethodEngine(nn.Module):
self, self,
train_dataloader, train_dataloader,
eval_dataloader=None, eval_dataloader=None,
# monitor: Optional[TrainingMonitor] = None, monitor=None,
criterion=None, criterion=None,
optimizer=None, optimizer=None,
num_epochs=1, num_epochs=1,
@ -285,7 +285,6 @@ class InputMethodEngine(nn.Module):
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress))) current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
return current_lr return current_lr
"""训练函数,调整了输入参数"""
if self.device is None: if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device) self.to(self.device)
@ -297,13 +296,16 @@ class InputMethodEngine(nn.Module):
if optimizer is None: if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
# 损失函数:需要设置 ignore_index=-1因为标签中无效位置用 -1 表示
if criterion is None: if criterion is None:
if loss_weight is not None: if loss_weight is not None:
criterion = nn.CrossEntropyLoss( criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
) )
else: else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) criterion = nn.CrossEntropyLoss(
label_smoothing=label_smoothing, ignore_index=-1
)
scaler = amp.GradScaler(enabled=mixed_precision) scaler = amp.GradScaler(enabled=mixed_precision)
@ -313,86 +315,105 @@ class InputMethodEngine(nn.Module):
processed_batches = 0 processed_batches = 0
batch_loss_sum = 0.0 batch_loss_sum = 0.0
optimizer.zero_grad() optimizer.zero_grad()
try: try:
for epoch in range(num_epochs): for epoch in range(num_epochs):
for batch_idx, batch in enumerate( for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch)) tqdm(train_dataloader, total=int(stop_batch))
): ):
# LR Schedule # 学习率调度
current_lr = lr_schedule( current_lr = lr_schedule(
lr, stop_batch, processed_batches, warmup_steps lr, stop_batch, processed_batches, warmup_steps
) )
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group["lr"] = current_lr param_group["lr"] = current_lr
# 移动数据注意batch 中现在包含 token_type_ids # 从 batch 中获取数据
input_ids = batch["hint"]["input_ids"].to(self.device) input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to( token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
self.device labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
) # 新增
pg = batch["pg"].to(self.device) # 构建 slot_target_mask有效位置为 1无效位置为 0假设无效标签为 -1
labels = batch["char_id"].to(self.device) slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
with torch.amp.autocast( with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision device_type=self.device.type, enabled=mixed_precision
): ):
logits = self(input_ids, attention_mask, token_type_ids, pg) # 调用模型(训练模式)
loss = criterion(logits, labels) logits = self(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
slot_target_ids=labels,
slot_target_mask=slot_target_mask,
mode="train",
) # logits: [batch, max_slot_steps, output_vocab_size]
# 计算损失忽略填充位置ignore_index=-1 已在 criterion 中设置)
loss = criterion(
logits.view(-1, self.output_vocab_size), labels.view(-1)
)
loss = loss / grad_accum_steps loss = loss / grad_accum_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
if (processed_batches) % grad_accum_steps == 0: # 梯度累积更新
if (processed_batches + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
self.parameters(), clip_grad_norm
)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps batch_loss_sum += loss.item() * grad_accum_steps
# 定期评估
if processed_batches % eval_frequency == 0: if processed_batches % eval_frequency == 0:
if eval_dataloader: if eval_dataloader:
self.eval() self.eval()
acc, eval_loss = self.model_eval( acc, eval_loss = self.model_eval(eval_dataloader, criterion)
eval_dataloader, criterion
)
self.train() self.train()
if monitor: if monitor:
# 使用 eval_loss 作为监控指标
monitor.add_step( monitor.add_step(
processed_batches, processed_batches,
{ {
"train_loss": batch_loss_sum "train_loss": batch_loss_sum
/ ( / (eval_frequency if processed_batches > 0 else 1),
eval_frequency
if processed_batches > 0
else 1
),
"acc": acc, "acc": acc,
"loss": eval_loss, "loss": eval_loss,
"lr": current_lr, "lr": current_lr,
}, },
) )
logger.info( logger.info(
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
) )
else: else:
logger.info( logger.info(
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
f"current_lr: {current_lr}"
) )
batch_loss_sum = 0.0 batch_loss_sum = 0.0
processed_batches += 1
if processed_batches >= stop_batch: if processed_batches >= stop_batch:
break break
processed_batches += 1
else:
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
# 但为了兼容原有逻辑,这里不做额外处理
pass
# 训练结束通知
if monitor:
monitor.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Training interrupted by user") logger.info("Training interrupted by user")
# 训练结束发送通知
if monitor:
monitor.finish()
def load_from_state_dict(self, state_dict_path: Union[str, Path]): def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load( state_dict = torch.load(

486
src/model/query.py Normal file
View File

@ -0,0 +1,486 @@
# file name: query_engine.py
import gzip
import json
import os
import pickle
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
from .char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine:
"""
高效拼音-字符查询引擎
特性:
1. O(1)时间复杂度的ID查询
2. O(1)时间复杂度的字符查询
3. O(1)时间复杂度的拼音查询
4. 内存友好构建高效索引
5. 支持批量查询和前缀搜索
"""
def __init__(self, min_count: int = 109):
"""初始化查询引擎"""
self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0
self.min_count = min_count
def load(
self,
file_path: Union[str, Path] = (
files(__package__) / "data" / "pinyin_char_statistics.json"
),
) -> Dict[str, Any]:
"""
加载统计结果文件
Args:
file_path: 文件路径文件支持msgpack/pickle/json格式自动检测压缩
Returns:
元数据字典
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式不支持
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
start_time = time.time()
# 读取并解析文件
self._counter_data = self.parse_file(file_path)
# 构建索引
self._build_indices()
self._load_time = time.time() - start_time
self._loaded = True
return self._counter_data.metadata
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, "rb") as f:
data = f.read()
# 尝试解压
try:
data = gzip.decompress(data)
except Exception:
pass
# 尝试不同格式
for parser_name, parser in [
("msgpack", self._parse_msgpack),
("pickle", self._parse_pickle),
("json", self._parse_json),
]:
try:
return parser(data)
except Exception:
continue
raise ValueError("无法解析文件格式")
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
"""解析msgpack格式"""
data_dict = msgpack.unpackb(data, raw=False)
return self._dict_to_counter(data_dict)
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
"""解析pickle格式"""
return pickle.loads(data)
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式"""
data_str = data.decode("utf-8")
data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict)
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
"""字典转PinyinCharPairsCounter"""
# 转换CharInfo字典
pairs_dict = {}
if "pairs" in data_dict and data_dict["pairs"]:
for id_str, info_dict in data_dict["pairs"].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict["pairs"] = pairs_dict
return PinyinCharPairsCounter(**data_dict)
def _build_indices(self):
"""构建所有查询索引"""
start_time = time.time()
# 重置索引
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
# 复制频率数据
if self._counter_data.chars:
self._char_freq = self._counter_data.chars.copy()
if self._counter_data.pinyins:
self._pinyin_freq = self._counter_data.pinyins.copy()
# 构建核心索引
for char_info in self._counter_data.pairs.values():
if char_info.count < self.min_count:
continue
char = char_info.char
pinyin = char_info.pinyin
char_info_id = char_info.id
# ID索引
self._id_to_info[char_info_id] = char_info
# 字符索引
if char not in self._char_to_ids:
self._char_to_ids[char] = []
self._char_to_ids[char].append(char_info_id)
# 拼音索引
if pinyin not in self._pinyin_to_ids:
self._pinyin_to_ids[pinyin] = []
self._pinyin_to_ids[pinyin].append(char_info_id)
# 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time
def query_by_id(self, id: int) -> Optional[CharInfo]:
"""
通过ID查询字符信息 - O(1)时间复杂度
Args:
id: 记录ID
Returns:
CharInfo对象不存在则返回None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._id_to_info.get(id)
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数
Args:
char: 汉字字符
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 拼音, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if char not in self._char_to_ids:
return []
# 获取所有相关ID
ids = self._char_to_ids[char]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.pinyin, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def query_by_pinyin(
self, pinyin: str, limit: int = 0
) -> List[Tuple[int, str, int]]:
"""
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
Args:
pinyin: 拼音字符串
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 字符, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if pinyin not in self._pinyin_to_ids:
return []
# 获取所有相关ID
ids = self._pinyin_to_ids[pinyin]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.char, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def get_char_frequency(self, char: str) -> int:
"""
获取字符的总出现频率所有拼音变体之和 - O(1)时间复杂度
Args:
char: 汉字字符
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_freq.get(char, 0)
def get_pinyin_frequency(self, pinyin: str) -> int:
"""
获取拼音的总出现频率所有字符之和 - O(1)时间复杂度
Args:
pinyin: 拼音字符串
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._pinyin_freq.get(pinyin, 0)
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
"""
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0)
def get_all_weights(self):
"""获取所有字符-拼音对出现的次数 - O(n)时间复杂度"""
items_sorted = sorted(self._id_to_info.items(), key=lambda x: x[0])
return [info.count for _, info in items_sorted]
def get_char_info_by_char_pinyin(
self, char: str, pinyin: str
) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
ID和频率
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(
self, pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息
Args:
pairs: 汉字-拼音列表
Returns:
字典key为汉字-拼音对value为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
result = {}
for pair in pairs:
char_info_id = self._char_pinyin_to_ids.get(pair)
if char_info_id is not None:
result[pair] = self._id_to_info.get(char_info_id)
else:
result[pair] = None
return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
Args:
ids: ID列表
Returns:
字典key为IDvalue为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for id_value in ids:
results[id_value] = self._id_to_info.get(id_value, None)
return results
def batch_query_by_chars(
self, chars: List[str], limit_per_char: int = 0
) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
Args:
chars: 字符列表
limit_per_char: 每个字符的结果数量限制
Returns:
字典key为字符value为查询结果列表
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for char in chars:
results[char] = self.query_by_char(char, limit_per_char)
return results
def search_chars_by_prefix(
self, prefix: str, limit: int = 20
) -> List[Tuple[str, int]]:
"""
根据字符前缀搜索 - O(n)时间复杂度n为字符总数
Args:
prefix: 字符前缀
limit: 返回结果数量限制
Returns:
列表每个元素为(字符, 总频率)按频率降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
matches = []
for char, freq in self._char_freq.items():
if char.startswith(prefix):
matches.append((char, freq))
# 按频率降序排序
matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches
def is_chinese_char(self, char: str) -> bool:
"""
判断是否是汉字
"""
if not self.is_loaded():
raise ValueError("请先调用 load() 方法加载数据")
return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]:
"""
获取系统统计信息
Returns:
统计信息字典
"""
if not self._loaded:
return {"status": "not_loaded"}
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
:10
]
top_pinyins = sorted(
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
)[:10]
return {
"status": "loaded",
"timestamp": self._counter_data.timestamp,
"total_pairs": self._total_pairs,
"total_characters": len(self._char_freq),
"total_pinyins": len(self._pinyin_freq),
"valid_input_character_count": self._counter_data.valid_input_character_count,
"load_time_seconds": self._load_time,
"index_time_seconds": self._index_time,
"top_chars": top_chars,
"top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata,
}
def is_loaded(self) -> bool:
"""检查数据是否已加载"""
return self._loaded
def clear(self):
"""清除所有数据和索引,释放内存"""
self._counter_data = None
self._id_to_info.clear()
self._char_to_ids.clear()
self._pinyin_to_ids.clear()
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
self._char_pinyin_to_ids.clear()
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0

3048
uv.lock

File diff suppressed because it is too large Load Diff