feat(model): 添加 fit 方法支持模型训练流程

This commit is contained in:
songsenand 2026-03-24 00:38:02 +08:00
parent 1d2ae677f9
commit 917d5976a9
1 changed files with 192 additions and 1 deletions

View File

@ -1,7 +1,18 @@
import math
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.amp as amp
import torch.optim as optim
import torch.nn.functional as F
from loguru import logger
from modelscope import AutoTokenizer
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
@ -30,7 +41,6 @@ class InputMethodEngine(nn.Module):
# 预训练编码器
self.encoder = pretrained_encoder
# 其他组件与之前相同
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
@ -234,3 +244,184 @@ class InputMethodEngine(nn.Module):
raise NotImplementedError(
"Inference mode should call forward_single_step directly"
)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def fit(
self,
train_dataloader,
eval_dataloader=None,
# monitor: Optional[TrainingMonitor] = None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
lr_schedule=None,
save_dir=None,
save_frequency=1000,
):
"""
训练模型
"""
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
if _processed_batches < _warmup_steps:
current_lr = _lr * (_processed_batches / _warmup_steps)
else:
progress = (_processed_batches - _warmup_steps) / (
_stop_batch - _warmup_steps
)
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
return current_lr
"""训练函数,调整了输入参数"""
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
if lr_schedule is None:
lr_schedule = default_lr_schedule
self.train()
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0
batch_loss_sum = 0.0
optimizer.zero_grad()
try:
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
# LR Schedule
current_lr = lr_schedule(
lr, stop_batch, processed_batches, warmup_steps
)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# 移动数据注意batch 中现在包含 token_type_ids
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(
self.device
) # 新增
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.parameters(), clip_grad_norm
)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
if processed_batches % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(
eval_dataloader, criterion
)
self.train()
if monitor:
# 使用 eval_loss 作为监控指标
monitor.add_step(
processed_batches,
{
"train_loss": batch_loss_sum
/ (
eval_frequency
if processed_batches > 0
else 1
),
"acc": acc,
"loss": eval_loss,
"lr": current_lr,
},
)
logger.info(
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
)
else:
logger.info(
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
)
batch_loss_sum = 0.0
if processed_batches >= stop_batch:
break
processed_batches += 1
except KeyboardInterrupt:
logger.info("Training interrupted by user")
# 训练结束发送通知
if monitor:
monitor.finish()
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False