diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 43930cb..bb51a1d 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 - sample_context_section=[0.90, 0.95, 1], + sample_context_section=[0.95, 0.98, 1], drop_py_rate: float = 0, ): """ @@ -133,7 +133,7 @@ class PinyinInputDataset(IterableDataset): self, text: str, start_idx: int, - max_count: int = 3, + max_count: int = 4, pinyin_list: List[str] = None, ) -> List[Tuple[str, str]]: """ @@ -142,7 +142,7 @@ class PinyinInputDataset(IterableDataset): Args: text (str): 完整的输入文本。 start_idx (int): 开始搜索的索引位置。 - max_count (int, optional): 最多返回的中文字符数量,默认为3。 + max_count (int, optional): 最多返回的中文字符数量,默认为 4。 pinyin_list (List[str], optional): 预先计算好的拼音列表,用于提高效率。如果未提供,则会动态计算。 Returns: @@ -158,7 +158,7 @@ class PinyinInputDataset(IterableDataset): char = text[i] # 判断当前字符是否为中文字符 - if self.query_engine.is_chinese_char(char): + if self.query_engine.is_chinese_char(char) and (char != pinyin_list[i]): # 获取拼音信息 try: # 如果没有提供拼音列表,则动态计算整个文本的拼音 @@ -172,8 +172,8 @@ class PinyinInputDataset(IterableDataset): # 发生异常时终止循环 break else: - # 当前字符不是中文,跳过 - continue + # 当前字符不是中文,退出 + break return result @@ -200,11 +200,11 @@ class PinyinInputDataset(IterableDataset): # 方式1: 靠近汉字的54个字符 return context[-54:] if context_len >= 54 else context elif choice < self.sample_context_section[1]: - # 方式2: 随机位置取46个连续字符 - if context_len <= 46: + # 方式2: 随机位置取42个连续字符 + if context_len <= 42: return context - start = random.randint(0, context_len - 46) - return context[start : start + 46] + start = random.randint(0, context_len - 42) + return context[start : start + 42] + context[-12:] else: # 方式3: 12+6×7组合 if context_len < 12: @@ -253,7 +253,7 @@ class PinyinInputDataset(IterableDataset): Returns: 截断后的拼音,可能为空字符串 """ - if not pinyin: + """if not pinyin: return "" # 随机决定截断方式 @@ -275,6 +275,8 @@ class PinyinInputDataset(IterableDataset): # 随机选择截断长度 (1 到 max_len-1) trunc_len = random.randint(1, max_len - 1) return pinyin[:trunc_len] + """ + return pinyin def process_pinyin_sequence(self, pinyin_list: List[str]) -> str: """ @@ -523,9 +525,9 @@ class PinyinInputDataset(IterableDataset): char = text[i] py = pinyin_list[i] - # 获取后续最多3个中文字符的拼音 + # 获取后续最多4个中文字符的拼音 next_chars = self.get_next_chinese_chars( - text, i, max_count=3, pinyin_list=pinyin_list + text, i, max_count=4, pinyin_list=pinyin_list ) next_pinyins = [py] + [p for _, p in next_chars] # 获取前文上下文(最多100字符) diff --git a/src/tmp_utils/gen_eval_dataset.py b/src/tmp_utils/gen_eval_dataset.py index 80486fb..834b370 100644 --- a/src/tmp_utils/gen_eval_dataset.py +++ b/src/tmp_utils/gen_eval_dataset.py @@ -16,14 +16,13 @@ if __name__ == "__main__": # 创建数据集 dataset = PinyinInputDataset( - data_dir="/home/songsenand/DataSet/data", + data_dir="/root/autodl-tmp/data", query_engine=query_engine, tokenizer_name="iic/nlp_structbert_backbone_lite_std", max_len=88, batch_query_size=300, shuffle=True, shuffle_buffer_size=4000, - drop_py_rate=0 ) logger.info("数据集初始化") dataloader = DataLoader( diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl index d6fc7d5..379539a 100644 Binary files a/src/trainer/eval_dataset/sample_0.pkl and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl index 09231d0..47b037e 100644 Binary files a/src/trainer/eval_dataset/sample_1.pkl and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl index 666872a..bedd1bf 100644 Binary files a/src/trainer/eval_dataset/sample_2.pkl and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl index 88a7bde..a9b1df3 100644 Binary files a/src/trainer/eval_dataset/sample_3.pkl and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl index 0aec844..b999ccb 100644 Binary files a/src/trainer/eval_dataset/sample_4.pkl and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index 829ce68..933d45c 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -73,8 +73,8 @@ class Expert(nn.Module): self, input_dim, d_model=768, - num_resblocks=3, - output_multiplier=1, + num_resblocks=4, + output_multiplier=2, dropout_prob=0.3, ): super().__init__() @@ -157,7 +157,9 @@ class MoEModel(nn.Module): self.classifier = nn.Sequential( nn.LayerNorm(self.hidden_size * self.output_multiplier), nn.Dropout(0.4), - nn.Linear(self.hidden_size * self.output_multiplier, num_classes), + nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2), + nn.GELU(), + nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes), ) def to(self, device): @@ -368,7 +370,7 @@ class MoEModel(nn.Module): "token_type_ids": encoded["token_type_ids"], # 新增 }, "pg": torch.tensor( - [PG[py[0]]] + [PG[py[0]] if py != "" else 12] ), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进) } return sample