调整数据集打乱缓冲区大小并优化样本处理逻辑

This commit is contained in:
songsenand 2026-02-13 11:05:42 +08:00
parent 54ac5af876
commit 92b12ef703
1 changed files with 14 additions and 78 deletions

View File

@ -1,9 +1,6 @@
import json
import os
import random
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional
import numpy as np
import torch
@ -39,7 +36,7 @@ class PinyinInputDataset(IterableDataset):
batch_query_size: int = 1000,
# 打乱参数
shuffle: bool = True,
shuffle_buffer_size: int = 100,
shuffle_buffer_size: int = 10000,
# 削峰填谷参数
max_freq: int = 434748359, # "的"的频率
min_freq: int = 109, # "蓚"的频率
@ -47,7 +44,6 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Optional[Dict[str, int]] = None,
):
"""
初始化数据集
@ -415,7 +411,6 @@ class PinyinInputDataset(IterableDataset):
if not char_info:
continue
logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
@ -446,7 +441,7 @@ class PinyinInputDataset(IterableDataset):
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
),
}
@ -485,7 +480,8 @@ class PinyinInputDataset(IterableDataset):
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
batch_samples = []
for item in self.dataset:
text = item.get(self.text_field, "")
if not text:
@ -531,18 +527,21 @@ class PinyinInputDataset(IterableDataset):
# 达到批量大小时处理
if len(char_pinyin_batch) >= self.batch_query_size:
batch_samples = self._process_batch(
batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
char_pinyin_batch = []
char_positions = []
if len(batch_samples) >= self.shuffle_buffer_size:
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
yield from self._shuffle_and_yield(batch_samples)
batch_samples = []
# 处理剩余的字符
if char_pinyin_batch:
batch_samples = self._process_batch(
batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
yield from self._shuffle_and_yield(batch_samples)
def __len__(self):
"""
@ -582,6 +581,7 @@ def custom_collate_with_txt(batch):
"char": [item["char"] for item in batch],
"txt": [item["txt"] for item in batch],
"py": [item["py"] for item in batch],
"pg": torch.cat([item["pg"] for item in batch]),
}
return result
@ -602,71 +602,7 @@ def custom_collate(batch):
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"py": [item["py"] for item in batch],
# "py_group_id": [item["py"] for item in batch],
"pg": torch.cat([item["pg"] for item in batch]),
}
return result
# 使用示例
if __name__ == "__main__":
from query import QueryEngine
from tqdm import tqdm
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=15,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
"""import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集
try:
logger.info("测试数据集")
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
print(f" Pinyin: {sample['py']}")
print(f" Context length: {len(sample['txt'])}")
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
print()
"""
except StopIteration:
print("数据集为空")