优化打乱逻辑并提升数据处理效率
This commit is contained in:
parent
1bdbbe284c
commit
9b813732fd
|
|
@ -1,6 +1,6 @@
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -11,6 +11,8 @@ from pypinyin import lazy_pinyin
|
||||||
from torch.utils.data import DataLoader, IterableDataset
|
from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
拼音输入法模拟数据集
|
拼音输入法模拟数据集
|
||||||
|
|
@ -70,7 +72,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 打乱相关参数
|
# 打乱相关参数
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.shuffle_buffer_size = shuffle_buffer_size
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
self.shuffle_buffer = []
|
|
||||||
|
|
||||||
# 削峰填谷参数
|
# 削峰填谷参数
|
||||||
self.max_freq = max_freq
|
self.max_freq = max_freq
|
||||||
|
|
@ -87,18 +88,12 @@ class PinyinInputDataset(IterableDataset):
|
||||||
stats = query_engine.get_statistics()
|
stats = query_engine.get_statistics()
|
||||||
self.total_chars = stats.get("valid_input_character_count", 0)
|
self.total_chars = stats.get("valid_input_character_count", 0)
|
||||||
|
|
||||||
# 汉字正则表达式
|
|
||||||
self.chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
|
||||||
|
|
||||||
# 缓存字典
|
# 缓存字典
|
||||||
self.char_info_cache = {}
|
self.char_info_cache = {}
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||||
|
|
||||||
def is_chinese_char(self, char: str) -> bool:
|
|
||||||
"""判断是否为中文字符"""
|
|
||||||
return bool(self.chinese_pattern.match(char))
|
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
|
|
@ -126,7 +121,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
break
|
break
|
||||||
|
|
||||||
char = text[i]
|
char = text[i]
|
||||||
if self.is_chinese_char(char):
|
if self.query_engine.is_chinese_char(char):
|
||||||
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
||||||
try:
|
try:
|
||||||
# 重新计算整个text的拼音可能效率低,但确保准确
|
# 重新计算整个text的拼音可能效率低,但确保准确
|
||||||
|
|
@ -430,21 +425,16 @@ class PinyinInputDataset(IterableDataset):
|
||||||
return batch_samples
|
return batch_samples
|
||||||
|
|
||||||
def _shuffle_and_yield(self, batch_samples):
|
def _shuffle_and_yield(self, batch_samples):
|
||||||
"""打乱并yield样本"""
|
"""优化打乱逻辑"""
|
||||||
if not self.shuffle:
|
if not self.shuffle:
|
||||||
for sample in batch_samples:
|
yield from batch_samples
|
||||||
yield sample
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 添加到打乱缓冲区
|
# 使用numpy批量操作代替random.shuffle
|
||||||
self.shuffle_buffer.extend(batch_samples)
|
if batch_samples:
|
||||||
|
indices = np.random.permutation(len(batch_samples))
|
||||||
# 如果缓冲区达到指定大小,打乱并输出
|
for idx in indices:
|
||||||
if len(self.shuffle_buffer) >= self.shuffle_buffer_size:
|
yield batch_samples[idx]
|
||||||
random.shuffle(self.shuffle_buffer)
|
|
||||||
for sample in self.shuffle_buffer:
|
|
||||||
yield sample
|
|
||||||
self.shuffle_buffer = []
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -464,9 +454,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
random.seed(seed % (2**32))
|
random.seed(seed % (2**32))
|
||||||
np.random.seed(seed % (2**32))
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
# 重置打乱缓冲区
|
|
||||||
self.shuffle_buffer = []
|
|
||||||
|
|
||||||
for item in self.dataset:
|
for item in self.dataset:
|
||||||
text = item.get(self.text_field, "")
|
text = item.get(self.text_field, "")
|
||||||
if not text:
|
if not text:
|
||||||
|
|
@ -474,13 +461,19 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 转换为拼音列表
|
# 转换为拼音列表
|
||||||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
# 批量收集需要查询的字符信息
|
||||||
|
chinese_positions = [
|
||||||
|
i for i, char in enumerate(text)
|
||||||
|
if self.query_engine.is_chinese_char(char)
|
||||||
|
]
|
||||||
|
|
||||||
# 批量收集需要查询的字符信息
|
# 批量收集需要查询的字符信息
|
||||||
char_pinyin_batch = []
|
char_pinyin_batch = []
|
||||||
char_positions = [] # 保存字符位置和上下文信息
|
char_positions = []
|
||||||
# 遍历文本中的每个字符
|
|
||||||
for i, (char, py) in enumerate(zip(text, pinyin_list)):
|
for i in chinese_positions: # 只遍历中文字符位置
|
||||||
if not self.is_chinese_char(char):
|
char = text[i]
|
||||||
continue
|
py = pinyin_list[i]
|
||||||
|
|
||||||
# 获取后续最多3个中文字符的拼音
|
# 获取后续最多3个中文字符的拼音
|
||||||
next_chars = self.get_next_chinese_chars(
|
next_chars = self.get_next_chinese_chars(
|
||||||
|
|
@ -518,12 +511,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
# 清空缓冲区(处理完所有数据后)
|
|
||||||
if self.shuffle_buffer:
|
|
||||||
random.shuffle(self.shuffle_buffer)
|
|
||||||
for sample in self.shuffle_buffer:
|
|
||||||
yield sample
|
|
||||||
self.shuffle_buffer = []
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -545,7 +532,7 @@ def worker_init_fn(worker_id):
|
||||||
torch.manual_seed(seed % (2**32))
|
torch.manual_seed(seed % (2**32))
|
||||||
|
|
||||||
|
|
||||||
def custom_collate(batch):
|
def custom_collate_with_txt(batch):
|
||||||
"""自定义批处理函数"""
|
"""自定义批处理函数"""
|
||||||
if not batch:
|
if not batch:
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -565,15 +552,26 @@ def custom_collate(batch):
|
||||||
"py": [item["py"] for item in batch],
|
"py": [item["py"] for item in batch],
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果存在token_type_ids则添加
|
return result
|
||||||
if "token_type_ids" in hints[0]:
|
|
||||||
result["hint"]["token_type_ids"] = torch.cat(
|
|
||||||
[h["token_type_ids"] for h in hints]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果存在freq则添加
|
|
||||||
if "freq" in batch[0]:
|
def custom_collate(batch):
|
||||||
result["freq"] = torch.tensor([item["freq"] for item in batch])
|
"""自定义批处理函数"""
|
||||||
|
if not batch:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 处理hint字段
|
||||||
|
hints = [item["hint"] for item in batch]
|
||||||
|
|
||||||
|
# 合并所有张量字段
|
||||||
|
result = {
|
||||||
|
"hint": {
|
||||||
|
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||||||
|
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||||
|
},
|
||||||
|
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||||
|
"py": [item["py"] for item in batch],
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -593,31 +591,41 @@ if __name__ == "__main__":
|
||||||
query_engine=query_engine,
|
query_engine=query_engine,
|
||||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
max_len=88,
|
max_len=88,
|
||||||
batch_query_size=500,
|
batch_query_size=300,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
shuffle_buffer_size=1000,
|
shuffle_buffer_size=4000,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("数据集初始化")
|
logger.info("数据集初始化")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=256,
|
batch_size=1024,
|
||||||
num_workers=15,
|
num_workers=15,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
# pin_memory=True,
|
# pin_memory=True,
|
||||||
collate_fn=custom_collate,
|
collate_fn=custom_collate_with_txt,
|
||||||
prefetch_factor=32,
|
prefetch_factor=8,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
shuffle=False, # 数据集内部已实现打乱
|
shuffle=False, # 数据集内部已实现打乱
|
||||||
timeout=60,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""import cProfile
|
||||||
|
|
||||||
|
def profile_func(dataloader):
|
||||||
|
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||||
|
if i >= 3000:
|
||||||
|
break
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
cProfile.run('profile_func(dataloader)')
|
||||||
|
|
||||||
|
"""
|
||||||
# 测试数据集
|
# 测试数据集
|
||||||
try:
|
try:
|
||||||
iterator = iter(dataset)
|
|
||||||
logger.info("测试数据集")
|
logger.info("测试数据集")
|
||||||
for i, _ in tqdm(enumerate(dataloader), total=200):
|
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||||
if i >= 200:
|
if i >= 3000:
|
||||||
break
|
break
|
||||||
"""
|
"""
|
||||||
print(f"Sample {i+1}:")
|
print(f"Sample {i+1}:")
|
||||||
|
|
@ -629,3 +637,4 @@ if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print("数据集为空")
|
print("数据集为空")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -400,6 +400,12 @@ class QueryEngine:
|
||||||
|
|
||||||
return matches[:limit] if limit > 0 else matches
|
return matches[:limit] if limit > 0 else matches
|
||||||
|
|
||||||
|
def is_chinese_char(self, char: str) -> bool:
|
||||||
|
"""
|
||||||
|
判断是否是汉字
|
||||||
|
"""
|
||||||
|
return char in self._char_to_ids
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取系统统计信息
|
获取系统统计信息
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue