feat(dataset): 调整拼音输入数据集的采样和处理逻辑以提升效果
This commit is contained in:
parent
66c2f78dda
commit
dfcce1f1ed
|
|
@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
sample_context_section=[0.90, 0.95, 1],
|
sample_context_section=[0.95, 0.98, 1],
|
||||||
drop_py_rate: float = 0,
|
drop_py_rate: float = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -133,7 +133,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
start_idx: int,
|
start_idx: int,
|
||||||
max_count: int = 3,
|
max_count: int = 4,
|
||||||
pinyin_list: List[str] = None,
|
pinyin_list: List[str] = None,
|
||||||
) -> List[Tuple[str, str]]:
|
) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -142,7 +142,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
Args:
|
Args:
|
||||||
text (str): 完整的输入文本。
|
text (str): 完整的输入文本。
|
||||||
start_idx (int): 开始搜索的索引位置。
|
start_idx (int): 开始搜索的索引位置。
|
||||||
max_count (int, optional): 最多返回的中文字符数量,默认为3。
|
max_count (int, optional): 最多返回的中文字符数量,默认为 4。
|
||||||
pinyin_list (List[str], optional): 预先计算好的拼音列表,用于提高效率。如果未提供,则会动态计算。
|
pinyin_list (List[str], optional): 预先计算好的拼音列表,用于提高效率。如果未提供,则会动态计算。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -158,7 +158,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
char = text[i]
|
char = text[i]
|
||||||
# 判断当前字符是否为中文字符
|
# 判断当前字符是否为中文字符
|
||||||
if self.query_engine.is_chinese_char(char):
|
if self.query_engine.is_chinese_char(char) and (char != pinyin_list[i]):
|
||||||
# 获取拼音信息
|
# 获取拼音信息
|
||||||
try:
|
try:
|
||||||
# 如果没有提供拼音列表,则动态计算整个文本的拼音
|
# 如果没有提供拼音列表,则动态计算整个文本的拼音
|
||||||
|
|
@ -172,8 +172,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 发生异常时终止循环
|
# 发生异常时终止循环
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# 当前字符不是中文,跳过
|
# 当前字符不是中文,退出
|
||||||
continue
|
break
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -200,11 +200,11 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 方式1: 靠近汉字的54个字符
|
# 方式1: 靠近汉字的54个字符
|
||||||
return context[-54:] if context_len >= 54 else context
|
return context[-54:] if context_len >= 54 else context
|
||||||
elif choice < self.sample_context_section[1]:
|
elif choice < self.sample_context_section[1]:
|
||||||
# 方式2: 随机位置取46个连续字符
|
# 方式2: 随机位置取42个连续字符
|
||||||
if context_len <= 46:
|
if context_len <= 42:
|
||||||
return context
|
return context
|
||||||
start = random.randint(0, context_len - 46)
|
start = random.randint(0, context_len - 42)
|
||||||
return context[start : start + 46]
|
return context[start : start + 42] + context[-12:]
|
||||||
else:
|
else:
|
||||||
# 方式3: 12+6×7组合
|
# 方式3: 12+6×7组合
|
||||||
if context_len < 12:
|
if context_len < 12:
|
||||||
|
|
@ -253,7 +253,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
Returns:
|
Returns:
|
||||||
截断后的拼音,可能为空字符串
|
截断后的拼音,可能为空字符串
|
||||||
"""
|
"""
|
||||||
if not pinyin:
|
"""if not pinyin:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 随机决定截断方式
|
# 随机决定截断方式
|
||||||
|
|
@ -275,6 +275,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 随机选择截断长度 (1 到 max_len-1)
|
# 随机选择截断长度 (1 到 max_len-1)
|
||||||
trunc_len = random.randint(1, max_len - 1)
|
trunc_len = random.randint(1, max_len - 1)
|
||||||
return pinyin[:trunc_len]
|
return pinyin[:trunc_len]
|
||||||
|
"""
|
||||||
|
return pinyin
|
||||||
|
|
||||||
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -523,9 +525,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
char = text[i]
|
char = text[i]
|
||||||
py = pinyin_list[i]
|
py = pinyin_list[i]
|
||||||
|
|
||||||
# 获取后续最多3个中文字符的拼音
|
# 获取后续最多4个中文字符的拼音
|
||||||
next_chars = self.get_next_chinese_chars(
|
next_chars = self.get_next_chinese_chars(
|
||||||
text, i, max_count=3, pinyin_list=pinyin_list
|
text, i, max_count=4, pinyin_list=pinyin_list
|
||||||
)
|
)
|
||||||
next_pinyins = [py] + [p for _, p in next_chars]
|
next_pinyins = [py] + [p for _, p in next_chars]
|
||||||
# 获取前文上下文(最多100字符)
|
# 获取前文上下文(最多100字符)
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# 创建数据集
|
# 创建数据集
|
||||||
dataset = PinyinInputDataset(
|
dataset = PinyinInputDataset(
|
||||||
data_dir="/home/songsenand/DataSet/data",
|
data_dir="/root/autodl-tmp/data",
|
||||||
query_engine=query_engine,
|
query_engine=query_engine,
|
||||||
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
|
tokenizer_name="iic/nlp_structbert_backbone_lite_std",
|
||||||
max_len=88,
|
max_len=88,
|
||||||
batch_query_size=300,
|
batch_query_size=300,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
shuffle_buffer_size=4000,
|
shuffle_buffer_size=4000,
|
||||||
drop_py_rate=0
|
|
||||||
)
|
)
|
||||||
logger.info("数据集初始化")
|
logger.info("数据集初始化")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -73,8 +73,8 @@ class Expert(nn.Module):
|
||||||
self,
|
self,
|
||||||
input_dim,
|
input_dim,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
num_resblocks=3,
|
num_resblocks=4,
|
||||||
output_multiplier=1,
|
output_multiplier=2,
|
||||||
dropout_prob=0.3,
|
dropout_prob=0.3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -157,7 +157,9 @@ class MoEModel(nn.Module):
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||||||
nn.Dropout(0.4),
|
nn.Dropout(0.4),
|
||||||
nn.Linear(self.hidden_size * self.output_multiplier, num_classes),
|
nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
|
|
@ -368,7 +370,7 @@ class MoEModel(nn.Module):
|
||||||
"token_type_ids": encoded["token_type_ids"], # 新增
|
"token_type_ids": encoded["token_type_ids"], # 新增
|
||||||
},
|
},
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor(
|
||||||
[PG[py[0]]]
|
[PG[py[0]] if py != "" else 12]
|
||||||
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
|
), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进)
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue