SUimeModelTraner/src/model/trainer.py

1815 lines
68 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import math
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import typer
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from torch import autocast
from torch.amp.grad_scaler import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .dataset import PinyinInputDataset
# 导入模型和数据
from .model import InputMethodEngine
class Trainer:
"""
输入法模型训练器
实现训练InputMethodEngine模型支持
- 预热+余弦退火学习率调度
- TensorBoard日志记录
- AdamW优化器weight_decay=0.1
- 混合精度训练
- CrossEntropyLoss损失函数支持weight和label_smoothing
- Rich终端美化输出
"""
training_status_data: List[Dict[str, Any]]
def __init__(
self,
model: nn.Module,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
total_steps: int,
output_dir: str = "./output",
num_epochs: int = 10,
learning_rate: float = 1e-4,
min_learning_rate: float = 1e-6,
weight_decay: float = 0.1,
warmup_ratio: float = 0.1,
label_smoothing: float = 0.15,
loss_weight: Optional[torch.Tensor] = None,
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
eval_frequency: int = 500,
save_frequency: int = 10000,
mixed_precision: bool = True,
device: Optional[torch.device] = None,
status_file: str = "training_status.json",
use_tensorboard: bool = True,
):
"""
初始化训练器
Args:
model: 要训练的InputMethodEngine模型
train_dataloader: 训练数据加载器
eval_dataloader: 评估数据加载器(可选)
output_dir: 输出目录,用于保存模型和日志
num_epochs: 训练轮数
total_steps: 总训练步数如果为None则根据epochs计算
learning_rate: 最大学习率(预热后)
min_learning_rate: 最小学习率(余弦退火后的最低值)
weight_decay: AdamW优化器的权重衰减
warmup_ratio: 热身步数占总步数的比例
label_smoothing: CrossEntropyLoss的标签平滑参数
loss_weight: CrossEntropyLoss的类别权重
grad_accum_steps: 梯度累积步数
clip_grad_norm: 梯度裁剪的最大范数
eval_frequency: 评估频率(步数)
save_frequency: 保存检查点频率(步数)
mixed_precision: 是否使用混合精度训练
device: 训练设备如果为None则自动选择
use_tensorboard: 是否使用TensorBoard记录
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = list([i for i in eval_dataloader])
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.min_learning_rate = min_learning_rate
self.weight_decay = weight_decay
self.warmup_ratio = warmup_ratio
self.label_smoothing = label_smoothing
self.loss_weight = loss_weight
self.grad_accum_steps = grad_accum_steps
self.clip_grad_norm = clip_grad_norm
self.eval_frequency = eval_frequency
self.save_frequency = save_frequency
self.mixed_precision = mixed_precision
self.use_tensorboard = use_tensorboard
# 设置设备
logger.info(f"GPU可用 {torch.cuda.is_available()}")
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
# 移动模型到设备
self.model.to(self.device)
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(exist_ok=True)
# 计算总步数
self.total_steps = total_steps
self.warmup_steps = int(self.total_steps * warmup_ratio)
# 初始化优化器
self.optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999),
eps=1e-8,
)
# 初始化损失函数
if loss_weight is not None:
self.criterion = nn.CrossEntropyLoss(
weight=loss_weight.to(self.device), label_smoothing=label_smoothing
)
else:
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 初始化混合精度训练器
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
self.scaler = GradScaler(device_type, enabled=mixed_precision)
# 初始化TensorBoard
if use_tensorboard:
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
else:
self.writer = None
# 设置状态文件
self.use_tensorboard = use_tensorboard
self.status_file = self.output_dir / status_file
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
self.training_status_data = []
# 初始化Rich控制台
self.console = Console()
# 训练状态
self.current_step = 0
self.current_epoch = 0
self.best_eval_loss = float("inf")
# 学习率调度函数
self.lr_scheduler = self._create_lr_scheduler()
logger.info(f"Trainer initialized with device: {self.device}")
logger.info(
f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}"
)
logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
def _create_lr_scheduler(self) -> Callable[[int], float]:
"""创建学习率调度函数(预热 + 余弦退火)"""
def lr_scheduler(step: int) -> float:
if step < self.warmup_steps:
# 线性预热
return self.learning_rate * (step / self.warmup_steps)
else:
# 余弦退火
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
decayed_lr = (
self.min_learning_rate
+ (self.learning_rate - self.min_learning_rate) * cosine_decay
)
return decayed_lr
return lr_scheduler
def _get_current_lr(self) -> float:
"""获取当前学习率"""
return self.lr_scheduler(self.current_step)
def _update_learning_rate(self):
"""更新优化器中的学习率"""
current_lr = self._get_current_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = current_lr
return current_lr
def train_step(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[float, Dict[str, Any]]:
"""
执行单个训练步骤
Args:
batch: 包含输入数据的批次
Returns:
loss: 损失值
metrics: 训练指标字典
"""
self.model.train()
# 移动数据到设备 (异步传输以提升 GPU 利用率)
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch["token_type_ids"].to(self.device, non_blocking=True)
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
history_slot_ids = batch["history_slot_ids"].to(self.device, non_blocking=True)
pinyin_ids = batch["pinyin_ids"].to(self.device, non_blocking=True)
labels = (
batch["labels"].to(self.device, non_blocking=True).squeeze(-1)
) # [batch_size]
# 混合精度训练
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
loss = loss / self.grad_accum_steps
# 反向传播
self.scaler.scale(loss).backward()
metrics = {
"loss": loss.item() * self.grad_accum_steps,
"lr": self._get_current_lr(),
}
# 计算准确率
with torch.no_grad():
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total = labels.size(0)
metrics["accuracy"] = correct / total if total > 0 else 0.0
return loss.item() * self.grad_accum_steps, metrics
def evaluate(self) -> Dict[str, float]:
"""
在评估集上评估模型
Returns:
评估指标字典
"""
if self.eval_dataloader is None:
return {}
self.model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
with torch.no_grad():
for batch in self.eval_dataloader:
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
pinyin_ids = batch["pinyin_ids"].to(self.device)
labels = batch["labels"].to(self.device).squeeze(-1)
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total_correct += correct
total_samples += labels.size(0)
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
def save_checkpoint(
self, filename: str, is_best: bool = False, is_periodic: bool = False
):
"""
保存检查点
Args:
filename: 检查点文件名
is_best: 是否是最佳模型
is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点)
"""
# 如果是定期保存,使用固定的文件名来覆盖之前的
if is_periodic:
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
else:
checkpoint_path = self.checkpoint_dir / filename
checkpoint = {
"step": self.current_step,
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"best_eval_loss": self.best_eval_loss,
"config": {
"learning_rate": self.learning_rate,
"weight_decay": self.weight_decay,
"warmup_ratio": self.warmup_ratio,
"label_smoothing": self.label_smoothing,
"total_steps": self.total_steps,
},
}
torch.save(checkpoint, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
logger.info(f"Best model saved to {best_path}")
def load_checkpoint(
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
):
"""
加载检查点
Args:
checkpoint_path: 检查点文件路径
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
if not reset_training_state:
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.current_step = checkpoint["step"]
self.current_epoch = checkpoint["epoch"]
self.best_eval_loss = checkpoint["best_eval_loss"]
logger.info(f"Checkpoint loaded from {checkpoint_path}")
logger.info(
f"Resuming from step {self.current_step}, epoch {self.current_epoch}"
)
else:
# 重置训练状态
self.current_step = 0
self.current_epoch = 0
self.best_eval_loss = float("inf")
logger.info(
f"Checkpoint loaded from {checkpoint_path} (training state reset)"
)
logger.info("Training state reset: starting from step 0, epoch 0")
def _log_to_tensorboard(self, metrics: Dict[str, float], step: int):
"""将指标记录到TensorBoard和JSON状态文件"""
if self.writer is not None:
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
# 同时记录到JSON状态文件
self._write_training_status(metrics, step)
def _load_existing_status_data(self) -> List[Dict]:
"""从文件加载已有的训练状态数据"""
try:
if self.status_file.exists():
with open(self.status_file, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
logger.info(
f"Loaded {len(data)} existing training status records from {self.status_file}"
)
return data
else:
logger.warning(
f"Status file {self.status_file} does not contain a list, starting fresh"
)
return []
else:
logger.info(
f"Status file {self.status_file} does not exist, starting fresh"
)
return []
except json.JSONDecodeError:
logger.warning(
f"Status file {self.status_file} has invalid JSON format, starting fresh"
)
return []
except Exception as e:
logger.error(
f"Failed to load existing status data from {self.status_file}: {e}"
)
return []
def _write_training_status(self, metrics: Dict[str, float], step: int):
"""将训练状态写入JSON文件"""
try:
# 创建状态记录
status_record = {
"step": step,
"epoch": self.current_epoch + 1,
"timestamp": datetime.now().isoformat(),
}
# 添加所有指标
for key, value in metrics.items():
status_record[key] = float(value)
# 检查是否已存在相同步数的记录(避免重复)
existing_indices = [
i
for i, record in enumerate(self.training_status_data)
if record.get("step") == step
]
if existing_indices:
# 替换现有记录
for idx in existing_indices:
self.training_status_data[idx] = status_record # type: ignore
else:
# 添加到内存缓存
self.training_status_data.append(status_record) # type: ignore
# 限制内存中的数据量只保留最近1000条记录
if len(self.training_status_data) > 1000:
self.training_status_data = self.training_status_data[-1000:]
# 确保数据是列表格式
if not isinstance(self.training_status_data, list):
logger.warning(
f"training_status_data is not a list (type: {type(self.training_status_data).__name__}), converting to list"
)
self.training_status_data = (
[self.training_status_data] if self.training_status_data else []
)
# 使用原子写入避免读取不完整JSON
# 先写入临时文件,然后原子重命名
temp_file = Path(f"{self.status_file}.tmp")
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.training_status_data, f, indent=2, ensure_ascii=False)
# 原子重命名Unix系统是原子操作
temp_file.rename(self.status_file)
except Exception as e:
logger.error(f"Failed to write training status: {e}")
def _create_progress_bar(self) -> Progress:
"""创建Rich进度条"""
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
def _print_training_info(self):
"""打印训练信息"""
info_table = Table(
title="Training Configuration",
show_header=True,
header_style="bold magenta",
)
info_table.add_column("Parameter", style="cyan")
info_table.add_column("Value", style="green")
info_table.add_row("Device", str(self.device))
info_table.add_row("Total Steps", str(self.total_steps))
info_table.add_row("Warmup Steps", str(self.warmup_steps))
info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}")
info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}")
info_table.add_row("Weight Decay", str(self.weight_decay))
info_table.add_row("Label Smoothing", str(self.label_smoothing))
info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps))
info_table.add_row("Mixed Precision", str(self.mixed_precision))
self.console.print(info_table)
def train(
self, resume_from: Optional[str] = None, reset_training_state: bool = False
):
"""
主训练循环
Args:
resume_from: 从哪个检查点恢复训练(可选)
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
"""
# 如果提供了检查点,则恢复训练
if resume_from is not None:
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
# 打印训练信息
self._print_training_info()
# 初始化训练状态
global_step = self.current_step
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
)
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
# 训练步骤
loss, metrics = self.train_step(batch)
# 累积指标
accumulated_loss += loss
accumulated_accuracy += metrics.get("accuracy", 0.0)
accumulation_counter += 1
# 梯度累积每grad_accum_steps步更新一次参数
if (global_step + 1) % self.grad_accum_steps == 0:
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_grad_norm
)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
progress.update(
epoch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
)
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
avg_loss = accumulated_loss / accumulation_counter
avg_accuracy = accumulated_accuracy / accumulation_counter
# 评估模型
eval_metrics = self.evaluate()
# 准备日志指标
log_metrics = {
"train/loss": avg_loss,
"train/accuracy": avg_accuracy,
"train/learning_rate": current_lr,
}
if eval_metrics:
log_metrics.update(
{
"eval/loss": eval_metrics["eval_loss"],
"eval/accuracy": eval_metrics["eval_accuracy"],
}
)
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True)
# 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step)
# 打印日志
log_text = (
f"[Epoch {epoch + 1}/{self.num_epochs}] "
f"[Step {global_step}/{self.total_steps}] "
f"Train Loss: {avg_loss:.4f} | "
f"Train Acc: {avg_accuracy:.4f} | "
f"LR: {current_lr:.2e}"
)
if eval_metrics:
log_text += (
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 定期保存检查点(覆盖之前的定期检查点)
if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
# 更新步数
global_step += 1
self.current_step = global_step
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 重置进度条
progress.reset(epoch_task)
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
# 检查是否达到总步数
if global_step >= self.total_steps:
break
# 训练完成
logger.info("Training completed!")
# 保存最终模型
self.save_checkpoint("final_model.pt")
# 关闭TensorBoard写入器
if self.writer is not None:
self.writer.close()
def load_expanded_model(
base_model_path: str,
new_model_spec: str,
device: torch.device,
**model_kwargs,
) -> nn.Module:
"""
加载预训练基础模型并创建扩容后的新模型,冻结匹配的层。
Args:
base_model_path: 预训练基础模型检查点路径
new_model_spec: 新模型规格,格式 "module:ClassName",如 "new_model:NewModel"
device: 设备
**model_kwargs: 传递给新模型构造函数的参数
Returns:
扩容后的新模型,匹配的层已冻结
"""
import importlib
import sys
# 解析新模型规格
if ":" not in new_model_spec:
raise ValueError(
f"Invalid model spec format: {new_model_spec}. Expected format: 'module:ClassName'"
)
module_name, class_name = new_model_spec.split(":", 1)
# 导入模块(支持任意路径)
module = None
try:
# 尝试直接导入
module = importlib.import_module(module_name)
except ImportError:
# 如果失败,尝试将其视为文件路径
try:
# 将模块名转换为可能的文件路径
module_path = module_name.replace(".", "/") + ".py"
import importlib.util
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot find module or loader: {module_name}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
except Exception as e:
# 尝试在当前目录下查找
import os
if os.path.exists(module_name + ".py"):
spec = importlib.util.spec_from_file_location(
module_name, module_name + ".py"
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from file: {module_name}.py")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
else:
raise ImportError(f"Failed to import module '{module_name}': {e}")
if module is None:
raise ImportError(f"Module '{module_name}' could not be imported")
# 获取模型类
model_class = getattr(module, class_name)
# 检查模型类是否是 InputMethodEngine 的子类
from .model import InputMethodEngine
if not issubclass(model_class, InputMethodEngine):
raise TypeError(
f"Model class {class_name} must be a subclass of InputMethodEngine. "
f"Got {model_class.__name__} instead."
)
# 创建新模型
new_model = model_class(**model_kwargs)
new_model.to(device)
# 加载预训练权重
checkpoint = torch.load(base_model_path, map_location=device)
if "model_state_dict" in checkpoint:
pretrained_state_dict = checkpoint["model_state_dict"]
else:
pretrained_state_dict = checkpoint
# 获取新模型的状态字典
new_state_dict = new_model.state_dict()
# 冻结匹配的层
frozen_layers = []
for key in new_state_dict.keys():
if key in pretrained_state_dict:
if new_state_dict[key].shape == pretrained_state_dict[key].shape:
new_state_dict[key] = pretrained_state_dict[key].to(device)
frozen_layers.append(key)
# 加载更新后的状态字典
new_model.load_state_dict(new_state_dict)
# 设置参数 requires_grad
for name, param in new_model.named_parameters():
if name in frozen_layers:
param.requires_grad = False
logger.info(f"Loaded expanded model with {len(frozen_layers)} frozen layers")
logger.info(
f"Frozen layers: {frozen_layers[:10]}{'...' if len(frozen_layers) > 10 else ''}"
)
return new_model
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
Args:
batch: 样本列表,每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
# Typer CLI应用
def create_dataloader(
dataset: PinyinInputDataset,
batch_size: int,
num_workers: int = 2,
pin_memory: bool = True,
shuffle: bool = False,
max_iter_length: Optional[int] = None,
) -> Any:
"""
创建数据加载器优先使用DataLoader2如果不可用则回退到DataLoader。
专门针对流式数据集优化。
Args:
dataset: PinyinInputDataset实例
batch_size: 批次大小
num_workers: worker数量对于流式数据集建议为2
pin_memory: 是否固定内存
shuffle: 是否打乱(流式数据集内部处理打乱)
max_iter_length: 最大迭代长度,用于计算总步数
Returns:
数据加载器实例
"""
logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}")
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
shuffle=shuffle,
)
return dataloader
app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False)
@app.command()
def train(
# 数据参数
train_data_path: str = typer.Option(
..., "--train-data-path", "-t", help="训练数据集路径"
),
eval_data_path: str = typer.Option(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
num_workers: int = typer.Option(
2, "--num-workers", help="数据加载worker数量流式数据集建议为2"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: bool = typer.Option(
False,
"--compile/--no-compile",
help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
),
):
"""
训练输入法模型
"""
torch.multiprocessing.set_sharing_strategy("file_system")
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 打印配置信息
console.print(
Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan")
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
# 添加配置信息
config_table.add_row("数据", "训练数据路径", train_data_path)
config_table.add_row("数据", "评估数据路径", eval_data_path)
config_table.add_row("数据", "输出目录", output_dir)
config_table.add_row("数据", "批次大小", str(batch_size))
config_table.add_row("数据", "Worker数量", str(num_workers))
config_table.add_row("模型", "词汇表大小", str(vocab_size))
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
config_table.add_row("模型", "模型维度", str(dim))
config_table.add_row("模型", "槽位数量", str(num_slots))
config_table.add_row("模型", "Transformer层数", str(n_layers))
config_table.add_row("模型", "注意力头数", str(n_heads))
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio))
config_table.add_row("训练", "标签平滑", str(label_smoothing))
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存配置
config = {
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"use_pinyin": use_pinyin,
"batch_size": batch_size,
"num_workers": num_workers,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"label_smoothing": label_smoothing,
"grad_accum_steps": grad_accum_steps,
"clip_grad_norm": clip_grad_norm,
"eval_frequency": eval_frequency,
"save_frequency": save_frequency,
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"max_iter_length": max_iter_length,
"compile": compile,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
# 注意PinyinInputDataset是IterableDataset所以不能使用shuffle参数
# 多worker配置每个worker处理数据集的一个分片由dataset.__iter__中的shard处理
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
# 评估数据集(使用相同的设置,但可以调整参数)
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=1, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
console.print("[bold cyan]正在创建模型...[/bold cyan]")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=compile,
)
console.print(
f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
# 创建训练器
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size),
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}")
@app.command()
def evaluate(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"),
batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"),
):
"""
评估训练好的模型
"""
console = Console()
console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]")
# 这里应该实现评估逻辑
# 1. 加载检查点
# 2. 创建数据加载器
# 3. 评估模型
console.print("[yellow]评估功能待实现[/yellow]")
@app.command()
def export(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
output_path: str = typer.Option(
"./exported_model.onnx", "--output", "-o", help="输出路径"
),
):
"""
导出模型为ONNX格式
"""
console = Console()
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
# 这里应该实现导出逻辑
# 1. 加载检查点
# 2. 导出为ONNX
console.print("[yellow]导出功能待实现[/yellow]")
@app.command()
def expand_and_train(
# 数据参数
train_data_path: str = typer.Option(
..., "--train-data-path", "-t", help="训练数据集路径"
),
eval_data_path: str = typer.Option(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
base_model_path: str = typer.Option(
..., "--base-model-path", help="预训练基础模型检查点路径"
),
new_model_spec: str = typer.Option(
...,
"--new-model-spec",
"-m",
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
),
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 两阶段训练参数
frozen_patience: int = typer.Option(
10,
"--frozen-patience",
help="冻结阶段验证损失连续不下降的epoch数触发切换到全量微调",
),
frozen_lr: float = typer.Option(1e-3, "--frozen-lr", help="冻结阶段学习率"),
full_lr: float = typer.Option(1e-4, "--full-lr", help="全量微调阶段学习率"),
frozen_scheduler: str = typer.Option(
"cosine", "--frozen-scheduler", help="冻结阶段学习率调度器类型cosine或plateau"
),
full_scheduler: str = typer.Option(
"cosine",
"--full-scheduler",
help="全量微调阶段学习率调度器类型cosine或plateau",
),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
num_workers: int = typer.Option(
2, "--num-workers", help="数据加载worker数量流式数据集建议为2"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: bool = typer.Option(
False,
"--compile/--no-compile",
help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
),
):
torch.multiprocessing.set_sharing_strategy("file_system")
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 打印配置信息
console.print(
Panel.fit(
"[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan"
)
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
# 添加配置信息
config_table.add_row("数据", "训练数据路径", train_data_path)
config_table.add_row("数据", "评估数据路径", eval_data_path)
config_table.add_row("数据", "输出目录", output_dir)
config_table.add_row("数据", "批次大小", str(batch_size))
config_table.add_row("数据", "Worker数量", str(num_workers))
config_table.add_row("模型", "基础模型路径", base_model_path)
config_table.add_row("模型", "新模型规格", new_model_spec)
config_table.add_row("模型", "词汇表大小", str(vocab_size))
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
config_table.add_row("模型", "模型维度", str(dim))
config_table.add_row("模型", "槽位数量", str(num_slots))
config_table.add_row("模型", "Transformer层数", str(n_layers))
config_table.add_row("模型", "注意力头数", str(n_heads))
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio))
config_table.add_row("训练", "标签平滑", str(label_smoothing))
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存配置
config = {
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"base_model_path": base_model_path,
"new_model_spec": new_model_spec,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"use_pinyin": use_pinyin,
"batch_size": batch_size,
"num_workers": num_workers,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"label_smoothing": label_smoothing,
"grad_accum_steps": grad_accum_steps,
"clip_grad_norm": clip_grad_norm,
"eval_frequency": eval_frequency,
"save_frequency": save_frequency,
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"max_iter_length": max_iter_length,
"compile": compile,
}
config_file = output_path / "expansion_training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
# 评估数据集
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=1, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
# 创建扩容模型
console.print("[bold cyan]正在创建扩容模型...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_kwargs = {
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"compile": compile,
}
model = load_expanded_model(
base_model_path=base_model_path,
new_model_spec=new_model_spec,
device=device,
**model_kwargs,
)
console.print(
f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
# 统计冻结参数比例
total_params = sum(p.numel() for p in model.parameters())
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
console.print(
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]"
)
# 创建训练器(使用普通 Trainer只进行第一阶段冻结训练
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size),
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
# 保存扩容信息供第二阶段使用
expansion_info = {
"stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"),
"model_spec": new_model_spec,
"model_kwargs": model_kwargs,
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"batch_size": batch_size,
"max_iter_length": max_iter_length,
"max_seq_len": max_seq_len,
"num_workers": num_workers,
}
expansion_info_file = output_path / "expansion_info.json"
with open(expansion_info_file, "w", encoding="utf-8") as f:
json.dump(expansion_info, f, indent=2, ensure_ascii=False)
logger.info(f"Expansion info saved to {expansion_info_file}")
console.print("[bold green]✓ 第一阶段训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}")
console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]")
console.print(
"[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]"
)
@app.command()
def expand_finetune(
expand_config: str = typer.Option(
...,
"--expand-config",
"-c",
help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'",
),
stage1_info: str = typer.Option(
..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径"
),
# 可选覆盖参数
checkpoint: Optional[str] = typer.Option(
None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)"
),
output_dir: Optional[str] = typer.Option(
None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)"
),
train_data_path: Optional[str] = typer.Option(
None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)"
),
eval_data_path: Optional[str] = typer.Option(
None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)"
),
batch_size: Optional[int] = typer.Option(
None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)"
),
num_epochs: Optional[int] = typer.Option(
None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)"
),
learning_rate: Optional[float] = typer.Option(
None, "--learning-rate", "-lr", help="学习率"
),
min_learning_rate: Optional[float] = typer.Option(
None, "--min-learning-rate", help="最小学习率"
),
weight_decay: Optional[float] = typer.Option(
None, "--weight-decay", help="权重衰减"
),
warmup_ratio: Optional[float] = typer.Option(
None, "--warmup-ratio", help="热身步数比例"
),
label_smoothing: Optional[float] = typer.Option(
None, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: Optional[int] = typer.Option(
None, "--grad-accum-steps", help="梯度累积步数"
),
clip_grad_norm: Optional[float] = typer.Option(
None, "--clip-grad-norm", help="梯度裁剪范数"
),
eval_frequency: Optional[int] = typer.Option(
None, "--eval-frequency", help="评估频率"
),
save_frequency: Optional[int] = typer.Option(
None, "--save-frequency", help="保存频率"
),
max_iter_length: Optional[int] = typer.Option(
None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)"
),
max_seq_len: Optional[int] = typer.Option(
None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)"
),
num_workers: Optional[int] = typer.Option(
None, "--num-workers", help="数据加载worker数量"
),
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: Optional[bool] = typer.Option(
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
),
):
"""
模型扩容第二阶段训练:读取第一阶段的 expansion_info.json加载扩容模型进行全量微调。
命令行参数优先级高于 JSON 文件中的配置。
"""
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 加载第一阶段信息
stage1_info_path = Path(stage1_info)
if not stage1_info_path.exists():
console.print(
f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]"
)
raise typer.Exit(1)
with open(stage1_info_path, "r", encoding="utf-8") as f:
info = json.load(f)
# 命令行参数优先级高于 JSON 文件
final_checkpoint = checkpoint or info["stage1_checkpoint_path"]
final_output_dir = output_dir or info["output_dir"]
final_train_data_path = train_data_path or info["train_data_path"]
final_eval_data_path = eval_data_path or info["eval_data_path"]
final_batch_size = batch_size if batch_size is not None else info["batch_size"]
final_num_epochs = (
num_epochs if num_epochs is not None else info.get("num_epochs", 10)
)
final_max_iter_length = (
max_iter_length if max_iter_length is not None else info["max_iter_length"]
)
final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"]
final_num_workers = (
num_workers if num_workers is not None else info.get("num_workers", 2)
)
# 训练参数(有默认值,不覆盖则使用默认)
final_learning_rate = learning_rate if learning_rate is not None else 1e-4
final_min_learning_rate = (
min_learning_rate if min_learning_rate is not None else 1e-9
)
final_weight_decay = weight_decay if weight_decay is not None else 0.1
final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1
final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15
final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1
final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0
final_eval_frequency = eval_frequency if eval_frequency is not None else 500
final_save_frequency = save_frequency if save_frequency is not None else 1000
# 模型参数从 JSON 获取
model_kwargs = info["model_kwargs"]
if compile is not None:
model_kwargs["compile"] = compile
console.print(
Panel.fit(
"[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan"
)
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path))
config_table.add_row("数据", "训练数据路径", final_train_data_path)
config_table.add_row("数据", "评估数据路径", final_eval_data_path)
config_table.add_row("数据", "输出目录", final_output_dir)
config_table.add_row("数据", "批次大小", str(final_batch_size))
config_table.add_row("数据", "Worker数量", str(final_num_workers))
config_table.add_row("模型", "新模型规格", expand_config)
config_table.add_row("模型", "检查点路径", final_checkpoint)
for k, v in model_kwargs.items():
config_table.add_row("模型", k, str(v))
config_table.add_row("训练", "训练轮数", str(final_num_epochs))
config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(final_weight_decay))
config_table.add_row("训练", "热身比例", str(final_warmup_ratio))
config_table.add_row("训练", "标签平滑", str(final_label_smoothing))
config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
output_path = Path(final_output_dir)
output_path.mkdir(parents=True, exist_ok=True)
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=final_train_data_path,
max_workers=-1,
max_iter_length=final_max_iter_length,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=final_batch_size,
num_workers=final_num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_max_iter_length,
)
eval_dataset = PinyinInputDataset(
data_path=final_eval_data_path,
max_workers=-1,
max_iter_length=final_batch_size * 64,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=final_batch_size,
num_workers=1,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_batch_size * 64,
)
console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_expanded_model(
base_model_path=final_checkpoint,
new_model_spec=expand_config,
device=device,
**model_kwargs,
)
# 全量微调:解冻所有参数
for param in model.parameters():
param.requires_grad = True
console.print(
f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]")
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size),
output_dir=final_output_dir,
num_epochs=final_num_epochs,
learning_rate=final_learning_rate,
min_learning_rate=final_min_learning_rate,
weight_decay=final_weight_decay,
warmup_ratio=final_warmup_ratio,
label_smoothing=final_label_smoothing,
grad_accum_steps=final_grad_accum_steps,
clip_grad_norm=final_clip_grad_norm,
eval_frequency=final_eval_frequency,
save_frequency=final_save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status_finetune.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {final_output_dir}")
if __name__ == "__main__":
app()