调整数据采样逻辑以提升模型训练效果

This commit is contained in:
songsenand 2026-02-21 00:56:08 +08:00
parent 17324ffa10
commit 917c9f4256
7 changed files with 4 additions and 4 deletions

View File

@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
# 随机选择采样方式 (各1/3概率)
choice = random.random()
if choice < 0.333:
if choice < 0.85:
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < 0.667:
elif choice < 0.95:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
return context
start = random.randint(0, context_len - 46)
return context[start : start + 46]
return context[start : start + 46] + context[-8:]
else:
# 方式3: 12+6×7组合
if context_len < 12:

View File

@ -31,7 +31,7 @@ if __name__ == "__main__":
batch_size=1024,
num_workers=1,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
# pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,