128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
import os
|
||
import sys
|
||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||
|
||
import time
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.model import InputMethodEngine
|
||
from model.query import QueryEngine
|
||
|
||
import random
|
||
import re
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
from model.dataset import PinyinInputDataset
|
||
|
||
|
||
def worker_init_fn(worker_id: int) -> None:
|
||
"""
|
||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||
|
||
Args:
|
||
worker_id: worker的ID
|
||
"""
|
||
worker_seed = torch.initial_seed() % (2**32)
|
||
np.random.seed(worker_seed)
|
||
random.seed(worker_seed)
|
||
|
||
|
||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""
|
||
自定义批处理函数,将多个样本组合成一个batch。
|
||
支持动态padding:根据batch内最大序列长度进行padding。
|
||
"""
|
||
input_ids_list = [item["input_ids"] for item in batch]
|
||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||
attention_mask_list = [item["attention_mask"] for item in batch]
|
||
|
||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||
|
||
padded_input_ids = []
|
||
padded_token_type_ids = []
|
||
padded_attention_mask = []
|
||
for ids, tt_ids, mask in zip(
|
||
input_ids_list, token_type_ids_list, attention_mask_list
|
||
):
|
||
seq_len = ids.shape[0]
|
||
if seq_len < target_len:
|
||
pad_len = target_len - seq_len
|
||
padded_input_ids.append(
|
||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||
)
|
||
padded_token_type_ids.append(
|
||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||
)
|
||
padded_attention_mask.append(
|
||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||
)
|
||
elif seq_len > target_len:
|
||
padded_input_ids.append(ids[:target_len])
|
||
padded_token_type_ids.append(tt_ids[:target_len])
|
||
padded_attention_mask.append(mask[:target_len])
|
||
else:
|
||
padded_input_ids.append(ids)
|
||
padded_token_type_ids.append(tt_ids)
|
||
padded_attention_mask.append(mask)
|
||
|
||
input_ids = torch.stack(padded_input_ids)
|
||
token_type_ids = torch.stack(padded_token_type_ids)
|
||
attention_mask = torch.stack(padded_attention_mask)
|
||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||
|
||
prefixes = [item["prefix"] for item in batch]
|
||
suffixes = [item["suffix"] for item in batch]
|
||
pinyins = [item["pinyin"] for item in batch]
|
||
|
||
return {
|
||
"input_ids": input_ids,
|
||
"token_type_ids": token_type_ids,
|
||
"attention_mask": attention_mask,
|
||
"labels": labels,
|
||
"history_slot_ids": history_slot_ids,
|
||
"prefix": prefixes,
|
||
"suffix": suffixes,
|
||
"pinyin": pinyins,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
|
||
train_dataset = PinyinInputDataset(
|
||
data_path="/home/songsenand/Data/corpus/CCI-Data/",
|
||
max_workers=-1, # 自动选择worker数量
|
||
max_iter_length=1000000,
|
||
text_field="text",
|
||
py_style_weight=(90, 2, 1),
|
||
shuffle_buffer_size=20000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
dataloader = DataLoader(
|
||
train_dataset,
|
||
batch_size=512,
|
||
num_workers=16,
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=2, # 减少预取以避免内存问题
|
||
persistent_workers=True,
|
||
)
|
||
|
||
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
|
||
pass
|