diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index b7223d3..3699cfa 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -424,8 +424,7 @@ class PinyinInputDataset(IterableDataset): # Tokenize hint = self.tokenizer( - sampled_context, - processed_pinyin, + sampled_context + processed_pinyin, max_length=self.max_len, padding="max_length", truncation=True, diff --git a/src/tmp_utils/gen_eval_dataset.py b/src/tmp_utils/gen_eval_dataset.py index d13b4c3..4e8ef56 100644 --- a/src/tmp_utils/gen_eval_dataset.py +++ b/src/tmp_utils/gen_eval_dataset.py @@ -28,7 +28,7 @@ if __name__ == "__main__": logger.info("数据集初始化") dataloader = DataLoader( dataset, - batch_size=2, + batch_size=1024, num_workers=1, worker_init_fn=worker_init_fn, pin_memory=True if torch.cuda.is_available() else False, diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl index 67552e4..3ebc5cf 100644 Binary files a/src/trainer/eval_dataset/sample_0.pkl and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl index a6e075c..b8149f8 100644 Binary files a/src/trainer/eval_dataset/sample_1.pkl and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl index c37b245..1ccb18a 100644 Binary files a/src/trainer/eval_dataset/sample_2.pkl and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl index 7527241..f0a2796 100644 Binary files a/src/trainer/eval_dataset/sample_3.pkl and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl index 444debc..05f6211 100644 Binary files a/src/trainer/eval_dataset/sample_4.pkl and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index 29a5f42..cbe8fe8 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -348,13 +348,13 @@ class MoEModel(nn.Module): labels = batch["char_id"].to(self.device) # 前向传播 - logits = self(input_ids, attention_mask, pg) + probs = self(input_ids, attention_mask, pg) log_probs = torch.log(probs + 1e-12) loss = nn.NLLLoss()(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 - preds = logits.argmax(dim=-1) + preds = probs.argmax(dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) @@ -388,8 +388,8 @@ class MoEModel(nn.Module): # ------------------ 1. 提取并规范化输入 ------------------ # 判断是否为单样本(input_ids 无 batch 维度) - input_ids = sample["input_ids"] - attention_mask = sample["attention_mask"] + input_ids = sample['hint']["input_ids"] + attention_mask = sample['hint']["attention_mask"] pg = sample["pg"] has_batch_dim = input_ids.dim() > 1 @@ -577,6 +577,37 @@ class MoEModel(nn.Module): ) batch_loss_sum = 0.0 + def load_from_state_dict(self, state_dict_path: Union[str, Path]): + state_dict = torch.load( + state_dict_path, weights_only=True, map_location=self.device + ) + self.load_state_dict(state_dict) + + def load_from_pretrained_base_model( + self, + BaseModel, + snapshot_path: Union[str, Path], + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + *args, + **kwargs, + ): + base_model = BaseModel(*args, **kwargs) + base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) + self_static_dict = self.state_dict() + pretrained_dict = base_model.state_dict() + + freeze_layers = [] + + for key in self_static_dict.keys(): + if key in pretrained_dict.keys(): + if self_static_dict[key].shape == pretrained_dict[key].shape: + self_static_dict[key] = pretrained_dict[key].to(self.device) + freeze_layers.append(key) + self.load_state_dict(self_static_dict) + for name, param in self.named_parameters(): + if name in freeze_layers: + param.requires_grad = False + # ============================ 使用示例 ============================ if __name__ == "__main__": diff --git a/src/trainer/new_model.py b/src/trainer/model_with_neck.py similarity index 88% rename from src/trainer/new_model.py rename to src/trainer/model_with_neck.py index 61ef140..f08819f 100644 --- a/src/trainer/new_model.py +++ b/src/trainer/model_with_neck.py @@ -19,6 +19,43 @@ def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] +def round_to_power_of_two(x): + if x < 1: + return 0 + n = x.bit_length() + n = min(max(7, n), 9) + lower = 1 << (n) # 小于等于x的最大2的幂次 + upper = lower << 1 # 大于x的最小2的幂次 + if x - lower < upper - x: + return lower + else: + return upper + + +EXPORT_HIDE_DIM = { + 0: 1024, + 1: 1024, + 2: 1024, + 3: 512, + 4: 512, + 5: 512, + 6: 512, + 7: 512, + 8: 512, + 9: 512, + 10: 512, + 11: 512, + 12: 512, + 13: 512, + 14: 512, + 15: 512, + 16: 512, + 17: 512, + 18: 512, + 19: 256, +} + + # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.1): @@ -93,7 +130,8 @@ class MoEModel(nn.Module): output_multiplier=2, d_model=768, num_resblocks=4, - num_domain_experts=23, + num_domain_experts=20, + experts_dim=EXPORT_HIDE_DIM, ): super().__init__() self.output_multiplier = output_multiplier @@ -104,30 +142,27 @@ class MoEModel(nn.Module): self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.device = None # 将在 to() 调用时设置 - - self.linear = nn.Linear(256, d_model) + self.experts_dim = experts_dim # 2. 4 层标准 Transformer Encoder(从 config 读取参数) encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, - nhead=self.bert_config.num_attention_heads, + nhead=8, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, - norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) - self.total_experts = 23 + self.total_experts = 20 self.experts = nn.ModuleList() for i in range(self.total_experts): - # 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则) expert = Expert( - input_dim=self.hidden_size * 2, - d_model=d_model, + input_dim=self.hidden_size, + d_model=self.experts_dim[i], num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size dropout_prob=0.1, @@ -146,11 +181,6 @@ class MoEModel(nn.Module): self.output_multiplier * self.hidden_size, ), nn.ReLU(inplace=True), - nn.Linear( - self.output_multiplier * self.hidden_size, - self.output_multiplier * self.hidden_size, - ), - nn.ReLU(inplace=True), nn.Linear( self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size * 2, @@ -267,22 +297,10 @@ class MoEModel(nn.Module): expert_out = self.experts[18](pooled) + self.expert_bias( torch.tensor(18, device=pooled.device) ) - elif group_id == 19: # group_id == 19 + else: # group_id == 19 expert_out = self.experts[19](pooled) + self.expert_bias( torch.tensor(19, device=pooled.device) ) - elif group_id == 20: # group_id == 20 - expert_out = self.experts[20](pooled) + self.expert_bias( - torch.tensor(20, device=pooled.device) - ) - elif group_id == 21: # group_id == 21 - expert_out = self.experts[21](pooled) + self.expert_bias( - torch.tensor(21, device=pooled.device) - ) - elif group_id == 22: # group_id == 22 - expert_out = self.experts[22](pooled) + self.expert_bias( - torch.tensor(22, device=pooled.device) - ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 @@ -330,12 +348,13 @@ class MoEModel(nn.Module): labels = batch["char_id"].to(self.device) # 前向传播 - logits = self(input_ids, attention_mask, pg) - loss = criterion(logits, labels) + probs = self(input_ids, attention_mask, pg) + log_probs = torch.log(probs + 1e-12) + loss = nn.NLLLoss()(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 - preds = logits.argmax(dim=-1) + preds = probs.argmax(dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) @@ -558,38 +577,33 @@ class MoEModel(nn.Module): ) batch_loss_sum = 0.0 + def load_from_state_dict(self, state_dict_path: Union[str, Path]): + state_dict = torch.load( + state_dict_path, weights_only=True, map_location=self.device + ) + self.model.load_state_dict(state_dict) -# ============================ 使用示例 ============================ -if __name__ == "__main__": - # 1. 初始化模型 - model = MoEModel() - model.eval() + def load_from_pretrained_base_model( + self, + BaseModel, + snapshot_path: Union[str, Path], + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + *args, + **kwargs, + ): + base_model = BaseModel(*args, **kwargs) + base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) + self_static_dict = self.state_dict() + pretrained_dict = base_model.state_dict() - # 2. 构造 dummy 输入(batch=1,用于导出 ONNX) - dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64] - dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64] - dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id + freeze_layers = [] - # 3. 导出 ONNX(使用条件分支,仅计算一个专家) - torch.onnx.export( - model, - (dummy_input_ids, dummy_attention_mask, dummy_pg), - "moe_cpu.onnx", - input_names=["input_ids", "attention_mask", "pg"], - output_names=["logits"], - dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes - "input_ids": {0: "batch"}, - "attention_mask": {0: "batch"}, - }, - opset_version=12, - do_constant_folding=True, - ) - print("ONNX 导出成功!") - - # 4. 测试训练模式(batch=4) - model.train() - batch_input_ids = torch.randint(0, 100, (4, 64)) - batch_attention_mask = torch.ones_like(batch_input_ids) - batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group - logits = model(batch_input_ids, batch_attention_mask, batch_pg) - print("训练模式输出形状:", logits.shape) # [4, 10018] + for key in self_static_dict.keys(): + if key in pretrained_dict.keys(): + if self_static_dict[key].shape == pretrained_dict[key].shape: + self_static_dict[key] = pretrained_dict[key].to(self.device) + freeze_layers.append(key) + self.load_state_dict(self_static_dict) + for name, param in self.named_parameters(): + if name in freeze_layers: + param.requires_grad = False