import math import pickle from importlib.resources import files from pathlib import Path from typing import Optional, Union import torch import torch.amp as amp import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from loguru import logger from modelscope import AutoModel, AutoTokenizer from tqdm.autonotebook import tqdm from suinput.dataset import PG from .monitor import TrainingMonitor, send_serverchan_message def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] # ---------------------------- 注意力池化模块(新增)---------------------------- class AttentionPooling(nn.Module): def __init__(self, hidden_size): super().__init__() self.attn = nn.Linear(hidden_size, 1) # 三个可学习偏置:文本、拼音、个性化 self.bias = nn.Parameter(torch.zeros(3)) # [text_bias, pinyin_bias, user_bias] def forward(self, x, mask=None, token_type_ids=None): scores = self.attn(x).squeeze(-1) # [batch, seq_len] if token_type_ids is not None: # 根据 token_type_ids 添加对应偏置 # bias 形状 [3],通过索引扩展为 [batch, seq_len] bias_per_token = self.bias[token_type_ids] # [batch, seq_len] scores = scores + bias_per_token if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = torch.softmax(scores, dim=-1) pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1) return pooled # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.3): super().__init__() self.linear1 = nn.Linear(dim, dim) self.ln1 = nn.LayerNorm(dim) self.linear2 = nn.Linear(dim, dim) self.ln2 = nn.LayerNorm(dim) self.gelu = nn.GELU() self.dropout = nn.Dropout(dropout_prob) def forward(self, x): residual = x # 修复:使用 self.gelu 而不是未定义的 self.relu x = self.gelu(self.linear1(x)) x = self.ln1(x) x = self.linear2(x) x = self.ln2(x) x = self.dropout(x) x = x + residual return self.gelu(x) # ---------------------------- 专家网络 ---------------------------- class Expert(nn.Module): def __init__( self, input_dim, d_model=768, num_resblocks=4, output_multiplier=2, dropout_prob=0.3, ): super().__init__() self.output_dim = input_dim * output_multiplier self.linear_in = nn.Linear(input_dim, d_model) self.res_blocks = nn.ModuleList( [ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)] ) self.output = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout_prob), nn.Linear(d_model, self.output_dim), ) def forward(self, x): x = self.linear_in(x) for block in self.res_blocks: x = block(x) return self.output(x) # ---------------------------- 主模型(MoE + 硬路由)------------------------ class MoEModel(nn.Module): def __init__( self, pretrained_model_name="iic/nlp_structbert_backbone_lite_std", num_classes=10018, output_multiplier=2, d_model=768, num_resblocks=4, num_domain_experts=12, num_shared_experts=1, ): super().__init__() self.output_multiplier = output_multiplier # 1. 加载预训练 Embedding (Lite: hidden_size=512) logger.info(f"Loading backbone: {pretrained_model_name}") bert = AutoModel.from_pretrained(pretrained_model_name) self.embedding = bert.embeddings self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # 512 self.device = None # 2. Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=8, dim_feedforward=self.bert_config.intermediate_size, dropout=self.bert_config.hidden_dropout_prob, activation="gelu", batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) # 3. 注意力池化(新增) self.attn_pool = AttentionPooling(self.hidden_size) # 4. 专家系统 total_experts = num_domain_experts + num_shared_experts self.experts = nn.ModuleList() for i in range(total_experts): dropout = 0.3 if i < num_domain_experts else 0.4 self.experts.append( Expert( input_dim=self.hidden_size, d_model=d_model, num_resblocks=num_resblocks, output_multiplier=self.output_multiplier, dropout_prob=dropout, ) ) self.expert_bias = nn.Embedding( total_experts, self.hidden_size * self.output_multiplier ) # 5. 分类头 self.classifier = nn.Sequential( nn.LayerNorm(self.hidden_size * self.output_multiplier), nn.Dropout(0.4), nn.Linear(self.hidden_size * self.output_multiplier, self.hidden_size * self.output_multiplier * 2), nn.GELU(), nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes), ) def to(self, device): """重写 to 方法,记录设备""" self.device = device return super().to(device) def forward(self, input_ids, attention_mask, token_type_ids, pg): """ 新版 Forward 函数,不再需要 p_start,改用 token_type_ids。 Args: input_ids: [B, L] attention_mask: [B, L] token_type_ids: [B, L] (0=文本, 1=拼音) pg: [B] 拼音组 ID """ # ----- 1. Embeddings ----- # 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理, # 但这里我们直接使用它的 embedding,并手动将 token_type_ids 的嵌入加到上面。 # 由于 bert.embeddings 通常包含 token_type_embeddings,我们可以利用它。 # 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids) # 如果当前 embedding 不支持传入 token_type_ids,可以手动相加: # embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids) # 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数 embeddings = self.embedding(input_ids, token_type_ids=token_type_ids) # ----- 2. Transformer Encoder ----- padding_mask = attention_mask == 0 encoded = self.encoder( embeddings, src_key_padding_mask=padding_mask ) # [B, S, H] # ----- 3. 注意力池化(代替原来的 Span Pooling)----- # 使用 attention_mask 忽略 padding 位置 pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H] # ----- 4. 专家路由(硬路由)----- if torch.jit.is_tracing(): # ONNX 导出模式:batch=1,根据 pg 选择专家 group_id = pg.item() if torch.is_tensor(pg) else pg # 注意:专家索引从 0 开始,确保所有 case 都覆盖且偏置正确 # 使用字典映射或 if-elif(ONNX 需要静态图,此处保持原样但修正索引错误) if group_id == 0: expert_out = self.experts[0](pooled) + self.expert_bias( torch.tensor(0, device=pooled.device) ) elif group_id == 1: expert_out = self.experts[1](pooled) + self.expert_bias( torch.tensor(1, device=pooled.device) ) elif group_id == 2: expert_out = self.experts[2](pooled) + self.expert_bias( torch.tensor(2, device=pooled.device) ) elif group_id == 3: expert_out = self.experts[3](pooled) + self.expert_bias( torch.tensor(3, device=pooled.device) ) elif group_id == 4: expert_out = self.experts[4](pooled) + self.expert_bias( torch.tensor(4, device=pooled.device) ) elif group_id == 5: expert_out = self.experts[5](pooled) + self.expert_bias( torch.tensor(5, device=pooled.device) ) elif group_id == 6: expert_out = self.experts[6](pooled) + self.expert_bias( torch.tensor(6, device=pooled.device) ) elif group_id == 7: expert_out = self.experts[7](pooled) + self.expert_bias( torch.tensor(7, device=pooled.device) ) elif group_id == 8: expert_out = self.experts[8](pooled) + self.expert_bias( torch.tensor(8, device=pooled.device) ) elif group_id == 9: expert_out = self.experts[9](pooled) + self.expert_bias( torch.tensor(9, device=pooled.device) ) elif group_id == 10: expert_out = self.experts[10](pooled) + self.expert_bias( torch.tensor(10, device=pooled.device) ) elif group_id == 11: expert_out = self.experts[11](pooled) + self.expert_bias( torch.tensor(11, device=pooled.device) ) else: # group_id == 12 expert_out = self.experts[12](pooled) + self.expert_bias( torch.tensor(12, device=pooled.device) ) else: batch_size = pooled.size(0) expert_outputs = torch.stack( [e(pooled) for e in self.experts], dim=0 ) # [E, B, D] expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D] bias = self.expert_bias(pg) # [B, D] expert_out = expert_out + bias # ----- 5. 分类头 ----- return self.classifier(expert_out) # [batch, num_classes] def model_eval(self, eval_dataloader, criterion): """ 评估模型在验证集上的性能。 Args: eval_dataloader (DataLoader): 验证集的数据加载器,每个batch包含以下字段: - hint: 包含input_ids、attention_mask和token_type_ids的字典 - pg: 程序图数据 - char_id: 字符ID标签 criterion (callable): 损失函数,用于计算模型输出与标签之间的损失 Returns: tuple: 包含两个浮点数的元组 (accuracy, avg_loss) - accuracy (float): 模型在验证集上的准确率 - avg_loss (float): 模型在验证集上的平均损失 Note: 该方法会自动将模型切换到评估模式(self.eval()), 并使用torch.no_grad()上下文管理器来禁用梯度计算, 以节省内存和计算资源。 """ self.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for batch in eval_dataloader: input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增 pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) logits = self(input_ids, attention_mask, token_type_ids, pg) loss = criterion(logits, labels) total_loss += loss.item() * labels.size(0) preds = torch.softmax(logits, dim=-1).argmax(dim=-1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / total if total > 0 else 0.0 accuracy = correct / total if total > 0 else 0.0 return accuracy, avg_loss def gen_predict_sample(self, text, py, tokenizer=None): """ 生成用于预测的样本数据。 该方法将文本和拼音转换为模型所需的输入格式,包括input_ids、attention_mask和token_type_ids。 如果没有提供tokenizer,会使用默认的AutoTokenizer。 Args: text (str): 输入文本,作为第一句输入。 py (str): 拼音字符串,作为第二句输入。 tokenizer (AutoTokenizer, optional): 分词器实例。如果为None且self.tokenizer不存在, 则会创建默认的分词器。默认为None。 Returns: dict: 包含模型输入的字典,格式为: { "hint": { "input_ids": tensor, # 文本和拼音的token IDs "attention_mask": tensor, # 注意力掩码 "token_type_ids": tensor # 句子类型ID }, "pg": tensor # 拼音组ID,根据拼音首字母生成 } Notes: - 使用text_pair参数让分词器自动生成token_type_ids - 确保分词器支持return_token_type_ids=True - 最大长度(max_length)设置为88 - 会自动进行padding和truncation处理 - 拼音组ID当前根据拼音首字母生成,可根据实际需要改进 """ if tokenizer is None and not hasattr(self, "tokenizer"): self.tokenizer = AutoTokenizer.from_pretrained( "iic/nlp_structbert_backbone_lite_std" ) else: self.tokenizer = tokenizer or self.tokenizer # 使用 text_pair 参数让分词器自动生成 token_type_ids # 注意:确保分词器支持 return_token_type_ids=True encoded = self.tokenizer( text, # 文本作为第一句 py, # 拼音作为第二句 max_length=88, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, # 显式要求返回 token_type_ids ) sample = { "hint": { "input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"], "token_type_ids": encoded["token_type_ids"], # 新增 }, "pg": torch.tensor( [PG[py[0]] if py != "" else 12] ), # 拼音组 ID 仍根据首字母生成(可根据实际需要改进) } return sample def predict(self, text, py, tokenizer=None): """ 预测函数,自动处理 batch 维度 Args: text (str or List[str]): 输入文本或文本列表 py (int or List[int]): 拼音特征,可以是单个值或列表 tokenizer (object, optional): 分词器对象,用于文本预处理。默认为None Returns: torch.Tensor: 预测结果,如果是单个输入则返回一维张量, 如果是批量输入则返回二维张量 """ self.eval() sample = self.gen_predict_sample(text, py, tokenizer) input_ids = sample["hint"]["input_ids"] attention_mask = sample["hint"]["attention_mask"] token_type_ids = sample["hint"]["token_type_ids"] pg = sample["pg"] has_batch_dim = input_ids.dim() > 1 if not has_batch_dim: input_ids = input_ids.unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) token_type_ids = token_type_ids.unsqueeze(0) if pg.dim() == 0: pg = pg.unsqueeze(0).expand(input_ids.size(0)) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) token_type_ids = token_type_ids.to(self.device) pg = pg.to(self.device) with torch.no_grad(): logits = self(input_ids, attention_mask, token_type_ids, pg) preds = torch.softmax(logits, dim=-1).argmax(dim=-1) if not has_batch_dim: return preds.squeeze(0) return preds def fit( self, train_dataloader, eval_dataloader=None, monitor: Optional[TrainingMonitor] = None, criterion=None, optimizer=None, num_epochs=1, stop_batch=2e5, eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, loss_weight=None, mixed_precision=True, weight_decay=0.1, warmup_ratio=0.1, label_smoothing=0.15, lr=1e-4, ): """训练函数,调整了输入参数""" if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) self.train() if optimizer is None: optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) if criterion is None: if loss_weight is not None: criterion = nn.CrossEntropyLoss( weight=loss_weight, label_smoothing=label_smoothing ) else: criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) scaler = amp.GradScaler(enabled=mixed_precision) total_steps = max(stop_batch, 2e5) warmup_steps = int(total_steps * warmup_ratio) logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}") processed_batches = 0 global_step = 0 # 初始化 batch_loss_sum = 0.0 optimizer.zero_grad() for epoch in range(num_epochs): for batch_idx, batch in enumerate( tqdm(train_dataloader, total=int(stop_batch)) ): # LR Schedule if processed_batches < warmup_steps: current_lr = lr * (processed_batches/ warmup_steps) else: progress = (processed_batches - warmup_steps) / ( total_steps - warmup_steps ) current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress))) for param_group in optimizer.param_groups: param_group["lr"] = current_lr # 移动数据(注意:batch 中现在包含 token_type_ids) input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) token_type_ids = batch["hint"]["token_type_ids"].to(self.device) # 新增 pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) with torch.amp.autocast( device_type=self.device.type, enabled=mixed_precision ): logits = self(input_ids, attention_mask, token_type_ids, pg) loss = criterion(logits, labels) loss = loss / grad_accum_steps scaler.scale(loss).backward() if (processed_batches) % grad_accum_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps if global_step % eval_frequency == 0: if eval_dataloader: self.eval() acc, eval_loss = self.model_eval(eval_dataloader, criterion) self.train() if monitor: # 使用 eval_loss 作为监控指标 monitor.add_step( global_step, {"loss": batch_loss_sum / (eval_frequency if global_step > 0 else 1), "acc": acc} ) logger.info( f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" ) else: logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") batch_loss_sum = 0.0 if processed_batches >= stop_batch: break processed_batches += 1 global_step += 1 # 训练结束发送通知 try: res_acc, res_loss = self.model_eval(eval_dataloader, criterion) send_serverchan_message( title="训练完成", content=f"acc: {res_acc:.4f}, loss: {res_loss:.4f}", ) logger.info(f"训练完成,acc: {res_acc:.4f}, loss: {res_loss:.4f}") except Exception as e: logger.error(f"发送消息失败: {e}") def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( state_dict_path, weights_only=True, map_location=self.device ) self.load_state_dict(state_dict) def load_from_pretrained_base_model( self, BaseModel, snapshot_path: Union[str, Path], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), *args, **kwargs, ): base_model = BaseModel(*args, **kwargs) base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) self_static_dict = self.state_dict() pretrained_dict = base_model.state_dict() freeze_layers = [] for key in self_static_dict.keys(): if key in pretrained_dict.keys(): if self_static_dict[key].shape == pretrained_dict[key].shape: self_static_dict[key] = pretrained_dict[key].to(self.device) freeze_layers.append(key) self.load_state_dict(self_static_dict) for name, param in self.named_parameters(): if name in freeze_layers: param.requires_grad = False # --- ONNX 导出辅助函数 --- def export_onnx(self, output_path, dummy_input): """ dummy_input 应该是一个元组,包含: (input_ids, attention_mask, token_type_ids, pg) """ self.eval() input_names = ["input_ids", "attention_mask", "token_type_ids", "pg"] output_names = ["logits"] torch.onnx.export( self, dummy_input, output_path, input_names=input_names, output_names=output_names, dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}, "token_type_ids": {0: "batch_size", 1: "seq_len"}, "pg": {0: "batch_size"}, "logits": {0: "batch_size"}, }, opset_version=14, do_constant_folding=True, ) logger.info(f"ONNX model exported to {output_path}")