优化打乱逻辑并提升数据处理效率

This commit is contained in:
songsenand 2026-02-09 23:53:11 +08:00
parent 1bdbbe284c
commit 9b813732fd
2 changed files with 69 additions and 54 deletions

View File

@ -1,6 +1,6 @@
import random
import re
from typing import Any, Dict, List, Tuple
import os
import numpy as np
import torch
@ -11,6 +11,8 @@ from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset):
"""
拼音输入法模拟数据集
@ -70,7 +72,6 @@ class PinyinInputDataset(IterableDataset):
# 打乱相关参数
self.shuffle = shuffle
self.shuffle_buffer_size = shuffle_buffer_size
self.shuffle_buffer = []
# 削峰填谷参数
self.max_freq = max_freq
@ -87,18 +88,12 @@ class PinyinInputDataset(IterableDataset):
stats = query_engine.get_statistics()
self.total_chars = stats.get("valid_input_character_count", 0)
# 汉字正则表达式
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
# 缓存字典
self.char_info_cache = {}
# 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True)
def is_chinese_char(self, char: str) -> bool:
"""判断是否为中文字符"""
return bool(self.chinese_pattern.match(char))
def get_next_chinese_chars(
self,
@ -126,7 +121,7 @@ class PinyinInputDataset(IterableDataset):
break
char = text[i]
if self.is_chinese_char(char):
if self.query_engine.is_chinese_char(char):
# 获取拼音注意这里需要确保拼音列表长度与text一致
try:
# 重新计算整个text的拼音可能效率低但确保准确
@ -430,21 +425,16 @@ class PinyinInputDataset(IterableDataset):
return batch_samples
def _shuffle_and_yield(self, batch_samples):
"""打乱并yield样本"""
"""优化打乱逻辑"""
if not self.shuffle:
for sample in batch_samples:
yield sample
yield from batch_samples
return
# 添加到打乱缓冲区
self.shuffle_buffer.extend(batch_samples)
# 如果缓冲区达到指定大小,打乱并输出
if len(self.shuffle_buffer) >= self.shuffle_buffer_size:
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
# 使用numpy批量操作代替random.shuffle
if batch_samples:
indices = np.random.permutation(len(batch_samples))
for idx in indices:
yield batch_samples[idx]
def __iter__(self):
"""
@ -464,9 +454,6 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
# 重置打乱缓冲区
self.shuffle_buffer = []
for item in self.dataset:
text = item.get(self.text_field, "")
if not text:
@ -474,13 +461,19 @@ class PinyinInputDataset(IterableDataset):
# 转换为拼音列表
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
chinese_positions = [
i for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char)
]
# 批量收集需要查询的字符信息
char_pinyin_batch = []
char_positions = [] # 保存字符位置和上下文信息
# 遍历文本中的每个字符
for i, (char, py) in enumerate(zip(text, pinyin_list)):
if not self.is_chinese_char(char):
continue
char_positions = []
for i in chinese_positions: # 只遍历中文字符位置
char = text[i]
py = pinyin_list[i]
# 获取后续最多3个中文字符的拼音
next_chars = self.get_next_chinese_chars(
@ -518,12 +511,6 @@ class PinyinInputDataset(IterableDataset):
)
yield from self._shuffle_and_yield(batch_samples)
# 清空缓冲区(处理完所有数据后)
if self.shuffle_buffer:
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
def __len__(self):
"""
@ -545,7 +532,7 @@ def worker_init_fn(worker_id):
torch.manual_seed(seed % (2**32))
def custom_collate(batch):
def custom_collate_with_txt(batch):
"""自定义批处理函数"""
if not batch:
return {}
@ -565,15 +552,26 @@ def custom_collate(batch):
"py": [item["py"] for item in batch],
}
# 如果存在token_type_ids则添加
if "token_type_ids" in hints[0]:
result["hint"]["token_type_ids"] = torch.cat(
[h["token_type_ids"] for h in hints]
)
return result
# 如果存在freq则添加
if "freq" in batch[0]:
result["freq"] = torch.tensor([item["freq"] for item in batch])
def custom_collate(batch):
"""自定义批处理函数"""
if not batch:
return {}
# 处理hint字段
hints = [item["hint"] for item in batch]
# 合并所有张量字段
result = {
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"py": [item["py"] for item in batch],
}
return result
@ -593,31 +591,41 @@ if __name__ == "__main__":
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=500,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=1000,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=256,
batch_size=1024,
num_workers=15,
worker_init_fn=worker_init_fn,
# pin_memory=True,
collate_fn=custom_collate,
prefetch_factor=32,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
timeout=60,
)
"""import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集
try:
iterator = iter(dataset)
logger.info("测试数据集")
for i, _ in tqdm(enumerate(dataloader), total=200):
if i >= 200:
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
"""
print(f"Sample {i+1}:")
@ -629,3 +637,4 @@ if __name__ == "__main__":
"""
except StopIteration:
print("数据集为空")

View File

@ -399,6 +399,12 @@ class QueryEngine:
matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit] if limit > 0 else matches
def is_chinese_char(self, char: str) -> bool:
"""
判断是否是汉字
"""
return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]:
"""