diff --git a/src/trainer/model_with_neck.py b/src/trainer/model_with_neck.py deleted file mode 100644 index 9fd087d..0000000 --- a/src/trainer/model_with_neck.py +++ /dev/null @@ -1,514 +0,0 @@ -import pickle -from importlib.resources import files -from pathlib import Path -from typing import Optional, Union - -import torch -import torch.amp as amp -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from loguru import logger -from modelscope import AutoModel -from tqdm import tqdm - -from .monitor import TrainingMonitor -from .model import ( - EXPORT_HIDE_DIM, - eval_dataloader, - ResidualBlock, - Expert -) - - - -# ---------------------------- 主模型(MoE + 硬路由)------------------------ -class MoEModel(nn.Module): - def __init__( - self, - pretrained_model_name="iic/nlp_structbert_backbone_tiny_std", - num_classes=10018, - output_multiplier=2, - d_model=768, - num_resblocks=4, - num_domain_experts=20, - experts_dim=EXPORT_HIDE_DIM, - ): - super().__init__() - self.output_multiplier = output_multiplier - - # 1. 加载预训练 BERT,仅保留 embeddings - bert = AutoModel.from_pretrained(pretrained_model_name) - self.embedding = bert.embeddings - self.bert_config = bert.config - self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 - self.device = None # 将在 to() 调用时设置 - self.experts_dim = experts_dim - - # 2. 4 层标准 Transformer Encoder(从 config 读取参数) - encoder_layer = nn.TransformerEncoderLayer( - d_model=self.hidden_size, - nhead=8, - dim_feedforward=self.bert_config.intermediate_size, - dropout=self.bert_config.hidden_dropout_prob, - activation="gelu", - batch_first=True, - ) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) - self.pooler = nn.AdaptiveAvgPool1d(1) - - self.res_blocks = nn.ModuleList([ResidualBlock(self.hidden_size) for _ in range(4)]) - - self.total_experts = 20 - self.experts = nn.ModuleList() - - for i in range(self.total_experts): - expert = Expert( - input_dim=self.hidden_size, - d_model=self.experts_dim[i], - num_resblocks=num_resblocks, - output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size - dropout_prob=0.1, - ) - self.experts.append(expert) - - self.expert_bias = nn.Embedding( - self.total_experts, self.output_multiplier * self.hidden_size - ) - - # 4. 分类头 - self.classifier = nn.Sequential( - nn.Dropout(0.2), - nn.LayerNorm(self.output_multiplier * self.hidden_size), - nn.Linear( - self.output_multiplier * self.hidden_size, - self.output_multiplier * self.hidden_size, - ), - nn.ReLU(inplace=True), - nn.Linear( - self.output_multiplier * self.hidden_size, - self.output_multiplier * self.hidden_size * 2, - ), - nn.ReLU(inplace=True), - nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes), - ) - - def to(self, device): - """重写 to 方法,记录设备""" - self.device = device - return super().to(device) - - def forward(self, input_ids, attention_mask, pg): - """ - input_ids : [batch, seq_len] - attention_mask: [batch, seq_len] (1 为有效,0 为 padding) - pg : group_id,训练时为 [batch] 的 LongTensor,推理导出时为标量 Tensor - """ - # ----- 1. Embeddings ----- - embeddings = self.embedding(input_ids) # [B, S, H] - - # ----- 2. Transformer Encoder ----- - # padding mask: True 表示忽略该位置 - padding_mask = attention_mask == 0 - encoded = self.encoder( - embeddings, src_key_padding_mask=padding_mask - ) # [B, S, H] - - for block in self.res_blocks: - encoded = block(encoded) - - # ----- 3. 池化量 ----- - pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) - - # ----- 4. 专家路由(硬路由)----- - if torch.jit.is_tracing(): - # ------------------ ONNX 导出模式:条件分支(batch=1)------------------ - # 此时 pg 为标量 Tensor,转换为 Python int - group_id = pg.item() if torch.is_tensor(pg) else pg - - if group_id == 0: - expert_out = self.experts[0](pooled) + self.expert_bias( - torch.tensor(0, device=pooled.device) - ) - elif group_id == 1: - expert_out = self.experts[1](pooled) + self.expert_bias( - torch.tensor(1, device=pooled.device) - ) - elif group_id == 2: - expert_out = self.experts[2](pooled) + self.expert_bias( - torch.tensor(2, device=pooled.device) - ) - elif group_id == 3: - expert_out = self.experts[3](pooled) + self.expert_bias( - torch.tensor(3, device=pooled.device) - ) - elif group_id == 4: - expert_out = self.experts[4](pooled) + self.expert_bias( - torch.tensor(4, device=pooled.device) - ) - elif group_id == 5: - expert_out = self.experts[5](pooled) + self.expert_bias( - torch.tensor(5, device=pooled.device) - ) - elif group_id == 6: - expert_out = self.experts[6](pooled) + self.expert_bias( - torch.tensor(6, device=pooled.device) - ) - elif group_id == 7: - expert_out = self.experts[7](pooled) + self.expert_bias( - torch.tensor(7, device=pooled.device) - ) - elif group_id == 8: # group_id == 8 - expert_out = self.experts[8](pooled) + self.expert_bias( - torch.tensor(8, device=pooled.device) - ) - elif group_id == 9: # group_id == 9 - expert_out = self.experts[9](pooled) + self.expert_bias( - torch.tensor(9, device=pooled.device) - ) - elif group_id == 10: # group_id == 10 - expert_out = self.experts[10](pooled) + self.expert_bias( - torch.tensor(10, device=pooled.device) - ) - elif group_id == 11: # group_id == 11 - expert_out = self.experts[11](pooled) + self.expert_bias( - torch.tensor(11, device=pooled.device) - ) - elif group_id == 12: # group_id == 12 - expert_out = self.experts[12](pooled) + self.expert_bias( - torch.tensor(12, device=pooled.device) - ) - elif group_id == 13: # group_id == 13 - expert_out = self.experts[13](pooled) + self.expert_bias( - torch.tensor(13, device=pooled.device) - ) - elif group_id == 14: # group_id == 14 - expert_out = self.experts[14](pooled) + self.expert_bias( - torch.tensor(14, device=pooled.device) - ) - elif group_id == 15: # group_id == 15 - expert_out = self.experts[15](pooled) + self.expert_bias( - torch.tensor(15, device=pooled.device) - ) - elif group_id == 16: # group_id == 16 - expert_out = self.experts[16](pooled) + self.expert_bias( - torch.tensor(16, device=pooled.device) - ) - elif group_id == 17: # group_id == 17 - expert_out = self.experts[17](pooled) + self.expert_bias( - torch.tensor(17, device=pooled.device) - ) - elif group_id == 18: # group_id == 18 - expert_out = self.experts[18](pooled) + self.expert_bias( - torch.tensor(18, device=pooled.device) - ) - else: # group_id == 19 - expert_out = self.experts[19](pooled) + self.expert_bias( - torch.tensor(19, device=pooled.device) - ) - else: - batch_size = pooled.size(0) - # 并行计算所有专家输出 - expert_outputs = torch.stack( - [e(pooled) for e in self.experts], dim=0 - ) # [E, B, D] - # 根据 pg 索引专家输出 - expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] - # 添加专家偏置 - bias = self.expert_bias(pg) # [B, D] - expert_out = expert_out + bias - - # ----- 5. 分类头 ----- - logits = self.classifier(expert_out) # [batch, num_classes] - if not self.training: # 推理时加 Softmax - probs = torch.softmax(logits, dim=-1) - return probs - return logits - - def model_eval(self, eval_dataloader, criterion=None): - """ - 在验证集上评估模型,返回准确率和平均损失。 - - 参数: - eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id' - criterion: 损失函数,默认为 CrossEntropyLoss() - 返回: - accuracy: float, 准确率 - avg_loss: float, 平均损失 - """ - if criterion is None: - criterion = nn.CrossEntropyLoss() - - self.eval() - total_loss = 0.0 - correct = 0 - total = 0 - - with torch.no_grad(): - for batch in eval_dataloader: - # 移动数据到模型设备 - input_ids = batch["hint"]["input_ids"].to(self.device) - attention_mask = batch["hint"]["attention_mask"].to(self.device) - pg = batch["pg"].to(self.device) - labels = batch["char_id"].to(self.device) - - # 前向传播 - probs = self(input_ids, attention_mask, pg) - log_probs = torch.log(probs + 1e-12) - loss = nn.NLLLoss()(log_probs, labels) - total_loss += loss.item() * labels.size(0) - - # 计算准确率 - preds = probs.argmax(dim=-1) - correct += (preds == labels).sum().item() - total += labels.size(0) - - avg_loss = total_loss / total if total > 0 else 0.0 - accuracy = correct / total if total > 0 else 0.0 - return accuracy, avg_loss - - def predict(self, sample, debug=False): - """ - 基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 - - 参数: - sample : dict - 必须包含字段: - - 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本) - - 'attention_mask': 同上 - - 'pg' : [batch] 或标量 - - 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供) - 调试时(debug=True)必须包含字段: - - 'txt' : 字符串列表(batch)或单个字符串 - - 'char' : 字符串列表(batch)或单个字符串 - - 'py' : 字符串列表(batch)或单个字符串 - debug : bool - 是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。 - - 返回: - preds : torch.Tensor - [batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量) - """ - self.eval() - - # ------------------ 1. 提取并规范化输入 ------------------ - # 判断是否为单样本(input_ids 无 batch 维度) - input_ids = sample["input_ids"] - attention_mask = sample["attention_mask"] - pg = sample["pg"] - has_batch_dim = input_ids.dim() > 1 - - if not has_batch_dim: - input_ids = input_ids.unsqueeze(0) - attention_mask = attention_mask.unsqueeze(0) - if pg.dim() == 0: - pg = pg.unsqueeze(0).expand(input_ids.size(0)) - - # ------------------ 2. 移动设备 ------------------ - input_ids = input_ids.to(self.device) - attention_mask = attention_mask.to(self.device) - pg = pg.to(self.device) - - # ------------------ 3. 推理 ------------------ - with torch.no_grad(): - logits = self(input_ids, attention_mask, pg) - preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] - - # ------------------ 4. 调试打印(错误样本) ------------------ - if debug: - # 检查必需字段 - required_keys = ["char_id", "txt", "char", "py"] - missing = [k for k in required_keys if k not in sample] - if missing: - raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}") - - # 提取真实标签 - true_labels = sample["char_id"] - if true_labels.dim() == 0: - true_labels = true_labels.unsqueeze(0) - - # 移动真实标签到相同设备 - true_labels = true_labels.to(self.device) - - # 找出预测错误的索引 - incorrect_mask = preds != true_labels - incorrect_indices = torch.where(incorrect_mask)[0] - - if len(incorrect_indices) > 0: - print("\n=== 预测错误样本 ===") - # 获取调试字段(可能是列表或单个字符串) - txts = sample["txt"] - chars = sample["char"] - pys = sample["py"] - - # 统一转换为列表(如果输入是单个字符串) - if isinstance(txts, str): - txts = [txts] - chars = [chars] - pys = [pys] - - for idx in incorrect_indices.cpu().numpy(): - print(f"样本索引 {idx}:") - print(f" Text : {txts[idx]}") - print(f" Char : {chars[idx]}") - print(f" Pinyin: {pys[idx]}") - print( - f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}" - ) - print("===================\n") - - # ------------------ 5. 返回结果(保持与输入维度一致) ------------------ - if not has_batch_dim: - return preds.squeeze(0) # 返回标量 - return preds - - def fit( - self, - train_dataloader, - eval_dataloader=None, - monitor: Optional[TrainingMonitor] = None, - criterion=nn.CrossEntropyLoss(), - optimizer=None, - num_epochs=1, - eval_frequency=500, - grad_accum_steps=1, - clip_grad_norm=1.0, - mixed_precision=False, - lr=1e-4, - lr_schedule=None, # 新增:可选的自定义学习率调度函数 - ): - """ - 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 - - 参数: - train_dataloader: DataLoader - 训练数据加载器。 - eval_dataloader: DataLoader, optional - 评估数据加载器。 - monitor: TrainingMonitor, optional - 训练监控器。 - criterion: nn.Module, optional - 损失函数。 - optimizer: optim.Optimizer, optional - 优化器。 - num_epochs: int, optional - 训练轮数。 - eval_frequency: int, optional - 评估频率。 - grad_accum_steps: int, optional - 梯度累积步数。 - clip_grad_norm: float, optional - 梯度裁剪范数。 - mixed_precision: bool, optional - 是否使用混合精度。 - lr: float, optional - 初始学习率。 - lr_schedule : callable, optional - 自定义学习率调度函数,接收参数 (processed_batches, optimizer), - 可在内部直接修改 optimizer.param_groups 中的学习率。 - """ - # 确保模型在正确的设备上 - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.to(self.device) - - # 切换到训练模式 - super().train() - - # 默认优化器 - if optimizer is None: - optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4 - - # 混合精度缩放器 - scaler = amp.GradScaler(enabled=mixed_precision) - - global_step = 0 - processed_batches = 0 # 新增:实际处理的 batch 数量计数器 - batch_loss_sum = 0.0 - optimizer.zero_grad() - - for epoch in range(num_epochs): - for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)): - # ---------- 更新 batch 计数器 ---------- - processed_batches += 1 - - # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- - if lr_schedule is not None: - # 调用用户自定义的调度函数 - lr_schedule(processed_batches, optimizer) - - # ---------- 移动数据 ---------- - input_ids = batch["hint"]["input_ids"].to(self.device) - attention_mask = batch["hint"]["attention_mask"].to(self.device) - pg = batch["pg"].to(self.device) - labels = batch["char_id"].to(self.device) - - # 混合精度前向 - with amp.autocast( - device_type=self.device.type, enabled=mixed_precision - ): - logits = self(input_ids, attention_mask, pg) - loss = criterion(logits, labels) - loss = loss / grad_accum_steps - - # 反向传播 - scaler.scale(loss).backward() - - # 梯度累积 - if (batch_idx + 1) % grad_accum_steps == 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - global_step += 1 - original_loss = loss.item() * grad_accum_steps - batch_loss_sum += original_loss - # 周期性评估(与原代码相同) - if ( - eval_dataloader is not None - and global_step % eval_frequency == 0 - ): - avg_loss = batch_loss_sum / eval_frequency - acc, eval_loss = self.model_eval(eval_dataloader, criterion) - super().train() - if monitor is not None: - monitor.add_step( - global_step, - {"loss": avg_loss, "acc": acc}, - ) - logger.info( - f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" - ) - batch_loss_sum = 0.0 - - def load_from_state_dict(self, state_dict_path: Union[str, Path]): - state_dict = torch.load( - state_dict_path, weights_only=True, map_location=self.device - ) - self.load_state_dict(state_dict) - - def load_from_pretrained_base_model( - self, - BaseModel, - snapshot_path: Union[str, Path], - device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - *args, - **kwargs, - ): - base_model = BaseModel(*args, **kwargs) - base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) - self_static_dict = self.state_dict() - pretrained_dict = base_model.state_dict() - - freeze_layers = [] - - for key in self_static_dict.keys(): - if key in pretrained_dict.keys(): - if self_static_dict[key].shape == pretrained_dict[key].shape: - self_static_dict[key] = pretrained_dict[key].to(self.device) - freeze_layers.append(key) - self.load_state_dict(self_static_dict) - for name, param in self.named_parameters(): - if name in freeze_layers: - param.requires_grad = False