import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import torch from torch.utils.data import DataLoader from tqdm import tqdm from model.model import InputMethodEngine from model.query import QueryEngine import random import re from importlib.resources import files from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset from model.dataset import PinyinInputDataset def worker_init_fn(worker_id: int) -> None: """ 初始化每个DataLoader worker的随机种子,确保可复现性 Args: worker_id: worker的ID """ worker_seed = torch.initial_seed() % (2**32) np.random.seed(worker_seed) random.seed(worker_seed) def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ 自定义批处理函数,将多个样本组合成一个batch。 支持动态padding:根据batch内最大序列长度进行padding。 """ input_ids_list = [item["input_ids"] for item in batch] token_type_ids_list = [item["token_type_ids"] for item in batch] attention_mask_list = [item["attention_mask"] for item in batch] target_len = max(ids.shape[0] for ids in input_ids_list) padded_input_ids = [] padded_token_type_ids = [] padded_attention_mask = [] for ids, tt_ids, mask in zip( input_ids_list, token_type_ids_list, attention_mask_list ): seq_len = ids.shape[0] if seq_len < target_len: pad_len = target_len - seq_len padded_input_ids.append( torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)]) ) padded_token_type_ids.append( torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)]) ) padded_attention_mask.append( torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)]) ) elif seq_len > target_len: padded_input_ids.append(ids[:target_len]) padded_token_type_ids.append(tt_ids[:target_len]) padded_attention_mask.append(mask[:target_len]) else: padded_input_ids.append(ids) padded_token_type_ids.append(tt_ids) padded_attention_mask.append(mask) input_ids = torch.stack(padded_input_ids) token_type_ids = torch.stack(padded_token_type_ids) attention_mask = torch.stack(padded_attention_mask) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] return { "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "labels": labels, "history_slot_ids": history_slot_ids, "prefix": prefixes, "suffix": suffixes, "pinyin": pinyins, "pinyin_ids": pinyin_ids, } train_dataset = PinyinInputDataset( data_path="/home/songsenand/Data/corpus/CCI-Data/", max_workers=-1, # 自动选择worker数量 max_iter_length=1000000, text_field="text", py_style_weight=(90, 2, 1), shuffle_buffer_size=20000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) dataloader = DataLoader( train_dataset, batch_size=512, num_workers=16, worker_init_fn=worker_init_fn, collate_fn=collate_fn, prefetch_factor=2, # 减少预取以避免内存问题 persistent_workers=True, ) for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512): pass