Compare commits

...

2 Commits

7 changed files with 186 additions and 149 deletions

View File

@ -1,8 +1,9 @@
import random
from typing import Any, Dict, List, Tuple
import os
import json import json
import os
import random
from importlib.resources import files
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset): class PinyinInputDataset(IterableDataset):
""" """
拼音输入法模拟数据集 拼音输入法模拟数据集
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率 max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望 max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json", py_group_json_file: Optional[Dict[str, int]] = None,
): ):
""" """
初始化数据集 初始化数据集
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
# 加载数据集 # 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True) self.dataset = load_dataset(data_dir, split="train", streaming=True)
with open(py_group_json_file, "r") as f: # 加载拼音分组
self.py_groups = json.load(f) self.pg_groups = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
def get_next_chinese_chars( def get_next_chinese_chars(
self, self,
@ -391,7 +415,7 @@ class PinyinInputDataset(IterableDataset):
if not char_info: if not char_info:
continue continue
logger.info(f'获取字符信息: {char_info}') logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整 # 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"]) adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0: if adjust_factor <= 0:
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
# 拼音处理 # 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins) processed_pinyin = self.process_pinyin_sequence(next_pinyins)
if not processed_pinyin:
continue
# Tokenize # Tokenize
hint = self.tokenizer( hint = self.tokenizer(
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
"char_id": torch.tensor([char_info["id"]]), "char_id": torch.tensor([char_info["id"]]),
"char": char, "char": char,
"freq": char_info["freq"], "freq": char_info["freq"],
"pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
),
} }
# 根据调整因子重复样本 # 根据调整因子重复样本
@ -470,7 +495,8 @@ class PinyinInputDataset(IterableDataset):
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息 # 批量收集需要查询的字符信息
chinese_positions = [ chinese_positions = [
i for i, char in enumerate(text) i
for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char) if self.query_engine.is_chinese_char(char)
] ]
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
) )
yield from self._shuffle_and_yield(batch_samples) yield from self._shuffle_and_yield(batch_samples)
def __len__(self): def __len__(self):
""" """
由于是流式数据集无法预先知道长度 由于是流式数据集无法预先知道长度
@ -591,7 +616,7 @@ if __name__ == "__main__":
# 初始化查询引擎 # 初始化查询引擎
query_engine = QueryEngine() query_engine = QueryEngine()
query_engine.load("./pinyin_char_pairs_info.json") query_engine.load()
# 创建数据集 # 创建数据集
dataset = PinyinInputDataset( dataset = PinyinInputDataset(
@ -645,4 +670,3 @@ if __name__ == "__main__":
""" """
except StopIteration: except StopIteration:
print("数据集为空") print("数据集为空")

View File

@ -1,11 +1,14 @@
# file name: query_engine.py # file name: query_engine.py
import json
import pickle
import msgpack
import gzip import gzip
from typing import Dict, List, Optional, Tuple, Any import json
import time
import os import os
import pickle
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import msgpack
from .char_info import CharInfo, PinyinCharPairsCounter from .char_info import CharInfo, PinyinCharPairsCounter
@ -27,14 +30,14 @@ class QueryEngine:
self._counter_data: Optional[PinyinCharPairsCounter] = None self._counter_data: Optional[PinyinCharPairsCounter] = None
# 核心索引 - 提供O(1)查询 # 核心索引 - 提供O(1)查询
self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo self._id_to_info: Dict[int, CharInfo] = {} # ID -> CharInfo
self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表 self._char_to_ids: Dict[str, List[int]] = {} # 字符 -> ID列表
self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表 self._pinyin_to_ids: Dict[str, List[int]] = {} # 拼音 -> ID列表
self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {} self._char_pinyin_to_ids: Dict[Tuple[str, str], int] = {}
# 辅助索引 - 快速获取详细信息 # 辅助索引 - 快速获取详细信息
self._char_freq: Dict[str, int] = {} # 字符总频率 self._char_freq: Dict[str, int] = {} # 字符总频率
self._pinyin_freq: Dict[str, int] = {} # 拼音总频率 self._pinyin_freq: Dict[str, int] = {} # 拼音总频率
self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count self._char_pinyin_map: Dict[Tuple[str, str], int] = {} # (字符, 拼音) -> count
# 统计信息 # 统计信息
@ -45,12 +48,17 @@ class QueryEngine:
self.min_count = min_count self.min_count = min_count
def load(self, file_path: str) -> Dict[str, Any]: def load(
self,
file_path: Union[str, Path] = (
files(__package__) / "data" / "pinyin_char_statistics.json"
),
) -> Dict[str, Any]:
""" """
加载统计结果文件 加载统计结果文件
Args: Args:
file_path: 文件路径支持msgpack/pickle/json格式自动检测压缩 file_path: 文件路径文件支持msgpack/pickle/json格式自动检测压缩
Returns: Returns:
元数据字典 元数据字典
@ -75,9 +83,9 @@ class QueryEngine:
return self._counter_data.metadata return self._counter_data.metadata
def parse_file(self, file_path: str) -> PinyinCharPairsCounter: def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
"""解析文件,支持多种格式""" """解析文件,支持多种格式"""
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = f.read() data = f.read()
# 尝试解压 # 尝试解压
@ -88,9 +96,9 @@ class QueryEngine:
# 尝试不同格式 # 尝试不同格式
for parser_name, parser in [ for parser_name, parser in [
('msgpack', self._parse_msgpack), ("msgpack", self._parse_msgpack),
('pickle', self._parse_pickle), ("pickle", self._parse_pickle),
('json', self._parse_json) ("json", self._parse_json),
]: ]:
try: try:
return parser(data) return parser(data)
@ -110,7 +118,7 @@ class QueryEngine:
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter: def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
"""解析json格式""" """解析json格式"""
data_str = data.decode('utf-8') data_str = data.decode("utf-8")
data_dict = json.loads(data_str) data_dict = json.loads(data_str)
return self._dict_to_counter(data_dict) return self._dict_to_counter(data_dict)
@ -118,10 +126,10 @@ class QueryEngine:
"""字典转PinyinCharPairsCounter""" """字典转PinyinCharPairsCounter"""
# 转换CharInfo字典 # 转换CharInfo字典
pairs_dict = {} pairs_dict = {}
if 'pairs' in data_dict and data_dict['pairs']: if "pairs" in data_dict and data_dict["pairs"]:
for id_str, info_dict in data_dict['pairs'].items(): for id_str, info_dict in data_dict["pairs"].items():
pairs_dict[int(id_str)] = CharInfo(**info_dict) pairs_dict[int(id_str)] = CharInfo(**info_dict)
data_dict['pairs'] = pairs_dict data_dict["pairs"] = pairs_dict
return PinyinCharPairsCounter(**data_dict) return PinyinCharPairsCounter(**data_dict)
@ -222,7 +230,9 @@ class QueryEngine:
return results return results
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]: def query_by_pinyin(
self, pinyin: str, limit: int = 0
) -> List[Tuple[int, str, int]]:
""" """
通过拼音查询字符信息 - O(1) + O(k)时间复杂度 通过拼音查询字符信息 - O(1) + O(k)时间复杂度
@ -303,7 +313,9 @@ class QueryEngine:
return self._char_pinyin_map.get((char, pinyin), 0) return self._char_pinyin_map.get((char, pinyin), 0)
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]: def get_char_info_by_char_pinyin(
self, char: str, pinyin: str
) -> Optional[CharInfo]:
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度 """获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
Args: Args:
@ -319,7 +331,9 @@ class QueryEngine:
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None) char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
return self.query_by_id(char_info_id) return self.query_by_id(char_info_id)
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]: def batch_get_char_pinyin_info(
self, pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], CharInfo]:
"""批量获取汉字-拼音信息 """批量获取汉字-拼音信息
Args: Args:
@ -340,7 +354,6 @@ class QueryEngine:
result[pair] = None result[pair] = None
return result return result
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]: def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
""" """
批量ID查询 - O(n)时间复杂度 批量ID查询 - O(n)时间复杂度
@ -360,7 +373,9 @@ class QueryEngine:
return results return results
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]: def batch_query_by_chars(
self, chars: List[str], limit_per_char: int = 0
) -> Dict[str, List[Tuple[int, str, int]]]:
""" """
批量字符查询 批量字符查询
@ -380,7 +395,9 @@ class QueryEngine:
return results return results
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]: def search_chars_by_prefix(
self, prefix: str, limit: int = 20
) -> List[Tuple[str, int]]:
""" """
根据字符前缀搜索 - O(n)时间复杂度n为字符总数 根据字符前缀搜索 - O(n)时间复杂度n为字符总数
@ -409,7 +426,7 @@ class QueryEngine:
判断是否是汉字 判断是否是汉字
""" """
if not self.is_loaded(): if not self.is_loaded():
raise ValueError('请先调用 load() 方法加载数据') raise ValueError("请先调用 load() 方法加载数据")
return char in self._char_to_ids return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]: def get_statistics(self) -> Dict[str, Any]:
@ -422,16 +439,12 @@ class QueryEngine:
if not self._loaded: if not self._loaded:
return {"status": "not_loaded"} return {"status": "not_loaded"}
top_chars = sorted( top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
self._char_freq.items(), :10
key=lambda x: x[1], ]
reverse=True
)[:10]
top_pinyins = sorted( top_pinyins = sorted(
self._pinyin_freq.items(), self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
key=lambda x: x[1],
reverse=True
)[:10] )[:10]
return { return {
@ -445,7 +458,7 @@ class QueryEngine:
"index_time_seconds": self._index_time, "index_time_seconds": self._index_time,
"top_chars": top_chars, "top_chars": top_chars,
"top_pinyins": top_pinyins, "top_pinyins": top_pinyins,
"metadata": self._counter_data.metadata "metadata": self._counter_data.metadata,
} }
def is_loaded(self) -> bool: def is_loaded(self) -> bool: