feat(dataset): 调整拼音输入数据集的采样和处理逻辑以提升效果

This commit is contained in:
songsenand 2026-02-26 14:13:50 +08:00
parent 66c2f78dda
commit dfcce1f1ed
8 changed files with 22 additions and 19 deletions

View File

@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率 max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望 max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section=[0.90, 0.95, 1], sample_context_section=[0.95, 0.98, 1],
drop_py_rate: float = 0, drop_py_rate: float = 0,
): ):
""" """
@ -133,7 +133,7 @@ class PinyinInputDataset(IterableDataset):
self, self,
text: str, text: str,
start_idx: int, start_idx: int,
max_count: int = 3, max_count: int = 4,
pinyin_list: List[str] = None, pinyin_list: List[str] = None,
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
""" """
@ -142,7 +142,7 @@ class PinyinInputDataset(IterableDataset):
Args: Args:
text (str): 完整的输入文本 text (str): 完整的输入文本
start_idx (int): 开始搜索的索引位置 start_idx (int): 开始搜索的索引位置
max_count (int, optional): 最多返回的中文字符数量默认为3 max_count (int, optional): 最多返回的中文字符数量默认为 4
pinyin_list (List[str], optional): 预先计算好的拼音列表用于提高效率如果未提供则会动态计算 pinyin_list (List[str], optional): 预先计算好的拼音列表用于提高效率如果未提供则会动态计算
Returns: Returns:
@ -158,7 +158,7 @@ class PinyinInputDataset(IterableDataset):
char = text[i] char = text[i]
# 判断当前字符是否为中文字符 # 判断当前字符是否为中文字符
if self.query_engine.is_chinese_char(char): if self.query_engine.is_chinese_char(char) and (char != pinyin_list[i]):
# 获取拼音信息 # 获取拼音信息
try: try:
# 如果没有提供拼音列表,则动态计算整个文本的拼音 # 如果没有提供拼音列表,则动态计算整个文本的拼音
@ -172,8 +172,8 @@ class PinyinInputDataset(IterableDataset):
# 发生异常时终止循环 # 发生异常时终止循环
break break
else: else:
# 当前字符不是中文,跳过 # 当前字符不是中文,退出
continue break
return result return result
@ -200,11 +200,11 @@ class PinyinInputDataset(IterableDataset):
# 方式1: 靠近汉字的54个字符 # 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context return context[-54:] if context_len >= 54 else context
elif choice < self.sample_context_section[1]: elif choice < self.sample_context_section[1]:
# 方式2: 随机位置取46个连续字符 # 方式2: 随机位置取42个连续字符
if context_len <= 46: if context_len <= 42:
return context return context
start = random.randint(0, context_len - 46) start = random.randint(0, context_len - 42)
return context[start : start + 46] return context[start : start + 42] + context[-12:]
else: else:
# 方式3: 12+6×7组合 # 方式3: 12+6×7组合
if context_len < 12: if context_len < 12:
@ -253,7 +253,7 @@ class PinyinInputDataset(IterableDataset):
Returns: Returns:
截断后的拼音可能为空字符串 截断后的拼音可能为空字符串
""" """
if not pinyin: """if not pinyin:
return "" return ""
# 随机决定截断方式 # 随机决定截断方式
@ -275,6 +275,8 @@ class PinyinInputDataset(IterableDataset):
# 随机选择截断长度 (1 到 max_len-1) # 随机选择截断长度 (1 到 max_len-1)
trunc_len = random.randint(1, max_len - 1) trunc_len = random.randint(1, max_len - 1)
return pinyin[:trunc_len] return pinyin[:trunc_len]
"""
return pinyin
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str: def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
""" """
@ -523,9 +525,9 @@ class PinyinInputDataset(IterableDataset):
char = text[i] char = text[i]
py = pinyin_list[i] py = pinyin_list[i]
# 获取后续最多3个中文字符的拼音 # 获取后续最多4个中文字符的拼音
next_chars = self.get_next_chinese_chars( next_chars = self.get_next_chinese_chars(
text, i, max_count=3, pinyin_list=pinyin_list text, i, max_count=4, pinyin_list=pinyin_list
) )
next_pinyins = [py] + [p for _, p in next_chars] next_pinyins = [py] + [p for _, p in next_chars]
# 获取前文上下文最多100字符 # 获取前文上下文最多100字符

View File

@ -16,14 +16,13 @@ if __name__ == "__main__":
# 创建数据集 # 创建数据集
dataset = PinyinInputDataset( dataset = PinyinInputDataset(
data_dir="/home/songsenand/DataSet/data", data_dir="/root/autodl-tmp/data",
query_engine=query_engine, query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_lite_std", tokenizer_name="iic/nlp_structbert_backbone_lite_std",
max_len=88, max_len=88,
batch_query_size=300, batch_query_size=300,
shuffle=True, shuffle=True,
shuffle_buffer_size=4000, shuffle_buffer_size=4000,
drop_py_rate=0
) )
logger.info("数据集初始化") logger.info("数据集初始化")
dataloader = DataLoader( dataloader = DataLoader(

View File

@ -73,8 +73,8 @@ class Expert(nn.Module):
self, self,
input_dim, input_dim,
d_model=768, d_model=768,
num_resblocks=3, num_resblocks=4,
output_multiplier=1, output_multiplier=2,
dropout_prob=0.3, dropout_prob=0.3,
): ):
super().__init__() super().__init__()
@ -157,7 +157,9 @@ class MoEModel(nn.Module):
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_size * self.output_multiplier), nn.LayerNorm(self.hidden_size * self.output_multiplier),
nn.Dropout(0.4), nn.Dropout(0.4),
nn.Linear(self.hidden_size * self.output_multiplier, num_classes), nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
nn.GELU(),
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
) )
def to(self, device): def to(self, device):
@ -368,7 +370,7 @@ class MoEModel(nn.Module):
"token_type_ids": encoded["token_type_ids"], # 新增 "token_type_ids": encoded["token_type_ids"], # 新增
}, },
"pg": torch.tensor( "pg": torch.tensor(
[PG[py[0]]] [PG[py[0]] if py != "" else 12]
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进) ), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
} }
return sample return sample