Compare commits
2 Commits
ea54c7da39
...
5b1c6fcb2b
| Author | SHA1 | Date |
|---|---|---|
|
|
5b1c6fcb2b | |
|
|
1cbb0b07c4 |
|
|
@ -1,8 +1,9 @@
|
||||||
import random
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -12,9 +13,9 @@ from modelscope import AutoTokenizer
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from torch.utils.data import DataLoader, IterableDataset
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
拼音输入法模拟数据集
|
拼音输入法模拟数据集
|
||||||
|
|
@ -46,7 +47,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
|
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -97,9 +98,32 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||||
|
|
||||||
with open(py_group_json_file, "r") as f:
|
# 加载拼音分组
|
||||||
self.py_groups = json.load(f)
|
self.pg_groups = {
|
||||||
|
"y": 0,
|
||||||
|
"k": 0,
|
||||||
|
"e": 0,
|
||||||
|
"l": 1,
|
||||||
|
"w": 1,
|
||||||
|
"f": 1,
|
||||||
|
"q": 2,
|
||||||
|
"a": 2,
|
||||||
|
"s": 2,
|
||||||
|
"x": 3,
|
||||||
|
"b": 3,
|
||||||
|
"r": 3,
|
||||||
|
"o": 4,
|
||||||
|
"m": 4,
|
||||||
|
"z": 4,
|
||||||
|
"g": 5,
|
||||||
|
"n": 5,
|
||||||
|
"c": 5,
|
||||||
|
"t": 6,
|
||||||
|
"p": 6,
|
||||||
|
"d": 6,
|
||||||
|
"j": 7,
|
||||||
|
"h": 7,
|
||||||
|
}
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
|
|
@ -391,7 +415,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not char_info:
|
if not char_info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f'获取字符信息: {char_info}')
|
logger.info(f"获取字符信息: {char_info}")
|
||||||
# 削峰填谷调整
|
# 削峰填谷调整
|
||||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||||
if adjust_factor <= 0:
|
if adjust_factor <= 0:
|
||||||
|
|
@ -402,8 +426,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 拼音处理
|
# 拼音处理
|
||||||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||||||
if not processed_pinyin:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Tokenize
|
# Tokenize
|
||||||
hint = self.tokenizer(
|
hint = self.tokenizer(
|
||||||
|
|
@ -423,6 +445,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char_id": torch.tensor([char_info["id"]]),
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
|
"pg": torch.tensor(
|
||||||
|
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据调整因子重复样本
|
# 根据调整因子重复样本
|
||||||
|
|
@ -470,7 +495,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
# 批量收集需要查询的字符信息
|
# 批量收集需要查询的字符信息
|
||||||
chinese_positions = [
|
chinese_positions = [
|
||||||
i for i, char in enumerate(text)
|
i
|
||||||
|
for i, char in enumerate(text)
|
||||||
if self.query_engine.is_chinese_char(char)
|
if self.query_engine.is_chinese_char(char)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -518,7 +544,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
由于是流式数据集,无法预先知道长度
|
由于是流式数据集,无法预先知道长度
|
||||||
|
|
@ -591,7 +616,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# 初始化查询引擎
|
# 初始化查询引擎
|
||||||
query_engine = QueryEngine()
|
query_engine = QueryEngine()
|
||||||
query_engine.load("./pinyin_char_pairs_info.json")
|
query_engine.load()
|
||||||
|
|
||||||
# 创建数据集
|
# 创建数据集
|
||||||
dataset = PinyinInputDataset(
|
dataset = PinyinInputDataset(
|
||||||
|
|
@ -645,4 +670,3 @@ if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print("数据集为空")
|
print("数据集为空")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
# file name: query_engine.py
|
# file name: query_engine.py
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
import msgpack
|
|
||||||
import gzip
|
import gzip
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
import json
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
|
||||||
from .char_info import CharInfo, PinyinCharPairsCounter
|
from .char_info import CharInfo, PinyinCharPairsCounter
|
||||||
|
|
||||||
|
|
@ -45,12 +48,17 @@ class QueryEngine:
|
||||||
|
|
||||||
self.min_count = min_count
|
self.min_count = min_count
|
||||||
|
|
||||||
def load(self, file_path: str) -> Dict[str, Any]:
|
def load(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path] = (
|
||||||
|
files(__package__) / "data" / "pinyin_char_statistics.json"
|
||||||
|
),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
加载统计结果文件
|
加载统计结果文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: 文件路径,支持msgpack/pickle/json格式,自动检测压缩
|
file_path: 文件路径,文件支持msgpack/pickle/json格式,自动检测压缩
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
元数据字典
|
元数据字典
|
||||||
|
|
@ -75,9 +83,9 @@ class QueryEngine:
|
||||||
|
|
||||||
return self._counter_data.metadata
|
return self._counter_data.metadata
|
||||||
|
|
||||||
def parse_file(self, file_path: str) -> PinyinCharPairsCounter:
|
def parse_file(self, file_path: Union[str, Path]) -> PinyinCharPairsCounter:
|
||||||
"""解析文件,支持多种格式"""
|
"""解析文件,支持多种格式"""
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
# 尝试解压
|
# 尝试解压
|
||||||
|
|
@ -88,9 +96,9 @@ class QueryEngine:
|
||||||
|
|
||||||
# 尝试不同格式
|
# 尝试不同格式
|
||||||
for parser_name, parser in [
|
for parser_name, parser in [
|
||||||
('msgpack', self._parse_msgpack),
|
("msgpack", self._parse_msgpack),
|
||||||
('pickle', self._parse_pickle),
|
("pickle", self._parse_pickle),
|
||||||
('json', self._parse_json)
|
("json", self._parse_json),
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
return parser(data)
|
return parser(data)
|
||||||
|
|
@ -110,7 +118,7 @@ class QueryEngine:
|
||||||
|
|
||||||
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
def _parse_json(self, data: bytes) -> PinyinCharPairsCounter:
|
||||||
"""解析json格式"""
|
"""解析json格式"""
|
||||||
data_str = data.decode('utf-8')
|
data_str = data.decode("utf-8")
|
||||||
data_dict = json.loads(data_str)
|
data_dict = json.loads(data_str)
|
||||||
return self._dict_to_counter(data_dict)
|
return self._dict_to_counter(data_dict)
|
||||||
|
|
||||||
|
|
@ -118,10 +126,10 @@ class QueryEngine:
|
||||||
"""字典转PinyinCharPairsCounter"""
|
"""字典转PinyinCharPairsCounter"""
|
||||||
# 转换CharInfo字典
|
# 转换CharInfo字典
|
||||||
pairs_dict = {}
|
pairs_dict = {}
|
||||||
if 'pairs' in data_dict and data_dict['pairs']:
|
if "pairs" in data_dict and data_dict["pairs"]:
|
||||||
for id_str, info_dict in data_dict['pairs'].items():
|
for id_str, info_dict in data_dict["pairs"].items():
|
||||||
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
pairs_dict[int(id_str)] = CharInfo(**info_dict)
|
||||||
data_dict['pairs'] = pairs_dict
|
data_dict["pairs"] = pairs_dict
|
||||||
|
|
||||||
return PinyinCharPairsCounter(**data_dict)
|
return PinyinCharPairsCounter(**data_dict)
|
||||||
|
|
||||||
|
|
@ -222,7 +230,9 @@ class QueryEngine:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def query_by_pinyin(self, pinyin: str, limit: int = 0) -> List[Tuple[int, str, int]]:
|
def query_by_pinyin(
|
||||||
|
self, pinyin: str, limit: int = 0
|
||||||
|
) -> List[Tuple[int, str, int]]:
|
||||||
"""
|
"""
|
||||||
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
通过拼音查询字符信息 - O(1) + O(k)时间复杂度
|
||||||
|
|
||||||
|
|
@ -303,7 +313,9 @@ class QueryEngine:
|
||||||
|
|
||||||
return self._char_pinyin_map.get((char, pinyin), 0)
|
return self._char_pinyin_map.get((char, pinyin), 0)
|
||||||
|
|
||||||
def get_char_info_by_char_pinyin(self, char: str, pinyin: str) -> Optional[CharInfo]:
|
def get_char_info_by_char_pinyin(
|
||||||
|
self, char: str, pinyin: str
|
||||||
|
) -> Optional[CharInfo]:
|
||||||
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
"""获取特定字符-拼音对对应的ID和频率 - O(1)时间复杂度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -319,7 +331,9 @@ class QueryEngine:
|
||||||
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
char_info_id = self._char_pinyin_to_ids.get((char, pinyin), None)
|
||||||
return self.query_by_id(char_info_id)
|
return self.query_by_id(char_info_id)
|
||||||
|
|
||||||
def batch_get_char_pinyin_info(self, pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], CharInfo]:
|
def batch_get_char_pinyin_info(
|
||||||
|
self, pairs: List[Tuple[str, str]]
|
||||||
|
) -> Dict[Tuple[str, str], CharInfo]:
|
||||||
"""批量获取汉字-拼音信息
|
"""批量获取汉字-拼音信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -340,7 +354,6 @@ class QueryEngine:
|
||||||
result[pair] = None
|
result[pair] = None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
def batch_query_by_ids(self, ids: List[int]) -> Dict[int, Optional[CharInfo]]:
|
||||||
"""
|
"""
|
||||||
批量ID查询 - O(n)时间复杂度
|
批量ID查询 - O(n)时间复杂度
|
||||||
|
|
@ -360,7 +373,9 @@ class QueryEngine:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def batch_query_by_chars(self, chars: List[str], limit_per_char: int = 0) -> Dict[str, List[Tuple[int, str, int]]]:
|
def batch_query_by_chars(
|
||||||
|
self, chars: List[str], limit_per_char: int = 0
|
||||||
|
) -> Dict[str, List[Tuple[int, str, int]]]:
|
||||||
"""
|
"""
|
||||||
批量字符查询
|
批量字符查询
|
||||||
|
|
||||||
|
|
@ -380,7 +395,9 @@ class QueryEngine:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_chars_by_prefix(self, prefix: str, limit: int = 20) -> List[Tuple[str, int]]:
|
def search_chars_by_prefix(
|
||||||
|
self, prefix: str, limit: int = 20
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
根据字符前缀搜索 - O(n)时间复杂度,n为字符总数
|
||||||
|
|
||||||
|
|
@ -409,7 +426,7 @@ class QueryEngine:
|
||||||
判断是否是汉字
|
判断是否是汉字
|
||||||
"""
|
"""
|
||||||
if not self.is_loaded():
|
if not self.is_loaded():
|
||||||
raise ValueError('请先调用 load() 方法加载数据')
|
raise ValueError("请先调用 load() 方法加载数据")
|
||||||
return char in self._char_to_ids
|
return char in self._char_to_ids
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
|
@ -422,16 +439,12 @@ class QueryEngine:
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
return {"status": "not_loaded"}
|
return {"status": "not_loaded"}
|
||||||
|
|
||||||
top_chars = sorted(
|
top_chars = sorted(self._char_freq.items(), key=lambda x: x[1], reverse=True)[
|
||||||
self._char_freq.items(),
|
:10
|
||||||
key=lambda x: x[1],
|
]
|
||||||
reverse=True
|
|
||||||
)[:10]
|
|
||||||
|
|
||||||
top_pinyins = sorted(
|
top_pinyins = sorted(
|
||||||
self._pinyin_freq.items(),
|
self._pinyin_freq.items(), key=lambda x: x[1], reverse=True
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
)[:10]
|
)[:10]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -445,7 +458,7 @@ class QueryEngine:
|
||||||
"index_time_seconds": self._index_time,
|
"index_time_seconds": self._index_time,
|
||||||
"top_chars": top_chars,
|
"top_chars": top_chars,
|
||||||
"top_pinyins": top_pinyins,
|
"top_pinyins": top_pinyins,
|
||||||
"metadata": self._counter_data.metadata
|
"metadata": self._counter_data.metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue