调整数据集打乱缓冲区大小并优化样本处理逻辑
This commit is contained in:
parent
54ac5af876
commit
92b12ef703
|
|
@ -1,9 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -39,7 +36,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
batch_query_size: int = 1000,
|
||||
# 打乱参数
|
||||
shuffle: bool = True,
|
||||
shuffle_buffer_size: int = 100,
|
||||
shuffle_buffer_size: int = 10000,
|
||||
# 削峰填谷参数
|
||||
max_freq: int = 434748359, # "的"的频率
|
||||
min_freq: int = 109, # "蓚"的频率
|
||||
|
|
@ -47,7 +44,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -415,7 +411,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not char_info:
|
||||
continue
|
||||
|
||||
logger.info(f"获取字符信息: {char_info}")
|
||||
# 削峰填谷调整
|
||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||
if adjust_factor <= 0:
|
||||
|
|
@ -446,7 +441,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
"char": char,
|
||||
"freq": char_info["freq"],
|
||||
"pg": torch.tensor(
|
||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -486,6 +481,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
batch_samples = []
|
||||
for item in self.dataset:
|
||||
text = item.get(self.text_field, "")
|
||||
if not text:
|
||||
|
|
@ -531,15 +527,18 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 达到批量大小时处理
|
||||
if len(char_pinyin_batch) >= self.batch_query_size:
|
||||
batch_samples = self._process_batch(
|
||||
batch_samples += self._process_batch(
|
||||
char_pinyin_batch, char_positions, text
|
||||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
char_pinyin_batch = []
|
||||
char_positions = []
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
batch_samples = []
|
||||
# 处理剩余的字符
|
||||
if char_pinyin_batch:
|
||||
batch_samples = self._process_batch(
|
||||
batch_samples += self._process_batch(
|
||||
char_pinyin_batch, char_positions, text
|
||||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
|
|
@ -582,6 +581,7 @@ def custom_collate_with_txt(batch):
|
|||
"char": [item["char"] for item in batch],
|
||||
"txt": [item["txt"] for item in batch],
|
||||
"py": [item["py"] for item in batch],
|
||||
"pg": torch.cat([item["pg"] for item in batch]),
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
@ -602,71 +602,7 @@ def custom_collate(batch):
|
|||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||
},
|
||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||
"py": [item["py"] for item in batch],
|
||||
# "py_group_id": [item["py"] for item in batch],
|
||||
"pg": torch.cat([item["pg"] for item in batch]),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
from query import QueryEngine
|
||||
from tqdm import tqdm
|
||||
|
||||
# 初始化查询引擎
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
||||
query_engine=query_engine,
|
||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||
max_len=88,
|
||||
batch_query_size=300,
|
||||
shuffle=True,
|
||||
shuffle_buffer_size=4000,
|
||||
)
|
||||
|
||||
logger.info("数据集初始化")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1024,
|
||||
num_workers=15,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True if torch.cuda.is_available() else False,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
shuffle=False, # 数据集内部已实现打乱
|
||||
)
|
||||
|
||||
"""import cProfile
|
||||
|
||||
def profile_func(dataloader):
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
return
|
||||
|
||||
|
||||
cProfile.run('profile_func(dataloader)')
|
||||
|
||||
"""
|
||||
# 测试数据集
|
||||
try:
|
||||
logger.info("测试数据集")
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
"""
|
||||
print(f"Sample {i+1}:")
|
||||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||
print(f" Pinyin: {sample['py']}")
|
||||
print(f" Context length: {len(sample['txt'])}")
|
||||
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
||||
print()
|
||||
"""
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
|
|
|
|||
Loading…
Reference in New Issue