From 134c8a09cfbd80e0624d0ea2a3d2a1d0db141e61 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 14 Feb 2026 15:24:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=8B=BC=E9=9F=B3?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=95=B0=E6=8D=AE=E9=9B=86=E4=B8=8E=20MoE=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=B8=93=E5=AE=B6=E7=BD=91=E7=BB=9C=E9=85=8D=E7=BD=AE=E5=8F=8A?= =?UTF-8?q?=E8=AF=84=E4=BC=B0=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 50 ++-- src/trainer/model.py | 151 ++++++++-- src/trainer/new_model.py | 595 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 741 insertions(+), 55 deletions(-) create mode 100644 src/trainer/new_model.py diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 0b86f4e..5833121 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,6 +1,6 @@ import os import random -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch @@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset): # 加载拼音分组 self.pg_groups = { "y": 0, - "k": 0, - "e": 0, - "l": 1, - "w": 1, - "f": 1, - "q": 2, - "a": 2, - "s": 2, - "x": 3, - "b": 3, - "r": 3, - "o": 4, - "m": 4, - "z": 4, - "g": 5, - "n": 5, - "c": 5, - "t": 6, - "p": 6, - "d": 6, - "j": 7, + "z": 1, + "j": 2, + "l": 3, + "s": 4, + "x": 5, + "c": 6, "h": 7, + "d": 8, + "b": 9, + "q": 10, + "g": 11, + "t": 12, + "m": 13, + "p": 14, + "w": 15, + "f": 16, + "k": 17, + "n": 18, + "r": 19, + "a": 19, + "e": 18, + "o": 17, } def get_next_chinese_chars( @@ -440,9 +440,7 @@ class PinyinInputDataset(IterableDataset): "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], - "pg": torch.tensor( - [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] - ), + "pg": torch.tensor([self.pg_groups[char_info.pinyin[0]]]), } # 根据调整因子重复样本 @@ -480,7 +478,7 @@ class PinyinInputDataset(IterableDataset): seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - + batch_samples = [] for item in self.dataset: text = item.get(self.text_field, "") diff --git a/src/trainer/model.py b/src/trainer/model.py index dbc4715..496deb2 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -1,5 +1,7 @@ import pickle from importlib.resources import files +from pathlib import Path +from typing import Optional, Union import torch import torch.amp as amp @@ -12,10 +14,46 @@ from tqdm import tqdm from .monitor import TrainingMonitor -EVAL_DATALOADER = [ - pickle.load(file.open("rb")) - for file in (files(__package__) / "eval_dataset").glob("*.pkl") -] + +def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): + return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] + + +def round_to_power_of_two(x): + if x < 1: + return 0 + n = x.bit_length() + n = min(max(7, n), 9) + lower = 1 << (n) # 小于等于x的最大2的幂次 + upper = lower << 1 # 大于x的最小2的幂次 + if x - lower < upper - x: + return lower + else: + return upper + + +EXPORT_HIDE_DIM = { + 0: 1024, + 1: 1024, + 2: 1024, + 3: 512, + 4: 512, + 5: 512, + 6: 512, + 7: 512, + 8: 512, + 9: 512, + 10: 512, + 11: 512, + 12: 512, + 13: 512, + 14: 512, + 15: 512, + 16: 512, + 17: 512, + 18: 512, + 19: 256, +} # ---------------------------- 残差块 ---------------------------- @@ -92,8 +130,8 @@ class MoEModel(nn.Module): output_multiplier=2, d_model=768, num_resblocks=4, - num_domain_experts=8, - num_shared_experts=1, + num_domain_experts=20, + experts_dim=EXPORT_HIDE_DIM, ): super().__init__() self.output_multiplier = output_multiplier @@ -104,38 +142,35 @@ class MoEModel(nn.Module): self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.device = None # 将在 to() 调用时设置 + self.experts_dim = experts_dim # 2. 4 层标准 Transformer Encoder(从 config 读取参数) encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, - nhead=self.bert_config.num_attention_heads, + nhead=8, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, - norm_first=True, # Pre-LN,与预训练一致 ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) - # 3. 专家层:8个领域专家 + 1个共享专家 - total_experts = num_domain_experts + num_shared_experts + self.total_experts = 20 self.experts = nn.ModuleList() - for i in range(total_experts): - # 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则) - dropout_prob = 0.1 if i < num_domain_experts else 0.2 + for i in range(self.total_experts): expert = Expert( input_dim=self.hidden_size, - d_model=d_model, + d_model=self.experts_dim[i], num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size - dropout_prob=dropout_prob, + dropout_prob=0.1, ) self.experts.append(expert) self.expert_bias = nn.Embedding( - total_experts, self.output_multiplier * self.hidden_size + self.total_experts, self.output_multiplier * self.hidden_size ) # 4. 分类头 @@ -146,11 +181,6 @@ class MoEModel(nn.Module): self.output_multiplier * self.hidden_size, ), nn.ReLU(inplace=True), - nn.Linear( - self.output_multiplier * self.hidden_size, - self.output_multiplier * self.hidden_size, - ), - nn.ReLU(inplace=True), nn.Linear( self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size * 2, @@ -223,11 +253,54 @@ class MoEModel(nn.Module): expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) - else: # group_id == 8 + elif group_id == 8: # group_id == 8 expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) - + elif group_id == 9: # group_id == 9 + expert_out = self.experts[9](pooled) + self.expert_bias( + torch.tensor(9, device=pooled.device) + ) + elif group_id == 10: # group_id == 10 + expert_out = self.experts[10](pooled) + self.expert_bias( + torch.tensor(10, device=pooled.device) + ) + elif group_id == 11: # group_id == 11 + expert_out = self.experts[11](pooled) + self.expert_bias( + torch.tensor(11, device=pooled.device) + ) + elif group_id == 12: # group_id == 12 + expert_out = self.experts[12](pooled) + self.expert_bias( + torch.tensor(12, device=pooled.device) + ) + elif group_id == 13: # group_id == 13 + expert_out = self.experts[13](pooled) + self.expert_bias( + torch.tensor(13, device=pooled.device) + ) + elif group_id == 14: # group_id == 14 + expert_out = self.experts[14](pooled) + self.expert_bias( + torch.tensor(14, device=pooled.device) + ) + elif group_id == 15: # group_id == 15 + expert_out = self.experts[15](pooled) + self.expert_bias( + torch.tensor(15, device=pooled.device) + ) + elif group_id == 16: # group_id == 16 + expert_out = self.experts[16](pooled) + self.expert_bias( + torch.tensor(16, device=pooled.device) + ) + elif group_id == 17: # group_id == 17 + expert_out = self.experts[17](pooled) + self.expert_bias( + torch.tensor(17, device=pooled.device) + ) + elif group_id == 18: # group_id == 18 + expert_out = self.experts[18](pooled) + self.expert_bias( + torch.tensor(18, device=pooled.device) + ) + else: # group_id == 19 + expert_out = self.experts[19](pooled) + self.expert_bias( + torch.tensor(19, device=pooled.device) + ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 @@ -387,7 +460,7 @@ class MoEModel(nn.Module): self, train_dataloader, eval_dataloader=None, - monitor: TrainingMonitor = None, + monitor: Optional[TrainingMonitor] = None, criterion=nn.CrossEntropyLoss(), optimizer=None, num_epochs=1, @@ -402,11 +475,31 @@ class MoEModel(nn.Module): 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 参数: - ... 原有参数 ... - lr_schedule : callable, optional + train_dataloader: DataLoader + 训练数据加载器。 + eval_dataloader: DataLoader, optional + 评估数据加载器。 + monitor: TrainingMonitor, optional + 训练监控器。 + criterion: nn.Module, optional + 损失函数。 + optimizer: optim.Optimizer, optional + 优化器。 + num_epochs: int, optional + 训练轮数。 + eval_frequency: int, optional + 评估频率。 + grad_accum_steps: int, optional + 梯度累积步数。 + clip_grad_norm: float, optional + 梯度裁剪范数。 + mixed_precision: bool, optional + 是否使用混合精度。 + lr: float, optional + 初始学习率。 + lr_schedule : callable, optional 自定义学习率调度函数,接收参数 (processed_batches, optimizer), 可在内部直接修改 optimizer.param_groups 中的学习率。 - 若为 None,则启用内置的固定阈值调度(前1000批 1e-4,之后 6e-6)。 """ # 确保模型在正确的设备上 if self.device is None: @@ -471,7 +564,7 @@ class MoEModel(nn.Module): and global_step % eval_frequency == 0 ): avg_loss = batch_loss_sum / eval_frequency - acc, _ = self.model_eval(eval_dataloader, criterion) + acc, eval_loss = self.model_eval(eval_dataloader, criterion) super().train() if monitor is not None: monitor.add_step( @@ -479,7 +572,7 @@ class MoEModel(nn.Module): {"loss": avg_loss, "acc": acc}, ) logger.info( - f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}" + f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 diff --git a/src/trainer/new_model.py b/src/trainer/new_model.py new file mode 100644 index 0000000..61ef140 --- /dev/null +++ b/src/trainer/new_model.py @@ -0,0 +1,595 @@ +import pickle +from importlib.resources import files +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.amp as amp +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from loguru import logger +from modelscope import AutoModel +from tqdm import tqdm + +from .monitor import TrainingMonitor + + +def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): + return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] + + +# ---------------------------- 残差块 ---------------------------- +class ResidualBlock(nn.Module): + def __init__(self, dim, dropout_prob=0.1): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.ln1 = nn.LayerNorm(dim) + self.linear2 = nn.Linear(dim, dim) + self.ln2 = nn.LayerNorm(dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, x): + residual = x + x = self.relu(self.linear1(x)) + x = self.ln1(x) + x = self.linear2(x) + x = self.ln2(x) + x = self.dropout(x) # 残差前加 Dropout(符合原描述) + x = x + residual + return self.relu(x) + + +# ---------------------------- 专家网络 ---------------------------- +class Expert(nn.Module): + def __init__( + self, + input_dim, + d_model=1024, + num_resblocks=4, + output_multiplier=2, + dropout_prob=0.1, + ): + """ + input_dim : BERT 输出的 hidden_size(如 312/768) + d_model : 专家内部维度(固定 1024) + output_multiplier : 输出维度 = input_dim * output_multiplier + dropout_prob : 残差块内 Dropout + """ + super().__init__() + self.input_dim = input_dim + self.d_model = d_model + self.output_dim = input_dim * output_multiplier + + # 输入映射:input_dim -> d_model + self.linear_in = nn.Linear(input_dim, d_model) + + # 残差堆叠 + self.res_blocks = nn.ModuleList( + [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] + ) + + # 输出映射:d_model -> output_dim + self.output = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(inplace=True), + nn.Linear(d_model, self.output_dim), + ) + + def forward(self, x): + x = self.linear_in(x) + for block in self.res_blocks: + x = block(x) + return self.output(x) + + +# ---------------------------- 主模型(MoE + 硬路由)------------------------ +class MoEModel(nn.Module): + def __init__( + self, + pretrained_model_name="iic/nlp_structbert_backbone_tiny_std", + num_classes=10018, + output_multiplier=2, + d_model=768, + num_resblocks=4, + num_domain_experts=23, + ): + super().__init__() + self.output_multiplier = output_multiplier + + # 1. 加载预训练 BERT,仅保留 embeddings + bert = AutoModel.from_pretrained(pretrained_model_name) + self.embedding = bert.embeddings + self.bert_config = bert.config + self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 + self.device = None # 将在 to() 调用时设置 + + self.linear = nn.Linear(256, d_model) + + # 2. 4 层标准 Transformer Encoder(从 config 读取参数) + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hidden_size, + nhead=self.bert_config.num_attention_heads, + dim_feedforward=self.bert_config.intermediate_size, + dropout=self.bert_config.hidden_dropout_prob, + activation="gelu", + batch_first=True, + norm_first=True, # Pre-LN,与预训练一致 + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) + self.pooler = nn.AdaptiveAvgPool1d(1) + + self.total_experts = 23 + self.experts = nn.ModuleList() + + for i in range(self.total_experts): + # 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则) + expert = Expert( + input_dim=self.hidden_size * 2, + d_model=d_model, + num_resblocks=num_resblocks, + output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size + dropout_prob=0.1, + ) + self.experts.append(expert) + + self.expert_bias = nn.Embedding( + self.total_experts, self.output_multiplier * self.hidden_size + ) + + # 4. 分类头 + self.classifier = nn.Sequential( + nn.LayerNorm(self.output_multiplier * self.hidden_size), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, + ), + nn.ReLU(inplace=True), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size, + ), + nn.ReLU(inplace=True), + nn.Linear( + self.output_multiplier * self.hidden_size, + self.output_multiplier * self.hidden_size * 2, + ), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), + ) + # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) + + def to(self, device): + """重写 to 方法,记录设备""" + self.device = device + return super().to(device) + + def forward(self, input_ids, attention_mask, pg): + """ + input_ids : [batch, seq_len] + attention_mask: [batch, seq_len] (1 为有效,0 为 padding) + pg : group_id,训练时为 [batch] 的 LongTensor,推理导出时为标量 Tensor + """ + # ----- 1. Embeddings ----- + embeddings = self.embedding(input_ids) # [B, S, H] + + # ----- 2. Transformer Encoder ----- + # padding mask: True 表示忽略该位置 + padding_mask = attention_mask == 0 + encoded = self.encoder( + embeddings, src_key_padding_mask=padding_mask + ) # [B, S, H] + + # ----- 3. 池化量 ----- + pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) + + # ----- 4. 专家路由(硬路由)----- + if torch.jit.is_tracing(): + # ------------------ ONNX 导出模式:条件分支(batch=1)------------------ + # 此时 pg 为标量 Tensor,转换为 Python int + group_id = pg.item() if torch.is_tensor(pg) else pg + + if group_id == 0: + expert_out = self.experts[0](pooled) + self.expert_bias( + torch.tensor(0, device=pooled.device) + ) + elif group_id == 1: + expert_out = self.experts[1](pooled) + self.expert_bias( + torch.tensor(1, device=pooled.device) + ) + elif group_id == 2: + expert_out = self.experts[2](pooled) + self.expert_bias( + torch.tensor(2, device=pooled.device) + ) + elif group_id == 3: + expert_out = self.experts[3](pooled) + self.expert_bias( + torch.tensor(3, device=pooled.device) + ) + elif group_id == 4: + expert_out = self.experts[4](pooled) + self.expert_bias( + torch.tensor(4, device=pooled.device) + ) + elif group_id == 5: + expert_out = self.experts[5](pooled) + self.expert_bias( + torch.tensor(5, device=pooled.device) + ) + elif group_id == 6: + expert_out = self.experts[6](pooled) + self.expert_bias( + torch.tensor(6, device=pooled.device) + ) + elif group_id == 7: + expert_out = self.experts[7](pooled) + self.expert_bias( + torch.tensor(7, device=pooled.device) + ) + elif group_id == 8: # group_id == 8 + expert_out = self.experts[8](pooled) + self.expert_bias( + torch.tensor(8, device=pooled.device) + ) + elif group_id == 9: # group_id == 9 + expert_out = self.experts[9](pooled) + self.expert_bias( + torch.tensor(9, device=pooled.device) + ) + elif group_id == 10: # group_id == 10 + expert_out = self.experts[10](pooled) + self.expert_bias( + torch.tensor(10, device=pooled.device) + ) + elif group_id == 11: # group_id == 11 + expert_out = self.experts[11](pooled) + self.expert_bias( + torch.tensor(11, device=pooled.device) + ) + elif group_id == 12: # group_id == 12 + expert_out = self.experts[12](pooled) + self.expert_bias( + torch.tensor(12, device=pooled.device) + ) + elif group_id == 13: # group_id == 13 + expert_out = self.experts[13](pooled) + self.expert_bias( + torch.tensor(13, device=pooled.device) + ) + elif group_id == 14: # group_id == 14 + expert_out = self.experts[14](pooled) + self.expert_bias( + torch.tensor(14, device=pooled.device) + ) + elif group_id == 15: # group_id == 15 + expert_out = self.experts[15](pooled) + self.expert_bias( + torch.tensor(15, device=pooled.device) + ) + elif group_id == 16: # group_id == 16 + expert_out = self.experts[16](pooled) + self.expert_bias( + torch.tensor(16, device=pooled.device) + ) + elif group_id == 17: # group_id == 17 + expert_out = self.experts[17](pooled) + self.expert_bias( + torch.tensor(17, device=pooled.device) + ) + elif group_id == 18: # group_id == 18 + expert_out = self.experts[18](pooled) + self.expert_bias( + torch.tensor(18, device=pooled.device) + ) + elif group_id == 19: # group_id == 19 + expert_out = self.experts[19](pooled) + self.expert_bias( + torch.tensor(19, device=pooled.device) + ) + elif group_id == 20: # group_id == 20 + expert_out = self.experts[20](pooled) + self.expert_bias( + torch.tensor(20, device=pooled.device) + ) + elif group_id == 21: # group_id == 21 + expert_out = self.experts[21](pooled) + self.expert_bias( + torch.tensor(21, device=pooled.device) + ) + elif group_id == 22: # group_id == 22 + expert_out = self.experts[22](pooled) + self.expert_bias( + torch.tensor(22, device=pooled.device) + ) + else: + batch_size = pooled.size(0) + # 并行计算所有专家输出 + expert_outputs = torch.stack( + [e(pooled) for e in self.experts], dim=0 + ) # [E, B, D] + # 根据 pg 索引专家输出 + expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] + # 添加专家偏置 + bias = self.expert_bias(pg) # [B, D] + expert_out = expert_out + bias + + # ----- 5. 分类头 ----- + logits = self.classifier(expert_out) # [batch, num_classes] + if not self.training: # 推理时加 Softmax + probs = torch.softmax(logits, dim=-1) + return probs + return logits + + def model_eval(self, eval_dataloader, criterion=None): + """ + 在验证集上评估模型,返回准确率和平均损失。 + + 参数: + eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id' + criterion: 损失函数,默认为 CrossEntropyLoss() + 返回: + accuracy: float, 准确率 + avg_loss: float, 平均损失 + """ + if criterion is None: + criterion = nn.CrossEntropyLoss() + + self.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for batch in eval_dataloader: + # 移动数据到模型设备 + input_ids = batch["hint"]["input_ids"].to(self.device) + attention_mask = batch["hint"]["attention_mask"].to(self.device) + pg = batch["pg"].to(self.device) + labels = batch["char_id"].to(self.device) + + # 前向传播 + logits = self(input_ids, attention_mask, pg) + loss = criterion(logits, labels) + total_loss += loss.item() * labels.size(0) + + # 计算准确率 + preds = logits.argmax(dim=-1) + correct += (preds == labels).sum().item() + total += labels.size(0) + + avg_loss = total_loss / total if total > 0 else 0.0 + accuracy = correct / total if total > 0 else 0.0 + return accuracy, avg_loss + + def predict(self, sample, debug=False): + """ + 基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 + + 参数: + sample : dict + 必须包含字段: + - 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本) + - 'attention_mask': 同上 + - 'pg' : [batch] 或标量 + - 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供) + 调试时(debug=True)必须包含字段: + - 'txt' : 字符串列表(batch)或单个字符串 + - 'char' : 字符串列表(batch)或单个字符串 + - 'py' : 字符串列表(batch)或单个字符串 + debug : bool + 是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。 + + 返回: + preds : torch.Tensor + [batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量) + """ + self.eval() + + # ------------------ 1. 提取并规范化输入 ------------------ + # 判断是否为单样本(input_ids 无 batch 维度) + input_ids = sample["input_ids"] + attention_mask = sample["attention_mask"] + pg = sample["pg"] + has_batch_dim = input_ids.dim() > 1 + + if not has_batch_dim: + input_ids = input_ids.unsqueeze(0) + attention_mask = attention_mask.unsqueeze(0) + if pg.dim() == 0: + pg = pg.unsqueeze(0).expand(input_ids.size(0)) + + # ------------------ 2. 移动设备 ------------------ + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + pg = pg.to(self.device) + + # ------------------ 3. 推理 ------------------ + with torch.no_grad(): + logits = self(input_ids, attention_mask, pg) + preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] + + # ------------------ 4. 调试打印(错误样本) ------------------ + if debug: + # 检查必需字段 + required_keys = ["char_id", "txt", "char", "py"] + missing = [k for k in required_keys if k not in sample] + if missing: + raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}") + + # 提取真实标签 + true_labels = sample["char_id"] + if true_labels.dim() == 0: + true_labels = true_labels.unsqueeze(0) + + # 移动真实标签到相同设备 + true_labels = true_labels.to(self.device) + + # 找出预测错误的索引 + incorrect_mask = preds != true_labels + incorrect_indices = torch.where(incorrect_mask)[0] + + if len(incorrect_indices) > 0: + print("\n=== 预测错误样本 ===") + # 获取调试字段(可能是列表或单个字符串) + txts = sample["txt"] + chars = sample["char"] + pys = sample["py"] + + # 统一转换为列表(如果输入是单个字符串) + if isinstance(txts, str): + txts = [txts] + chars = [chars] + pys = [pys] + + for idx in incorrect_indices.cpu().numpy(): + print(f"样本索引 {idx}:") + print(f" Text : {txts[idx]}") + print(f" Char : {chars[idx]}") + print(f" Pinyin: {pys[idx]}") + print( + f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}" + ) + print("===================\n") + + # ------------------ 5. 返回结果(保持与输入维度一致) ------------------ + if not has_batch_dim: + return preds.squeeze(0) # 返回标量 + return preds + + def fit( + self, + train_dataloader, + eval_dataloader=None, + monitor: Optional[TrainingMonitor] = None, + criterion=nn.CrossEntropyLoss(), + optimizer=None, + num_epochs=1, + eval_frequency=500, + grad_accum_steps=1, + clip_grad_norm=1.0, + mixed_precision=False, + lr=1e-4, + lr_schedule=None, # 新增:可选的自定义学习率调度函数 + ): + """ + 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 + + 参数: + train_dataloader: DataLoader + 训练数据加载器。 + eval_dataloader: DataLoader, optional + 评估数据加载器。 + monitor: TrainingMonitor, optional + 训练监控器。 + criterion: nn.Module, optional + 损失函数。 + optimizer: optim.Optimizer, optional + 优化器。 + num_epochs: int, optional + 训练轮数。 + eval_frequency: int, optional + 评估频率。 + grad_accum_steps: int, optional + 梯度累积步数。 + clip_grad_norm: float, optional + 梯度裁剪范数。 + mixed_precision: bool, optional + 是否使用混合精度。 + lr: float, optional + 初始学习率。 + lr_schedule : callable, optional + 自定义学习率调度函数,接收参数 (processed_batches, optimizer), + 可在内部直接修改 optimizer.param_groups 中的学习率。 + """ + # 确保模型在正确的设备上 + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self.device) + + # 切换到训练模式 + super().train() + + # 默认优化器 + if optimizer is None: + optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4 + + # 混合精度缩放器 + scaler = amp.GradScaler(enabled=mixed_precision) + + global_step = 0 + processed_batches = 0 # 新增:实际处理的 batch 数量计数器 + batch_loss_sum = 0.0 + optimizer.zero_grad() + + for epoch in range(num_epochs): + for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)): + # ---------- 更新 batch 计数器 ---------- + processed_batches += 1 + + # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- + if lr_schedule is not None: + # 调用用户自定义的调度函数 + lr_schedule(processed_batches, optimizer) + + # ---------- 移动数据 ---------- + input_ids = batch["hint"]["input_ids"].to(self.device) + attention_mask = batch["hint"]["attention_mask"].to(self.device) + pg = batch["pg"].to(self.device) + labels = batch["char_id"].to(self.device) + + # 混合精度前向 + with amp.autocast( + device_type=self.device.type, enabled=mixed_precision + ): + logits = self(input_ids, attention_mask, pg) + loss = criterion(logits, labels) + loss = loss / grad_accum_steps + + # 反向传播 + scaler.scale(loss).backward() + + # 梯度累积 + if (batch_idx + 1) % grad_accum_steps == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + global_step += 1 + original_loss = loss.item() * grad_accum_steps + batch_loss_sum += original_loss + # 周期性评估(与原代码相同) + if ( + eval_dataloader is not None + and global_step % eval_frequency == 0 + ): + avg_loss = batch_loss_sum / eval_frequency + acc, eval_loss = self.model_eval(eval_dataloader, criterion) + super().train() + if monitor is not None: + monitor.add_step( + global_step, + {"loss": avg_loss, "acc": acc}, + ) + logger.info( + f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" + ) + batch_loss_sum = 0.0 + + +# ============================ 使用示例 ============================ +if __name__ == "__main__": + # 1. 初始化模型 + model = MoEModel() + model.eval() + + # 2. 构造 dummy 输入(batch=1,用于导出 ONNX) + dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64] + dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64] + dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id + + # 3. 导出 ONNX(使用条件分支,仅计算一个专家) + torch.onnx.export( + model, + (dummy_input_ids, dummy_attention_mask, dummy_pg), + "moe_cpu.onnx", + input_names=["input_ids", "attention_mask", "pg"], + output_names=["logits"], + dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes + "input_ids": {0: "batch"}, + "attention_mask": {0: "batch"}, + }, + opset_version=12, + do_constant_folding=True, + ) + print("ONNX 导出成功!") + + # 4. 测试训练模式(batch=4) + model.train() + batch_input_ids = torch.randint(0, 100, (4, 64)) + batch_attention_mask = torch.ones_like(batch_input_ids) + batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group + logits = model(batch_input_ids, batch_attention_mask, batch_pg) + print("训练模式输出形状:", logits.shape) # [4, 10018]