diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index f1921b5..5bf37a4 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -13,30 +13,30 @@ from torch.utils.data import DataLoader, IterableDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" PG = { - "y": 0, - "k": 0, - "e": 0, - "l": 1, - "w": 1, - "f": 1, - "q": 2, - "a": 2, - "s": 2, - "x": 3, - "b": 3, - "r": 3, - "o": 4, - "m": 4, - "z": 4, - "g": 5, - "n": 5, - "c": 5, - "t": 6, - "p": 6, - "d": 6, - "j": 7, - "h": 7, - } + "r": 0, + "l": 0, + "p": 1, + "d": 1, + "h": 2, + "f": 2, + "g": 3, + "m": 3, + "z": 4, + "o": 4, + "t": 5, + "q": 5, + "b": 6, + "w": 6, + "j": 7, + "e": 7, + "k": 8, + "c": 8, + "s": 9, + "a": 9, + "n": 10, + "x": 10, + "y": 11, +} class PinyinInputDataset(IterableDataset): @@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 - sample_context_section = [0.90, 0.95, 1] + sample_context_section=[0.90, 0.95, 1], ): """ 初始化数据集 @@ -127,8 +127,6 @@ class PinyinInputDataset(IterableDataset): # 上下文采样方式概率区间 self.sample_context_section = sample_context_section - - def get_next_chinese_chars( self, text: str, @@ -441,7 +439,7 @@ class PinyinInputDataset(IterableDataset): ) prob = random.random() - pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 + pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12 if prob < 0.1: py = "" else: @@ -455,9 +453,7 @@ class PinyinInputDataset(IterableDataset): "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], - "pg": torch.tensor( - [pg] - ), + "pg": torch.tensor([pg]), } # 根据调整因子重复样本 diff --git a/src/tmp_utils/group.py b/src/tmp_utils/group.py index 7549ec2..0f569ea 100644 --- a/src/tmp_utils/group.py +++ b/src/tmp_utils/group.py @@ -149,7 +149,7 @@ def simulated_annealing_grouping(letters, num_groups=8, iterations=10000): # 运行模拟退火算法 -best_groups, sums_c1, sums_c2, best_energy = simulated_annealing_grouping(list(c1.keys())) +best_groups, sums_c1, sums_c2, best_energy = simulated_annealing_grouping(list(c1.keys()), 12) print("模拟退火算法分组结果:") for i, (group, sum1, sum2) in enumerate(zip(best_groups, sums_c1, sums_c2)): @@ -163,5 +163,4 @@ print(f"c1各组总和变异系数: {np.std(sums_c1) / np.mean(sums_c1):.4f}") print(f"c2各组总和变异系数: {np.std(sums_c2) / np.mean(sums_c2):.4f}") -with open("pinyin_group.json", "w") as f: - json.dump({letter: i for i, sub_item in enumerate(best_groups) for letter in sub_item}, f, indent=4) +print({letter: i for i, sub_item in enumerate(best_groups) for letter in sub_item}) diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl index 036ff3d..4069112 100644 Binary files a/src/trainer/eval_dataset/sample_0.pkl and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl index cf2c7e3..469cd31 100644 Binary files a/src/trainer/eval_dataset/sample_1.pkl and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl index a9da721..640bac7 100644 Binary files a/src/trainer/eval_dataset/sample_2.pkl and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl index 9964f8e..3cc5b31 100644 Binary files a/src/trainer/eval_dataset/sample_3.pkl and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl index bda53a5..785ab66 100644 Binary files a/src/trainer/eval_dataset/sample_4.pkl and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index 9d0371f..47d7a7b 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -96,7 +96,7 @@ class MoEModel(nn.Module): output_multiplier=2, d_model=768, num_resblocks=4, - num_domain_experts=8, + num_domain_experts=12, num_shared_experts=1, ): super().__init__() @@ -238,10 +238,26 @@ class MoEModel(nn.Module): expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) - else: # group_id == 8 + elif group_id == 8: expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) + elif group_id == 9: + expert_out = self.experts[9](pooled) + self.expert_bias( + torch.tensor(9, device=pooled.device) + ) + elif group_id == 10: + expert_out = self.experts[10](pooled) + self.expert_bias( + torch.tensor(10, device=pooled.device) + ) + elif group_id == 11: + expert_out = self.experts[11](pooled) + self.expert_bias( + torch.tensor(12, device=pooled.device) + ) + else: # group_id == 12 + expert_out = self.experts[12](pooled) + self.expert_bias( + torch.tensor(12, device=pooled.device) + ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 @@ -498,7 +514,7 @@ class MoEModel(nn.Module): f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 - if processed_batches + 1 >= stop_batch: + if processed_batches - 1 >= stop_batch: break global_step += 1