Compare commits

..

2 Commits

7 changed files with 186 additions and 149 deletions

View File

@ -1,8 +1,9 @@
import random
from typing import Any, Dict, List, Tuple
import os
import json
import os
import random
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset):
"""
拼音输入法模拟数据集
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
py_group_json_file: Optional[Dict[str, int]] = None,
):
"""
初始化数据集
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
# 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True)
with open(py_group_json_file, "r") as f:
self.py_groups = json.load(f)
# 加载拼音分组
self.pg_groups = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
def get_next_chinese_chars(
self,
@ -390,8 +414,8 @@ class PinyinInputDataset(IterableDataset):
char_info = char_info_map.get((char, pinyin))
if not char_info:
continue
logger.info(f'获取字符信息: {char_info}')
logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
# 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
if not processed_pinyin:
continue
# Tokenize
hint = self.tokenizer(
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
"char_id": torch.tensor([char_info["id"]]),
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
),
}
# 根据调整因子重复样本
@ -436,7 +461,7 @@ class PinyinInputDataset(IterableDataset):
if not self.shuffle:
yield from batch_samples
return
# 使用numpy批量操作代替random.shuffle
if batch_samples:
indices = np.random.permutation(len(batch_samples))
@ -470,14 +495,15 @@ class PinyinInputDataset(IterableDataset):
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
chinese_positions = [
i for i, char in enumerate(text)
i
for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char)
]
# 批量收集需要查询的字符信息
char_pinyin_batch = []
char_positions = []
for i in chinese_positions: # 只遍历中文字符位置
char = text[i]
py = pinyin_list[i]
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
)
yield from self._shuffle_and_yield(batch_samples)
def __len__(self):
"""
由于是流式数据集无法预先知道长度
@ -591,7 +616,7 @@ if __name__ == "__main__":
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load("./pinyin_char_pairs_info.json")
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
@ -625,7 +650,7 @@ if __name__ == "__main__":
break
return
cProfile.run('profile_func(dataloader)')
"""
@ -645,4 +670,3 @@ if __name__ == "__main__":
"""
except StopIteration:
print("数据集为空")

View File

@ -1,11 +1,14 @@
# file name: query_engine.py
import json
import pickle
import msgpack
import gzip
from typing import Dict, List, Optional, Tuple, Any
import time
import json
import os
import pickle
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
from .char_info import CharInfo, PinyinCharPairsCounter
@ -13,7 +16,7 @@ from .char_info import CharInfo, PinyinCharPairsCounter
class QueryEngine:
"""
高效拼音-字符查询引擎
特性:
1. O(1)时间复杂度的ID查询
2. O(1)时间复杂度的字符查询
@ -21,22 +24,22 @@ class QueryEngine:
4. 内存友好构建高效索引
5. 支持批量查询和前缀搜索
"""
def __init__(self, min_count: int = 109):
"""初始化查询引擎"""
self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息
self._loaded = False
self._total_pairs = 0
@ -44,91 +47,96 @@ class QueryEngine:
self._index_time = 0.0
self.min_count = min_count
def load(self, file_path: str) -> Dict[str, Any]:
def load(
self,
file_path: Union[str, Path] = (
files(__package__) / "data" / "pinyin_char_statistics.json"
),
) -> Dict[str, Any]:
"""
加载统计结果文件
Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩
file_path: 文件路径文件支持msgpack/pickle/json格式自动检测压缩
Returns:
元数据字典
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式不支持
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
start_time = time.time()
# 读取并解析文件
self._counter_data = self.parse_file(file_path)
# 构建索引
self._build_indices()
self._load_time = time.time() - start_time
self._loaded = True
return self._counter_data.metadata
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = f.read()
# 尝试解压
try:
data = gzip.decompress(data)
except Exception:
pass
# 尝试不同格式
for parser_name, parser in [
('msgpack', self._parse_msgpack),
('pickle', self._parse_pickle),
('json', self._parse_json)
("msgpack", self._parse_msgpack),
("pickle", self._parse_pickle),
("json", self._parse_json),
]:
try:
return parser(data)
except Exception:
continue
raise ValueError("无法解析文件格式")
def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter:
"""解析msgpack格式"""
data_dict = msgpack.unpackb(data, raw=False)
return self._dict_to_counter(data_dict)
def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter:
"""解析pickle格式"""
return pickle.loads(data)
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式"""
data_str = data.decode('utf-8')
data_str = data.decode("utf-8")
data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict)
def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter:
"""字典转PinyinCharPairsCounter"""
# 转换CharInfo字典
pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']:
for id_str, info_dict in data_dict['pairs'].items():
if "pairs" in data_dict and data_dict["pairs"]:
for id_str, info_dict in data_dict["pairs"].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict
data_dict["pairs"] = pairs_dict
return PinyinCharPairsCounter(**data_dict)
def _build_indices(self):
"""构建所有查询索引"""
start_time = time.time()
# 重置索引
self._id_to_info.clear()
self._char_to_ids.clear()
@ -137,13 +145,13 @@ class QueryEngine:
self._char_freq.clear()
self._pinyin_freq.clear()
self._char_pinyin_map.clear()
# 复制频率数据
if self._counter_data.chars:
self._char_freq = self._counter_data.chars.copy()
if self._counter_data.pinyins:
self._pinyin_freq = self._counter_data.pinyins.copy()
# 构建核心索引
for char_info in self._counter_data.pairs.values():
if char_info.count < self.min_count:
@ -151,165 +159,169 @@ class QueryEngine:
char = char_info.char
pinyin = char_info.pinyin
char_info_id = char_info.id
# ID索引
self._id_to_info[char_info_id] = char_info
# 字符索引
if char not in self._char_to_ids:
self._char_to_ids[char] = []
self._char_to_ids[char].append(char_info_id)
# 拼音索引
if pinyin not in self._pinyin_to_ids:
self._pinyin_to_ids[pinyin] = []
self._pinyin_to_ids[pinyin].append(char_info_id)
# 字符-拼音映射
self._char_pinyin_map[(char, pinyin)] = char_info.count
self._char_pinyin_to_ids[(char, pinyin)] = char_info_id
self._total_pairs = len(self._id_to_info)
self._index_time = time.time() - start_time
def query_by_id(self, id: int) -> Optional[CharInfo]:
"""
通过ID查询字符信息 - O(1)时间复杂度
Args:
id: 记录ID
Returns:
CharInfo对象不存在则返回None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._id_to_info.get(id)
def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]:
"""
通过字符查询拼音信息 - O(1) + O(k)时间复杂度k为结果数
Args:
char: 汉字字符
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 拼音, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if char not in self._char_to_ids:
return []
# 获取所有相关ID
ids = self._char_to_ids[char]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.pinyin, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
def query_by_pinyin(
self, pinyin: str, limit: int = 0
) -> List[Tuple[int, str, int]]:
"""
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
Args:
pinyin: 拼音字符串
limit: 返回结果数量限制0表示返回所有
Returns:
列表每个元素为(id, 字符, 次数)按次数降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
if pinyin not in self._pinyin_to_ids:
return []
# 获取所有相关ID
ids = self._pinyin_to_ids[pinyin]
# 构建结果并排序
results = []
for char_info_id in ids:
char_info = self._id_to_info[char_info_id]
results.append((char_info_id, char_info.char, char_info.count))
# 按次数降序排序
results.sort(key=lambda x: x[2], reverse=True)
# 应用限制
if limit > 0 and len(results) > limit:
results = results[:limit]
return results
def get_char_frequency(self, char: str) -> int:
"""
获取字符的总出现频率所有拼音变体之和 - O(1)时间复杂度
Args:
char: 汉字字符
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_freq.get(char, 0)
def get_pinyin_frequency(self, pinyin: str) -> int:
"""
获取拼音的总出现频率所有字符之和 - O(1)时间复杂度
Args:
pinyin: 拼音字符串
Returns:
总出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._pinyin_freq.get(pinyin, 0)
def get_char_pinyin_count(self, char: str, pinyin: str) -> int:
"""
获取特定字符-拼音对的出现次数 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
出现次数
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
return self._char_pinyin_map.get((char, pinyin), 0)
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
def get_char_info_by_char_pinyin(
self, char: str, pinyin: str
) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args:
char: 汉字字符
pinyin: 拼音字符串
Returns:
ID和频率
"""
@ -319,12 +331,14 @@ class QueryEngine:
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
def batch_get_char_pinyin_info(
self, pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息
Args:
pairs: 汉字-拼音列表
Returns:
字典key为汉字-拼音对value为CharInfo对象不存在则为None
"""
@ -339,69 +353,72 @@ class QueryEngine:
else:
result[pair] = None
return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
Args:
ids: ID列表
Returns:
字典key为IDvalue为CharInfo对象不存在则为None
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for id_value in ids:
results[id_value] = self._id_to_info.get(id_value, None)
return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
def batch_query_by_chars(
self, chars: List[str], limit_per_char: int = 0
) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
Args:
chars: 字符列表
limit_per_char: 每个字符的结果数量限制
Returns:
字典key为字符value为查询结果列表
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
results = {}
for char in chars:
results[char] = self.query_by_char(char, limit_per_char)
return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
def search_chars_by_prefix(
self, prefix: str, limit: int = 20
) -> List[Tuple[str, int]]:
"""
根据字符前缀搜索 - O(n)时间复杂度n为字符总数
Args:
prefix: 字符前缀
limit: 返回结果数量限制
Returns:
列表每个元素为(字符, 总频率)按频率降序排序
"""
if not self._loaded:
raise RuntimeError("数据未加载请先调用load()方法")
matches = []
for char, freq in self._char_freq.items():
if char.startswith(prefix):
matches.append((char, freq))
# 按频率降序排序
matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches
def is_chinese_char(self, char: str) -> bool:
@ -409,31 +426,27 @@ class QueryEngine:
判断是否是汉字
"""
if not self.is_loaded():
raise ValueError('请先调用 load() 方法加载数据')
raise ValueError("请先调用 load() 方法加载数据")
return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]:
"""
获取系统统计信息
Returns:
统计信息字典
"""
if not self._loaded:
return {"status": "not_loaded"}
top_chars = sorted(
self._char_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
:10
]
top_pinyins = sorted(
self._pinyin_freq.items(),
key=lambda x: x[1],
reverse=True
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
)[:10]
return {
"status": "loaded",
"timestamp": self._counter_data.timestamp,
@ -445,13 +458,13 @@ class QueryEngine:
"index_time_seconds": self._index_time,
"top_chars": top_chars,
"top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata
"metadata": self._counter_data.metadata,
}
def is_loaded(self) -> bool:
"""检查数据是否已加载"""
return self._loaded
def clear(self):
"""清除所有数据和索引,释放内存"""
self._counter_data = None
@ -465,4 +478,4 @@ class QueryEngine:
self._loaded = False
self._total_pairs = 0
self._load_time = 0.0
self._index_time = 0.0
self._index_time = 0.0