From 9b813732fd9609529f0226a9048c2c6dc5f5c9cf Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 9 Feb 2026 23:53:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=89=93=E4=B9=B1=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E5=B9=B6=E6=8F=90=E5=8D=87=E6=95=B0=E6=8D=AE=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 117 ++++++++++++++++++++++------------------- src/suinput/query.py | 6 +++ 2 files changed, 69 insertions(+), 54 deletions(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index ff5cafb..d437bd0 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,6 +1,6 @@ import random -import re from typing import Any, Dict, List, Tuple +import os import numpy as np import torch @@ -11,6 +11,8 @@ from pypinyin import lazy_pinyin from torch.utils.data import DataLoader, IterableDataset +os.environ["TOKENIZERS_PARALLELISM"] = "false" + class PinyinInputDataset(IterableDataset): """ 拼音输入法模拟数据集 @@ -70,7 +72,6 @@ class PinyinInputDataset(IterableDataset): # 打乱相关参数 self.shuffle = shuffle self.shuffle_buffer_size = shuffle_buffer_size - self.shuffle_buffer = [] # 削峰填谷参数 self.max_freq = max_freq @@ -87,18 +88,12 @@ class PinyinInputDataset(IterableDataset): stats = query_engine.get_statistics() self.total_chars = stats.get("valid_input_character_count", 0) - # 汉字正则表达式 - self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]") - # 缓存字典 self.char_info_cache = {} # 加载数据集 self.dataset = load_dataset(data_dir, split="train", streaming=True) - def is_chinese_char(self, char: str) -> bool: - """判断是否为中文字符""" - return bool(self.chinese_pattern.match(char)) def get_next_chinese_chars( self, @@ -126,7 +121,7 @@ class PinyinInputDataset(IterableDataset): break char = text[i] - if self.is_chinese_char(char): + if self.query_engine.is_chinese_char(char): # 获取拼音(注意:这里需要确保拼音列表长度与text一致) try: # 重新计算整个text的拼音可能效率低,但确保准确 @@ -430,21 +425,16 @@ class PinyinInputDataset(IterableDataset): return batch_samples def _shuffle_and_yield(self, batch_samples): - """打乱并yield样本""" + """优化打乱逻辑""" if not self.shuffle: - for sample in batch_samples: - yield sample + yield from batch_samples return - - # 添加到打乱缓冲区 - self.shuffle_buffer.extend(batch_samples) - - # 如果缓冲区达到指定大小,打乱并输出 - if len(self.shuffle_buffer) >= self.shuffle_buffer_size: - random.shuffle(self.shuffle_buffer) - for sample in self.shuffle_buffer: - yield sample - self.shuffle_buffer = [] + + # 使用numpy批量操作代替random.shuffle + if batch_samples: + indices = np.random.permutation(len(batch_samples)) + for idx in indices: + yield batch_samples[idx] def __iter__(self): """ @@ -464,9 +454,6 @@ class PinyinInputDataset(IterableDataset): random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - # 重置打乱缓冲区 - self.shuffle_buffer = [] - for item in self.dataset: text = item.get(self.text_field, "") if not text: @@ -474,13 +461,19 @@ class PinyinInputDataset(IterableDataset): # 转换为拼音列表 pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) + # 批量收集需要查询的字符信息 + chinese_positions = [ + i for i, char in enumerate(text) + if self.query_engine.is_chinese_char(char) + ] + # 批量收集需要查询的字符信息 char_pinyin_batch = [] - char_positions = [] # 保存字符位置和上下文信息 - # 遍历文本中的每个字符 - for i, (char, py) in enumerate(zip(text, pinyin_list)): - if not self.is_chinese_char(char): - continue + char_positions = [] + + for i in chinese_positions: # 只遍历中文字符位置 + char = text[i] + py = pinyin_list[i] # 获取后续最多3个中文字符的拼音 next_chars = self.get_next_chinese_chars( @@ -518,12 +511,6 @@ class PinyinInputDataset(IterableDataset): ) yield from self._shuffle_and_yield(batch_samples) - # 清空缓冲区(处理完所有数据后) - if self.shuffle_buffer: - random.shuffle(self.shuffle_buffer) - for sample in self.shuffle_buffer: - yield sample - self.shuffle_buffer = [] def __len__(self): """ @@ -545,7 +532,7 @@ def worker_init_fn(worker_id): torch.manual_seed(seed % (2**32)) -def custom_collate(batch): +def custom_collate_with_txt(batch): """自定义批处理函数""" if not batch: return {} @@ -565,15 +552,26 @@ def custom_collate(batch): "py": [item["py"] for item in batch], } - # 如果存在token_type_ids则添加 - if "token_type_ids" in hints[0]: - result["hint"]["token_type_ids"] = torch.cat( - [h["token_type_ids"] for h in hints] - ) + return result - # 如果存在freq则添加 - if "freq" in batch[0]: - result["freq"] = torch.tensor([item["freq"] for item in batch]) + +def custom_collate(batch): + """自定义批处理函数""" + if not batch: + return {} + + # 处理hint字段 + hints = [item["hint"] for item in batch] + + # 合并所有张量字段 + result = { + "hint": { + "input_ids": torch.cat([h["input_ids"] for h in hints]), + "attention_mask": torch.cat([h["attention_mask"] for h in hints]), + }, + "char_id": torch.cat([item["char_id"] for item in batch]), + "py": [item["py"] for item in batch], + } return result @@ -593,31 +591,41 @@ if __name__ == "__main__": query_engine=query_engine, tokenizer_name="iic/nlp_structbert_backbone_tiny_std", max_len=88, - batch_query_size=500, + batch_query_size=300, shuffle=True, - shuffle_buffer_size=1000, + shuffle_buffer_size=4000, ) logger.info("数据集初始化") dataloader = DataLoader( dataset, - batch_size=256, + batch_size=1024, num_workers=15, worker_init_fn=worker_init_fn, # pin_memory=True, - collate_fn=custom_collate, - prefetch_factor=32, + collate_fn=custom_collate_with_txt, + prefetch_factor=8, persistent_workers=True, shuffle=False, # 数据集内部已实现打乱 - timeout=60, ) + """import cProfile + + def profile_func(dataloader): + for i, sample in tqdm(enumerate(dataloader), total=3000): + if i >= 3000: + break + return + + + cProfile.run('profile_func(dataloader)') + + """ # 测试数据集 try: - iterator = iter(dataset) logger.info("测试数据集") - for i, _ in tqdm(enumerate(dataloader), total=200): - if i >= 200: + for i, sample in tqdm(enumerate(dataloader), total=3000): + if i >= 3000: break """ print(f"Sample {i+1}:") @@ -629,3 +637,4 @@ if __name__ == "__main__": """ except StopIteration: print("数据集为空") + diff --git a/src/suinput/query.py b/src/suinput/query.py index 3fd590c..bf7e2b7 100644 --- a/src/suinput/query.py +++ b/src/suinput/query.py @@ -399,6 +399,12 @@ class QueryEngine: matches.sort(key=lambda x: x[1], reverse=True) return matches[:limit] if limit > 0 else matches + + def is_chinese_char(self, char: str) -> bool: + """ + 判断是否是汉字 + """ + return char in self._char_to_ids def get_statistics(self) -> Dict[str, Any]: """