调整数据集打乱缓冲区大小并优化样本处理逻辑

This commit is contained in:
songsenand 2026-02-13 11:05:42 +08:00
parent 54ac5af876
commit 92b12ef703
1 changed files with 14 additions and 78 deletions

View File

@ -1,9 +1,6 @@
import json
import os import os
import random import random
from importlib.resources import files from typing import Any, Dict, List, Tuple, Optional
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -39,7 +36,7 @@ class PinyinInputDataset(IterableDataset):
batch_query_size: int = 1000, batch_query_size: int = 1000,
# 打乱参数 # 打乱参数
shuffle: bool = True, shuffle: bool = True,
shuffle_buffer_size: int = 100, shuffle_buffer_size: int = 10000,
# 削峰填谷参数 # 削峰填谷参数
max_freq: int = 434748359, # "的"的频率 max_freq: int = 434748359, # "的"的频率
min_freq: int = 109, # "蓚"的频率 min_freq: int = 109, # "蓚"的频率
@ -47,7 +44,6 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率 max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望 max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Optional[Dict[str, int]] = None,
): ):
""" """
初始化数据集 初始化数据集
@ -415,7 +411,6 @@ class PinyinInputDataset(IterableDataset):
if not char_info: if not char_info:
continue continue
logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整 # 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"]) adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0: if adjust_factor <= 0:
@ -446,7 +441,7 @@ class PinyinInputDataset(IterableDataset):
"char": char, "char": char,
"freq": char_info["freq"], "freq": char_info["freq"],
"pg": torch.tensor( "pg": torch.tensor(
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
), ),
} }
@ -486,6 +481,7 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32)) random.seed(seed % (2**32))
np.random.seed(seed % (2**32)) np.random.seed(seed % (2**32))
batch_samples = []
for item in self.dataset: for item in self.dataset:
text = item.get(self.text_field, "") text = item.get(self.text_field, "")
if not text: if not text:
@ -531,15 +527,18 @@ class PinyinInputDataset(IterableDataset):
# 达到批量大小时处理 # 达到批量大小时处理
if len(char_pinyin_batch) >= self.batch_query_size: if len(char_pinyin_batch) >= self.batch_query_size:
batch_samples = self._process_batch( batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text char_pinyin_batch, char_positions, text
) )
yield from self._shuffle_and_yield(batch_samples)
char_pinyin_batch = [] char_pinyin_batch = []
char_positions = [] char_positions = []
if len(batch_samples) >= self.shuffle_buffer_size:
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
yield from self._shuffle_and_yield(batch_samples)
batch_samples = []
# 处理剩余的字符 # 处理剩余的字符
if char_pinyin_batch: if char_pinyin_batch:
batch_samples = self._process_batch( batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text char_pinyin_batch, char_positions, text
) )
yield from self._shuffle_and_yield(batch_samples) yield from self._shuffle_and_yield(batch_samples)
@ -582,6 +581,7 @@ def custom_collate_with_txt(batch):
"char": [item["char"] for item in batch], "char": [item["char"] for item in batch],
"txt": [item["txt"] for item in batch], "txt": [item["txt"] for item in batch],
"py": [item["py"] for item in batch], "py": [item["py"] for item in batch],
"pg": torch.cat([item["pg"] for item in batch]),
} }
return result return result
@ -602,71 +602,7 @@ def custom_collate(batch):
"attention_mask": torch.cat([h["attention_mask"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]),
}, },
"char_id": torch.cat([item["char_id"] for item in batch]), "char_id": torch.cat([item["char_id"] for item in batch]),
"py": [item["py"] for item in batch], "pg": torch.cat([item["pg"] for item in batch]),
# "py_group_id": [item["py"] for item in batch],
} }
return result return result
# 使用示例
if __name__ == "__main__":
from query import QueryEngine
from tqdm import tqdm
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=15,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
"""import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集
try:
logger.info("测试数据集")
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
print(f" Pinyin: {sample['py']}")
print(f" Context length: {len(sample['txt'])}")
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
print()
"""
except StopIteration:
print("数据集为空")