feat(dataset): 调整拼音输入数据集的采样和处理逻辑以提升效果

This commit is contained in:
songsenand 2026-02-26 14:13:50 +08:00
parent 66c2f78dda
commit dfcce1f1ed
8 changed files with 22 additions and 19 deletions

View File

@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section=[0.90, 0.95, 1],
sample_context_section=[0.95, 0.98, 1],
drop_py_rate: float = 0,
):
"""
@ -133,7 +133,7 @@ class PinyinInputDataset(IterableDataset):
self,
text: str,
start_idx: int,
max_count: int = 3,
max_count: int = 4,
pinyin_list: List[str] = None,
) -> List[Tuple[str, str]]:
"""
@ -142,7 +142,7 @@ class PinyinInputDataset(IterableDataset):
Args:
text (str): 完整的输入文本
start_idx (int): 开始搜索的索引位置
max_count (int, optional): 最多返回的中文字符数量默认为3
max_count (int, optional): 最多返回的中文字符数量默认为 4
pinyin_list (List[str], optional): 预先计算好的拼音列表用于提高效率如果未提供则会动态计算
Returns:
@ -158,7 +158,7 @@ class PinyinInputDataset(IterableDataset):
char = text[i]
# 判断当前字符是否为中文字符
if self.query_engine.is_chinese_char(char):
if self.query_engine.is_chinese_char(char) and (char != pinyin_list[i]):
# 获取拼音信息
try:
# 如果没有提供拼音列表,则动态计算整个文本的拼音
@ -172,8 +172,8 @@ class PinyinInputDataset(IterableDataset):
# 发生异常时终止循环
break
else:
# 当前字符不是中文,跳过
continue
# 当前字符不是中文,退出
break
return result
@ -200,11 +200,11 @@ class PinyinInputDataset(IterableDataset):
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < self.sample_context_section[1]:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
# 方式2: 随机位置取42个连续字符
if context_len <= 42:
return context
start = random.randint(0, context_len - 46)
return context[start : start + 46]
start = random.randint(0, context_len - 42)
return context[start : start + 42] + context[-12:]
else:
# 方式3: 12+6×7组合
if context_len < 12:
@ -253,7 +253,7 @@ class PinyinInputDataset(IterableDataset):
Returns:
截断后的拼音可能为空字符串
"""
if not pinyin:
"""if not pinyin:
return ""
# 随机决定截断方式
@ -275,6 +275,8 @@ class PinyinInputDataset(IterableDataset):
# 随机选择截断长度 (1 到 max_len-1)
trunc_len = random.randint(1, max_len - 1)
return pinyin[:trunc_len]
"""
return pinyin
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
"""
@ -523,9 +525,9 @@ class PinyinInputDataset(IterableDataset):
char = text[i]
py = pinyin_list[i]
# 获取后续最多3个中文字符的拼音
# 获取后续最多4个中文字符的拼音
next_chars = self.get_next_chinese_chars(
text, i, max_count=3, pinyin_list=pinyin_list
text, i, max_count=4, pinyin_list=pinyin_list
)
next_pinyins = [py] + [p for _, p in next_chars]
# 获取前文上下文最多100字符

View File

@ -16,14 +16,13 @@ if __name__ == "__main__":
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/DataSet/data",
data_dir="/root/autodl-tmp/data",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
drop_py_rate=0
)
logger.info("数据集初始化")
dataloader = DataLoader(

View File

@ -73,8 +73,8 @@ class Expert(nn.Module):
self,
input_dim,
d_model=768,
num_resblocks=3,
output_multiplier=1,
num_resblocks=4,
output_multiplier=2,
dropout_prob=0.3,
):
super().__init__()
@ -157,7 +157,9 @@ class MoEModel(nn.Module):
self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4),
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
nn.GELU(),
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
)
def to(self, device):
@ -368,7 +370,7 @@ class MoEModel(nn.Module):
"token_type_ids": encoded["token_type_ids"], # 新增
},
"pg": torch.tensor(
[PG[py[0]]]
[PG[py[0]] if py != "" else 12]
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
}
return sample