feat: 添加拼音输入法模拟数据集及相关功能实现
This commit is contained in:
parent
5ea0b0b31c
commit
f2c260de72
|
|
@ -213,4 +213,5 @@ cython_debug/
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
*.log
|
*.log
|
||||||
marimo/
|
marimo/
|
||||||
|
__marimo__/
|
||||||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
||||||
"msgpack>=1.1.2",
|
"msgpack>=1.1.2",
|
||||||
"pypinyin>=0.55.0",
|
"pypinyin>=0.55.0",
|
||||||
"rich>=14.3.1",
|
"rich>=14.3.1",
|
||||||
|
"transformers>=5.1.0",
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from loguru import logger
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from .char_info import PinyinCharPairsCounter, CharInfo
|
from .char_info import PinyinCharPairsCounter, CharInfo
|
||||||
|
from .query import QueryEngine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,6 +49,12 @@ class PinyinCharStatistics:
|
||||||
|
|
||||||
# 启动工作进程
|
# 启动工作进程
|
||||||
self._start_workers()
|
self._start_workers()
|
||||||
|
|
||||||
|
|
||||||
|
# 实现一个加载历史快照,并且以历史快照的数据进行初始化的函数
|
||||||
|
def load_history_snapshot(self, file_path: str):
|
||||||
|
"""加载历史快照,并且以历史快照的数据进行初始化"""
|
||||||
|
self
|
||||||
|
|
||||||
def _start_workers(self):
|
def _start_workers(self):
|
||||||
"""启动工作进程"""
|
"""启动工作进程"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,624 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import IterableDataset, DataLoader
|
||||||
|
from datasets import load_dataset
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
import random
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
from typing import Tuple, List, Dict, Any
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class PinyinInputDataset(IterableDataset):
|
||||||
|
"""
|
||||||
|
拼音输入法模拟数据集
|
||||||
|
|
||||||
|
特性:
|
||||||
|
1. 流式读取数据集,内存友好
|
||||||
|
2. 实时拼音转换和多音字处理
|
||||||
|
3. 前文上下文多种采样方式
|
||||||
|
4. 拼音截断模拟不完整输入
|
||||||
|
5. 内置削峰填谷算法平衡数据分布
|
||||||
|
6. 缓冲区打乱支持多进程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_dir: str,
|
||||||
|
query_engine,
|
||||||
|
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
|
||||||
|
max_len: int = 88,
|
||||||
|
text_field: str = "text",
|
||||||
|
batch_query_size: int = 1000,
|
||||||
|
# 打乱参数
|
||||||
|
shuffle: bool = True,
|
||||||
|
shuffle_buffer_size: int = 100,
|
||||||
|
# 削峰填谷参数
|
||||||
|
max_freq: int = 434748359, # "的"的频率
|
||||||
|
min_freq: int = 109, # "蓚"的频率
|
||||||
|
drop_start_freq: int = 30000000, # 开始丢弃的阈值
|
||||||
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化数据集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: 数据集目录
|
||||||
|
query_engine: QueryEngine实例
|
||||||
|
tokenizer_name: tokenizer名称
|
||||||
|
max_len: 最大序列长度
|
||||||
|
text_field: 文本字段名
|
||||||
|
batch_query_size: 批量查询大小
|
||||||
|
shuffle: 是否打乱数据
|
||||||
|
shuffle_buffer_size: 打乱缓冲区大小
|
||||||
|
max_freq: 最大字符频率
|
||||||
|
min_freq: 最小字符频率
|
||||||
|
drop_start_freq: 开始削峰的频率阈值
|
||||||
|
repeat_end_freq: 开始填谷的频率阈值
|
||||||
|
max_drop_prob: 最高频率字符的丢弃概率
|
||||||
|
max_repeat_expect: 最低频率字符的重复期望
|
||||||
|
"""
|
||||||
|
self.query_engine = query_engine
|
||||||
|
self.max_len = max_len
|
||||||
|
self.text_field = text_field
|
||||||
|
self.batch_query_size = batch_query_size
|
||||||
|
|
||||||
|
# 打乱相关参数
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
|
self.shuffle_buffer = []
|
||||||
|
|
||||||
|
# 削峰填谷参数
|
||||||
|
self.max_freq = max_freq
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.drop_start_freq = drop_start_freq
|
||||||
|
self.repeat_end_freq = repeat_end_freq
|
||||||
|
self.max_drop_prob = max_drop_prob
|
||||||
|
self.max_repeat_expect = max_repeat_expect
|
||||||
|
|
||||||
|
# 加载tokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
|
||||||
|
# 获取总字频用于后续计算
|
||||||
|
stats = query_engine.get_statistics()
|
||||||
|
self.total_chars = stats.get("valid_input_character_count", 0)
|
||||||
|
|
||||||
|
# 汉字正则表达式
|
||||||
|
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
||||||
|
|
||||||
|
# 缓存字典
|
||||||
|
self.char_info_cache = {}
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||||
|
|
||||||
|
def is_chinese_char(self, char: str) -> bool:
|
||||||
|
"""判断是否为中文字符"""
|
||||||
|
return bool(self.chinese_pattern.match(char))
|
||||||
|
|
||||||
|
def get_next_chinese_chars(
|
||||||
|
self, text: str, start_idx: int, max_count: int = 3
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
获取后续的中文字符及其拼音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 完整文本
|
||||||
|
start_idx: 起始位置
|
||||||
|
max_count: 最大字符数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
列表,每个元素为(字符, 拼音)
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for i in range(start_idx + 1, len(text)):
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
char = text[i]
|
||||||
|
if self.is_chinese_char(char):
|
||||||
|
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
||||||
|
try:
|
||||||
|
# 重新计算整个text的拼音可能效率低,但确保准确
|
||||||
|
# 实际实现中可以考虑缓存或优化
|
||||||
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
if i < len(pinyin_list):
|
||||||
|
result.append((char, pinyin_list[i]))
|
||||||
|
count += 1
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 非汉字,继续查找
|
||||||
|
continue
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def sample_context(self, context: str) -> str:
|
||||||
|
"""
|
||||||
|
三种方式采样前文上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 原始前文(最多100字符)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
采样后的54个字符
|
||||||
|
"""
|
||||||
|
if not context:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 确保有足够长度
|
||||||
|
context_len = len(context)
|
||||||
|
|
||||||
|
# 随机选择采样方式 (各1/3概率)
|
||||||
|
choice = random.random()
|
||||||
|
|
||||||
|
if choice < 0.333:
|
||||||
|
# 方式1: 靠近汉字的54个字符
|
||||||
|
return context[-54:] if context_len >= 54 else context
|
||||||
|
elif choice < 0.667:
|
||||||
|
# 方式2: 随机位置取46个连续字符
|
||||||
|
if context_len <= 46:
|
||||||
|
return context
|
||||||
|
start = random.randint(0, context_len - 46)
|
||||||
|
return context[start : start + 46]
|
||||||
|
else:
|
||||||
|
# 方式3: 12+6×7组合
|
||||||
|
if context_len < 12:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# 最后12个字符
|
||||||
|
last_12 = context[-12:]
|
||||||
|
|
||||||
|
# 从剩下的前88个字符中随机取6段,每段7个字符
|
||||||
|
remaining = context[:-12] if context_len > 12 else ""
|
||||||
|
remaining_len = len(remaining)
|
||||||
|
|
||||||
|
if remaining_len < 7:
|
||||||
|
# 如果不够7个字符,直接返回最后12个字符
|
||||||
|
return last_12
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
for _ in range(6):
|
||||||
|
if remaining_len < 7:
|
||||||
|
break
|
||||||
|
start = random.randint(0, remaining_len - 7)
|
||||||
|
segment = remaining[start : start + 7]
|
||||||
|
segments.append(segment)
|
||||||
|
|
||||||
|
# 拼接
|
||||||
|
combined = "".join(segments)
|
||||||
|
result = combined + last_12
|
||||||
|
|
||||||
|
# 确保总长度为54(可能不足)
|
||||||
|
if len(result) < 54:
|
||||||
|
# 如果不够,从前面补一些字符
|
||||||
|
needed = 54 - len(result)
|
||||||
|
if context_len >= needed:
|
||||||
|
result = context[:needed] + result
|
||||||
|
|
||||||
|
# 截断到54字符
|
||||||
|
return result[:54]
|
||||||
|
|
||||||
|
def truncate_pinyin(self, pinyin: str) -> str:
|
||||||
|
"""
|
||||||
|
截断拼音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pinyin: 原始拼音
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
截断后的拼音,可能为空字符串
|
||||||
|
"""
|
||||||
|
if not pinyin:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 随机决定截断方式
|
||||||
|
rand_val = random.random()
|
||||||
|
|
||||||
|
if rand_val < 0.1:
|
||||||
|
# 10%概率截断为空
|
||||||
|
return ""
|
||||||
|
elif rand_val < 0.6:
|
||||||
|
# 50%概率不截断
|
||||||
|
return pinyin
|
||||||
|
else:
|
||||||
|
# 40%概率随机截断
|
||||||
|
# 均匀分配剩余概率给各种截断长度
|
||||||
|
max_len = len(pinyin)
|
||||||
|
if max_len <= 1:
|
||||||
|
return pinyin
|
||||||
|
|
||||||
|
# 随机选择截断长度 (1 到 max_len-1)
|
||||||
|
trunc_len = random.randint(1, max_len - 1)
|
||||||
|
return pinyin[:trunc_len]
|
||||||
|
|
||||||
|
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
处理拼音序列,逐个截断并拼接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pinyin_list: 拼音列表,长度1-4
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
拼接后的拼音字符串
|
||||||
|
"""
|
||||||
|
result_parts = []
|
||||||
|
|
||||||
|
for pinyin in pinyin_list:
|
||||||
|
truncated = self.truncate_pinyin(pinyin)
|
||||||
|
if not truncated:
|
||||||
|
# 如果某个拼音截断为空,则停止
|
||||||
|
break
|
||||||
|
result_parts.append(truncated)
|
||||||
|
|
||||||
|
if not result_parts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result = "".join(result_parts)
|
||||||
|
|
||||||
|
# 限制最大长度
|
||||||
|
if len(result) > 18:
|
||||||
|
result = result[:18]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def adjust_frequency(self, freq: int) -> int:
|
||||||
|
"""
|
||||||
|
削峰填谷 - 根据频率调整采样
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freq: 当前字符频率
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
调整后的采样次数,0表示丢弃
|
||||||
|
"""
|
||||||
|
# 1. 削峰处理(高频字,>= 3000W开始丢弃)
|
||||||
|
if freq >= self.drop_start_freq:
|
||||||
|
# 线性丢弃概率:3000W时丢弃概率为0,434748359时丢弃概率为0.8
|
||||||
|
# 使用线性插值计算丢弃概率
|
||||||
|
if self.max_freq == self.drop_start_freq:
|
||||||
|
drop_prob = 0.0 # 防止除零
|
||||||
|
else:
|
||||||
|
drop_prob = (
|
||||||
|
self.max_drop_prob
|
||||||
|
* (freq - self.drop_start_freq)
|
||||||
|
/ (self.max_freq - self.drop_start_freq)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据丢弃概率决定是否保留
|
||||||
|
if random.random() < drop_prob:
|
||||||
|
return 0 # 丢弃该样本
|
||||||
|
else:
|
||||||
|
return 1 # 保留,但不重复
|
||||||
|
|
||||||
|
# 2. 填谷处理(低频字,<= 1W开始重复)
|
||||||
|
elif freq <= self.repeat_end_freq:
|
||||||
|
# 线性重复期望:1W时重复期望为0,109时重复期望为50
|
||||||
|
# 使用线性插值计算期望重复次数
|
||||||
|
if freq <= self.min_freq:
|
||||||
|
repeat_expect = self.max_repeat_expect # 最低频字重复期望为50
|
||||||
|
else:
|
||||||
|
if self.repeat_end_freq == self.min_freq:
|
||||||
|
repeat_expect = 0 # 防止除零
|
||||||
|
else:
|
||||||
|
# 线性插值公式
|
||||||
|
repeat_expect = (
|
||||||
|
self.max_repeat_expect
|
||||||
|
* (self.repeat_end_freq - freq)
|
||||||
|
/ (self.repeat_end_freq - self.min_freq)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 期望重复次数转换为实际重复次数
|
||||||
|
# 使用泊松分布实现期望重复,确保有随机性
|
||||||
|
repeat_count = np.random.poisson(repeat_expect)
|
||||||
|
|
||||||
|
# 确保至少返回1次
|
||||||
|
return max(1, repeat_count)
|
||||||
|
|
||||||
|
# 3. 中等频率字(1W < freq < 3000W)
|
||||||
|
else:
|
||||||
|
return 1 # 保持原样
|
||||||
|
|
||||||
|
def batch_get_char_info(
|
||||||
|
self, char_pinyin_pairs: List[Tuple[str, str]]
|
||||||
|
) -> Dict[Tuple[str, str], Any]:
|
||||||
|
"""
|
||||||
|
批量获取字符信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_pinyin_pairs: [(字符, 拼音), ...]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key为(字符, 拼音),value为(id, 频率)或None
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# 先检查缓存
|
||||||
|
uncached_pairs = []
|
||||||
|
for pair in char_pinyin_pairs:
|
||||||
|
if pair in self.char_info_cache:
|
||||||
|
results[pair] = self.char_info_cache[pair]
|
||||||
|
else:
|
||||||
|
uncached_pairs.append(pair)
|
||||||
|
|
||||||
|
# 批量查询未缓存的
|
||||||
|
if uncached_pairs:
|
||||||
|
# 使用query_engine批量查询
|
||||||
|
char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs)
|
||||||
|
for pair, char_info in char_infos.items():
|
||||||
|
if char_info:
|
||||||
|
info = {
|
||||||
|
"id": char_info.id,
|
||||||
|
"freq": char_info.count,
|
||||||
|
"char": char_info.char,
|
||||||
|
"pinyin": char_info.pinyin,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
info = None
|
||||||
|
|
||||||
|
results[pair] = info
|
||||||
|
self.char_info_cache[pair] = info
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _process_batch(self, char_pinyin_batch, char_positions, text):
|
||||||
|
"""处理批量字符"""
|
||||||
|
# 批量查询字符信息
|
||||||
|
char_info_map = self.batch_get_char_info(char_pinyin_batch)
|
||||||
|
|
||||||
|
batch_samples = []
|
||||||
|
|
||||||
|
for pos_info in char_positions:
|
||||||
|
char = pos_info["char"]
|
||||||
|
pinyin = pos_info["pinyin"]
|
||||||
|
next_pinyins = pos_info["next_pinyins"]
|
||||||
|
context = pos_info["context"]
|
||||||
|
|
||||||
|
# 获取字符信息
|
||||||
|
char_info = char_info_map.get((char, pinyin))
|
||||||
|
if not char_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 削峰填谷调整
|
||||||
|
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||||
|
if adjust_factor <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 前文采样
|
||||||
|
sampled_context = self.sample_context(context)
|
||||||
|
|
||||||
|
# 拼音处理
|
||||||
|
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||||||
|
if not processed_pinyin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
hint = self.tokenizer(
|
||||||
|
sampled_context,
|
||||||
|
processed_pinyin,
|
||||||
|
max_length=self.max_len,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成样本
|
||||||
|
sample = {
|
||||||
|
"hint": hint,
|
||||||
|
"txt": sampled_context,
|
||||||
|
"py": processed_pinyin,
|
||||||
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
|
"char": char,
|
||||||
|
"freq": char_info["freq"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据调整因子重复样本
|
||||||
|
for _ in range(adjust_factor):
|
||||||
|
batch_samples.append(sample)
|
||||||
|
|
||||||
|
return batch_samples
|
||||||
|
|
||||||
|
def _shuffle_and_yield(self, batch_samples):
|
||||||
|
"""打乱并yield样本"""
|
||||||
|
if not self.shuffle:
|
||||||
|
for sample in batch_samples:
|
||||||
|
yield sample
|
||||||
|
return
|
||||||
|
|
||||||
|
# 添加到打乱缓冲区
|
||||||
|
self.shuffle_buffer.extend(batch_samples)
|
||||||
|
|
||||||
|
# 如果缓冲区达到指定大小,打乱并输出
|
||||||
|
if len(self.shuffle_buffer) >= self.shuffle_buffer_size:
|
||||||
|
random.shuffle(self.shuffle_buffer)
|
||||||
|
for sample in self.shuffle_buffer:
|
||||||
|
yield sample
|
||||||
|
self.shuffle_buffer = []
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
迭代器实现,支持多进程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
生成器,每次返回一个样本
|
||||||
|
"""
|
||||||
|
# 获取worker信息,为每个worker设置不同的随机种子
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
|
||||||
|
if worker_info is not None:
|
||||||
|
worker_id = worker_info.id
|
||||||
|
# 使用base_seed + worker_id确保每个worker有不同但确定的随机序列
|
||||||
|
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||||||
|
seed = base_seed + worker_id
|
||||||
|
random.seed(seed % (2**32))
|
||||||
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
|
# 重置打乱缓冲区
|
||||||
|
self.shuffle_buffer = []
|
||||||
|
|
||||||
|
for item in self.dataset:
|
||||||
|
text = item.get(self.text_field, "")
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转换为拼音列表
|
||||||
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
# 批量收集需要查询的字符信息
|
||||||
|
char_pinyin_batch = []
|
||||||
|
char_positions = [] # 保存字符位置和上下文信息
|
||||||
|
# 遍历文本中的每个字符
|
||||||
|
for i, (char, py) in enumerate(zip(text, pinyin_list)):
|
||||||
|
if not self.is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取后续最多3个中文字符的拼音
|
||||||
|
next_chars = self.get_next_chinese_chars(text, i, max_count=3)
|
||||||
|
next_pinyins = [py] + [p for _, p in next_chars]
|
||||||
|
# 获取前文上下文(最多100字符)
|
||||||
|
context = text[max(0, i - 100) : i]
|
||||||
|
|
||||||
|
# 收集信息用于批量查询
|
||||||
|
char_pinyin_batch.append((char, py))
|
||||||
|
char_positions.append(
|
||||||
|
{
|
||||||
|
"index": i,
|
||||||
|
"char": char,
|
||||||
|
"pinyin": py,
|
||||||
|
"next_pinyins": next_pinyins,
|
||||||
|
"context": context,
|
||||||
|
"next_chars": next_chars,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 达到批量大小时处理
|
||||||
|
if len(char_pinyin_batch) >= self.batch_query_size:
|
||||||
|
batch_samples = self._process_batch(
|
||||||
|
char_pinyin_batch, char_positions, text
|
||||||
|
)
|
||||||
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
char_pinyin_batch = []
|
||||||
|
char_positions = []
|
||||||
|
# 处理剩余的字符
|
||||||
|
if char_pinyin_batch:
|
||||||
|
batch_samples = self._process_batch(
|
||||||
|
char_pinyin_batch, char_positions, text
|
||||||
|
)
|
||||||
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
|
# 清空缓冲区(处理完所有数据后)
|
||||||
|
if self.shuffle_buffer:
|
||||||
|
random.shuffle(self.shuffle_buffer)
|
||||||
|
for sample in self.shuffle_buffer:
|
||||||
|
yield sample
|
||||||
|
self.shuffle_buffer = []
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
由于是流式数据集,无法预先知道长度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
返回一个估计值或-1
|
||||||
|
"""
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数,用于DataLoader
|
||||||
|
def worker_init_fn(worker_id):
|
||||||
|
"""DataLoader worker初始化函数"""
|
||||||
|
# 设置每个worker的随机种子
|
||||||
|
seed = torch.initial_seed() + worker_id
|
||||||
|
random.seed(seed % (2**32))
|
||||||
|
np.random.seed(seed % (2**32))
|
||||||
|
torch.manual_seed(seed % (2**32))
|
||||||
|
|
||||||
|
|
||||||
|
def custom_collate(batch):
|
||||||
|
"""自定义批处理函数"""
|
||||||
|
if not batch:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 处理hint字段
|
||||||
|
hints = [item["hint"] for item in batch]
|
||||||
|
|
||||||
|
# 合并所有张量字段
|
||||||
|
result = {
|
||||||
|
"hint": {
|
||||||
|
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||||||
|
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||||
|
},
|
||||||
|
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||||
|
"char": [item["char"] for item in batch],
|
||||||
|
"txt": [item["txt"] for item in batch],
|
||||||
|
"py": [item["py"] for item in batch],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果存在token_type_ids则添加
|
||||||
|
if "token_type_ids" in hints[0]:
|
||||||
|
result["hint"]["token_type_ids"] = torch.cat(
|
||||||
|
[h["token_type_ids"] for h in hints]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果存在freq则添加
|
||||||
|
if "freq" in batch[0]:
|
||||||
|
result["freq"] = torch.tensor([item["freq"] for item in batch])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from query import QueryEngine
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 初始化查询引擎
|
||||||
|
query_engine = QueryEngine()
|
||||||
|
query_engine.load("./pinyin_char_statistics.json")
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
dataset = PinyinInputDataset(
|
||||||
|
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
||||||
|
query_engine=query_engine,
|
||||||
|
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
|
max_len=88,
|
||||||
|
batch_query_size=500,
|
||||||
|
shuffle=True,
|
||||||
|
shuffle_buffer_size=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("数据集初始化")
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=256,
|
||||||
|
num_workers=12,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
# pin_memory=True,
|
||||||
|
collate_fn=custom_collate,
|
||||||
|
prefetch_factor=32,
|
||||||
|
persistent_workers=True,
|
||||||
|
shuffle=False, # 数据集内部已实现打乱
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试数据集
|
||||||
|
try:
|
||||||
|
iterator = iter(dataset)
|
||||||
|
logger.info("测试数据集")
|
||||||
|
for i, _ in tqdm(enumerate(dataloader), total=200):
|
||||||
|
if i >= 200:
|
||||||
|
break
|
||||||
|
"""
|
||||||
|
print(f"Sample {i+1}:")
|
||||||
|
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||||
|
print(f" Pinyin: {sample['py']}")
|
||||||
|
print(f" Context length: {len(sample['txt'])}")
|
||||||
|
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
||||||
|
print()
|
||||||
|
"""
|
||||||
|
except StopIteration:
|
||||||
|
print("数据集为空")
|
||||||
|
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Any
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
from char_info import CharInfo, PinyinCharPairsCounter
|
||||||
|
|
||||||
|
|
||||||
class QueryEngine:
|
class QueryEngine:
|
||||||
|
|
@ -30,6 +30,7 @@ class QueryEngine:
|
||||||
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
|
||||||
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
|
||||||
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
|
||||||
|
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
|
||||||
|
|
||||||
# 辅助索引 - 快速获取详细信息
|
# 辅助索引 - 快速获取详细信息
|
||||||
self._char_freq: Dict[str, int] = {} # 字符总频率
|
self._char_freq: Dict[str, int] = {} # 字符总频率
|
||||||
|
|
@ -62,7 +63,7 @@ class QueryEngine:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 读取并解析文件
|
# 读取并解析文件
|
||||||
self._counter_data = self._parse_file(file_path)
|
self._counter_data = self.parse_file(file_path)
|
||||||
|
|
||||||
# 构建索引
|
# 构建索引
|
||||||
self._build_indices()
|
self._build_indices()
|
||||||
|
|
@ -72,7 +73,7 @@ class QueryEngine:
|
||||||
|
|
||||||
return self._counter_data.metadata
|
return self._counter_data.metadata
|
||||||
|
|
||||||
def _parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
||||||
"""解析文件,支持多种格式"""
|
"""解析文件,支持多种格式"""
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, 'rb') as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
@ -130,6 +131,7 @@ class QueryEngine:
|
||||||
self._id_to_info.clear()
|
self._id_to_info.clear()
|
||||||
self._char_to_ids.clear()
|
self._char_to_ids.clear()
|
||||||
self._pinyin_to_ids.clear()
|
self._pinyin_to_ids.clear()
|
||||||
|
self._char_pinyin_to_ids.clear()
|
||||||
self._char_freq.clear()
|
self._char_freq.clear()
|
||||||
self._pinyin_freq.clear()
|
self._pinyin_freq.clear()
|
||||||
self._char_pinyin_map.clear()
|
self._char_pinyin_map.clear()
|
||||||
|
|
@ -161,6 +163,7 @@ class QueryEngine:
|
||||||
|
|
||||||
# 字符-拼音映射
|
# 字符-拼音映射
|
||||||
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
self._char_pinyin_map[(char, pinyin)] = char_info.count
|
||||||
|
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
|
||||||
|
|
||||||
self._total_pairs = len(self._id_to_info)
|
self._total_pairs = len(self._id_to_info)
|
||||||
self._index_time = time.time() - start_time
|
self._index_time = time.time() - start_time
|
||||||
|
|
@ -295,7 +298,45 @@ class QueryEngine:
|
||||||
raise RuntimeError("数据未加载,请先调用load()方法")
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||||
|
|
||||||
|
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
|
||||||
|
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: 汉字字符
|
||||||
|
pinyin: 拼音字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ID和频率
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||||
|
return self.query_by_id(char_info_id)
|
||||||
|
|
||||||
|
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
|
||||||
|
"""批量获取汉字-拼音信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: 汉字-拼音列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None)
|
||||||
|
"""
|
||||||
|
if not self._loaded:
|
||||||
|
raise RuntimeError("数据未加载,请先调用load()方法")
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for pair in pairs:
|
||||||
|
char_info_id = self._char_pinyin_to_ids.get(pair)
|
||||||
|
if char_info_id is not None:
|
||||||
|
result[pair] = self._id_to_info.get(char_info_id)
|
||||||
|
else:
|
||||||
|
result[pair] = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||||
"""
|
"""
|
||||||
批量ID查询 - O(n)时间复杂度
|
批量ID查询 - O(n)时间复杂度
|
||||||
|
|
@ -311,10 +352,10 @@ class QueryEngine:
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for id_value in ids:
|
for id_value in ids:
|
||||||
results[id_value] = self._id_to_info.get(id_value)
|
results[id_value] = self._id_to_info.get(id_value, None)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||||
"""
|
"""
|
||||||
批量字符查询
|
批量字符查询
|
||||||
|
|
@ -408,6 +449,7 @@ class QueryEngine:
|
||||||
self._char_freq.clear()
|
self._char_freq.clear()
|
||||||
self._pinyin_freq.clear()
|
self._pinyin_freq.clear()
|
||||||
self._char_pinyin_map.clear()
|
self._char_pinyin_map.clear()
|
||||||
|
self._char_pinyin_to_ids.clear()
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self._total_pairs = 0
|
self._total_pairs = 0
|
||||||
self._load_time = 0.0
|
self._load_time = 0.0
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
# test_query_engine.py
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from suinput.query import QueryEngine
|
||||||
|
from suinput.char_info import CharInfo, PinyinCharPairsCounter
|
||||||
|
|
||||||
|
# 将测试数据保存为 JSON 文件
|
||||||
|
@pytest.fixture
|
||||||
|
def json_file_path():
|
||||||
|
yield "pinyin_char_statistics.json"
|
||||||
|
|
||||||
|
# 测试 QueryEngine 的基本功能
|
||||||
|
class TestQueryEngine:
|
||||||
|
def test_load_from_json(self, json_file_path):
|
||||||
|
"""测试从 JSON 文件加载数据"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
metadata = engine.load(json_file_path)
|
||||||
|
|
||||||
|
assert engine.is_loaded() is True
|
||||||
|
assert metadata["format"] == "json"
|
||||||
|
assert metadata["pair_count"] == 20646
|
||||||
|
|
||||||
|
def test_query_by_id(self, json_file_path):
|
||||||
|
"""测试通过 ID 查询字符信息"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
result = engine.query_by_id(8)
|
||||||
|
assert result is not None
|
||||||
|
assert result.char == "中"
|
||||||
|
assert result.pinyin == "zhong"
|
||||||
|
assert result.count == 73927282
|
||||||
|
|
||||||
|
result = engine.query_by_id(100000) # 不存在的 ID
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_query_by_char(self, json_file_path):
|
||||||
|
"""测试通过字符查询拼音信息"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
results = engine.query_by_char("长")
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0] == (159, "zhang", 15424264)
|
||||||
|
assert results[1] == (414, "chang", 6663465)
|
||||||
|
|
||||||
|
results_limited = engine.query_by_char("长", limit=1)
|
||||||
|
assert len(results_limited) == 1
|
||||||
|
assert results_limited[0] == (159, "zhang", 15424264)
|
||||||
|
|
||||||
|
results_empty = engine.query_by_char("X") # 不存在的字符
|
||||||
|
assert results_empty == []
|
||||||
|
|
||||||
|
def test_query_by_pinyin(self, json_file_path):
|
||||||
|
"""测试通过拼音查询字符信息"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
results = engine.query_by_pinyin("zhong")
|
||||||
|
assert len(results) == 57
|
||||||
|
assert results[0] == (8, "中", 73927282)
|
||||||
|
|
||||||
|
results_empty = engine.query_by_pinyin("xxx") # 不存在的拼音
|
||||||
|
assert results_empty == []
|
||||||
|
|
||||||
|
def test_get_char_frequency(self, json_file_path):
|
||||||
|
"""测试获取字符总频率"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
freq = engine.get_char_frequency("中")
|
||||||
|
assert freq == 73927282
|
||||||
|
|
||||||
|
freq_zero = engine.get_char_frequency("X") # 不存在的字符
|
||||||
|
assert freq_zero == 0
|
||||||
|
|
||||||
|
def test_get_pinyin_frequency(self, json_file_path):
|
||||||
|
"""测试获取拼音总频率"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
freq = engine.get_pinyin_frequency("zhong")
|
||||||
|
assert freq == 136246123
|
||||||
|
|
||||||
|
freq_zero = engine.get_pinyin_frequency("xxx") # 不存在的拼音
|
||||||
|
assert freq_zero == 0
|
||||||
|
|
||||||
|
def test_get_char_pinyin_count(self, json_file_path):
|
||||||
|
"""测试获取字符-拼音对的出现次数"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
count = engine.get_char_pinyin_count("中", "zhong")
|
||||||
|
assert count == 73927282
|
||||||
|
|
||||||
|
count_zero = engine.get_char_pinyin_count("中", "xxx") # 不存在的拼音
|
||||||
|
assert count_zero == 0
|
||||||
|
|
||||||
|
def test_batch_query_by_ids(self, json_file_path):
|
||||||
|
"""测试批量 ID 查询"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
results = engine.batch_query_by_ids([8, 9, 10000000])
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results[9].char == "为"
|
||||||
|
|
||||||
|
def test_search_chars_by_prefix(self, json_file_path):
|
||||||
|
"""测试根据字符前缀搜索"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
results = engine.search_chars_by_prefix("中")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0] == ("中", 73927282)
|
||||||
|
|
||||||
|
results_empty = engine.search_chars_by_prefix("X") # 不存在的前缀
|
||||||
|
assert results_empty == []
|
||||||
|
|
||||||
|
def test_get_statistics(self, json_file_path):
|
||||||
|
"""测试获取统计信息"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
|
||||||
|
stats = engine.get_statistics()
|
||||||
|
assert stats["status"] == "loaded"
|
||||||
|
assert stats["total_pairs"] == 20646
|
||||||
|
assert stats["total_characters"] == 18240
|
||||||
|
assert stats["top_chars"][0] == ("的", 439524694)
|
||||||
|
|
||||||
|
def test_clear(self, json_file_path):
|
||||||
|
"""测试清除数据"""
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
assert engine.is_loaded() is True
|
||||||
|
|
||||||
|
engine.clear()
|
||||||
|
assert engine.is_loaded() is False
|
||||||
|
assert engine.get_statistics()["status"] == "not_loaded"
|
||||||
|
|
||||||
|
def test_batch_get_char_pinyin_info(self, json_file_path):
|
||||||
|
engine = QueryEngine()
|
||||||
|
engine.load(json_file_path)
|
||||||
|
assert engine.is_loaded() is True
|
||||||
|
|
||||||
|
pairs = engine.batch_get_char_pinyin_info([("我", "wo"), ("你", "ni"), ("他", "ta")])
|
||||||
|
assert pairs[("我", "wo")] == engine.get_char_info_by_char_pinyin("我", "wo")
|
||||||
|
assert pairs[("你", "ni")] == engine.get_char_info_by_char_pinyin("你", "ni")
|
||||||
|
assert pairs[("他", "ta")] == engine.get_char_info_by_char_pinyin("他", "ta")
|
||||||
Loading…
Reference in New Issue