import math import pickle from importlib.resources import files from pathlib import Path from typing import Optional, Union import torch import torch.amp as amp import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from loguru import logger from modelscope import AutoModel, AutoTokenizer from tqdm.autonotebook import tqdm from suinput.dataset import PG from .monitor import TrainingMonitor, send_serverchan_message def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.3): super().__init__() self.linear1 = nn.Linear(dim, dim) self.ln1 = nn.LayerNorm(dim) self.linear2 = nn.Linear(dim, dim) self.ln2 = nn.LayerNorm(dim) self.gelu = nn.GELU() self.dropout = nn.Dropout(dropout_prob) def forward(self, x): residual = x x = self.gelu(self.linear1(x)) x = self.ln1(x) x = self.linear2(x) x = self.ln2(x) x = self.dropout(x) x = x + residual return self.gelu(x) # ---------------------------- 专家网络 ---------------------------- class Expert(nn.Module): def __init__( self, input_dim, d_model=768, num_resblocks=3, output_multiplier=1, dropout_prob=0.3, ): super().__init__() self.output_dim = input_dim * output_multiplier self.linear_in = nn.Linear(input_dim, d_model) self.res_blocks = nn.ModuleList( [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] ) self.output = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(inplace=True), nn.Dropout(dropout_prob), nn.Linear(d_model, self.output_dim), ) def forward(self, x): x = self.linear_in(x) for block in self.res_blocks: x = block(x) return self.output(x) # ---------------------------- 主模型(MoE + 硬路由)------------------------ class MoEModel(nn.Module): def __init__( self, pretrained_model_name="iic/nlp_structbert_backbone_lite_std", num_classes=10018, output_multiplier=2, d_model=768, num_resblocks=4, num_domain_experts=12, num_shared_experts=1, ): super().__init__() self.output_multiplier = output_multiplier # 1. 加载预训练 Embedding (Lite: hidden_size=512) logger.info(f"Loading backbone: {pretrained_model_name}") bert = AutoModel.from_pretrained(pretrained_model_name) self.embedding = bert.embeddings self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # 512 self.device = None # 2. Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=8, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) # 3. 专家系统 total_experts = num_domain_experts + num_shared_experts self.experts = nn.ModuleList() for i in range(total_experts): dropout = 0.3 if i < num_domain_experts else 0.4 self.experts.append( Expert( input_dim=self.hidden_size, d_model=d_model, num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, dropout_prob=dropout, ) ) self.expert_bias = nn.Embedding( total_experts, self.hidden_size * self.output_multiplier ) # 4. 分类头 self.classifier = nn.Sequential( nn.LayerNorm(self.hidden_size * self.output_multiplier), nn.Dropout(0.4), nn.Linear(self.hidden_size * self.output_multiplier, num_classes), ) def to(self, device): """重写 to 方法,记录设备""" self.device = device return super().to(device) def forward(self, input_ids, attention_mask, pg, p_start): """ ONNX 兼容的 Forward 函数 Args: input_ids: [B, L] attention_mask: [B, L] pg: [B] 拼音组 ID p_start: [B] 拼音起始索引位置 (整数 Tensor) """ # ----- 1. Embeddings ----- embeddings = self.embedding(input_ids) # ----- 2. Transformer Encoder ----- # padding mask: True 表示忽略该位置 padding_mask = attention_mask == 0 encoded = self.encoder( embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] # ----- 3. ONNX 兼容的 Span Pooling (向量化实现) ----- """ 思路: 我们不能用循环去切片。我们要构造一个 Mask 矩阵。 目标:对于每个样本 i,生成一个长度为 L 的向量,其中 p_start[i] < index < p_end[i] 的位置为 1,其余为 0。 步骤: 1. 生成位置索引轴:indices = [0, 1, 2, ..., L-1] (Shape: [L]) 2. 扩展维度以匹配 Batch: indices: [1, L] p_start: [B, 1] p_end: [B, 1] 3. 逻辑比较 (Broadcasting): mask = (indices > p_start) & (indices < p_end) 结果 Shape: [B, L] (Boolean) 4. 应用 Mask: masked_encoded = encoded * mask.unsqueeze(-1) 5. 求和并归一化: sum_vec = masked_encoded.sum(dim=1) count = mask.sum(dim=1).clamp(min=1) # 防止除零 pooled = sum_vec / count """ B, L, H = encoded.shape device = encoded.device # 生成位置轴 [0, 1, ..., L-1] positions = torch.arange(L, device=device).unsqueeze(0) # [1, L] # 调整 p_start 形状为 [B, 1] 以便广播 p_start_exp = p_start.unsqueeze(1) # [B, 1] span_mask = positions >= p_start_exp # 转换为 Float 用于乘法 span_mask_float = span_mask.float() # [B, L] # 应用 Mask # encoded: [B, L, H] -> mask: [B, L, 1] masked_encoded = encoded * span_mask_float.unsqueeze(-1) # 求和 span_sum = masked_encoded.sum(dim=1) # [B, H] # 计算有效长度 (防止除以 0) span_count = span_mask_float.sum(dim=1, keepdim=True).clamp(min=1.0) # [B, 1] # 平均池化 pooled = span_sum / span_count # [B, H] # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): # ------------------ ONNX 导出模式:条件分支(batch=1)------------------ # 此时 pg 为标量 Tensor,转换为 Python int group_id = pg.item() if torch.is_tensor(pg) else pg if group_id == 0: expert_out = self.experts[0](pooled) + self.expert_bias( torch.tensor(0, device=pooled.device) ) elif group_id == 1: expert_out = self.experts[1](pooled) + self.expert_bias( torch.tensor(1, device=pooled.device) ) elif group_id == 2: expert_out = self.experts[2](pooled) + self.expert_bias( torch.tensor(2, device=pooled.device) ) elif group_id == 3: expert_out = self.experts[3](pooled) + self.expert_bias( torch.tensor(3, device=pooled.device) ) elif group_id == 4: expert_out = self.experts[4](pooled) + self.expert_bias( torch.tensor(4, device=pooled.device) ) elif group_id == 5: expert_out = self.experts[5](pooled) + self.expert_bias( torch.tensor(5, device=pooled.device) ) elif group_id == 6: expert_out = self.experts[6](pooled) + self.expert_bias( torch.tensor(6, device=pooled.device) ) elif group_id == 7: expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) elif group_id == 8: expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) elif group_id == 9: expert_out = self.experts[9](pooled) + self.expert_bias( torch.tensor(9, device=pooled.device) ) elif group_id == 10: expert_out = self.experts[10](pooled) + self.expert_bias( torch.tensor(10, device=pooled.device) ) elif group_id == 11: expert_out = self.experts[11](pooled) + self.expert_bias( torch.tensor(11, device=pooled.device) ) else: # group_id == 12 expert_out = self.experts[12](pooled) + self.expert_bias( torch.tensor(12, device=pooled.device) ) else: batch_size = pooled.size(0) # 并行计算所有专家输出 expert_outputs = torch.stack( [e(pooled) for e in self.experts], dim=0 ) # [E, B, D] # 根据 pg 索引专家输出 expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] # 添加专家偏置 bias = self.expert_bias(pg) # [B, D] expert_out = expert_out + bias # ----- 5. 分类头 ----- logits = self.classifier(expert_out) # [batch, num_classes] if not self.training: # 推理时加 Softmax probs = torch.softmax(logits, dim=-1) return probs return logits def model_eval(self, eval_dataloader, criterion): """ 在验证集上评估模型,返回准确率和平均损失。 参数: eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id' criterion: 损失函数,默认为 CrossEntropyLoss() 返回: accuracy: float, 准确率 avg_loss: float, 平均损失 """ self.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for batch in eval_dataloader: # 移动数据到模型设备 input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) p_start = batch["p_start"].to(self.device) labels = batch["char_id"].to(self.device) # 前向传播 probs = self(input_ids, attention_mask, pg) loss = criterion(log_probs, labels) total_loss += loss.item() * labels.size(0) # 计算准确率 preds = probs.argmax(dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / total if total > 0 else 0.0 accuracy = correct / total if total > 0 else 0.0 return accuracy, avg_loss def gen_predict_sample(self, text, py, tokenizer=None): """ 生成用于预测的样本数据。 参数: text (str): 输入的文本内容。 py (list): 与文本对应的拼音列表。 tokenizer (PreTrainedTokenizer, optional): 用于文本编码的分词器。如果未提供且实例中没有默认分词器, 则会自动加载预训练的分词器。 返回: dict: 包含以下键值的字典: - "hint": 包含编码后的输入特征,包括 "input_ids" 和 "attention_mask"。 - "pg": 一个张量,表示拼音的第一个字符在 PG 映射中的索引。 功能说明: 1. 如果未提供分词器且实例中不存在默认分词器,则从预训练模型加载分词器。 2. 使用分词器对输入文本和拼音进行编码,设置最大长度为 88,并进行填充和截断。 3. 构造样本字典,包含编码后的输入特征和拼音映射张量。 """ # 如果未提供分词器且实例中没有默认分词器,则加载预训练分词器 if tokenizer is None and not hasattr(self, "tokenizer"): self.tokenizer = AutoTokenizer.from_pretrained( "iic/nlp_structbert_backbone_tiny_std" ) else: # 使用传入的分词器或实例中的默认分词器 self.tokenizer = tokenizer or self.tokenizer # 对输入文本和拼音进行编码,生成模型所需的输入格式 hint = self.tokenizer( text, py, max_length=88, padding="max_length", truncation=True, return_tensors="pt", ) # 构造样本字典 sample = {} sample["hint"] = { "input_ids": hint["input_ids"], "attention_mask": hint["attention_mask"], } # 将拼音的第一个字符映射为 PG 中的索引并转换为张量 sample["pg"] = torch.tensor([PG[py[0]]]) sample["p_start"] = torch.tensor([len(text)]) return sample def predict(self, text, py, tokenizer=None): """ 基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 参数: text : str 输入的文本。 py : str 输入的拼音。 tokenizer : Tokenizer, optional 用于分词的分词器,默认为 None。 debug : bool 是否打印预测错误的样本信息。 返回: preds : torch.Tensor [batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量) """ self.eval() # 将模型设置为评估模式,关闭dropout等训练时需要的层 # ------------------ 1. 提取并规范化输入 ------------------ # 判断是否为单样本(input_ids 无 batch 维度) sample = self.gen_predict_sample(text, py, tokenizer) # 生成预测所需的样本数据 input_ids = sample["hint"]["input_ids"] # 获取输入ID attention_mask = sample["hint"]["attention_mask"] # 获取注意力掩码 pg = sample["pg"] # 获取拼音引导 has_batch_dim = input_ids.dim() > 1 # 判断输入是否有batch维度 # 如果没有batch维度,则添加batch维度 if not has_batch_dim: input_ids = input_ids.unsqueeze(0) # 在第0维添加batch维度 attention_mask = attention_mask.unsqueeze(0) # 在第0维添加batch维度 # 如果拼音引导是标量,则扩展为与输入ID相同的batch大小 if pg.dim() == 0: pg = pg.unsqueeze(0).expand(input_ids.size(0)) # 扩展拼音引导的batch维度 # ------------------ 2. 移动设备 ------------------ # 将输入数据移动到模型所在设备(GPU/CPU) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) pg = pg.to(self.device) # ------------------ 3. 推理 ------------------ # 使用torch.no_grad()上下文管理器,不计算梯度,节省内存 with torch.no_grad(): logits = self(input_ids, attention_mask, pg) # 前向传播获取logits preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] # ------------------ 4. 返回结果(保持与输入维度一致) ------------------ if not has_batch_dim: return preds.squeeze(0) # 返回标量 return preds def fit( self, train_dataloader, # 训练数据加载器 eval_dataloader=None, # 评估数据加载器,可选 monitor: Optional[TrainingMonitor] = None, # 训练监控器,用于记录训练过程 criterion=None, # 损失函数 optimizer=None, # 优化器 num_epochs=1, # 训练轮数 stop_batch=1e6, # 最大训练批次数 eval_frequency=500, grad_accum_steps=1, # 梯度累积步数 clip_grad_norm=1.0, # 梯度裁剪的范数 loss_weight=None, mixed_precision=True, weight_decay=0.1, warmup_ratio=0.1, label_smoothing=0.15, lr=1e-4, ): """ 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 参数: # TODO: 添加参数注释 """ # 确保模型在正确的设备上(GPU或CPU) if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) # 切换到训练模式 self.train() # 默认优化器设置 if optimizer is None: optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) # 损失函数设置 if criterion is None: if loss_weight is not None: criterion = nn.CrossEntropyLoss( weight=loss_weight, label_smoothing=label_smoothing ) else: criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) # 混合精度缩放器 scaler = amp.GradScaler(enabled=mixed_precision) total_steps = stop_batch warmup_steps = int(total_steps * warmup_ratio) logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}") processed_batches = 0 # 新增:实际处理的 batch 数量计数器 batch_loss_sum = 0.0 optimizer.zero_grad() for epoch in range(num_epochs): for batch_idx, batch in enumerate( tqdm(train_dataloader, total=int(stop_batch)) ): processed_batches += 1 # LR Schedule if processed_batches < warmup_steps: current_lr = lr * (processed_batches / warmup_steps) else: progress = (processed_batches - warmup_steps) / ( total_steps - warmup_steps ) current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress))) for param_group in optimizer.param_groups: param_group["lr"] = current_lr # ---------- 移动数据 ---------- input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) p_start = batch["p_start"].to(self.device) # [B] labels = batch["char_id"].to(self.device) # 混合精度前向 # Forward with torch.amp.autocast( device_type=self.device.type, enabled=mixed_precision ): logits = self(input_ids, attention_mask, pg, p_start) loss = criterion(logits, labels) loss = loss / grad_accum_steps # 反向传播 scaler.scale(loss).backward() # 梯度累积 if (processed_batches) % grad_accum_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) has_nan = False for p in self.parameters(): if p.grad is not None and torch.isnan(p.grad).any(): has_nan = True break if not has_nan: scaler.step(optimizer) scaler.update() else: logger.warning("NaN detected, skipping step.") optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps # 周期性评估 if eval_dataloader and global_step % eval_frequency == 0: self.eval() acc, eval_loss = self.model_eval(eval_dataloader, criterion) if global_step == 0: avg_loss = eval_loss self.train() if monitor: monitor.add_step( global_step, {"loss": avg_loss, "acc": acc} ) logger.info( f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" ) batch_loss_sum = 0.0 if processed_batches - 1 >= stop_batch: break global_step += 1 try: res_acc, res_loss = self.model_eval(eval_dataloader, criterion) to_wechat_response = send_serverchan_message( title="训练完成", content=f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}", ) logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}") logger.info(f"发送消息: {to_wechat_response}") except Exception as e: logger.error(f"发送消息失败: {e}") def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( state_dict_path, weights_only=True, map_location=self.device ) self.load_state_dict(state_dict) def load_from_pretrained_base_model( self, BaseModel, snapshot_path: Union[str, Path], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), *args, **kwargs, ): base_model = BaseModel(*args, **kwargs) base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) self_static_dict = self.state_dict() pretrained_dict = base_model.state_dict() freeze_layers = [] for key in self_static_dict.keys(): if key in pretrained_dict.keys(): if self_static_dict[key].shape == pretrained_dict[key].shape: self_static_dict[key] = pretrained_dict[key].to(self.device) freeze_layers.append(key) self.load_state_dict(self_static_dict) for name, param in self.named_parameters(): if name in freeze_layers: param.requires_grad = False # --- ONNX 导出辅助函数 --- def export_onnx(self, output_path, dummy_input): """ dummy_input 应该是一个字典或元组,包含: (input_ids, attention_mask, pg, p_start) """ self.eval() input_names = ["input_ids", "attention_mask", "pg", "p_start"] output_names = ["logits"] torch.onnx.export( self, dummy_input, output_path, input_names=input_names, output_names=output_names, dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "pg": {0: "batch_size"}, "p_start": {0: "batch_size"}, "logits": {0: "batch_size"}, }, opset_version=14, # 推荐使用 14+ 以支持更好的算子 do_constant_folding=True, ) logger.info(f"ONNX model exported to {output_path}")