feat(dataset): 调整拼音输入数据集的采样和处理逻辑以提升效果
This commit is contained in:
parent
66c2f78dda
commit
dfcce1f1ed
|
|
@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
sample_context_section=[0.90, 0.95, 1],
|
||||
sample_context_section=[0.95, 0.98, 1],
|
||||
drop_py_rate: float = 0,
|
||||
):
|
||||
"""
|
||||
|
|
@ -133,7 +133,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
self,
|
||||
text: str,
|
||||
start_idx: int,
|
||||
max_count: int = 3,
|
||||
max_count: int = 4,
|
||||
pinyin_list: List[str] = None,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
|
|
@ -142,7 +142,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
Args:
|
||||
text (str): 完整的输入文本。
|
||||
start_idx (int): 开始搜索的索引位置。
|
||||
max_count (int, optional): 最多返回的中文字符数量,默认为3。
|
||||
max_count (int, optional): 最多返回的中文字符数量,默认为 4。
|
||||
pinyin_list (List[str], optional): 预先计算好的拼音列表,用于提高效率。如果未提供,则会动态计算。
|
||||
|
||||
Returns:
|
||||
|
|
@ -158,7 +158,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
char = text[i]
|
||||
# 判断当前字符是否为中文字符
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
if self.query_engine.is_chinese_char(char) and (char != pinyin_list[i]):
|
||||
# 获取拼音信息
|
||||
try:
|
||||
# 如果没有提供拼音列表,则动态计算整个文本的拼音
|
||||
|
|
@ -172,8 +172,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 发生异常时终止循环
|
||||
break
|
||||
else:
|
||||
# 当前字符不是中文,跳过
|
||||
continue
|
||||
# 当前字符不是中文,退出
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -200,11 +200,11 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 方式1: 靠近汉字的54个字符
|
||||
return context[-54:] if context_len >= 54 else context
|
||||
elif choice < self.sample_context_section[1]:
|
||||
# 方式2: 随机位置取46个连续字符
|
||||
if context_len <= 46:
|
||||
# 方式2: 随机位置取42个连续字符
|
||||
if context_len <= 42:
|
||||
return context
|
||||
start = random.randint(0, context_len - 46)
|
||||
return context[start : start + 46]
|
||||
start = random.randint(0, context_len - 42)
|
||||
return context[start : start + 42] + context[-12:]
|
||||
else:
|
||||
# 方式3: 12+6×7组合
|
||||
if context_len < 12:
|
||||
|
|
@ -253,7 +253,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
Returns:
|
||||
截断后的拼音,可能为空字符串
|
||||
"""
|
||||
if not pinyin:
|
||||
"""if not pinyin:
|
||||
return ""
|
||||
|
||||
# 随机决定截断方式
|
||||
|
|
@ -275,6 +275,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 随机选择截断长度 (1 到 max_len-1)
|
||||
trunc_len = random.randint(1, max_len - 1)
|
||||
return pinyin[:trunc_len]
|
||||
"""
|
||||
return pinyin
|
||||
|
||||
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
||||
"""
|
||||
|
|
@ -523,9 +525,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
char = text[i]
|
||||
py = pinyin_list[i]
|
||||
|
||||
# 获取后续最多3个中文字符的拼音
|
||||
# 获取后续最多4个中文字符的拼音
|
||||
next_chars = self.get_next_chinese_chars(
|
||||
text, i, max_count=3, pinyin_list=pinyin_list
|
||||
text, i, max_count=4, pinyin_list=pinyin_list
|
||||
)
|
||||
next_pinyins = [py] + [p for _, p in next_chars]
|
||||
# 获取前文上下文(最多100字符)
|
||||
|
|
|
|||
|
|
@ -16,14 +16,13 @@ if __name__ == "__main__":
|
|||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
data_dir="/home/songsenand/DataSet/data",
|
||||
data_dir="/root/autodl-tmp/data",
|
||||
query_engine=query_engine,
|
||||
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
|
||||
max_len=88,
|
||||
batch_query_size=300,
|
||||
shuffle=True,
|
||||
shuffle_buffer_size=4000,
|
||||
drop_py_rate=0
|
||||
)
|
||||
logger.info("数据集初始化")
|
||||
dataloader = DataLoader(
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -73,8 +73,8 @@ class Expert(nn.Module):
|
|||
self,
|
||||
input_dim,
|
||||
d_model=768,
|
||||
num_resblocks=3,
|
||||
output_multiplier=1,
|
||||
num_resblocks=4,
|
||||
output_multiplier=2,
|
||||
dropout_prob=0.3,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
@ -157,7 +157,9 @@ class MoEModel(nn.Module):
|
|||
self.classifier = nn.Sequential(
|
||||
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
|
||||
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
||||
)
|
||||
|
||||
def to(self, device):
|
||||
|
|
@ -368,7 +370,7 @@ class MoEModel(nn.Module):
|
|||
"token_type_ids": encoded["token_type_ids"], # 新增
|
||||
},
|
||||
"pg": torch.tensor(
|
||||
[PG[py[0]]]
|
||||
[PG[py[0]] if py != "" else 12]
|
||||
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
|
||||
}
|
||||
return sample
|
||||
|
|
|
|||
Loading…
Reference in New Issue