import math from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn import torch.amp as amp import torch.optim as optim import torch.nn.functional as F from loguru import logger from modelscope import AutoTokenizer from tqdm.notebook import tqdm from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现 class InputMethodEngine(nn.Module): def __init__( self, pretrained_encoder, # 已加载并扩展好的预训练编码器 output_vocab_size: int, hidden_size: int = 512, # 需与预训练模型隐藏维度一致 max_slot_steps: int = 24, num_experts: int = 20, top_k: int = 3, expert_res_blocks: int = 4, dropout: float = 0.3, use_attention_pooling: bool = False, **kwargs, ): super().__init__() self.hidden_size = hidden_size self.output_vocab_size = output_vocab_size self.max_slot_steps = max_slot_steps self.num_experts = num_experts self.top_k = top_k self.use_attention_pooling = use_attention_pooling # 预训练编码器 self.encoder = pretrained_encoder self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size) self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size) self.cross_attention = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=4, # 可配置 dropout=dropout, batch_first=True, ) if use_attention_pooling: self.attention_pooling = AttentionPooling(hidden_size) self.gate = nn.Linear(hidden_size, num_experts) self.experts = nn.ModuleList( [ Expert( input_dim=hidden_size, d_model=hidden_size, num_resblocks=expert_res_blocks, output_multiplier=1, dropout_prob=dropout, ) for _ in range(num_experts) ] ) self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size, bias=False), nn.LayerNorm(hidden_size), nn.GELU(), nn.Linear(hidden_size, output_vocab_size), ) self._init_weights() def encode_text(self, input_ids, token_type_ids, attention_mask): """ 使用预训练编码器编码文本 注意:预训练模型输出可能包含 last_hidden_state 和 pooler_output """ outputs = self.encoder( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True, ) # 取 last_hidden_state [batch, seq_len, hidden] return outputs.last_hidden_state def forward_single_step(self, context, slot_seq_emb, slot_seq_mask=None): """ 单步预测:根据当前槽位序列(已拼接的嵌入),预测下一个文字的概率分布 context: [batch, seq_len, hidden] 文本编码结果 slot_seq_emb: [batch, current_len, hidden] 当前槽位序列的嵌入(已拼接) slot_seq_mask: [batch, current_len] 有效位置mask(1有效) 返回: [batch, output_vocab_size] 概率分布 """ batch_size = slot_seq_emb.size(0) # 交叉注意力:Query是槽位序列(通常只取最后一个步的嵌入作为Query,但这里我们使用整个序列) # 为了简单,我们使用整个序列作为Query,然后取最后一个位置的输出(因为自回归) # 方法1:Query = 最后一个位置的嵌入(单个向量) last_query = slot_seq_emb[:, -1:, :] # [batch, 1, hidden] # 交叉注意力 attn_out, _ = self.cross_attention( query=last_query, key=context, value=context, key_padding_mask=( context.sum(-1) == 0 ), # 忽略填充位置,实际应传入attention_mask ) # [batch, 1, hidden] attn_out = attn_out.squeeze(1) # [batch, hidden] # 门控网络:选择top-k专家 gate_logits = self.gate(attn_out) # [batch, num_experts] topk_weights, topk_indices = torch.topk( F.softmax(gate_logits, dim=-1), self.top_k, dim=-1 ) # 归一化权重 topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) # 计算专家输出加权和 expert_outputs = torch.zeros_like(attn_out) # [batch, hidden] for i in range(self.top_k): expert_idx = topk_indices[:, i] # [batch] weight = topk_weights[:, i].unsqueeze(1) # [batch, 1] # 批量获取专家输出(需对每个样本分别处理,或用循环) # 这里简单循环,实际可优化为 gather for b in range(batch_size): expert_out = self.experts[expert_idx[b]](attn_out[b : b + 1]) expert_outputs[b] += weight[b] * expert_out.squeeze(0) # 分类头 logits = self.classifier(expert_outputs) # [batch, output_vocab_size] return logits def forward_train( self, input_ids, token_type_ids, attention_mask, slot_target_ids, slot_target_mask=None, ): """ 训练模式:一次性并行计算所有槽位步的预测 slot_target_ids: [batch, max_slot_steps] 真实标签(teacher forcing) slot_target_mask: [batch, max_slot_steps] 有效步mask(1表示该步存在) 返回: 各步logits [batch, max_slot_steps, output_vocab_size] """ batch_size, max_steps = slot_target_ids.shape device = input_ids.device # 1. 编码文本 context = self.encode_text( input_ids, token_type_ids, attention_mask ) # [b, L, h] # 2. 构建槽位序列嵌入(teacher forcing:使用真实标签构建) # 将每个样本的槽位ID转为embedding,并添加位置编码 # slot_target_ids 形状 [b, T] slot_emb = self.slot_embedding(slot_target_ids) # [b, T, h] # 位置编码(位置从0开始) positions = torch.arange(max_steps, device=device).unsqueeze(0) # [1, T] pos_emb = self.slot_position_embedding(positions) # [1, T, h] slot_emb = slot_emb + pos_emb # 可选:加入mask(无效位置置0) if slot_target_mask is not None: slot_emb = slot_emb * slot_target_mask.unsqueeze(-1) # 3. 交叉注意力:Query为整个槽位序列(每个位置独立预测) # 注意:我们这里使用 self-attention 的方式?实际上应该是 cross-attention 且 Query 是槽位序列 # 但常规做法是每个位置的 Query 只与文本编码交互,不与其他槽位交互,所以不需要掩码。 # 使用 MultiheadAttention 的 query 和 key/value 不同即可。 attn_out, _ = self.cross_attention( query=slot_emb, # [b, T, h] key=context, # [b, L, h] value=context, key_padding_mask=(attention_mask == 0), # 忽略文本填充位置 ) # [b, T, h] # 4. 对每个槽位步分别通过门控+专家+分类头 # 由于不同步之间共享参数,我们可以将 batch 和 steps 合并处理 b, T, h = attn_out.shape attn_flat = attn_out.view(b * T, h) # [b*T, h] # 门控网络 gate_logits = self.gate(attn_flat) # [b*T, num_experts] gate_probs = F.softmax(gate_logits, dim=-1) topk_probs, topk_indices = torch.topk( gate_probs, self.top_k, dim=-1 ) # [b*T, top_k] topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 归一化 # 计算专家输出(并行化较复杂,这里用循环简化,实际可用 scatter 优化) expert_out_flat = torch.zeros_like(attn_flat) # [b*T, h] for i in range(self.top_k): weight = topk_probs[:, i].unsqueeze(1) # [b*T, 1] idx = topk_indices[:, i] # [b*T] # 批量获取专家输出 # 注意:每个专家处理所有样本,这里用循环专家,效率较低 # 实际生产环境应优化为组卷积或使用 einsum,此处保持清晰 for expert_id in range(self.num_experts): mask = idx == expert_id if mask.any(): # 取出属于该专家的样本 sub_input = attn_flat[mask] # [k, h] sub_output = self.experts[expert_id](sub_input) # [k, h] expert_out_flat[mask] += weight[mask] * sub_output # 分类头 logits_flat = self.classifier(expert_out_flat) # [b*T, output_vocab_size] logits = logits_flat.view(b, T, -1) # [b, T, output_vocab_size] return logits def forward( self, input_ids, token_type_ids, attention_mask, slot_target_ids=None, slot_target_mask=None, mode="train", ): """ 统一接口 mode='train': 使用 teacher forcing 并行计算所有步 logits,返回 [b, T, vocab] mode='infer': 需结合外部循环,使用 forward_single_step """ if mode == "train": return self.forward_train( input_ids, token_type_ids, attention_mask, slot_target_ids, slot_target_mask, ) else: raise NotImplementedError( "Inference mode should call forward_single_step directly" ) def to(self, device): """重写 to 方法,记录设备""" self.device = device return super().to(device) def fit( self, train_dataloader, eval_dataloader=None, monitor=None, criterion=None, optimizer=None, num_epochs=1, stop_batch=2e5, eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, loss_weight=None, mixed_precision=True, weight_decay=0.1, warmup_ratio=0.1, label_smoothing=0.15, lr=1e-4, lr_schedule=None, save_dir=None, save_frequency=1000, ): """ 训练模型 """ def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps): if _processed_batches < _warmup_steps: current_lr = _lr * (_processed_batches / _warmup_steps) else: progress = (_processed_batches - _warmup_steps) / ( _stop_batch - _warmup_steps ) current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress))) return current_lr if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) if lr_schedule is None: lr_schedule = default_lr_schedule self.train() if optimizer is None: optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) # 损失函数:需要设置 ignore_index=-1,因为标签中无效位置用 -1 表示 if criterion is None: if loss_weight is not None: criterion = nn.CrossEntropyLoss( weight=loss_weight, label_smoothing=label_smoothing, ignore_index=-1 ) else: criterion = nn.CrossEntropyLoss( label_smoothing=label_smoothing, ignore_index=-1 ) scaler = amp.GradScaler(enabled=mixed_precision) total_steps = stop_batch warmup_steps = int(total_steps * warmup_ratio) logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}") processed_batches = 0 batch_loss_sum = 0.0 optimizer.zero_grad() try: for epoch in range(num_epochs): for batch_idx, batch in enumerate( tqdm(train_dataloader, total=int(stop_batch)) ): # 学习率调度 current_lr = lr_schedule( lr, stop_batch, processed_batches, warmup_steps ) for param_group in optimizer.param_groups: param_group["lr"] = current_lr # 从 batch 中获取数据 input_ids = batch["hint"]["input_ids"].to(self.device) attention_mask = batch["hint"]["attention_mask"].to(self.device) token_type_ids = batch["hint"]["token_type_ids"].to(self.device) labels = batch["char_id"].to(self.device) # [batch, max_slot_steps] # 构建 slot_target_mask:有效位置为 1,无效位置为 0(假设无效标签为 -1) slot_target_mask = (labels != -1).float() # [batch, max_slot_steps] with torch.amp.autocast( device_type=self.device.type, enabled=mixed_precision ): # 调用模型(训练模式) logits = self( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, slot_target_ids=labels, slot_target_mask=slot_target_mask, mode="train", ) # logits: [batch, max_slot_steps, output_vocab_size] # 计算损失(忽略填充位置,ignore_index=-1 已在 criterion 中设置) loss = criterion( logits.view(-1, self.output_vocab_size), labels.view(-1) ) loss = loss / grad_accum_steps scaler.scale(loss).backward() # 梯度累积更新 if (processed_batches + 1) % grad_accum_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps # 定期评估 if processed_batches % eval_frequency == 0: if eval_dataloader: self.eval() acc, eval_loss = self.model_eval(eval_dataloader, criterion) self.train() if monitor: monitor.add_step( processed_batches, { "train_loss": batch_loss_sum / (eval_frequency if processed_batches > 0 else 1), "acc": acc, "loss": eval_loss, "lr": current_lr, }, ) logger.info( f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, " f"batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, " f"current_lr: {current_lr}" ) else: logger.info( f"step: {processed_batches}, batch_loss_sum: {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, " f"current_lr: {current_lr}" ) batch_loss_sum = 0.0 processed_batches += 1 if processed_batches >= stop_batch: break else: # 未达到梯度累积步数,只累加损失值,但不更新计数器(因为 processed_batches 在梯度更新时才增加) # 注意:这里需要小心,原代码中 processed_batches 是在梯度更新后才增加,所以上面已经统一在更新后增加 # 但为了兼容原有逻辑,这里不做额外处理 pass # 训练结束通知 if monitor: monitor.finish() except KeyboardInterrupt: logger.info("Training interrupted by user") def load_from_state_dict(self, state_dict_path: Union[str, Path]): state_dict = torch.load( state_dict_path, weights_only=True, map_location=self.device ) self.load_state_dict(state_dict) def load_from_pretrained_base_model( self, BaseModel, snapshot_path: Union[str, Path], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), *args, **kwargs, ): base_model = BaseModel(*args, **kwargs) base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) self_static_dict = self.state_dict() pretrained_dict = base_model.state_dict() freeze_layers = [] for key in self_static_dict.keys(): if key in pretrained_dict.keys(): if self_static_dict[key].shape == pretrained_dict[key].shape: self_static_dict[key] = pretrained_dict[key].to(self.device) freeze_layers.append(key) self.load_state_dict(self_static_dict) for name, param in self.named_parameters(): if name in freeze_layers: param.requires_grad = False