import pickle from importlib.resources import files from pathlib import Path from typing import Optional, Union import torch import torch.amp as amp import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from loguru import logger from modelscope import AutoModel from tqdm import tqdm from .monitor import TrainingMonitor def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] def round_to_power_of_two(x): if x < 1: return 0 n = x.bit_length() n = min(max(7, n), 9) lower = 1 << (n) # 小于等于x的最大2的幂次 upper = lower << 1 # 大于x的最小2的幂次 if x - lower < upper - x: return lower else: return upper EXPORT_HIDE_DIM = { 0: 1024, 1: 1024, 2: 1024, 3: 512, 4: 512, 5: 512, 6: 512, 7: 512, 8: 512, 9: 512, 10: 512, 11: 512, 12: 512, 13: 512, 14: 512, 15: 512, 16: 512, 17: 512, 18: 512, 19: 256, } # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.1): super().__init__() self.linear1 = nn.Linear(dim, dim) self.ln1 = nn.LayerNorm(dim) self.linear2 = nn.Linear(dim, dim) self.ln2 = nn.LayerNorm(dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout_prob) def forward(self, x): residual = x x = self.relu(self.linear1(x)) x = self.ln1(x) x = self.linear2(x) x = self.ln2(x) x = self.dropout(x) # 残差前加 Dropout(符合原描述) x = x + residual return self.relu(x) # ---------------------------- 专家网络 ---------------------------- class Expert(nn.Module): def __init__( self, input_dim, d_model=1024, num_resblocks=4, output_multiplier=2, dropout_prob=0.1, ): """ input_dim : BERT 输出的 hidden_size(如 312/768) d_model : 专家内部维度(固定 1024) output_multiplier : 输出维度 = input_dim * output_multiplier dropout_prob : 残差块内 Dropout """ super().__init__() self.input_dim = input_dim self.d_model = d_model self.output_dim = input_dim * output_multiplier # 输入映射:input_dim -> d_model self.linear_in = nn.Linear(input_dim, d_model) # 残差堆叠 self.res_blocks = nn.ModuleList( [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] ) # 输出映射:d_model -> output_dim self.output = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(inplace=True), nn.Linear(d_model, self.output_dim), ) def forward(self, x): x = self.linear_in(x) for block in self.res_blocks: x = block(x) return self.output(x) # ---------------------------- 主模型(MoE + 硬路由)------------------------ class MoEModel(nn.Module): def __init__( self, pretrained_model_name="iic/nlp_structbert_backbone_tiny_std", num_classes=10018, output_multiplier=2, d_model=768, num_resblocks=4, num_domain_experts=20, experts_dim=EXPORT_HIDE_DIM, ): super().__init__() self.output_multiplier = output_multiplier # 1. 加载预训练 BERT,仅保留 embeddings bert = AutoModel.from_pretrained(pretrained_model_name) self.embedding = bert.embeddings self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.device = None # 将在 to() 调用时设置 self.experts_dim = experts_dim # 2. 4 层标准 Transformer Encoder(从 config 读取参数) encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=8, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.pooler = nn.AdaptiveAvgPool1d(1) self.total_experts = 20 self.experts = nn.ModuleList() for i in range(self.total_experts): expert = Expert( input_dim=self.hidden_size, d_model=self.experts_dim[i], num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size dropout_prob=0.1, ) self.experts.append(expert) self.expert_bias = nn.Embedding( self.total_experts, self.output_multiplier * self.hidden_size ) # 4. 分类头 self.classifier = nn.Sequential( nn.LayerNorm(self.output_multiplier * self.hidden_size), nn.Linear( self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size, ), nn.ReLU(inplace=True), nn.Linear( self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size * 2, ), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), ) # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) def to(self, device): """重写 to 方法,记录设备""" self.device = device return super().to(device) def forward(self, input_ids, attention_mask, pg): """ input_ids : [batch, seq_len] attention_mask: [batch, seq_len] (1 为有效,0 为 padding) pg : group_id,训练时为 [batch] 的 LongTensor,推理导出时为标量 Tensor """ # ----- 1. Embeddings ----- embeddings = self.embedding(input_ids) # [B, S, H] # ----- 2. Transformer Encoder ----- # padding mask: True 表示忽略该位置 padding_mask = attention_mask == 0 encoded = self.encoder( embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] # ----- 3. 池化量 ----- pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): # ------------------ ONNX 导出模式:条件分支(batch=1)------------------ # 此时 pg 为标量 Tensor,转换为 Python int group_id = pg.item() if torch.is_tensor(pg) else pg if group_id == 0: expert_out = self.experts[0](pooled) + self.expert_bias( torch.tensor(0, device=pooled.device) ) elif group_id == 1: expert_out = self.experts[1](pooled) + self.expert_bias( torch.tensor(1, device=pooled.device) ) elif group_id == 2: expert_out = self.experts[2](pooled) + self.expert_bias( torch.tensor(2, device=pooled.device) ) elif group_id == 3: expert_out = self.experts[3](pooled) + self.expert_bias( torch.tensor(3, device=pooled.device) ) elif group_id == 4: expert_out = self.experts[4](pooled) + self.expert_bias( torch.tensor(4, device=pooled.device) ) elif group_id == 5: expert_out = self.experts[5](pooled) + self.expert_bias( torch.tensor(5, device=pooled.device) ) elif group_id == 6: expert_out = self.experts[6](pooled) + self.expert_bias( torch.tensor(6, device=pooled.device) ) elif group_id == 7: expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) elif group_id == 8: # group_id == 8 expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) elif group_id == 9: # group_id == 9 expert_out = self.experts[9](pooled) + self.expert_bias( torch.tensor(9, device=pooled.device) ) elif group_id == 10: # group_id == 10 expert_out = self.experts[10](pooled) + self.expert_bias( torch.tensor(10, device=pooled.device) ) elif group_id == 11: # group_id == 11 expert_out = self.experts[11](pooled) + self.expert_bias( torch.tensor(11, device=pooled.device) ) elif group_id == 12: # group_id == 12 expert_out = self.experts[12](pooled) + self.expert_bias( torch.tensor(12, device=pooled.device) ) elif group_id == 13: # group_id == 13 expert_out = self.experts[13](pooled) + self.expert_bias( torch.tensor(13, device=pooled.device) ) elif group_id == 14: # group_id == 14 expert_out = self.experts[14](pooled) + self.expert_bias( torch.tensor(14, device=pooled.device) ) elif group_id == 15: # group_id == 15 expert_out = self.experts[15](pooled) + self.expert_bias( torch.tensor(15, device=pooled.device) ) elif group_id == 16: # group_id == 16 expert_out = self.experts[16](pooled) + self.expert_bias( torch.tensor(16, device=pooled.device) ) elif group_id == 17: # group_id == 17 expert_out = self.experts[17](pooled) + self.expert_bias( torch.tensor(17, device=pooled.device) ) elif group_id == 18: # group_id == 18 expert_out = self.experts[18](pooled) + self.expert_bias( torch.tensor(18, device=pooled.device) ) else: # group_id == 19 expert_out = self.experts[19](pooled) + self.expert_bias( torch.tensor(19, device=pooled.device) ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 expert_outputs = torch.stack( [e(pooled) for e in self.experts], dim=0 ) # [E, B, D] # 根据 pg 索引专家输出 expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] # 添加专家偏置 bias = self.expert_bias(pg) # [B, D] expert_out = expert_out + bias # ----- 5. 分类头 ----- logits = self.classifier(expert_out) # [batch, num_classes] if not self.training: # 推理时加 Softmax probs = torch.softmax(logits, dim=-1) return probs return logits def model_eval(self, eval_dataloader, criterion=None): """ 在验证集上评估模型,返回准确率和平均损失。 参数: eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id' criterion: 损失函数,默认为 CrossEntropyLoss() 返回: accuracy: float, 准确率 avg_loss: float, 平均损失 """ if criterion is None: criterion = nn.CrossEntropyLoss() self.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for batch in eval_dataloader: # 移动数据到模型设备 input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) # 前向传播 probs = self(input_ids, attention_mask, pg) log_probs = torch.log(probs + 1e-12) loss = nn.NLLLoss()(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 preds = probs.argmax(dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / total if total > 0 else 0.0 accuracy = correct / total if total > 0 else 0.0 return accuracy, avg_loss def predict(self, sample, debug=False): """ 基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 参数: sample : dict 必须包含字段: - 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本) - 'attention_mask': 同上 - 'pg' : [batch] 或标量 - 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供) 调试时(debug=True)必须包含字段: - 'txt' : 字符串列表(batch)或单个字符串 - 'char' : 字符串列表(batch)或单个字符串 - 'py' : 字符串列表(batch)或单个字符串 debug : bool 是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。 返回: preds : torch.Tensor [batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量) """ self.eval() # ------------------ 1. 提取并规范化输入 ------------------ # 判断是否为单样本(input_ids 无 batch 维度) input_ids = sample['hint']["input_ids"] attention_mask = sample['hint']["attention_mask"] pg = sample["pg"] has_batch_dim = input_ids.dim() > 1 if not has_batch_dim: input_ids = input_ids.unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) if pg.dim() == 0: pg = pg.unsqueeze(0).expand(input_ids.size(0)) # ------------------ 2. 移动设备 ------------------ input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) pg = pg.to(self.device) # ------------------ 3. 推理 ------------------ with torch.no_grad(): logits = self(input_ids, attention_mask, pg) preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] # ------------------ 4. 调试打印(错误样本) ------------------ if debug: # 检查必需字段 required_keys = ["char_id", "txt", "char", "py"] missing = [k for k in required_keys if k not in sample] if missing: raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}") # 提取真实标签 true_labels = sample["char_id"] if true_labels.dim() == 0: true_labels = true_labels.unsqueeze(0) # 移动真实标签到相同设备 true_labels = true_labels.to(self.device) # 找出预测错误的索引 incorrect_mask = preds != true_labels incorrect_indices = torch.where(incorrect_mask)[0] if len(incorrect_indices) > 0: print("\n=== 预测错误样本 ===") # 获取调试字段(可能是列表或单个字符串) txts = sample["txt"] chars = sample["char"] pys = sample["py"] # 统一转换为列表(如果输入是单个字符串) if isinstance(txts, str): txts = [txts] chars = [chars] pys = [pys] for idx in incorrect_indices.cpu().numpy(): print(f"样本索引 {idx}:") print(f" Text : {txts[idx]}") print(f" Char : {chars[idx]}") print(f" Pinyin: {pys[idx]}") print( f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}" ) print("===================\n") # ------------------ 5. 返回结果(保持与输入维度一致) ------------------ if not has_batch_dim: return preds.squeeze(0) # 返回标量 return preds def fit( self, train_dataloader, eval_dataloader=None, monitor: Optional[TrainingMonitor] = None, criterion=nn.CrossEntropyLoss(), optimizer=None, num_epochs=1, stop_batch=1e6, eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, mixed_precision=False, loss_weight=None, lr=1e-4, lr_schedule=None, # 新增:可选的自定义学习率调度函数 ): """ 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 参数: train_dataloader: DataLoader 训练数据加载器。 eval_dataloader: DataLoader, optional 评估数据加载器。 monitor: TrainingMonitor, optional 训练监控器。 criterion: nn.Module, optional 损失函数。 optimizer: optim.Optimizer, optional 优化器。 num_epochs: int, optional 训练轮数。 eval_frequency: int, optional 评估频率。 grad_accum_steps: int, optional 梯度累积步数。 clip_grad_norm: float, optional 梯度裁剪范数。 mixed_precision: bool, optional 是否使用混合精度。 lr: float, optional 初始学习率。 lr_schedule : callable, optional 自定义学习率调度函数,接收参数 (processed_batches, optimizer), 可在内部直接修改 optimizer.param_groups 中的学习率。 """ # 确保模型在正确的设备上 if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) if loss_weight: loss_weight = 1 / torch.sqrt(torch.tensor(loss_weight)) loss_weight = loss_weight / loss_weight.mean() loss_weight = torch.clamp(loss_weight, min=0.01, max=1.0) self.loss_weight = loss_weight.to(self.device) criterion.weight = self.loss_weight # 切换到训练模式 super().train() # 默认优化器 if optimizer is None: optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4 # 混合精度缩放器 scaler = amp.GradScaler(enabled=mixed_precision) global_step = 0 processed_batches = 0 # 新增:实际处理的 batch 数量计数器 batch_loss_sum = 0.0 optimizer.zero_grad() for epoch in range(num_epochs): for batch_idx, batch in enumerate(tqdm(train_dataloader, total=stop_batch)): # ---------- 更新 batch 计数器 ---------- processed_batches += 1 # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- if lr_schedule is not None: # 调用用户自定义的调度函数 lr_schedule(processed_batches, optimizer) # ---------- 移动数据 ---------- input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) # 混合精度前向 with amp.autocast( device_type=self.device.type, enabled=mixed_precision ): logits = self(input_ids, attention_mask, pg) loss = criterion(logits, labels) loss = loss / grad_accum_steps # 反向传播 scaler.scale(loss).backward() # 梯度累积 if (batch_idx + 1) % grad_accum_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() global_step += 1 original_loss = loss.item() * grad_accum_steps batch_loss_sum += original_loss # 周期性评估(与原代码相同) if ( eval_dataloader is not None and global_step % eval_frequency == 0 ): avg_loss = batch_loss_sum / eval_frequency acc, eval_loss = self.model_eval(eval_dataloader, criterion) super().train() if monitor is not None: monitor.add_step( global_step, {"loss": avg_loss, "acc": acc}, ) logger.info( f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 if processed_batches >= stop_batch: break def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( state_dict_path, weights_only=True, map_location=self.device ) self.load_state_dict(state_dict) def load_from_pretrained_base_model( self, BaseModel, snapshot_path: Union[str, Path], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), *args, **kwargs, ): base_model = BaseModel(*args, **kwargs) base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) self_static_dict = self.state_dict() pretrained_dict = base_model.state_dict() freeze_layers = [] for key in self_static_dict.keys(): if key in pretrained_dict.keys(): if self_static_dict[key].shape == pretrained_dict[key].shape: self_static_dict[key] = pretrained_dict[key].to(self.device) freeze_layers.append(key) self.load_state_dict(self_static_dict) for name, param in self.named_parameters(): if name in freeze_layers: param.requires_grad = False # ============================ 使用示例 ============================ if __name__ == "__main__": # 1. 初始化模型 model = MoEModel() model.eval() # 2. 构造 dummy 输入(batch=1,用于导出 ONNX) dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64] dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64] dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id # 3. 导出 ONNX(使用条件分支,仅计算一个专家) torch.onnx.export( model, (dummy_input_ids, dummy_attention_mask, dummy_pg), "moe_cpu.onnx", input_names=["input_ids", "attention_mask", "pg"], output_names=["logits"], dynamic_axes={ # 固定 batch=1,可不设 dynamic_axes "input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, }, opset_version=12, do_constant_folding=True, ) print("ONNX 导出成功!") # 4. 测试训练模式(batch=4) model.train() batch_input_ids = torch.randint(0, 100, (4, 64)) batch_attention_mask = torch.ones_like(batch_input_ids) batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group logits = model(batch_input_ids, batch_attention_mask, batch_pg) print("训练模式输出形状:", logits.shape) # [4, 10018]