feat(model): 添加 fit 方法支持模型训练流程
This commit is contained in:
parent
1d2ae677f9
commit
917d5976a9
|
|
@ -1,7 +1,18 @@
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.amp as amp
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
from .components import AttentionPooling, Expert # , ResidualBlock # 假设已实现
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,7 +41,6 @@ class InputMethodEngine(nn.Module):
|
||||||
# 预训练编码器
|
# 预训练编码器
|
||||||
self.encoder = pretrained_encoder
|
self.encoder = pretrained_encoder
|
||||||
|
|
||||||
# 其他组件与之前相同
|
|
||||||
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
self.slot_embedding = nn.Embedding(output_vocab_size, hidden_size)
|
||||||
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
self.slot_position_embedding = nn.Embedding(max_slot_steps + 1, hidden_size)
|
||||||
|
|
||||||
|
|
@ -234,3 +244,184 @@ class InputMethodEngine(nn.Module):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Inference mode should call forward_single_step directly"
|
"Inference mode should call forward_single_step directly"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""重写 to 方法,记录设备"""
|
||||||
|
self.device = device
|
||||||
|
return super().to(device)
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
# monitor: Optional[TrainingMonitor] = None,
|
||||||
|
criterion=None,
|
||||||
|
optimizer=None,
|
||||||
|
num_epochs=1,
|
||||||
|
stop_batch=2e5,
|
||||||
|
eval_frequency=500,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
clip_grad_norm=1.0,
|
||||||
|
loss_weight=None,
|
||||||
|
mixed_precision=True,
|
||||||
|
weight_decay=0.1,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
label_smoothing=0.15,
|
||||||
|
lr=1e-4,
|
||||||
|
lr_schedule=None,
|
||||||
|
save_dir=None,
|
||||||
|
save_frequency=1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练模型
|
||||||
|
"""
|
||||||
|
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
||||||
|
if _processed_batches < _warmup_steps:
|
||||||
|
current_lr = _lr * (_processed_batches / _warmup_steps)
|
||||||
|
else:
|
||||||
|
progress = (_processed_batches - _warmup_steps) / (
|
||||||
|
_stop_batch - _warmup_steps
|
||||||
|
)
|
||||||
|
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
|
return current_lr
|
||||||
|
|
||||||
|
"""训练函数,调整了输入参数"""
|
||||||
|
if self.device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.to(self.device)
|
||||||
|
if lr_schedule is None:
|
||||||
|
lr_schedule = default_lr_schedule
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
if criterion is None:
|
||||||
|
if loss_weight is not None:
|
||||||
|
criterion = nn.CrossEntropyLoss(
|
||||||
|
weight=loss_weight, label_smoothing=label_smoothing
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||||
|
|
||||||
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
||||||
|
total_steps = stop_batch
|
||||||
|
warmup_steps = int(total_steps * warmup_ratio)
|
||||||
|
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||||||
|
processed_batches = 0
|
||||||
|
batch_loss_sum = 0.0
|
||||||
|
optimizer.zero_grad()
|
||||||
|
try:
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
for batch_idx, batch in enumerate(
|
||||||
|
tqdm(train_dataloader, total=int(stop_batch))
|
||||||
|
):
|
||||||
|
# LR Schedule
|
||||||
|
current_lr = lr_schedule(
|
||||||
|
lr, stop_batch, processed_batches, warmup_steps
|
||||||
|
)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = current_lr
|
||||||
|
|
||||||
|
# 移动数据(注意:batch 中现在包含 token_type_ids)
|
||||||
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
|
token_type_ids = batch["hint"]["token_type_ids"].to(
|
||||||
|
self.device
|
||||||
|
) # 新增
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
with torch.amp.autocast(
|
||||||
|
device_type=self.device.type, enabled=mixed_precision
|
||||||
|
):
|
||||||
|
logits = self(input_ids, attention_mask, token_type_ids, pg)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
loss = loss / grad_accum_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if (processed_batches) % grad_accum_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.parameters(), clip_grad_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
batch_loss_sum += loss.item() * grad_accum_steps
|
||||||
|
if processed_batches % eval_frequency == 0:
|
||||||
|
if eval_dataloader:
|
||||||
|
self.eval()
|
||||||
|
acc, eval_loss = self.model_eval(
|
||||||
|
eval_dataloader, criterion
|
||||||
|
)
|
||||||
|
self.train()
|
||||||
|
if monitor:
|
||||||
|
# 使用 eval_loss 作为监控指标
|
||||||
|
monitor.add_step(
|
||||||
|
processed_batches,
|
||||||
|
{
|
||||||
|
"train_loss": batch_loss_sum
|
||||||
|
/ (
|
||||||
|
eval_frequency
|
||||||
|
if processed_batches > 0
|
||||||
|
else 1
|
||||||
|
),
|
||||||
|
"acc": acc,
|
||||||
|
"loss": eval_loss,
|
||||||
|
"lr": current_lr,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
|
)
|
||||||
|
batch_loss_sum = 0.0
|
||||||
|
if processed_batches >= stop_batch:
|
||||||
|
break
|
||||||
|
processed_batches += 1
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Training interrupted by user")
|
||||||
|
|
||||||
|
# 训练结束发送通知
|
||||||
|
if monitor:
|
||||||
|
monitor.finish()
|
||||||
|
|
||||||
|
def load_from_state_dict(self, state_dict_path: Union[str, Path]):
|
||||||
|
state_dict = torch.load(
|
||||||
|
state_dict_path, weights_only=True, map_location=self.device
|
||||||
|
)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def load_from_pretrained_base_model(
|
||||||
|
self,
|
||||||
|
BaseModel,
|
||||||
|
snapshot_path: Union[str, Path],
|
||||||
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
base_model = BaseModel(*args, **kwargs)
|
||||||
|
base_model.load_state_dict(torch.load(snapshot_path, map_location=device))
|
||||||
|
self_static_dict = self.state_dict()
|
||||||
|
pretrained_dict = base_model.state_dict()
|
||||||
|
|
||||||
|
freeze_layers = []
|
||||||
|
|
||||||
|
for key in self_static_dict.keys():
|
||||||
|
if key in pretrained_dict.keys():
|
||||||
|
if self_static_dict[key].shape == pretrained_dict[key].shape:
|
||||||
|
self_static_dict[key] = pretrained_dict[key].to(self.device)
|
||||||
|
freeze_layers.append(key)
|
||||||
|
self.load_state_dict(self_static_dict)
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if name in freeze_layers:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue