102 lines
2.9 KiB
Python
102 lines
2.9 KiB
Python
import sys
|
||
|
||
sys.path.append("src")
|
||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||
|
||
import time
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
from model.model import InputMethodEngine
|
||
from model.query import QueryEngine
|
||
|
||
import random
|
||
import re
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
from model.dataset import PinyinInputDataset
|
||
|
||
|
||
def worker_init_fn(worker_id: int) -> None:
|
||
"""
|
||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||
|
||
Args:
|
||
worker_id: worker的ID
|
||
"""
|
||
worker_seed = torch.initial_seed() % (2**32)
|
||
np.random.seed(worker_seed)
|
||
random.seed(worker_seed)
|
||
|
||
|
||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""
|
||
自定义批处理函数,将多个样本组合成一个batch
|
||
|
||
Args:
|
||
batch: 样本列表,每个样本是一个字典
|
||
|
||
Returns:
|
||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||
"""
|
||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||
|
||
# 字符串字段保持为列表
|
||
prefixes = [item["prefix"] for item in batch]
|
||
suffixes = [item["suffix"] for item in batch]
|
||
pinyins = [item["pinyin"] for item in batch]
|
||
|
||
return {
|
||
"input_ids": input_ids,
|
||
"token_type_ids": token_type_ids,
|
||
"attention_mask": attention_mask,
|
||
"labels": labels,
|
||
"history_slot_ids": history_slot_ids,
|
||
"prefix": prefixes,
|
||
"suffix": suffixes,
|
||
"pinyin": pinyins,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
|
||
train_dataset = PinyinInputDataset(
|
||
data_path="/home/songsenand/Data/corpus/CCI-Data/",
|
||
max_workers=-1, # 自动选择worker数量
|
||
max_iter_length=1000000,
|
||
text_field="text",
|
||
py_style_weight=(90, 2, 1),
|
||
shuffle_buffer_size=20000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
dataloader = DataLoader(
|
||
train_dataset,
|
||
batch_size=512,
|
||
num_workers=2,
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=2, # 减少预取以避免内存问题
|
||
persistent_workers=True,
|
||
)
|
||
|
||
for i, shape in tqdm(enumerate(dataloader), total=1000000/512):
|
||
pass
|
||
|