SUInput/example.py

69 lines
2.0 KiB
Python

from tqdm import tqdm
from loguru import logger
import torch
from torch.utils.data import DataLoader
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate, custom_collate_with_txt
from suinput.query import QueryEngine
# 使用示例
if __name__ == "__main__":
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=15,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
"""
import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集
try:
logger.info("测试数据集")
total = 3000
for i, sample in tqdm(enumerate(dataloader), total=total):
if i >= total:
break
#print(f"Sample {i+1}: {sample['txt'][0:10]}")
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
print(f" Pinyin: {sample['py']}")
print(f" Context length: {len(sample['txt'])}")
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
print()
"""
except StopIteration:
print("数据集为空")