diff --git a/pinyin_char_statistics.json b/src/suinput/data/pinyin_char_statistics.json similarity index 100% rename from pinyin_char_statistics.json rename to src/suinput/data/pinyin_char_statistics.json diff --git a/pinyin_group.json b/src/suinput/data/pinyin_group.json similarity index 100% rename from pinyin_group.json rename to src/suinput/data/pinyin_group.json diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index b75ecdc..a71f8fb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,8 +1,9 @@ -import random -from typing import Any, Dict, List, Tuple -import os import json +import os +import random +from importlib.resources import files from pathlib import Path +from typing import Any, Dict, List, Tuple import numpy as np import torch @@ -12,9 +13,9 @@ from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from torch.utils.data import DataLoader, IterableDataset - os.environ["TOKENIZERS_PARALLELISM"] = "false" + class PinyinInputDataset(IterableDataset): """ 拼音输入法模拟数据集 @@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 - py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json", + py_group_json_file: Optional[Dict[str, int]] = None, ): """ 初始化数据集 @@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset): # 加载数据集 self.dataset = load_dataset(data_dir, split="train", streaming=True) - with open(py_group_json_file, "r") as f: - self.py_groups = json.load(f) - + # 加载拼音分组 + self.pg_groups = { + "y": 0, + "k": 0, + "e": 0, + "l": 1, + "w": 1, + "f": 1, + "q": 2, + "a": 2, + "s": 2, + "x": 3, + "b": 3, + "r": 3, + "o": 4, + "m": 4, + "z": 4, + "g": 5, + "n": 5, + "c": 5, + "t": 6, + "p": 6, + "d": 6, + "j": 7, + "h": 7, + } def get_next_chinese_chars( self, @@ -390,8 +414,8 @@ class PinyinInputDataset(IterableDataset): char_info = char_info_map.get((char, pinyin)) if not char_info: continue - - logger.info(f'获取字符信息: {char_info}') + + logger.info(f"获取字符信息: {char_info}") # 削峰填谷调整 adjust_factor = self.adjust_frequency(char_info["freq"]) if adjust_factor <= 0: @@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset): # 拼音处理 processed_pinyin = self.process_pinyin_sequence(next_pinyins) - if not processed_pinyin: - continue # Tokenize hint = self.tokenizer( @@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset): "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], + "pg": torch.tensor( + self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 + ), } # 根据调整因子重复样本 @@ -436,7 +461,7 @@ class PinyinInputDataset(IterableDataset): if not self.shuffle: yield from batch_samples return - + # 使用numpy批量操作代替random.shuffle if batch_samples: indices = np.random.permutation(len(batch_samples)) @@ -470,14 +495,15 @@ class PinyinInputDataset(IterableDataset): pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) # 批量收集需要查询的字符信息 chinese_positions = [ - i for i, char in enumerate(text) + i + for i, char in enumerate(text) if self.query_engine.is_chinese_char(char) ] - + # 批量收集需要查询的字符信息 char_pinyin_batch = [] char_positions = [] - + for i in chinese_positions: # 只遍历中文字符位置 char = text[i] py = pinyin_list[i] @@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset): ) yield from self._shuffle_and_yield(batch_samples) - def __len__(self): """ 由于是流式数据集,无法预先知道长度 @@ -591,7 +616,7 @@ if __name__ == "__main__": # 初始化查询引擎 query_engine = QueryEngine() - query_engine.load("./pinyin_char_pairs_info.json") + query_engine.load() # 创建数据集 dataset = PinyinInputDataset( @@ -625,7 +650,7 @@ if __name__ == "__main__": break return - + cProfile.run('profile_func(dataloader)') """ @@ -645,4 +670,3 @@ if __name__ == "__main__": """ except StopIteration: print("数据集为空") - diff --git a/src/suinput/query.py b/src/suinput/query.py index cf672fe..7f4fba2 100644 --- a/src/suinput/query.py +++ b/src/suinput/query.py @@ -1,11 +1,14 @@ # file name: query_engine.py -import json -import pickle -import msgpack import gzip -from typing import Dict, List, Optional, Tuple, Any -import time +import json import os +import pickle +import time +from importlib.resources import files +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import msgpack from .char_info import CharInfo, PinyinCharPairsCounter @@ -13,7 +16,7 @@ from .char_info import CharInfo, PinyinCharPairsCounter class QueryEngine: """ 高效拼音-字符查询引擎 - + 特性: 1. O(1)时间复杂度的ID查询 2. O(1)时间复杂度的字符查询 @@ -21,22 +24,22 @@ class QueryEngine: 4. 内存友好,构建高效索引 5. 支持批量查询和前缀搜索 """ - + def __init__(self, min_count: int = 109): """初始化查询引擎""" self._counter_data: Optional[PinyinCharPairsCounter] = None - + # 核心索引 - 提供O(1)查询 - self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo - self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 - self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 + self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo + self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 + self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {} - + # 辅助索引 - 快速获取详细信息 - self._char_freq: Dict[str, int] = {} # 字符总频率 - self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 + self._char_freq: Dict[str, int] = {} # 字符总频率 + self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count - + # 统计信息 self._loaded = False self._total_pairs = 0 @@ -44,91 +47,96 @@ class QueryEngine: self._index_time = 0.0 self.min_count = min_count - - def load(self, file_path: str) -> Dict[str, Any]: + + def load( + self, + file_path: Union[str, Path] = ( + files(__package__) / "data" / "pinyin_char_statistics.json" + ), + ) -> Dict[str, Any]: """ 加载统计结果文件 - + Args: - file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩 - + file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩 + Returns: 元数据字典 - + Raises: FileNotFoundError: 文件不存在 ValueError: 文件格式不支持 """ if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") - + start_time = time.time() - + # 读取并解析文件 self._counter_data = self.parse_file(file_path) - + # 构建索引 self._build_indices() - + self._load_time = time.time() - start_time self._loaded = True - + return self._counter_data.metadata - - def parse_file(self, file_path: str) -> PinyinCharPairsCounter: + + def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter: """解析文件,支持多种格式""" - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: data = f.read() - + # 尝试解压 try: data = gzip.decompress(data) except Exception: pass - + # 尝试不同格式 for parser_name, parser in [ - ('msgpack', self._parse_msgpack), - ('pickle', self._parse_pickle), - ('json', self._parse_json) + ("msgpack", self._parse_msgpack), + ("pickle", self._parse_pickle), + ("json", self._parse_json), ]: try: return parser(data) except Exception: continue - + raise ValueError("无法解析文件格式") - + def _parse_msgpack(self, data: bytes) -> PinyinCharPairsCounter: """解析msgpack格式""" data_dict = msgpack.unpackb(data, raw=False) return self._dict_to_counter(data_dict) - + def _parse_pickle(self, data: bytes) -> PinyinCharPairsCounter: """解析pickle格式""" return pickle.loads(data) - + def _parse_json(self, data: bytes) -> PinyinCharPairsCounter: """解析json格式""" - data_str = data.decode('utf-8') + data_str = data.decode("utf-8") data_dict = json.loads(data_str) return self._dict_to_counter(data_dict) - + def _dict_to_counter(self, data_dict: Dict) -> PinyinCharPairsCounter: """字典转PinyinCharPairsCounter""" # 转换CharInfo字典 pairs_dict = {} - if 'pairs' in data_dict and data_dict['pairs']: - for id_str, info_dict in data_dict['pairs'].items(): + if "pairs" in data_dict and data_dict["pairs"]: + for id_str, info_dict in data_dict["pairs"].items(): pairs_dict[int(id_str)] = CharInfo(**info_dict) - data_dict['pairs'] = pairs_dict - + data_dict["pairs"] = pairs_dict + return PinyinCharPairsCounter(**data_dict) - + def _build_indices(self): """构建所有查询索引""" start_time = time.time() - + # 重置索引 self._id_to_info.clear() self._char_to_ids.clear() @@ -137,13 +145,13 @@ class QueryEngine: self._char_freq.clear() self._pinyin_freq.clear() self._char_pinyin_map.clear() - + # 复制频率数据 if self._counter_data.chars: self._char_freq = self._counter_data.chars.copy() if self._counter_data.pinyins: self._pinyin_freq = self._counter_data.pinyins.copy() - + # 构建核心索引 for char_info in self._counter_data.pairs.values(): if char_info.count < self.min_count: @@ -151,165 +159,169 @@ class QueryEngine: char = char_info.char pinyin = char_info.pinyin char_info_id = char_info.id - + # ID索引 self._id_to_info[char_info_id] = char_info - + # 字符索引 if char not in self._char_to_ids: self._char_to_ids[char] = [] self._char_to_ids[char].append(char_info_id) - + # 拼音索引 if pinyin not in self._pinyin_to_ids: self._pinyin_to_ids[pinyin] = [] self._pinyin_to_ids[pinyin].append(char_info_id) - + # 字符-拼音映射 self._char_pinyin_map[(char, pinyin)] = char_info.count self._char_pinyin_to_ids[(char, pinyin)] = char_info_id - + self._total_pairs = len(self._id_to_info) self._index_time = time.time() - start_time - + def query_by_id(self, id: int) -> Optional[CharInfo]: """ 通过ID查询字符信息 - O(1)时间复杂度 - + Args: id: 记录ID - + Returns: CharInfo对象,不存在则返回None """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + return self._id_to_info.get(id) - + def query_by_char(self, char: str, limit: int = 0) -> List[Tuple[int, str, int]]: """ 通过字符查询拼音信息 - O(1) + O(k)时间复杂度,k为结果数 - + Args: char: 汉字字符 limit: 返回结果数量限制,0表示返回所有 - + Returns: 列表,每个元素为(id, 拼音, 次数),按次数降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + if char not in self._char_to_ids: return [] - + # 获取所有相关ID ids = self._char_to_ids[char] - + # 构建结果并排序 results = [] for char_info_id in ids: char_info = self._id_to_info[char_info_id] results.append((char_info_id, char_info.pinyin, char_info.count)) - + # 按次数降序排序 results.sort(key=lambda x: x[2], reverse=True) - + # 应用限制 if limit > 0 and len(results) > limit: results = results[:limit] - + return results - - def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]: + + def query_by_pinyin( + self, pinyin: str, limit: int = 0 + ) -> List[Tuple[int, str, int]]: """ 通过拼音查询字符信息 - O(1) + O(k)时间复杂度 - + Args: pinyin: 拼音字符串 limit: 返回结果数量限制,0表示返回所有 - + Returns: 列表,每个元素为(id, 字符, 次数),按次数降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + if pinyin not in self._pinyin_to_ids: return [] - + # 获取所有相关ID ids = self._pinyin_to_ids[pinyin] - + # 构建结果并排序 results = [] for char_info_id in ids: char_info = self._id_to_info[char_info_id] results.append((char_info_id, char_info.char, char_info.count)) - + # 按次数降序排序 results.sort(key=lambda x: x[2], reverse=True) - + # 应用限制 if limit > 0 and len(results) > limit: results = results[:limit] - + return results - + def get_char_frequency(self, char: str) -> int: """ 获取字符的总出现频率(所有拼音变体之和) - O(1)时间复杂度 - + Args: char: 汉字字符 - + Returns: 总出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + return self._char_freq.get(char, 0) - + def get_pinyin_frequency(self, pinyin: str) -> int: """ 获取拼音的总出现频率(所有字符之和) - O(1)时间复杂度 - + Args: pinyin: 拼音字符串 - + Returns: 总出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + return self._pinyin_freq.get(pinyin, 0) - + def get_char_pinyin_count(self, char: str, pinyin: str) -> int: """ 获取特定字符-拼音对的出现次数 - O(1)时间复杂度 - + Args: char: 汉字字符 pinyin: 拼音字符串 - + Returns: 出现次数 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + return self._char_pinyin_map.get((char, pinyin), 0) - def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]: + def get_char_info_by_char_pinyin( + self, char: str, pinyin: str + ) -> Optional[CharInfo]: """获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度 - + Args: char: 汉字字符 pinyin: 拼音字符串 - + Returns: ID和频率 """ @@ -319,12 +331,14 @@ class QueryEngine: char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None) return self.query_by_id(char_info_id) - def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]: + def batch_get_char_pinyin_info( + self, pairs: List[Tuple[str, str]] + ) -> Dict[Tuple[str, str], CharInfo]: """批量获取汉字-拼音信息 - + Args: pairs: 汉字-拼音列表 - + Returns: 字典,key为汉字-拼音对,value为CharInfo对象(不存在则为None) """ @@ -339,69 +353,72 @@ class QueryEngine: else: result[pair] = None return result - def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: """ 批量ID查询 - O(n)时间复杂度 - + Args: ids: ID列表 - + Returns: 字典,key为ID,value为CharInfo对象(不存在则为None) """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + results = {} for id_value in ids: results[id_value] = self._id_to_info.get(id_value, None) - + return results - def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: + def batch_query_by_chars( + self, chars: List[str], limit_per_char: int = 0 + ) -> Dict[str, List[Tuple[int, str, int]]]: """ 批量字符查询 - + Args: chars: 字符列表 limit_per_char: 每个字符的结果数量限制 - + Returns: 字典,key为字符,value为查询结果列表 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + results = {} for char in chars: results[char] = self.query_by_char(char, limit_per_char) - + return results - - def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]: + + def search_chars_by_prefix( + self, prefix: str, limit: int = 20 + ) -> List[Tuple[str, int]]: """ 根据字符前缀搜索 - O(n)时间复杂度,n为字符总数 - + Args: prefix: 字符前缀 limit: 返回结果数量限制 - + Returns: 列表,每个元素为(字符, 总频率),按频率降序排序 """ if not self._loaded: raise RuntimeError("数据未加载,请先调用load()方法") - + matches = [] for char, freq in self._char_freq.items(): if char.startswith(prefix): matches.append((char, freq)) - + # 按频率降序排序 matches.sort(key=lambda x: x[1], reverse=True) - + return matches[:limit] if limit > 0 else matches def is_chinese_char(self, char: str) -> bool: @@ -409,31 +426,27 @@ class QueryEngine: 判断是否是汉字 """ if not self.is_loaded(): - raise ValueError('请先调用 load() 方法加载数据') + raise ValueError("请先调用 load() 方法加载数据") return char in self._char_to_ids - + def get_statistics(self) -> Dict[str, Any]: """ 获取系统统计信息 - + Returns: 统计信息字典 """ if not self._loaded: return {"status": "not_loaded"} - - top_chars = sorted( - self._char_freq.items(), - key=lambda x: x[1], - reverse=True - )[:10] - + + top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[ + :10 + ] + top_pinyins = sorted( - self._pinyin_freq.items(), - key=lambda x: x[1], - reverse=True + self._pinyin_freq.items(), key=lambda x: x[1], reverse=True )[:10] - + return { "status": "loaded", "timestamp": self._counter_data.timestamp, @@ -445,13 +458,13 @@ class QueryEngine: "index_time_seconds": self._index_time, "top_chars": top_chars, "top_pinyins": top_pinyins, - "metadata": self._counter_data.metadata + "metadata": self._counter_data.metadata, } - + def is_loaded(self) -> bool: """检查数据是否已加载""" return self._loaded - + def clear(self): """清除所有数据和索引,释放内存""" self._counter_data = None @@ -465,4 +478,4 @@ class QueryEngine: self._loaded = False self._total_pairs = 0 self._load_time = 0.0 - self._index_time = 0.0 \ No newline at end of file + self._index_time = 0.0