优化打乱逻辑并提升数据处理效率

This commit is contained in:
songsenand 2026-02-09 23:53:11 +08:00
parent 1bdbbe284c
commit 9b813732fd
2 changed files with 69 additions and 54 deletions

View File

@ -1,6 +1,6 @@
import random import random
import re
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import os
import numpy as np import numpy as np
import torch import torch
@ -11,6 +11,8 @@ from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class PinyinInputDataset(IterableDataset): class PinyinInputDataset(IterableDataset):
""" """
拼音输入法模拟数据集 拼音输入法模拟数据集
@ -70,7 +72,6 @@ class PinyinInputDataset(IterableDataset):
# 打乱相关参数 # 打乱相关参数
self.shuffle = shuffle self.shuffle = shuffle
self.shuffle_buffer_size = shuffle_buffer_size self.shuffle_buffer_size = shuffle_buffer_size
self.shuffle_buffer = []
# 削峰填谷参数 # 削峰填谷参数
self.max_freq = max_freq self.max_freq = max_freq
@ -87,18 +88,12 @@ class PinyinInputDataset(IterableDataset):
stats = query_engine.get_statistics() stats = query_engine.get_statistics()
self.total_chars = stats.get("valid_input_character_count", 0) self.total_chars = stats.get("valid_input_character_count", 0)
# 汉字正则表达式
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
# 缓存字典 # 缓存字典
self.char_info_cache = {} self.char_info_cache = {}
# 加载数据集 # 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True) self.dataset = load_dataset(data_dir, split="train", streaming=True)
def is_chinese_char(self, char: str) -> bool:
"""判断是否为中文字符"""
return bool(self.chinese_pattern.match(char))
def get_next_chinese_chars( def get_next_chinese_chars(
self, self,
@ -126,7 +121,7 @@ class PinyinInputDataset(IterableDataset):
break break
char = text[i] char = text[i]
if self.is_chinese_char(char): if self.query_engine.is_chinese_char(char):
# 获取拼音注意这里需要确保拼音列表长度与text一致 # 获取拼音注意这里需要确保拼音列表长度与text一致
try: try:
# 重新计算整个text的拼音可能效率低但确保准确 # 重新计算整个text的拼音可能效率低但确保准确
@ -430,21 +425,16 @@ class PinyinInputDataset(IterableDataset):
return batch_samples return batch_samples
def _shuffle_and_yield(self, batch_samples): def _shuffle_and_yield(self, batch_samples):
"""打乱并yield样本""" """优化打乱逻辑"""
if not self.shuffle: if not self.shuffle:
for sample in batch_samples: yield from batch_samples
yield sample
return return
# 添加到打乱缓冲区 # 使用numpy批量操作代替random.shuffle
self.shuffle_buffer.extend(batch_samples) if batch_samples:
indices = np.random.permutation(len(batch_samples))
# 如果缓冲区达到指定大小,打乱并输出 for idx in indices:
if len(self.shuffle_buffer) >= self.shuffle_buffer_size: yield batch_samples[idx]
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
def __iter__(self): def __iter__(self):
""" """
@ -464,9 +454,6 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32)) random.seed(seed % (2**32))
np.random.seed(seed % (2**32)) np.random.seed(seed % (2**32))
# 重置打乱缓冲区
self.shuffle_buffer = []
for item in self.dataset: for item in self.dataset:
text = item.get(self.text_field, "") text = item.get(self.text_field, "")
if not text: if not text:
@ -474,13 +461,19 @@ class PinyinInputDataset(IterableDataset):
# 转换为拼音列表 # 转换为拼音列表
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
chinese_positions = [
i for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char)
]
# 批量收集需要查询的字符信息 # 批量收集需要查询的字符信息
char_pinyin_batch = [] char_pinyin_batch = []
char_positions = [] # 保存字符位置和上下文信息 char_positions = []
# 遍历文本中的每个字符
for i, (char, py) in enumerate(zip(text, pinyin_list)): for i in chinese_positions: # 只遍历中文字符位置
if not self.is_chinese_char(char): char = text[i]
continue py = pinyin_list[i]
# 获取后续最多3个中文字符的拼音 # 获取后续最多3个中文字符的拼音
next_chars = self.get_next_chinese_chars( next_chars = self.get_next_chinese_chars(
@ -518,12 +511,6 @@ class PinyinInputDataset(IterableDataset):
) )
yield from self._shuffle_and_yield(batch_samples) yield from self._shuffle_and_yield(batch_samples)
# 清空缓冲区(处理完所有数据后)
if self.shuffle_buffer:
random.shuffle(self.shuffle_buffer)
for sample in self.shuffle_buffer:
yield sample
self.shuffle_buffer = []
def __len__(self): def __len__(self):
""" """
@ -545,7 +532,7 @@ def worker_init_fn(worker_id):
torch.manual_seed(seed % (2**32)) torch.manual_seed(seed % (2**32))
def custom_collate(batch): def custom_collate_with_txt(batch):
"""自定义批处理函数""" """自定义批处理函数"""
if not batch: if not batch:
return {} return {}
@ -565,15 +552,26 @@ def custom_collate(batch):
"py": [item["py"] for item in batch], "py": [item["py"] for item in batch],
} }
# 如果存在token_type_ids则添加 return result
if "token_type_ids" in hints[0]:
result["hint"]["token_type_ids"] = torch.cat(
[h["token_type_ids"] for h in hints]
)
# 如果存在freq则添加
if "freq" in batch[0]: def custom_collate(batch):
result["freq"] = torch.tensor([item["freq"] for item in batch]) """自定义批处理函数"""
if not batch:
return {}
# 处理hint字段
hints = [item["hint"] for item in batch]
# 合并所有张量字段
result = {
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"py": [item["py"] for item in batch],
}
return result return result
@ -593,31 +591,41 @@ if __name__ == "__main__":
query_engine=query_engine, query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std", tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88, max_len=88,
batch_query_size=500, batch_query_size=300,
shuffle=True, shuffle=True,
shuffle_buffer_size=1000, shuffle_buffer_size=4000,
) )
logger.info("数据集初始化") logger.info("数据集初始化")
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=256, batch_size=1024,
num_workers=15, num_workers=15,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
# pin_memory=True, # pin_memory=True,
collate_fn=custom_collate, collate_fn=custom_collate_with_txt,
prefetch_factor=32, prefetch_factor=8,
persistent_workers=True, persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱 shuffle=False, # 数据集内部已实现打乱
timeout=60,
) )
"""import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集 # 测试数据集
try: try:
iterator = iter(dataset)
logger.info("测试数据集") logger.info("测试数据集")
for i, _ in tqdm(enumerate(dataloader), total=200): for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 200: if i >= 3000:
break break
""" """
print(f"Sample {i+1}:") print(f"Sample {i+1}:")
@ -629,3 +637,4 @@ if __name__ == "__main__":
""" """
except StopIteration: except StopIteration:
print("数据集为空") print("数据集为空")

View File

@ -400,6 +400,12 @@ class QueryEngine:
return matches[:limit] if limit > 0 else matches return matches[:limit] if limit > 0 else matches
def is_chinese_char(self, char: str) -> bool:
"""
判断是否是汉字
"""
return char in self._char_to_ids
def get_statistics(self) -> Dict[str, Any]: def get_statistics(self) -> Dict[str, Any]:
""" """
获取系统统计信息 获取系统统计信息