调整数据采样逻辑以提升模型训练效果
This commit is contained in:
parent
17324ffa10
commit
917c9f4256
|
|
@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 随机选择采样方式 (各1/3概率)
|
# 随机选择采样方式 (各1/3概率)
|
||||||
choice = random.random()
|
choice = random.random()
|
||||||
|
|
||||||
if choice < 0.333:
|
if choice < 0.85:
|
||||||
# 方式1: 靠近汉字的54个字符
|
# 方式1: 靠近汉字的54个字符
|
||||||
return context[-54:] if context_len >= 54 else context
|
return context[-54:] if context_len >= 54 else context
|
||||||
elif choice < 0.667:
|
elif choice < 0.95:
|
||||||
# 方式2: 随机位置取46个连续字符
|
# 方式2: 随机位置取46个连续字符
|
||||||
if context_len <= 46:
|
if context_len <= 46:
|
||||||
return context
|
return context
|
||||||
start = random.randint(0, context_len - 46)
|
start = random.randint(0, context_len - 46)
|
||||||
return context[start : start + 46]
|
return context[start : start + 46] + context[-8:]
|
||||||
else:
|
else:
|
||||||
# 方式3: 12+6×7组合
|
# 方式3: 12+6×7组合
|
||||||
if context_len < 12:
|
if context_len < 12:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||||
batch_size=1024,
|
batch_size=1024,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
pin_memory=True if torch.cuda.is_available() else False,
|
# pin_memory=True if torch.cuda.is_available() else False,
|
||||||
collate_fn=custom_collate_with_txt,
|
collate_fn=custom_collate_with_txt,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue