调整数据采样逻辑以提升模型训练效果
This commit is contained in:
parent
17324ffa10
commit
917c9f4256
|
|
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 随机选择采样方式 (各1/3概率)
|
||||
choice = random.random()
|
||||
|
||||
if choice < 0.333:
|
||||
if choice < 0.85:
|
||||
# 方式1: 靠近汉字的54个字符
|
||||
return context[-54:] if context_len >= 54 else context
|
||||
elif choice < 0.667:
|
||||
elif choice < 0.95:
|
||||
# 方式2: 随机位置取46个连续字符
|
||||
if context_len <= 46:
|
||||
return context
|
||||
start = random.randint(0, context_len - 46)
|
||||
return context[start : start + 46]
|
||||
return context[start : start + 46] + context[-8:]
|
||||
else:
|
||||
# 方式3: 12+6×7组合
|
||||
if context_len < 12:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||
batch_size=1024,
|
||||
num_workers=1,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True if torch.cuda.is_available() else False,
|
||||
# pin_memory=True if torch.cuda.is_available() else False,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue