69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
from tqdm import tqdm
|
|
from loguru import logger
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate, custom_collate_with_txt
|
|
from suinput.query import QueryEngine
|
|
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
# 初始化查询引擎
|
|
query_engine = QueryEngine()
|
|
query_engine.load()
|
|
|
|
# 创建数据集
|
|
dataset = PinyinInputDataset(
|
|
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
|
query_engine=query_engine,
|
|
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
|
max_len=88,
|
|
batch_query_size=300,
|
|
shuffle=True,
|
|
shuffle_buffer_size=4000,
|
|
)
|
|
logger.info("数据集初始化")
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=1024,
|
|
num_workers=15,
|
|
worker_init_fn=worker_init_fn,
|
|
pin_memory=True if torch.cuda.is_available() else False,
|
|
collate_fn=custom_collate,
|
|
prefetch_factor=8,
|
|
persistent_workers=True,
|
|
shuffle=False, # 数据集内部已实现打乱
|
|
)
|
|
"""
|
|
import cProfile
|
|
|
|
def profile_func(dataloader):
|
|
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
|
if i >= 3000:
|
|
break
|
|
return
|
|
|
|
|
|
cProfile.run('profile_func(dataloader)')
|
|
|
|
"""
|
|
# 测试数据集
|
|
try:
|
|
logger.info("测试数据集")
|
|
total = 3000
|
|
for i, sample in tqdm(enumerate(dataloader), total=total):
|
|
if i >= total:
|
|
break
|
|
#print(f"Sample {i+1}: {sample['txt'][0:10]}")
|
|
"""
|
|
print(f"Sample {i+1}:")
|
|
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
|
print(f" Pinyin: {sample['py']}")
|
|
print(f" Context length: {len(sample['txt'])}")
|
|
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
|
print()
|
|
"""
|
|
except StopIteration:
|
|
print("数据集为空")
|