调整数据采样逻辑以提升模型训练效果

This commit is contained in:
songsenand 2026-02-21 00:56:08 +08:00
parent 17324ffa10
commit 917c9f4256
7 changed files with 4 additions and 4 deletions

View File

@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
# 随机选择采样方式 (各1/3概率) # 随机选择采样方式 (各1/3概率)
choice = random.random() choice = random.random()
if choice < 0.333: if choice < 0.85:
# 方式1: 靠近汉字的54个字符 # 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context return context[-54:] if context_len >= 54 else context
elif choice < 0.667: elif choice < 0.95:
# 方式2: 随机位置取46个连续字符 # 方式2: 随机位置取46个连续字符
if context_len <= 46: if context_len <= 46:
return context return context
start = random.randint(0, context_len - 46) start = random.randint(0, context_len - 46)
return context[start : start + 46] return context[start : start + 46] + context[-8:]
else: else:
# 方式3: 12+6×7组合 # 方式3: 12+6×7组合
if context_len < 12: if context_len < 12:

View File

@ -31,7 +31,7 @@ if __name__ == "__main__":
batch_size=1024, batch_size=1024,
num_workers=1, num_workers=1,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False, # pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt, collate_fn=custom_collate_with_txt,
prefetch_factor=8, prefetch_factor=8,
persistent_workers=True, persistent_workers=True,