SUimeModelTraner/tests/test_dataset.py

128 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.model import InputMethodEngine
from model.query import QueryEngine
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
from model.dataset import PinyinInputDataset
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch。
支持动态padding根据batch内最大序列长度进行padding。
"""
input_ids_list = [item["input_ids"] for item in batch]
token_type_ids_list = [item["token_type_ids"] for item in batch]
attention_mask_list = [item["attention_mask"] for item in batch]
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
train_dataset = PinyinInputDataset(
data_path="/home/songsenand/Data/corpus/CCI-Data/",
max_workers=-1, # 自动选择worker数量
max_iter_length=1000000,
text_field="text",
py_style_weight=(90, 2, 1),
shuffle_buffer_size=20000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
dataloader = DataLoader(
train_dataset,
batch_size=512,
num_workers=16,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
)
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
pass