feat(model): 添加 fit 方法支持模型训练流程
This commit is contained in:
parent
1d2ae677f9
commit
917d5976a9
|
|
@ -1,7 +1,18 @@
|
|||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
import torch.optim as optim
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
|
||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||
|
||||
|
||||
|
|
@ -30,7 +41,6 @@ class InputMethodEngine(nn.Module):
|
|||
# 预训练编码器
|
||||
self.encoder = pretrained_encoder
|
||||
|
||||
# 其他组件与之前相同
|
||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||
|
||||
|
|
@ -234,3 +244,184 @@ class InputMethodEngine(nn.Module):
|
|||
raise NotImplementedError(
|
||||
"Inference mode should call forward_single_step directly"
|
||||
)
|
||||
|
||||
def to(self, device):
|
||||
"""重写 to 方法,记录设备"""
|
||||
self.device = device
|
||||
return super().to(device)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader,
|
||||
eval_dataloader=None,
|
||||
# monitor: Optional[TrainingMonitor] = None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
num_epochs=1,
|
||||
stop_batch=2e5,
|
||||
eval_frequency=500,
|
||||
grad_accum_steps=1,
|
||||
clip_grad_norm=1.0,
|
||||
loss_weight=None,
|
||||
mixed_precision=True,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
label_smoothing=0.15,
|
||||
lr=1e-4,
|
||||
lr_schedule=None,
|
||||
save_dir=None,
|
||||
save_frequency=1000,
|
||||
):
|
||||
"""
|
||||
训练模型
|
||||
"""
|
||||
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
||||
if _processed_batches < _warmup_steps:
|
||||
current_lr = _lr * (_processed_batches / _warmup_steps)
|
||||
else:
|
||||
progress = (_processed_batches - _warmup_steps) / (
|
||||
_stop_batch - _warmup_steps
|
||||
)
|
||||
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
return current_lr
|
||||
|
||||
"""训练函数,调整了输入参数"""
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
if lr_schedule is None:
|
||||
lr_schedule = default_lr_schedule
|
||||
|
||||
self.train()
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
if criterion is None:
|
||||
if loss_weight is not None:
|
||||
criterion = nn.CrossEntropyLoss(
|
||||
weight=loss_weight, label_smoothing=label_smoothing
|
||||
)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
total_steps = stop_batch
|
||||
warmup_steps = int(total_steps * warmup_ratio)
|
||||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||||
processed_batches = 0
|
||||
batch_loss_sum = 0.0
|
||||
optimizer.zero_grad()
|
||||
try:
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=int(stop_batch))
|
||||
):
|
||||
# LR Schedule
|
||||
current_lr = lr_schedule(
|
||||
lr, stop_batch, processed_batches, warmup_steps
|
||||
)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = current_lr
|
||||
|
||||
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
token_type_ids = batch["hint"]["token_type_ids"].to(
|
||||
self.device
|
||||
) # 新增
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
with torch.amp.autocast(
|
||||
device_type=self.device.type, enabled=mixed_precision
|
||||
):
|
||||
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||||
loss = criterion(logits, labels)
|
||||
loss = loss / grad_accum_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (processed_batches) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(), clip_grad_norm
|
||||
)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
batch_loss_sum += loss.item() * grad_accum_steps
|
||||
if processed_batches % eval_frequency == 0:
|
||||
if eval_dataloader:
|
||||
self.eval()
|
||||
acc, eval_loss = self.model_eval(
|
||||
eval_dataloader, criterion
|
||||
)
|
||||
self.train()
|
||||
if monitor:
|
||||
# 使用 eval_loss 作为监控指标
|
||||
monitor.add_step(
|
||||
processed_batches,
|
||||
{
|
||||
"train_loss": batch_loss_sum
|
||||
/ (
|
||||
eval_frequency
|
||||
if processed_batches > 0
|
||||
else 1
|
||||
),
|
||||
"acc": acc,
|
||||
"loss": eval_loss,
|
||||
"lr": current_lr,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||
)
|
||||
batch_loss_sum = 0.0
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
processed_batches += 1
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
# 训练结束发送通知
|
||||
if monitor:
|
||||
monitor.finish()
|
||||
|
||||
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||
state_dict = torch.load(
|
||||
state_dict_path, weights_only=True, map_location=self.device
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def load_from_pretrained_base_model(
|
||||
self,
|
||||
BaseModel,
|
||||
snapshot_path: Union[str, Path],
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
base_model = BaseModel(*args, **kwargs)
|
||||
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||||
self_static_dict = self.state_dict()
|
||||
pretrained_dict = base_model.state_dict()
|
||||
|
||||
freeze_layers = []
|
||||
|
||||
for key in self_static_dict.keys():
|
||||
if key in pretrained_dict.keys():
|
||||
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||||
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||||
freeze_layers.append(key)
|
||||
self.load_state_dict(self_static_dict)
|
||||
for name, param in self.named_parameters():
|
||||
if name in freeze_layers:
|
||||
param.requires_grad = False
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue