diff --git a/src/model/model.py b/src/model/model.py index 3d55db7..fad3b18 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -1,7 +1,18 @@ +import math +from pathlib import Path +from typing import Optional, Union + import torch import torch.nn as nn +import torch.amp as amp +import torch.optim as optim + import torch.nn.functional as F +from loguru import logger +from modelscope import AutoTokenizer + + from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现 @@ -30,7 +41,6 @@ class InputMethodEngine(nn.Module): # 预训练编码器 self.encoder = pretrained_encoder - # 其他组件与之前相同 self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size) self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size) @@ -234,3 +244,184 @@ class InputMethodEngine(nn.Module): raise NotImplementedError( "Inference mode should call forward_single_step directly" ) + + def to(self, device): + """重写 to 方法,记录设备""" + self.device = device + return super().to(device) + + def fit( + self, + train_dataloader, + eval_dataloader=None, + # monitor: Optional[TrainingMonitor] = None, + criterion=None, + optimizer=None, + num_epochs=1, + stop_batch=2e5, + eval_frequency=500, + grad_accum_steps=1, + clip_grad_norm=1.0, + loss_weight=None, + mixed_precision=True, + weight_decay=0.1, + warmup_ratio=0.1, + label_smoothing=0.15, + lr=1e-4, + lr_schedule=None, + save_dir=None, + save_frequency=1000, + ): + """ + 训练模型 + """ + def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps): + if _processed_batches < _warmup_steps: + current_lr = _lr * (_processed_batches / _warmup_steps) + else: + progress = (_processed_batches - _warmup_steps) / ( + _stop_batch - _warmup_steps + ) + current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress))) + return current_lr + + """训练函数,调整了输入参数""" + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self.device) + if lr_schedule is None: + lr_schedule = default_lr_schedule + + self.train() + + if optimizer is None: + optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) + + if criterion is None: + if loss_weight is not None: + criterion = nn.CrossEntropyLoss( + weight=loss_weight, label_smoothing=label_smoothing + ) + else: + criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) + + scaler = amp.GradScaler(enabled=mixed_precision) + + total_steps = stop_batch + warmup_steps = int(total_steps * warmup_ratio) + logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}") + processed_batches = 0 + batch_loss_sum = 0.0 + optimizer.zero_grad() + try: + for epoch in range(num_epochs): + for batch_idx, batch in enumerate( + tqdm(train_dataloader, total=int(stop_batch)) + ): + # LR Schedule + current_lr = lr_schedule( + lr, stop_batch, processed_batches, warmup_steps + ) + for param_group in optimizer.param_groups: + param_group["lr"] = current_lr + + # 移动数据(注意:batch 中现在包含 token_type_ids) + input_ids = batch["hint"]["input_ids"].to(self.device) + attention_mask = batch["hint"]["attention_mask"].to(self.device) + token_type_ids = batch["hint"]["token_type_ids"].to( + self.device + ) # 新增 + pg = batch["pg"].to(self.device) + labels = batch["char_id"].to(self.device) + + with torch.amp.autocast( + device_type=self.device.type, enabled=mixed_precision + ): + logits = self(input_ids, attention_mask, token_type_ids, pg) + loss = criterion(logits, labels) + loss = loss / grad_accum_steps + + scaler.scale(loss).backward() + + if (processed_batches) % grad_accum_steps == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + self.parameters(), clip_grad_norm + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + batch_loss_sum += loss.item() * grad_accum_steps + if processed_batches % eval_frequency == 0: + if eval_dataloader: + self.eval() + acc, eval_loss = self.model_eval( + eval_dataloader, criterion + ) + self.train() + if monitor: + # 使用 eval_loss 作为监控指标 + monitor.add_step( + processed_batches, + { + "train_loss": batch_loss_sum + / ( + eval_frequency + if processed_batches > 0 + else 1 + ), + "acc": acc, + "loss": eval_loss, + "lr": current_lr, + }, + ) + logger.info( + f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" + ) + else: + logger.info( + f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" + ) + batch_loss_sum = 0.0 + if processed_batches >= stop_batch: + break + processed_batches += 1 + except KeyboardInterrupt: + logger.info("Training interrupted by user") + + # 训练结束发送通知 + if monitor: + monitor.finish() + + def load_from_state_dict(self, state_dict_path: Union[str, Path]): + state_dict = torch.load( + state_dict_path, weights_only=True, map_location=self.device + ) + self.load_state_dict(state_dict) + + def load_from_pretrained_base_model( + self, + BaseModel, + snapshot_path: Union[str, Path], + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + *args, + **kwargs, + ): + base_model = BaseModel(*args, **kwargs) + base_model.load_state_dict(torch.load(snapshot_path, map_location=device)) + self_static_dict = self.state_dict() + pretrained_dict = base_model.state_dict() + + freeze_layers = [] + + for key in self_static_dict.keys(): + if key in pretrained_dict.keys(): + if self_static_dict[key].shape == pretrained_dict[key].shape: + self_static_dict[key] = pretrained_dict[key].to(self.device) + freeze_layers.append(key) + self.load_state_dict(self_static_dict) + for name, param in self.named_parameters(): + if name in freeze_layers: + param.requires_grad = False +