优化打乱逻辑并提升数据处理效率
This commit is contained in:
parent
1bdbbe284c
commit
9b813732fd
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -11,6 +11,8 @@ from pypinyin import lazy_pinyin
|
|||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
"""
|
||||
拼音输入法模拟数据集
|
||||
|
|
@ -70,7 +72,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 打乱相关参数
|
||||
self.shuffle = shuffle
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
self.shuffle_buffer = []
|
||||
|
||||
# 削峰填谷参数
|
||||
self.max_freq = max_freq
|
||||
|
|
@ -87,18 +88,12 @@ class PinyinInputDataset(IterableDataset):
|
|||
stats = query_engine.get_statistics()
|
||||
self.total_chars = stats.get("valid_input_character_count", 0)
|
||||
|
||||
# 汉字正则表达式
|
||||
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
||||
|
||||
# 缓存字典
|
||||
self.char_info_cache = {}
|
||||
|
||||
# 加载数据集
|
||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||
|
||||
def is_chinese_char(self, char: str) -> bool:
|
||||
"""判断是否为中文字符"""
|
||||
return bool(self.chinese_pattern.match(char))
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self,
|
||||
|
|
@ -126,7 +121,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
break
|
||||
|
||||
char = text[i]
|
||||
if self.is_chinese_char(char):
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
||||
try:
|
||||
# 重新计算整个text的拼音可能效率低,但确保准确
|
||||
|
|
@ -430,21 +425,16 @@ class PinyinInputDataset(IterableDataset):
|
|||
return batch_samples
|
||||
|
||||
def _shuffle_and_yield(self, batch_samples):
|
||||
"""打乱并yield样本"""
|
||||
"""优化打乱逻辑"""
|
||||
if not self.shuffle:
|
||||
for sample in batch_samples:
|
||||
yield sample
|
||||
yield from batch_samples
|
||||
return
|
||||
|
||||
# 添加到打乱缓冲区
|
||||
self.shuffle_buffer.extend(batch_samples)
|
||||
|
||||
# 如果缓冲区达到指定大小,打乱并输出
|
||||
if len(self.shuffle_buffer) >= self.shuffle_buffer_size:
|
||||
random.shuffle(self.shuffle_buffer)
|
||||
for sample in self.shuffle_buffer:
|
||||
yield sample
|
||||
self.shuffle_buffer = []
|
||||
# 使用numpy批量操作代替random.shuffle
|
||||
if batch_samples:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
for idx in indices:
|
||||
yield batch_samples[idx]
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
|
@ -464,9 +454,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
# 重置打乱缓冲区
|
||||
self.shuffle_buffer = []
|
||||
|
||||
for item in self.dataset:
|
||||
text = item.get(self.text_field, "")
|
||||
if not text:
|
||||
|
|
@ -474,13 +461,19 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 转换为拼音列表
|
||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
# 批量收集需要查询的字符信息
|
||||
chinese_positions = [
|
||||
i for i, char in enumerate(text)
|
||||
if self.query_engine.is_chinese_char(char)
|
||||
]
|
||||
|
||||
# 批量收集需要查询的字符信息
|
||||
char_pinyin_batch = []
|
||||
char_positions = [] # 保存字符位置和上下文信息
|
||||
# 遍历文本中的每个字符
|
||||
for i, (char, py) in enumerate(zip(text, pinyin_list)):
|
||||
if not self.is_chinese_char(char):
|
||||
continue
|
||||
char_positions = []
|
||||
|
||||
for i in chinese_positions: # 只遍历中文字符位置
|
||||
char = text[i]
|
||||
py = pinyin_list[i]
|
||||
|
||||
# 获取后续最多3个中文字符的拼音
|
||||
next_chars = self.get_next_chinese_chars(
|
||||
|
|
@ -518,12 +511,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
|
||||
# 清空缓冲区(处理完所有数据后)
|
||||
if self.shuffle_buffer:
|
||||
random.shuffle(self.shuffle_buffer)
|
||||
for sample in self.shuffle_buffer:
|
||||
yield sample
|
||||
self.shuffle_buffer = []
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
|
|
@ -545,7 +532,7 @@ def worker_init_fn(worker_id):
|
|||
torch.manual_seed(seed % (2**32))
|
||||
|
||||
|
||||
def custom_collate(batch):
|
||||
def custom_collate_with_txt(batch):
|
||||
"""自定义批处理函数"""
|
||||
if not batch:
|
||||
return {}
|
||||
|
|
@ -565,15 +552,26 @@ def custom_collate(batch):
|
|||
"py": [item["py"] for item in batch],
|
||||
}
|
||||
|
||||
# 如果存在token_type_ids则添加
|
||||
if "token_type_ids" in hints[0]:
|
||||
result["hint"]["token_type_ids"] = torch.cat(
|
||||
[h["token_type_ids"] for h in hints]
|
||||
)
|
||||
return result
|
||||
|
||||
# 如果存在freq则添加
|
||||
if "freq" in batch[0]:
|
||||
result["freq"] = torch.tensor([item["freq"] for item in batch])
|
||||
|
||||
def custom_collate(batch):
|
||||
"""自定义批处理函数"""
|
||||
if not batch:
|
||||
return {}
|
||||
|
||||
# 处理hint字段
|
||||
hints = [item["hint"] for item in batch]
|
||||
|
||||
# 合并所有张量字段
|
||||
result = {
|
||||
"hint": {
|
||||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||
},
|
||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||
"py": [item["py"] for item in batch],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -593,31 +591,41 @@ if __name__ == "__main__":
|
|||
query_engine=query_engine,
|
||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||
max_len=88,
|
||||
batch_query_size=500,
|
||||
batch_query_size=300,
|
||||
shuffle=True,
|
||||
shuffle_buffer_size=1000,
|
||||
shuffle_buffer_size=4000,
|
||||
)
|
||||
|
||||
logger.info("数据集初始化")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=256,
|
||||
batch_size=1024,
|
||||
num_workers=15,
|
||||
worker_init_fn=worker_init_fn,
|
||||
# pin_memory=True,
|
||||
collate_fn=custom_collate,
|
||||
prefetch_factor=32,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
shuffle=False, # 数据集内部已实现打乱
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
"""import cProfile
|
||||
|
||||
def profile_func(dataloader):
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
return
|
||||
|
||||
|
||||
cProfile.run('profile_func(dataloader)')
|
||||
|
||||
"""
|
||||
# 测试数据集
|
||||
try:
|
||||
iterator = iter(dataset)
|
||||
logger.info("测试数据集")
|
||||
for i, _ in tqdm(enumerate(dataloader), total=200):
|
||||
if i >= 200:
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
"""
|
||||
print(f"Sample {i+1}:")
|
||||
|
|
@ -629,3 +637,4 @@ if __name__ == "__main__":
|
|||
"""
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
|
||||
|
|
|
|||
|
|
@ -400,6 +400,12 @@ class QueryEngine:
|
|||
|
||||
return matches[:limit] if limit > 0 else matches
|
||||
|
||||
def is_chinese_char(self, char: str) -> bool:
|
||||
"""
|
||||
判断是否是汉字
|
||||
"""
|
||||
return char in self._char_to_ids
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
|
|
|||
Loading…
Reference in New Issue