feat(model): 添加 fit 方法支持模型训练流程

This commit is contained in:
songsenand 2026-03-24 00:38:02 +08:00
parent 1d2ae677f9
commit 917d5976a9
1 changed files with 192 additions and 1 deletions

View File

@ -1,7 +1,18 @@
import math
from pathlib import Path
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.amp as amp
import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger
from modelscope import AutoTokenizer
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现 from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
@ -30,7 +41,6 @@ class InputMethodEngine(nn.Module):
# 预训练编码器 # 预训练编码器
self.encoder = pretrained_encoder self.encoder = pretrained_encoder
# 其他组件与之前相同
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size) self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size) self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
@ -234,3 +244,184 @@ class InputMethodEngine(nn.Module):
raise NotImplementedError( raise NotImplementedError(
"Inference mode should call forward_single_step directly" "Inference mode should call forward_single_step directly"
) )
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def fit(
self,
train_dataloader,
eval_dataloader=None,
# monitor: Optional[TrainingMonitor] = None,
criterion=None,
optimizer=None,
num_epochs=1,
stop_batch=2e5,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
loss_weight=None,
mixed_precision=True,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.15,
lr=1e-4,
lr_schedule=None,
save_dir=None,
save_frequency=1000,
):
"""
训练模型
"""
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
if _processed_batches < _warmup_steps:
current_lr = _lr * (_processed_batches / _warmup_steps)
else:
progress = (_processed_batches - _warmup_steps) / (
_stop_batch - _warmup_steps
)
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
return current_lr
"""训练函数,调整了输入参数"""
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
if lr_schedule is None:
lr_schedule = default_lr_schedule
self.train()
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
if criterion is None:
if loss_weight is not None:
criterion = nn.CrossEntropyLoss(
weight=loss_weight, label_smoothing=label_smoothing
)
else:
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
scaler = amp.GradScaler(enabled=mixed_precision)
total_steps = stop_batch
warmup_steps = int(total_steps * warmup_ratio)
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
processed_batches = 0
batch_loss_sum = 0.0
optimizer.zero_grad()
try:
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=int(stop_batch))
):
# LR Schedule
current_lr = lr_schedule(
lr, stop_batch, processed_batches, warmup_steps
)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
# 移动数据注意batch 中现在包含 token_type_ids
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
token_type_ids = batch["hint"]["token_type_ids"].to(
self.device
) # 新增
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
with torch.amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, token_type_ids, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
if (processed_batches) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.parameters(), clip_grad_norm
)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
batch_loss_sum += loss.item() * grad_accum_steps
if processed_batches % eval_frequency == 0:
if eval_dataloader:
self.eval()
acc, eval_loss = self.model_eval(
eval_dataloader, criterion
)
self.train()
if monitor:
# 使用 eval_loss 作为监控指标
monitor.add_step(
processed_batches,
{
"train_loss": batch_loss_sum
/ (
eval_frequency
if processed_batches > 0
else 1
),
"acc": acc,
"loss": eval_loss,
"lr": current_lr,
},
)
logger.info(
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
)
else:
logger.info(
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
)
batch_loss_sum = 0.0
if processed_batches >= stop_batch:
break
processed_batches += 1
except KeyboardInterrupt:
logger.info("Training interrupted by user")
# 训练结束发送通知
if monitor:
monitor.finish()
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=self.device
)
self.load_state_dict(state_dict)
def load_from_pretrained_base_model(
self,
BaseModel,
snapshot_path: Union[str, Path],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
*args,
**kwargs,
):
base_model = BaseModel(*args, **kwargs)
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
self_static_dict = self.state_dict()
pretrained_dict = base_model.state_dict()
freeze_layers = []
for key in self_static_dict.keys():
if key in pretrained_dict.keys():
if self_static_dict[key].shape == pretrained_dict[key].shape:
self_static_dict[key] = pretrained_dict[key].to(self.device)
freeze_layers.append(key)
self.load_state_dict(self_static_dict)
for name, param in self.named_parameters():
if name in freeze_layers:
param.requires_grad = False