diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index a71f8fb..0b86f4e 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,9 +1,6 @@ -import json import os import random -from importlib.resources import files -from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Optional import numpy as np import torch @@ -39,7 +36,7 @@ class PinyinInputDataset(IterableDataset): batch_query_size: int = 1000, # 打乱参数 shuffle: bool = True, - shuffle_buffer_size: int = 100, + shuffle_buffer_size: int = 10000, # 削峰填谷参数 max_freq: int = 434748359, # "的"的频率 min_freq: int = 109, # "蓚"的频率 @@ -47,7 +44,6 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 - py_group_json_file: Optional[Dict[str, int]] = None, ): """ 初始化数据集 @@ -415,7 +411,6 @@ class PinyinInputDataset(IterableDataset): if not char_info: continue - logger.info(f"获取字符信息: {char_info}") # 削峰填谷调整 adjust_factor = self.adjust_frequency(char_info["freq"]) if adjust_factor <= 0: @@ -446,7 +441,7 @@ class PinyinInputDataset(IterableDataset): "char": char, "freq": char_info["freq"], "pg": torch.tensor( - self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 + [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] ), } @@ -485,7 +480,8 @@ class PinyinInputDataset(IterableDataset): seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - + + batch_samples = [] for item in self.dataset: text = item.get(self.text_field, "") if not text: @@ -531,18 +527,21 @@ class PinyinInputDataset(IterableDataset): # 达到批量大小时处理 if len(char_pinyin_batch) >= self.batch_query_size: - batch_samples = self._process_batch( + batch_samples += self._process_batch( char_pinyin_batch, char_positions, text ) - yield from self._shuffle_and_yield(batch_samples) char_pinyin_batch = [] char_positions = [] + if len(batch_samples) >= self.shuffle_buffer_size: + # logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}") + yield from self._shuffle_and_yield(batch_samples) + batch_samples = [] # 处理剩余的字符 if char_pinyin_batch: - batch_samples = self._process_batch( + batch_samples += self._process_batch( char_pinyin_batch, char_positions, text ) - yield from self._shuffle_and_yield(batch_samples) + yield from self._shuffle_and_yield(batch_samples) def __len__(self): """ @@ -582,6 +581,7 @@ def custom_collate_with_txt(batch): "char": [item["char"] for item in batch], "txt": [item["txt"] for item in batch], "py": [item["py"] for item in batch], + "pg": torch.cat([item["pg"] for item in batch]), } return result @@ -602,71 +602,7 @@ def custom_collate(batch): "attention_mask": torch.cat([h["attention_mask"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), - "py": [item["py"] for item in batch], - # "py_group_id": [item["py"] for item in batch], + "pg": torch.cat([item["pg"] for item in batch]), } return result - - -# 使用示例 -if __name__ == "__main__": - from query import QueryEngine - from tqdm import tqdm - - # 初始化查询引擎 - query_engine = QueryEngine() - query_engine.load() - - # 创建数据集 - dataset = PinyinInputDataset( - data_dir="/home/songsenand/Data/corpus/CCI-Data/", - query_engine=query_engine, - tokenizer_name="iic/nlp_structbert_backbone_tiny_std", - max_len=88, - batch_query_size=300, - shuffle=True, - shuffle_buffer_size=4000, - ) - - logger.info("数据集初始化") - dataloader = DataLoader( - dataset, - batch_size=1024, - num_workers=15, - worker_init_fn=worker_init_fn, - pin_memory=True if torch.cuda.is_available() else False, - collate_fn=custom_collate_with_txt, - prefetch_factor=8, - persistent_workers=True, - shuffle=False, # 数据集内部已实现打乱 - ) - - """import cProfile - - def profile_func(dataloader): - for i, sample in tqdm(enumerate(dataloader), total=3000): - if i >= 3000: - break - return - - - cProfile.run('profile_func(dataloader)') - - """ - # 测试数据集 - try: - logger.info("测试数据集") - for i, sample in tqdm(enumerate(dataloader), total=3000): - if i >= 3000: - break - """ - print(f"Sample {i+1}:") - print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}") - print(f" Pinyin: {sample['py']}") - print(f" Context length: {len(sample['txt'])}") - print(f" Hint shape: {sample['hint']['input_ids'].shape}") - print() - """ - except StopIteration: - print("数据集为空")