Compare commits

..

2 Commits

7 changed files with 186 additions and 149 deletions

View File

@ -1,8 +1,9 @@
import random
from typing import Any, Dict, List, Tuple
import os
import json
import os
import random
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset):
"""
拼音输入法模拟数据集
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
py_group_json_file: Optional[Dict[str, int]] = None,
):
"""
初始化数据集
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
# 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True)
with open(py_group_json_file, "r") as f:
self.py_groups = json.load(f)
# 加载拼音分组
self.pg_groups = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
def get_next_chinese_chars(
self,
@ -391,7 +415,7 @@ class PinyinInputDataset(IterableDataset):
if not char_info:
continue
logger.info(f'获取字符信息: {char_info}')
logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
# 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
if not processed_pinyin:
continue
# Tokenize
hint = self.tokenizer(
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
"char_id": torch.tensor([char_info["id"]]),
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
),
}
# 根据调整因子重复样本
@ -470,7 +495,8 @@ class PinyinInputDataset(IterableDataset):
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
chinese_positions = [
i for i, char in enumerate(text)
i
for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char)
]
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
)
yield from self._shuffle_and_yield(batch_samples)
def __len__(self):
"""
由于是流式数据集无法预先知道长度
@ -591,7 +616,7 @@ if __name__ == "__main__":
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load("./pinyin_char_pairs_info.json")
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
@ -645,4 +670,3 @@ if __name__ == "__main__":
"""
except StopIteration:
print("数据集为空")

View File

@ -1,11 +1,14 @@
# file name: query_engine.py
import json
import pickle
import msgpack
import gzip
from typing import Dict, List, Optional, Tuple, Any
import time
import json
import os
import pickle
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
from .char_info import CharInfo, PinyinCharPairsCounter
@ -27,14 +30,14 @@ class QueryEngine:
self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息
@ -45,12 +48,17 @@ class QueryEngine:
self.min_count = min_count
def load(self, file_path: str) -> Dict[str, Any]:
def load(
self,
file_path: Union[str, Path] = (
files(__package__) / "data" / "pinyin_char_statistics.json"
),
) -> Dict[str, Any]:
"""
加载统计结果文件
Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩
file_path: 文件路径文件支持msgpack/pickle/json格式自动检测压缩
Returns:
元数据字典
@ -75,9 +83,9 @@ class QueryEngine:
return self._counter_data.metadata
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式"""
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = f.read()
# 尝试解压
@ -88,9 +96,9 @@ class QueryEngine:
# 尝试不同格式
for parser_name, parser in [
('msgpack', self._parse_msgpack),
('pickle', self._parse_pickle),
('json', self._parse_json)
("msgpack", self._parse_msgpack),
("pickle", self._parse_pickle),
("json", self._parse_json),
]:
try:
return parser(data)
@ -110,7 +118,7 @@ class QueryEngine:
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式"""
data_str = data.decode('utf-8')
data_str = data.decode("utf-8")
data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict)
@ -118,10 +126,10 @@ class QueryEngine:
"""字典转PinyinCharPairsCounter"""
# 转换CharInfo字典
pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']:
for id_str, info_dict in data_dict['pairs'].items():
if "pairs" in data_dict and data_dict["pairs"]:
for id_str, info_dict in data_dict["pairs"].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict
data_dict["pairs"] = pairs_dict
return PinyinCharPairsCounter(**data_dict)
@ -222,7 +230,9 @@ class QueryEngine:
return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
def query_by_pinyin(
self, pinyin: str, limit: int = 0
) -> List[Tuple[int, str, int]]:
"""
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
@ -303,7 +313,9 @@ class QueryEngine:
return self._char_pinyin_map.get((char, pinyin), 0)
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
def get_char_info_by_char_pinyin(
self, char: str, pinyin: str
) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args:
@ -319,7 +331,9 @@ class QueryEngine:
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
def batch_get_char_pinyin_info(
self, pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息
Args:
@ -340,7 +354,6 @@ class QueryEngine:
result[pair] = None
return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
"""
批量ID查询 - O(n)时间复杂度
@ -360,7 +373,9 @@ class QueryEngine:
return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
def batch_query_by_chars(
self, chars: List[str], limit_per_char: int = 0
) -> Dict[str, List[Tuple[int, str, int]]]:
"""
批量字符查询
@ -380,7 +395,9 @@ class QueryEngine:
return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
def search_chars_by_prefix(
self, prefix: str, limit: int = 20
) -> List[Tuple[str, int]]:
"""
根据字符前缀搜索 - O(n)时间复杂度n为字符总数
@ -409,7 +426,7 @@ class QueryEngine:
判断是否是汉字
"""
if not self.is_loaded():
raise ValueError('请先调用 load() 方法加载数据')
raise ValueError("请先调用 load() 方法加载数据")
return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]:
@ -422,16 +439,12 @@ class QueryEngine:
if not self._loaded:
return {"status": "not_loaded"}
top_chars = sorted(
self._char_freq.items(),
key=lambda x: x[1],
reverse=True
)[:10]
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
:10
]
top_pinyins = sorted(
self._pinyin_freq.items(),
key=lambda x: x[1],
reverse=True
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
)[:10]
return {
@ -445,7 +458,7 @@ class QueryEngine:
"index_time_seconds": self._index_time,
"top_chars": top_chars,
"top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata
"metadata": self._counter_data.metadata,
}
def is_loaded(self) -> bool: