diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index b0eb3ec..91620e5 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset): # 加载拼音分组 self.pg_groups = { "y": 0, - "z": 1, - "j": 2, - "l": 3, - "s": 4, - "x": 5, - "c": 6, + "k": 0, + "e": 0, + "l": 1, + "w": 1, + "f": 1, + "q": 2, + "a": 2, + "s": 2, + "x": 3, + "b": 3, + "r": 3, + "o": 4, + "m": 4, + "z": 4, + "g": 5, + "n": 5, + "c": 5, + "t": 6, + "p": 6, + "d": 6, + "j": 7, "h": 7, - "d": 8, - "b": 9, - "q": 10, - "g": 11, - "t": 12, - "m": 13, - "p": 14, - "w": 15, - "f": 16, - "k": 17, - "n": 18, - "r": 19, - "a": 19, - "e": 18, - "o": 17, } def get_next_chinese_chars( @@ -184,15 +184,15 @@ class PinyinInputDataset(IterableDataset): # 随机选择采样方式 (各1/3概率) choice = random.random() - if choice < 0.85: + if choice < 0.3333: # 方式1: 靠近汉字的54个字符 return context[-54:] if context_len >= 54 else context - elif choice < 0.95: + elif choice < 0.6667: # 方式2: 随机位置取46个连续字符 if context_len <= 46: return context start = random.randint(0, context_len - 46) - return context[start : start + 46] + context[-8:] + return context[start : start + 46] else: # 方式3: 12+6×7组合 if context_len < 12: diff --git a/src/tmp_utils/gen_eval_dataset.py b/src/tmp_utils/gen_eval_dataset.py index 912399a..b118ec2 100644 --- a/src/tmp_utils/gen_eval_dataset.py +++ b/src/tmp_utils/gen_eval_dataset.py @@ -1,7 +1,6 @@ import pickle from pathlib import Path -import torch from loguru import logger from torch.utils.data import DataLoader from tqdm import tqdm diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl index d766b4d..19be87a 100644 Binary files a/src/trainer/eval_dataset/sample_0.pkl and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl index d417f58..7013b23 100644 Binary files a/src/trainer/eval_dataset/sample_1.pkl and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl index eab9fe9..8b69754 100644 Binary files a/src/trainer/eval_dataset/sample_2.pkl and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl index 21b6e65..d6bd3df 100644 Binary files a/src/trainer/eval_dataset/sample_3.pkl and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl index 592e7c6..69e39fd 100644 Binary files a/src/trainer/eval_dataset/sample_4.pkl and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index 362e6cc..bc86280 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -77,8 +77,9 @@ class ResidualBlock(nn.Module): x = x + residual return self.relu(x) + # ---------------------------- 专家网络 ---------------------------- + -# ---------------------------- 专家网络 ---------------------------- class Expert(nn.Module): def __init__( self, @@ -90,7 +91,7 @@ class Expert(nn.Module): ): """ input_dim : 输入维度 - d_model : 专家内部维度 + d_model : 专家内部维度(固定 1024) output_multiplier : 输出维度 = input_dim * output_multiplier dropout_prob : 残差块内 Dropout """ @@ -120,8 +121,9 @@ class Expert(nn.Module): x = block(x) return self.output(x) + # ---------------------------- 主模型(MoE + 硬路由)------------------------ + -# ---------------------------- 主模型(MoE + 硬路由)------------------------ class MoEModel(nn.Module): def __init__( self, @@ -130,8 +132,8 @@ class MoEModel(nn.Module): output_multiplier=2, d_model=768, num_resblocks=4, - num_domain_experts=20, - experts_dim=EXPORT_HIDE_DIM, + num_domain_experts=8, + num_shared_experts=1, ): super().__init__() self.output_multiplier = output_multiplier @@ -142,54 +144,60 @@ class MoEModel(nn.Module): self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.device = None # 将在 to() 调用时设置 - self.experts_dim = experts_dim # 2. 4 层标准 Transformer Encoder(从 config 读取参数) encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, - nhead=8, + nhead=self.bert_config.num_attention_heads, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, + norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) + self.pooler = nn.AdaptiveAvgPool1d(1) - self.pooler = nn.AdaptiveAvgPool1d(2) - - self.total_experts = 20 + # 3. 专家层:8个领域专家 + 1个共享专家 + total_experts = num_domain_experts + num_shared_experts self.experts = nn.ModuleList() - for i in range(self.total_experts): + for i in range(total_experts): + # 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则) + dropout_prob = 0.1 if i < num_domain_experts else 0.2 expert = Expert( - input_dim=self.hidden_size * 2, - d_model=self.experts_dim[i], + input_dim=self.hidden_size, + d_model=d_model, num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size - dropout_prob=0.1, + dropout_prob=dropout_prob, ) self.experts.append(expert) self.expert_bias = nn.Embedding( - self.total_experts, self.output_multiplier * self.hidden_size * 2 + total_experts, self.output_multiplier * self.hidden_size ) # 4. 分类头 self.classifier = nn.Sequential( - nn.LayerNorm(self.output_multiplier * self.hidden_size * 2), + nn.LayerNorm(self.output_multiplier * self.hidden_size), nn.Linear( - self.output_multiplier * self.hidden_size * 2, - self.output_multiplier * self.hidden_size * 2, + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, ), nn.ReLU(inplace=True), nn.Linear( - self.output_multiplier * self.hidden_size * 2, - self.output_multiplier * self.hidden_size * 4, + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, ), nn.ReLU(inplace=True), - nn.Dropout(0.1), - nn.Linear(self.output_multiplier * self.hidden_size * 4, num_classes), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size * 2, + ), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), ) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) @@ -214,12 +222,8 @@ class MoEModel(nn.Module): embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] - for block in self.res_blocks: - pooled = block(encoded) - # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] - pooled = pooled.flatten(1) # [B, H*2] + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): @@ -259,54 +263,10 @@ class MoEModel(nn.Module): expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) - elif group_id == 8: # group_id == 8 + else: # group_id == 8 expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) - elif group_id == 9: # group_id == 9 - expert_out = self.experts[9](pooled) + self.expert_bias( - torch.tensor(9, device=pooled.device) - ) - elif group_id == 10: # group_id == 10 - expert_out = self.experts[10](pooled) + self.expert_bias( - torch.tensor(10, device=pooled.device) - ) - elif group_id == 11: # group_id == 11 - expert_out = self.experts[11](pooled) + self.expert_bias( - torch.tensor(11, device=pooled.device) - ) - elif group_id == 12: # group_id == 12 - expert_out = self.experts[12](pooled) + self.expert_bias( - torch.tensor(12, device=pooled.device) - ) - elif group_id == 13: # group_id == 13 - expert_out = self.experts[13](pooled) + self.expert_bias( - torch.tensor(13, device=pooled.device) - ) - elif group_id == 14: # group_id == 14 - expert_out = self.experts[14](pooled) + self.expert_bias( - torch.tensor(14, device=pooled.device) - ) - elif group_id == 15: # group_id == 15 - expert_out = self.experts[15](pooled) + self.expert_bias( - torch.tensor(15, device=pooled.device) - ) - elif group_id == 16: # group_id == 16 - expert_out = self.experts[16](pooled) + self.expert_bias( - torch.tensor(16, device=pooled.device) - ) - elif group_id == 17: # group_id == 17 - expert_out = self.experts[17](pooled) + self.expert_bias( - torch.tensor(17, device=pooled.device) - ) - elif group_id == 18: # group_id == 18 - expert_out = self.experts[18](pooled) + self.expert_bias( - torch.tensor(18, device=pooled.device) - ) - else: # group_id == 19 - expert_out = self.experts[19](pooled) + self.expert_bias( - torch.tensor(19, device=pooled.device) - ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 @@ -394,8 +354,8 @@ class MoEModel(nn.Module): # ------------------ 1. 提取并规范化输入 ------------------ # 判断是否为单样本(input_ids 无 batch 维度) - input_ids = sample['hint']["input_ids"] - attention_mask = sample['hint']["attention_mask"] + input_ids = sample["hint"]["input_ids"] + attention_mask = sample["hint"]["attention_mask"] pg = sample["pg"] has_batch_dim = input_ids.dim() > 1 @@ -521,7 +481,7 @@ class MoEModel(nn.Module): # 默认优化器 if optimizer is None: optimizer = optim.AdamW(self.parameters(), lr=lr) - + if criterion is None: if loss_weight is not None: criterion = nn.CrossEntropyLoss(weight=loss_weight) @@ -625,39 +585,3 @@ class MoEModel(nn.Module): for name, param in self.named_parameters(): if name in freeze_layers: param.requires_grad = False - - -# ============================ 使用示例 ============================ -if __name__ == "__main__": - # 1. 初始化模型 - model = MoEModel() - model.eval() - - # 2. 构造 dummy 输入(batch=1,用于导出 ONNX) - dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64] - dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64] - dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id - - # 3. 导出 ONNX(使用条件分支,仅计算一个专家) - torch.onnx.export( - model, - (dummy_input_ids, dummy_attention_mask, dummy_pg), - "moe_cpu.onnx", - input_names=["input_ids", "attention_mask", "pg"], - output_names=["logits"], - dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes - "input_ids": {0: "batch"}, - "attention_mask": {0: "batch"}, - }, - opset_version=12, - do_constant_folding=True, - ) - print("ONNX 导出成功!") - - # 4. 测试训练模式(batch=4) - model.train() - batch_input_ids = torch.randint(0, 100, (4, 64)) - batch_attention_mask = torch.ones_like(batch_input_ids) - batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group - logits = model(batch_input_ids, batch_attention_mask, batch_pg) - print("训练模式输出形状:", logits.shape) # [4, 10018]