调整数据集打乱缓冲区大小并优化样本处理逻辑
This commit is contained in:
parent
54ac5af876
commit
92b12ef703
|
|
@ -1,9 +1,6 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from importlib.resources import files
|
from typing import Any, Dict, List, Tuple, Optional
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -39,7 +36,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
batch_query_size: int = 1000,
|
batch_query_size: int = 1000,
|
||||||
# 打乱参数
|
# 打乱参数
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
shuffle_buffer_size: int = 100,
|
shuffle_buffer_size: int = 10000,
|
||||||
# 削峰填谷参数
|
# 削峰填谷参数
|
||||||
max_freq: int = 434748359, # "的"的频率
|
max_freq: int = 434748359, # "的"的频率
|
||||||
min_freq: int = 109, # "蓚"的频率
|
min_freq: int = 109, # "蓚"的频率
|
||||||
|
|
@ -47,7 +44,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -415,7 +411,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not char_info:
|
if not char_info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"获取字符信息: {char_info}")
|
|
||||||
# 削峰填谷调整
|
# 削峰填谷调整
|
||||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||||
if adjust_factor <= 0:
|
if adjust_factor <= 0:
|
||||||
|
|
@ -446,7 +441,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor(
|
||||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -485,7 +480,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
seed = base_seed + worker_id
|
seed = base_seed + worker_id
|
||||||
random.seed(seed % (2**32))
|
random.seed(seed % (2**32))
|
||||||
np.random.seed(seed % (2**32))
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
|
batch_samples = []
|
||||||
for item in self.dataset:
|
for item in self.dataset:
|
||||||
text = item.get(self.text_field, "")
|
text = item.get(self.text_field, "")
|
||||||
if not text:
|
if not text:
|
||||||
|
|
@ -531,18 +527,21 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 达到批量大小时处理
|
# 达到批量大小时处理
|
||||||
if len(char_pinyin_batch) >= self.batch_query_size:
|
if len(char_pinyin_batch) >= self.batch_query_size:
|
||||||
batch_samples = self._process_batch(
|
batch_samples += self._process_batch(
|
||||||
char_pinyin_batch, char_positions, text
|
char_pinyin_batch, char_positions, text
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
|
||||||
char_pinyin_batch = []
|
char_pinyin_batch = []
|
||||||
char_positions = []
|
char_positions = []
|
||||||
|
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||||
|
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
|
||||||
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
batch_samples = []
|
||||||
# 处理剩余的字符
|
# 处理剩余的字符
|
||||||
if char_pinyin_batch:
|
if char_pinyin_batch:
|
||||||
batch_samples = self._process_batch(
|
batch_samples += self._process_batch(
|
||||||
char_pinyin_batch, char_positions, text
|
char_pinyin_batch, char_positions, text
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -582,6 +581,7 @@ def custom_collate_with_txt(batch):
|
||||||
"char": [item["char"] for item in batch],
|
"char": [item["char"] for item in batch],
|
||||||
"txt": [item["txt"] for item in batch],
|
"txt": [item["txt"] for item in batch],
|
||||||
"py": [item["py"] for item in batch],
|
"py": [item["py"] for item in batch],
|
||||||
|
"pg": torch.cat([item["pg"] for item in batch]),
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -602,71 +602,7 @@ def custom_collate(batch):
|
||||||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||||
},
|
},
|
||||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||||
"py": [item["py"] for item in batch],
|
"pg": torch.cat([item["pg"] for item in batch]),
|
||||||
# "py_group_id": [item["py"] for item in batch],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from query import QueryEngine
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# 初始化查询引擎
|
|
||||||
query_engine = QueryEngine()
|
|
||||||
query_engine.load()
|
|
||||||
|
|
||||||
# 创建数据集
|
|
||||||
dataset = PinyinInputDataset(
|
|
||||||
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
|
||||||
query_engine=query_engine,
|
|
||||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
|
||||||
max_len=88,
|
|
||||||
batch_query_size=300,
|
|
||||||
shuffle=True,
|
|
||||||
shuffle_buffer_size=4000,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("数据集初始化")
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=1024,
|
|
||||||
num_workers=15,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
pin_memory=True if torch.cuda.is_available() else False,
|
|
||||||
collate_fn=custom_collate_with_txt,
|
|
||||||
prefetch_factor=8,
|
|
||||||
persistent_workers=True,
|
|
||||||
shuffle=False, # 数据集内部已实现打乱
|
|
||||||
)
|
|
||||||
|
|
||||||
"""import cProfile
|
|
||||||
|
|
||||||
def profile_func(dataloader):
|
|
||||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
|
||||||
if i >= 3000:
|
|
||||||
break
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
cProfile.run('profile_func(dataloader)')
|
|
||||||
|
|
||||||
"""
|
|
||||||
# 测试数据集
|
|
||||||
try:
|
|
||||||
logger.info("测试数据集")
|
|
||||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
|
||||||
if i >= 3000:
|
|
||||||
break
|
|
||||||
"""
|
|
||||||
print(f"Sample {i+1}:")
|
|
||||||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
|
||||||
print(f" Pinyin: {sample['py']}")
|
|
||||||
print(f" Context length: {len(sample['txt'])}")
|
|
||||||
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
|
||||||
print()
|
|
||||||
"""
|
|
||||||
except StopIteration:
|
|
||||||
print("数据集为空")
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue