import sys sys.path.append("src") from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.model import InputMethodEngine from model.query import QueryEngine import random import re from importlib.resources import files from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset from model.dataset import PinyinInputDataset def worker_init_fn(worker_id: int) -> None: """ 初始化每个DataLoader worker的随机种子,确保可复现性 Args: worker_id: worker的ID """ worker_seed = torch.initial_seed() % (2**32) np.random.seed(worker_seed) random.seed(worker_seed) def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ 自定义批处理函数,将多个样本组合成一个batch Args: batch: 样本列表,每个样本是一个字典 Returns: 批处理后的字典,tensor字段已stack,字符串字段保持为列表 """ # 处理tensor字段 - 使用squeeze去除多余的batch维度 input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) # 字符串字段保持为列表 prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] return { "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "labels": labels, "history_slot_ids": history_slot_ids, "prefix": prefixes, "suffix": suffixes, "pinyin": pinyins, "pinyin_ids": pinyin_ids, } train_dataset = PinyinInputDataset( data_path="/home/songsenand/Data/corpus/CCI-Data/", max_workers=-1, # 自动选择worker数量 max_iter_length=1000000, text_field="text", py_style_weight=(90, 2, 1), shuffle_buffer_size=20000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) dataloader = DataLoader( train_dataset, batch_size=512, num_workers=2, worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # 减少预取以避免内存问题 persistent_workers=True, ) for i, shape in tqdm(enumerate(dataloader), total=1000000/512): pass