feat(pinyin): 添加中文拼音字符统计文件以优化模型输入处理
This commit is contained in:
parent
917d5976a9
commit
4bcbbaa4eb
|
|
@ -5,9 +5,23 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"modelscope>=1.35.1",
|
"amd-quark>=0.1.0",
|
||||||
"onnx>=1.20.1",
|
"bokeh>=3.8.2",
|
||||||
"torch>=2.10.0",
|
"datasets>=4.5.0",
|
||||||
|
"ipykernel>=7.2.0",
|
||||||
|
"ipython>=9.10.0",
|
||||||
|
"loguru>=0.7.3",
|
||||||
|
"modelscope>=1.34.0",
|
||||||
|
"msgpack>=1.1.2",
|
||||||
|
"numpy>=2.4.2",
|
||||||
|
"onnxruntime>=1.24.2",
|
||||||
|
"pandas>=3.0.0",
|
||||||
|
"pypinyin>=0.55.0",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"rich>=14.3.1",
|
||||||
|
"tensorboard>=2.20.0",
|
||||||
|
"transformers==5.1.0",
|
||||||
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,160 @@
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
from pypinyin.contrib.tone_convert import to_initials
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
|
from modelscope import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
# 加载分词器和模型
|
||||||
|
# model = AutoModel.from_pretrained('iic/nlp_structbert_backbone_lite_std')
|
||||||
|
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained('iic/nlp_structbert_backbone_lite_std')
|
||||||
|
|
||||||
|
|
||||||
|
class PinyinInputDataset(IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_path: str,
|
||||||
|
max_workes: int = -1,
|
||||||
|
tokenizer_name: str = "iic/nlp_structbert_backbone_lite_std",
|
||||||
|
max_length=128,
|
||||||
|
text_field: str = "text",
|
||||||
|
py_style_weight = (9, 2, 1),
|
||||||
|
):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
self.data_path = data_path
|
||||||
|
self.max_length = max_length
|
||||||
|
self.text_field = text_field
|
||||||
|
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||||
|
self.max_workers = max_workes
|
||||||
|
self.py_style_weight = py_style_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def smart_multi_segment_encode(texts, tokenizer, max_length=128):
|
||||||
|
"""
|
||||||
|
智能多段落编码
|
||||||
|
|
||||||
|
思路:使用分词器的基础功能,但灵活控制token_type_ids
|
||||||
|
"""
|
||||||
|
# 第一步:使用分词器单独编码每个段落
|
||||||
|
encoded_segments = []
|
||||||
|
for text in texts:
|
||||||
|
# 注意:不添加特殊标记,我们后面统一处理
|
||||||
|
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
encoded_segments.append(encoded)
|
||||||
|
|
||||||
|
# 第二步:构建完整序列
|
||||||
|
tokens = []
|
||||||
|
token_type_ids = []
|
||||||
|
|
||||||
|
# 添加[CLS]
|
||||||
|
tokens.append(tokenizer.cls_token_id)
|
||||||
|
token_type_ids.append(0) # CLS通常为0
|
||||||
|
|
||||||
|
# 添加各个段落
|
||||||
|
for seg_idx, segment in enumerate(encoded_segments):
|
||||||
|
# 当前段落的类型(0-3循环)
|
||||||
|
current_type = seg_idx % 4
|
||||||
|
|
||||||
|
# 添加段落内容
|
||||||
|
tokens.extend(segment)
|
||||||
|
token_type_ids.extend([current_type] * len(segment))
|
||||||
|
|
||||||
|
# 添加[SEP](最后一个段落可以不加)
|
||||||
|
if seg_idx < len(encoded_segments) - 1:
|
||||||
|
tokens.append(tokenizer.sep_token_id)
|
||||||
|
token_type_ids.append(current_type)
|
||||||
|
else:
|
||||||
|
# 最后一个段落加[SEP]
|
||||||
|
tokens.append(tokenizer.sep_token_id)
|
||||||
|
token_type_ids.append(current_type)
|
||||||
|
|
||||||
|
# 第三步:截断和填充
|
||||||
|
if len(tokens) > max_length:
|
||||||
|
tokens = tokens[:max_length]
|
||||||
|
token_type_ids = token_type_ids[:max_length]
|
||||||
|
else:
|
||||||
|
# 填充
|
||||||
|
padding_length = max_length - len(tokens)
|
||||||
|
tokens = tokens + [tokenizer.pad_token_id] * padding_length
|
||||||
|
token_type_ids = token_type_ids + [0] * padding_length # 填充部分用0
|
||||||
|
|
||||||
|
# 第四步:创建attention mask
|
||||||
|
attention_mask = [
|
||||||
|
1 if token != tokenizer.pad_token_id else 0 for token in tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor([tokens]),
|
||||||
|
"token_type_ids": torch.tensor([token_type_ids]),
|
||||||
|
"attention_mask": torch.tensor([attention_mask]),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 生成对应文本的拼音
|
||||||
|
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||||
|
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
|
||||||
|
# 生成需要预测汉字对应的拼音,并进行加强
|
||||||
|
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> List[str]:
|
||||||
|
mask_pinyin = []
|
||||||
|
for i in range(len(text)):
|
||||||
|
if text[i] == pinyin_list[i]:
|
||||||
|
return i - 1, mask_pinyin
|
||||||
|
else:
|
||||||
|
py = random.choice(
|
||||||
|
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||||
|
weight=self.py_style_weight,
|
||||||
|
)
|
||||||
|
mask_pinyin.append(py)
|
||||||
|
return len(text) - 1, mask_pinyin
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
worker_id = worker_info.id
|
||||||
|
num_workers = (
|
||||||
|
self.max_workers if self.max_workers > 0 else worker_info.num_workers
|
||||||
|
)
|
||||||
|
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||||||
|
seed = base_seed + worker_id
|
||||||
|
random.seed(seed % (2**32))
|
||||||
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
|
self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||||
|
|
||||||
|
batch_samples = []
|
||||||
|
for sample in self.dataset:
|
||||||
|
text = sample.get(self.text_field, "")
|
||||||
|
if text:
|
||||||
|
pinyin_list = self.generate_pinyin(text)
|
||||||
|
for i in range(len(text)):
|
||||||
|
# 如果text[i]不再字符库中,则跳过
|
||||||
|
# 当i小于44时候,则将part1取text[0:i]
|
||||||
|
# 当i大于44时候,则将part1取text[i-44:i]
|
||||||
|
if text[i] == pinyin_list[i]:
|
||||||
|
continue
|
||||||
|
if i < 44:
|
||||||
|
part1 = text[0:i]
|
||||||
|
else:
|
||||||
|
part1 = text[i - 44 : i]
|
||||||
|
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||||
|
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||||
|
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||||
|
# 但是需要注意边界条件
|
||||||
|
pinyin_len = np.random.choice(
|
||||||
|
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||||
|
)
|
||||||
|
py_end = min(i + pinyin_len, len(text))
|
||||||
|
part2 = self.get_mask_pinyin(text[i: py_end], pinyin_list[i: py_end])
|
||||||
|
# part3为文本,大概率(0.85)为空,不为空则是i+pinyin_len所指向的字符
|
||||||
|
|
||||||
|
# encoded = self.smart_multi_segment_encode([pinyin_text], self.tokenizer, self.max_length)
|
||||||
|
|
@ -6,11 +6,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.amp as amp
|
import torch.amp as amp
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from modelscope import AutoTokenizer
|
from modelscope import AutoTokenizer
|
||||||
|
from tqdm.autonotebook import tqdm
|
||||||
|
|
||||||
|
|
||||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||||
|
|
@ -254,7 +254,7 @@ class InputMethodEngine(nn.Module):
|
||||||
self,
|
self,
|
||||||
train_dataloader,
|
train_dataloader,
|
||||||
eval_dataloader=None,
|
eval_dataloader=None,
|
||||||
# monitor: Optional[TrainingMonitor] = None,
|
monitor=None,
|
||||||
criterion=None,
|
criterion=None,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
|
|
@ -285,7 +285,6 @@ class InputMethodEngine(nn.Module):
|
||||||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
return current_lr
|
return current_lr
|
||||||
|
|
||||||
"""训练函数,调整了输入参数"""
|
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
@ -297,13 +296,16 @@ class InputMethodEngine(nn.Module):
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
# 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示
|
||||||
if criterion is None:
|
if criterion is None:
|
||||||
if loss_weight is not None:
|
if loss_weight is not None:
|
||||||
criterion = nn.CrossEntropyLoss(
|
criterion = nn.CrossEntropyLoss(
|
||||||
weight=loss_weight, label_smoothing=label_smoothing
|
weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
criterion = nn.CrossEntropyLoss(
|
||||||
|
label_smoothing=label_smoothing, ignore_index=-1
|
||||||
|
)
|
||||||
|
|
||||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
||||||
|
|
@ -313,86 +315,105 @@ class InputMethodEngine(nn.Module):
|
||||||
processed_batches = 0
|
processed_batches = 0
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for batch_idx, batch in enumerate(
|
for batch_idx, batch in enumerate(
|
||||||
tqdm(train_dataloader, total=int(stop_batch))
|
tqdm(train_dataloader, total=int(stop_batch))
|
||||||
):
|
):
|
||||||
# LR Schedule
|
# 学习率调度
|
||||||
current_lr = lr_schedule(
|
current_lr = lr_schedule(
|
||||||
lr, stop_batch, processed_batches, warmup_steps
|
lr, stop_batch, processed_batches, warmup_steps
|
||||||
)
|
)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group["lr"] = current_lr
|
param_group["lr"] = current_lr
|
||||||
|
|
||||||
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
# 从 batch 中获取数据
|
||||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
token_type_ids = batch["hint"]["token_type_ids"].to(
|
token_type_ids = batch["hint"]["token_type_ids"].to(self.device)
|
||||||
self.device
|
labels = batch["char_id"].to(self.device) # [batch, max_slot_steps]
|
||||||
) # 新增
|
|
||||||
pg = batch["pg"].to(self.device)
|
# 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1)
|
||||||
labels = batch["char_id"].to(self.device)
|
slot_target_mask = (labels != -1).float() # [batch, max_slot_steps]
|
||||||
|
|
||||||
with torch.amp.autocast(
|
with torch.amp.autocast(
|
||||||
device_type=self.device.type, enabled=mixed_precision
|
device_type=self.device.type, enabled=mixed_precision
|
||||||
):
|
):
|
||||||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
# 调用模型(训练模式)
|
||||||
loss = criterion(logits, labels)
|
logits = self(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
slot_target_ids=labels,
|
||||||
|
slot_target_mask=slot_target_mask,
|
||||||
|
mode="train",
|
||||||
|
) # logits: [batch, max_slot_steps, output_vocab_size]
|
||||||
|
|
||||||
|
# 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置)
|
||||||
|
loss = criterion(
|
||||||
|
logits.view(-1, self.output_vocab_size), labels.view(-1)
|
||||||
|
)
|
||||||
loss = loss / grad_accum_steps
|
loss = loss / grad_accum_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
if (processed_batches) % grad_accum_steps == 0:
|
# 梯度累积更新
|
||||||
|
if (processed_batches + 1) % grad_accum_steps == 0:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||||
self.parameters(), clip_grad_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
batch_loss_sum += loss.item() * grad_accum_steps
|
batch_loss_sum += loss.item() * grad_accum_steps
|
||||||
|
|
||||||
|
# 定期评估
|
||||||
if processed_batches % eval_frequency == 0:
|
if processed_batches % eval_frequency == 0:
|
||||||
if eval_dataloader:
|
if eval_dataloader:
|
||||||
self.eval()
|
self.eval()
|
||||||
acc, eval_loss = self.model_eval(
|
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
|
||||||
eval_dataloader, criterion
|
|
||||||
)
|
|
||||||
self.train()
|
self.train()
|
||||||
if monitor:
|
if monitor:
|
||||||
# 使用 eval_loss 作为监控指标
|
|
||||||
monitor.add_step(
|
monitor.add_step(
|
||||||
processed_batches,
|
processed_batches,
|
||||||
{
|
{
|
||||||
"train_loss": batch_loss_sum
|
"train_loss": batch_loss_sum
|
||||||
/ (
|
/ (eval_frequency if processed_batches > 0 else 1),
|
||||||
eval_frequency
|
|
||||||
if processed_batches > 0
|
|
||||||
else 1
|
|
||||||
),
|
|
||||||
"acc": acc,
|
"acc": acc,
|
||||||
"loss": eval_loss,
|
"loss": eval_loss,
|
||||||
"lr": current_lr,
|
"lr": current_lr,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, "
|
||||||
|
f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||||
|
f"current_lr: {current_lr}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, "
|
||||||
|
f"current_lr: {current_lr}"
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
|
|
||||||
|
processed_batches += 1
|
||||||
if processed_batches >= stop_batch:
|
if processed_batches >= stop_batch:
|
||||||
break
|
break
|
||||||
processed_batches += 1
|
|
||||||
|
else:
|
||||||
|
# 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加)
|
||||||
|
# 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加
|
||||||
|
# 但为了兼容原有逻辑,这里不做额外处理
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 训练结束通知
|
||||||
|
if monitor:
|
||||||
|
monitor.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Training interrupted by user")
|
logger.info("Training interrupted by user")
|
||||||
|
|
||||||
# 训练结束发送通知
|
|
||||||
if monitor:
|
|
||||||
monitor.finish()
|
|
||||||
|
|
||||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,486 @@
|
||||||
|
# file name: query_engine.py
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||||
|
|
||||||
|
|
||||||
|
class QueryEngine:
|
||||||
|
"""
|
||||||
|
高效拼音-字符查询引擎
|
||||||
|
|
||||||
|
特性:
|
||||||
|
1. O(1)时间复杂度的ID查询
|
||||||
|
2. O(1)时间复杂度的字符查询
|
||||||
|
3. O(1)时间复杂度的拼音查询
|
||||||
|
4. 内存友好,构建高效索引
|
||||||
|
5. 支持批量查询和前缀搜索
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_count: int = 109):
|
||||||
|
"""初始化查询引擎"""
|
||||||
|
self._counter_data: Optional[PinyinCharPairsCounter] = None
|
||||||
|
|
||||||
|
# 核心索引 - 提供O(1)查询
|
||||||
|
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||||
|
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||||
|
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||||
|
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||||
|
|
||||||
|
# 辅助索引 - 快速获取详细信息
|
||||||
|
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||||
|
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
|
||||||
|
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self._loaded = False
|
||||||
|
self._total_pairs = 0
|
||||||
|
self._load_time = 0.0
|
||||||
|
self._index_time = 0.0
|
||||||
|
|
||||||
|
self.min_count = min_count
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path] = (
|
||||||
|
files(__package__) / "data" / "pinyin_char_statistics.json"
|
||||||
|
),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
加载统计结果文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
元数据字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
ValueError: 文件格式不支持
|
||||||
|
"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 读取并解析文件
|
||||||
|
self._counter_data = self.parse_file(file_path)
|
||||||
|
|
||||||
|
# 构建索引
|
||||||
|
self._build_indices()
|
||||||
|
|
||||||
|
self._load_time = time.time() - start_time
|
||||||
|
self._loaded = True
|
||||||
|
|
||||||
|
return self._counter_data.metadata
|
||||||
|
|
||||||
|
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
||||||
|
"""解析文件,支持多种格式"""
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
# 尝试解压
|
||||||
|
try:
|
||||||
|
data = gzip.decompress(data)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试不同格式
|
||||||
|
for parser_name, parser in [
|
||||||
|
("msgpack", self._parse_msgpack),
|
||||||
|
("pickle", self._parse_pickle),
|
||||||
|
("json", self._parse_json),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
return parser(data)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError("无法解析文件格式")
|
||||||
|
|
||||||
|
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
|
"""解析msgpack格式"""
|
||||||
|
data_dict = msgpack.unpackb(data, raw=False)
|
||||||
|
return self._dict_to_counter(data_dict)
|
||||||
|
|
||||||
|
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
|
"""解析pickle格式"""
|
||||||
|
return pickle.loads(data)
|
||||||
|
|
||||||
|
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
|
"""解析json格式"""
|
||||||
|
data_str = data.decode("utf-8")
|
||||||
|
data_dict = json.loads(data_str)
|
||||||
|
return self._dict_to_counter(data_dict)
|
||||||
|
|
||||||
|
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
|
||||||
|
"""字典转PinyinCharPairsCounter"""
|
||||||
|
# 转换CharInfo字典
|
||||||
|
pairs_dict = {}
|
||||||
|
if "pairs" in data_dict and data_dict["pairs"]:
|
||||||
|
for id_str, info_dict in data_dict["pairs"].items():
|
||||||
|
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||||
|
data_dict["pairs"] = pairs_dict
|
||||||
|
|
||||||
|
return PinyinCharPairsCounter(**data_dict)
|
||||||
|
|
||||||
|
def _build_indices(self):
|
||||||
|
"""构建所有查询索引"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 重置索引
|
||||||
|
self._id_to_info.clear()
|
||||||
|
self._char_to_ids.clear()
|
||||||
|
self._pinyin_to_ids.clear()
|
||||||
|
self._char_pinyin_to_ids.clear()
|
||||||
|
self._char_freq.clear()
|
||||||
|
self._pinyin_freq.clear()
|
||||||
|
self._char_pinyin_map.clear()
|
||||||
|
|
||||||
|
# 复制频率数据
|
||||||
|
if self._counter_data.chars:
|
||||||
|
self._char_freq = self._counter_data.chars.copy()
|
||||||
|
if self._counter_data.pinyins:
|
||||||
|
self._pinyin_freq = self._counter_data.pinyins.copy()
|
||||||
|
|
||||||
|
# 构建核心索引
|
||||||
|
for char_info in self._counter_data.pairs.values():
|
||||||
|
if char_info.count < self.min_count:
|
||||||
|
continue
|
||||||
|
char = char_info.char
|
||||||
|
pinyin = char_info.pinyin
|
||||||
|
char_info_id = char_info.id
|
||||||
|
|
||||||
|
# ID索引
|
||||||
|
self._id_to_info[char_info_id] = char_info
|
||||||
|
|
||||||
|
# 字符索引
|
||||||
|
if char not in self._char_to_ids:
|
||||||
|
self._char_to_ids[char] = []
|
||||||
|
self._char_to_ids[char].append(char_info_id)
|
||||||
|
|
||||||
|
# 拼音索引
|
||||||
|
if pinyin not in self._pinyin_to_ids:
|
||||||
|
self._pinyin_to_ids[pinyin] = []
|
||||||
|
self._pinyin_to_ids[pinyin].append(char_info_id)
|
||||||
|
|
||||||
|
# 字符-拼音映射
|
||||||
|
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||||
|
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
||||||
|
|
||||||
|
self._total_pairs = len(self._id_to_info)
|
||||||
|
self._index_time = time.time() - start_time
|
||||||
|
|
||||||
|
def query_by_id(self, id: int) -> Optional[CharInfo]:
|
||||||
|
"""
|
||||||
|
通过ID查询字符信息 - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: 记录ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CharInfo对象,不存在则返回None
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
return self._id_to_info.get(id)
|
||||||
|
|
||||||
|
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
||||||
|
"""
|
||||||
|
通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: 汉字字符
|
||||||
|
limit: 返回结果数量限制,0表示返回所有
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
列表,每个元素为(id, 拼音, 次数),按次数降序排序
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
if char not in self._char_to_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取所有相关ID
|
||||||
|
ids = self._char_to_ids[char]
|
||||||
|
|
||||||
|
# 构建结果并排序
|
||||||
|
results = []
|
||||||
|
for char_info_id in ids:
|
||||||
|
char_info = self._id_to_info[char_info_id]
|
||||||
|
results.append((char_info_id, char_info.pinyin, char_info.count))
|
||||||
|
|
||||||
|
# 按次数降序排序
|
||||||
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
|
# 应用限制
|
||||||
|
if limit > 0 and len(results) > limit:
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def query_by_pinyin(
|
||||||
|
self, pinyin: str, limit: int = 0
|
||||||
|
) -> List[Tuple[int, str, int]]:
|
||||||
|
"""
|
||||||
|
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pinyin: 拼音字符串
|
||||||
|
limit: 返回结果数量限制,0表示返回所有
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
列表,每个元素为(id, 字符, 次数),按次数降序排序
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
if pinyin not in self._pinyin_to_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取所有相关ID
|
||||||
|
ids = self._pinyin_to_ids[pinyin]
|
||||||
|
|
||||||
|
# 构建结果并排序
|
||||||
|
results = []
|
||||||
|
for char_info_id in ids:
|
||||||
|
char_info = self._id_to_info[char_info_id]
|
||||||
|
results.append((char_info_id, char_info.char, char_info.count))
|
||||||
|
|
||||||
|
# 按次数降序排序
|
||||||
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
|
# 应用限制
|
||||||
|
if limit > 0 and len(results) > limit:
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_char_frequency(self, char: str) -> int:
|
||||||
|
"""
|
||||||
|
获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: 汉字字符
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
总出现次数
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
return self._char_freq.get(char, 0)
|
||||||
|
|
||||||
|
def get_pinyin_frequency(self, pinyin: str) -> int:
|
||||||
|
"""
|
||||||
|
获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
总出现次数
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
return self._pinyin_freq.get(pinyin, 0)
|
||||||
|
|
||||||
|
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
|
||||||
|
"""
|
||||||
|
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: 汉字字符
|
||||||
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
出现次数
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||||
|
|
||||||
|
def get_all_weights(self):
|
||||||
|
"""获取所有字符-拼音对出现的次数 - O(n)时间复杂度"""
|
||||||
|
items_sorted = sorted(self._id_to_info.items(), key=lambda x: x[0])
|
||||||
|
return [info.count for _, info in items_sorted]
|
||||||
|
|
||||||
|
def get_char_info_by_char_pinyin(
|
||||||
|
self, char: str, pinyin: str
|
||||||
|
) -> Optional[CharInfo]:
|
||||||
|
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: 汉字字符
|
||||||
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ID和频率
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||||
|
return self.query_by_id(char_info_id)
|
||||||
|
|
||||||
|
def batch_get_char_pinyin_info(
|
||||||
|
self, pairs: List[Tuple[str, str]]
|
||||||
|
) -> Dict[Tuple[str, str], CharInfo]:
|
||||||
|
"""批量获取汉字-拼音信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: 汉字-拼音列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for pair in pairs:
|
||||||
|
char_info_id = self._char_pinyin_to_ids.get(pair)
|
||||||
|
if char_info_id is not None:
|
||||||
|
result[pair] = self._id_to_info.get(char_info_id)
|
||||||
|
else:
|
||||||
|
result[pair] = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||||
|
"""
|
||||||
|
批量ID查询 - O(n)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key为ID,value为CharInfo对象(不存在则为None)
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for id_value in ids:
|
||||||
|
results[id_value] = self._id_to_info.get(id_value, None)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def batch_query_by_chars(
|
||||||
|
self, chars: List[str], limit_per_char: int = 0
|
||||||
|
) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||||
|
"""
|
||||||
|
批量字符查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chars: 字符列表
|
||||||
|
limit_per_char: 每个字符的结果数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key为字符,value为查询结果列表
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for char in chars:
|
||||||
|
results[char] = self.query_by_char(char, limit_per_char)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_chars_by_prefix(
|
||||||
|
self, prefix: str, limit: int = 20
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
|
"""
|
||||||
|
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: 字符前缀
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
列表,每个元素为(字符, 总频率),按频率降序排序
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for char, freq in self._char_freq.items():
|
||||||
|
if char.startswith(prefix):
|
||||||
|
matches.append((char, freq))
|
||||||
|
|
||||||
|
# 按频率降序排序
|
||||||
|
matches.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return matches[:limit] if limit > 0 else matches
|
||||||
|
|
||||||
|
def is_chinese_char(self, char: str) -> bool:
|
||||||
|
"""
|
||||||
|
判断是否是汉字
|
||||||
|
"""
|
||||||
|
if not self.is_loaded():
|
||||||
|
raise ValueError("请先调用 load() 方法加载数据")
|
||||||
|
return char in self._char_to_ids
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取系统统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
return {"status": "not_loaded"}
|
||||||
|
|
||||||
|
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
||||||
|
:10
|
||||||
|
]
|
||||||
|
|
||||||
|
top_pinyins = sorted(
|
||||||
|
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
||||||
|
)[:10]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "loaded",
|
||||||
|
"timestamp": self._counter_data.timestamp,
|
||||||
|
"total_pairs": self._total_pairs,
|
||||||
|
"total_characters": len(self._char_freq),
|
||||||
|
"total_pinyins": len(self._pinyin_freq),
|
||||||
|
"valid_input_character_count": self._counter_data.valid_input_character_count,
|
||||||
|
"load_time_seconds": self._load_time,
|
||||||
|
"index_time_seconds": self._index_time,
|
||||||
|
"top_chars": top_chars,
|
||||||
|
"top_pinyins": top_pinyins,
|
||||||
|
"metadata": self._counter_data.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""检查数据是否已加载"""
|
||||||
|
return self._loaded
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""清除所有数据和索引,释放内存"""
|
||||||
|
self._counter_data = None
|
||||||
|
self._id_to_info.clear()
|
||||||
|
self._char_to_ids.clear()
|
||||||
|
self._pinyin_to_ids.clear()
|
||||||
|
self._char_freq.clear()
|
||||||
|
self._pinyin_freq.clear()
|
||||||
|
self._char_pinyin_map.clear()
|
||||||
|
self._char_pinyin_to_ids.clear()
|
||||||
|
self._loaded = False
|
||||||
|
self._total_pairs = 0
|
||||||
|
self._load_time = 0.0
|
||||||
|
self._index_time = 0.0
|
||||||
Loading…
Reference in New Issue