from tqdm import tqdm from loguru import logger import torch from torch.utils.data import DataLoader from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate, custom_collate_with_txt from suinput.query import QueryEngine # 使用示例 if __name__ == "__main__": # 初始化查询引擎 query_engine = QueryEngine() query_engine.load() # 创建数据集 dataset = PinyinInputDataset( data_dir="/home/songsenand/Data/corpus/CCI-Data/", query_engine=query_engine, tokenizer_name="iic/nlp_structbert_backbone_tiny_std", max_len=88, batch_query_size=300, shuffle=True, shuffle_buffer_size=4000, ) logger.info("数据集初始化") dataloader = DataLoader( dataset, batch_size=1024, num_workers=1, worker_init_fn=worker_init_fn, pin_memory=True if torch.cuda.is_available() else False, collate_fn=custom_collate_with_txt, prefetch_factor=8, persistent_workers=True, shuffle=False, # 数据集内部已实现打乱 ) """ import cProfile def profile_func(dataloader): for i, sample in tqdm(enumerate(dataloader), total=3000): if i >= 3000: break return cProfile.run('profile_func(dataloader)') """ # 测试数据集 try: logger.info("测试数据集") total = 20 for i, sample in tqdm(enumerate(dataloader), total=20): if i >= total: break print(f"Sample {i+1}: {sample['txt'][0:10]}") """ print(f"Sample {i+1}:") print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}") print(f" Pinyin: {sample['py']}") print(f" Context length: {len(sample['txt'])}") print(f" Hint shape: {sample['hint']['input_ids'].shape}") print() """ except StopIteration: print("数据集为空")