1120 lines
41 KiB
Python
1120 lines
41 KiB
Python
import json
|
||
import math
|
||
import os
|
||
import random
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import typer
|
||
from loguru import logger
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.progress import (
|
||
BarColumn,
|
||
Progress,
|
||
SpinnerColumn,
|
||
TextColumn,
|
||
TimeElapsedColumn,
|
||
TimeRemainingColumn,
|
||
)
|
||
from rich.table import Table
|
||
from torch import autocast
|
||
from torch.amp.grad_scaler import GradScaler
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
# Try to import DataLoader2 for better streaming dataset support
|
||
try:
|
||
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
|
||
|
||
DATA_LOADER2_AVAILABLE = True
|
||
except ImportError:
|
||
DATA_LOADER2_AVAILABLE = False
|
||
DataLoader2 = None
|
||
MultiProcessingReadingService = None
|
||
|
||
from .dataset import PinyinInputDataset
|
||
|
||
# 导入模型和数据
|
||
from .model import InputMethodEngine
|
||
|
||
|
||
class Trainer:
|
||
"""
|
||
输入法模型训练器
|
||
|
||
实现训练InputMethodEngine模型,支持:
|
||
- 预热+余弦退火学习率调度
|
||
- TensorBoard日志记录
|
||
- AdamW优化器(weight_decay=0.1)
|
||
- 混合精度训练
|
||
- CrossEntropyLoss损失函数(支持weight和label_smoothing)
|
||
- Rich终端美化输出
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model: InputMethodEngine,
|
||
train_dataloader: DataLoader,
|
||
eval_dataloader: DataLoader,
|
||
total_steps: int,
|
||
output_dir: str = "./output",
|
||
num_epochs: int = 10,
|
||
learning_rate: float = 1e-4,
|
||
min_learning_rate: float = 1e-6,
|
||
weight_decay: float = 0.1,
|
||
warmup_ratio: float = 0.1,
|
||
label_smoothing: float = 0.15,
|
||
loss_weight: Optional[torch.Tensor] = None,
|
||
grad_accum_steps: int = 1,
|
||
clip_grad_norm: float = 1.0,
|
||
eval_frequency: int = 500,
|
||
save_frequency: int = 10000,
|
||
mixed_precision: bool = True,
|
||
device: Optional[torch.device] = None,
|
||
status_file: str = "training_status.json",
|
||
use_tensorboard: bool = True,
|
||
):
|
||
"""
|
||
初始化训练器
|
||
|
||
Args:
|
||
model: 要训练的InputMethodEngine模型
|
||
train_dataloader: 训练数据加载器
|
||
eval_dataloader: 评估数据加载器(可选)
|
||
output_dir: 输出目录,用于保存模型和日志
|
||
num_epochs: 训练轮数
|
||
total_steps: 总训练步数,如果为None则根据epochs计算
|
||
learning_rate: 最大学习率(预热后)
|
||
min_learning_rate: 最小学习率(余弦退火后的最低值)
|
||
weight_decay: AdamW优化器的权重衰减
|
||
warmup_ratio: 热身步数占总步数的比例
|
||
label_smoothing: CrossEntropyLoss的标签平滑参数
|
||
loss_weight: CrossEntropyLoss的类别权重
|
||
grad_accum_steps: 梯度累积步数
|
||
clip_grad_norm: 梯度裁剪的最大范数
|
||
eval_frequency: 评估频率(步数)
|
||
save_frequency: 保存检查点频率(步数)
|
||
mixed_precision: 是否使用混合精度训练
|
||
device: 训练设备,如果为None则自动选择
|
||
use_tensorboard: 是否使用TensorBoard记录
|
||
"""
|
||
self.model = model
|
||
self.train_dataloader = train_dataloader
|
||
self.eval_dataloader = list([i for i in eval_dataloader])
|
||
self.output_dir = Path(output_dir)
|
||
self.num_epochs = num_epochs
|
||
self.learning_rate = learning_rate
|
||
self.min_learning_rate = min_learning_rate
|
||
self.weight_decay = weight_decay
|
||
self.warmup_ratio = warmup_ratio
|
||
self.label_smoothing = label_smoothing
|
||
self.loss_weight = loss_weight
|
||
self.grad_accum_steps = grad_accum_steps
|
||
self.clip_grad_norm = clip_grad_norm
|
||
self.eval_frequency = eval_frequency
|
||
self.save_frequency = save_frequency
|
||
self.mixed_precision = mixed_precision
|
||
self.use_tensorboard = use_tensorboard
|
||
|
||
# 设置设备
|
||
logger.info(f"GPU可用: {torch.cuda.is_available()}")
|
||
if device is None:
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
else:
|
||
self.device = device
|
||
|
||
# 移动模型到设备
|
||
self.model.to(self.device)
|
||
|
||
# 创建输出目录
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.checkpoint_dir = self.output_dir / "checkpoints"
|
||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||
|
||
# 计算总步数
|
||
self.total_steps = total_steps
|
||
|
||
self.warmup_steps = int(self.total_steps * warmup_ratio)
|
||
|
||
# 初始化优化器
|
||
self.optimizer = optim.AdamW(
|
||
model.parameters(),
|
||
lr=learning_rate,
|
||
weight_decay=weight_decay,
|
||
betas=(0.9, 0.999),
|
||
eps=1e-8,
|
||
)
|
||
|
||
# 初始化损失函数
|
||
if loss_weight is not None:
|
||
self.criterion = nn.CrossEntropyLoss(
|
||
weight=loss_weight.to(self.device), label_smoothing=label_smoothing
|
||
)
|
||
else:
|
||
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||
|
||
# 初始化混合精度训练器
|
||
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
|
||
self.scaler = GradScaler(device_type, enabled=mixed_precision)
|
||
|
||
# 初始化TensorBoard
|
||
if use_tensorboard:
|
||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||
else:
|
||
self.writer = None
|
||
|
||
# 设置状态文件
|
||
self.use_tensorboard = use_tensorboard
|
||
self.status_file = self.output_dir / status_file
|
||
# 如果状态文件已存在,则加载已有数据
|
||
self.training_status_data = self._load_existing_status_data()
|
||
|
||
# 初始化Rich控制台
|
||
self.console = Console()
|
||
|
||
# 训练状态
|
||
self.current_step = 0
|
||
self.current_epoch = 0
|
||
self.best_eval_loss = float("inf")
|
||
|
||
# 学习率调度函数
|
||
self.lr_scheduler = self._create_lr_scheduler()
|
||
|
||
logger.info(f"Trainer initialized with device: {self.device}")
|
||
logger.info(
|
||
f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}"
|
||
)
|
||
logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
|
||
|
||
def _create_lr_scheduler(self) -> Callable[[int], float]:
|
||
"""创建学习率调度函数(预热 + 余弦退火)"""
|
||
|
||
def lr_scheduler(step: int) -> float:
|
||
if step < self.warmup_steps:
|
||
# 线性预热
|
||
return self.learning_rate * (step / self.warmup_steps)
|
||
else:
|
||
# 余弦退火
|
||
progress = (step - self.warmup_steps) / (
|
||
self.total_steps - self.warmup_steps
|
||
)
|
||
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||
decayed_lr = (
|
||
self.min_learning_rate
|
||
+ (self.learning_rate - self.min_learning_rate) * cosine_decay
|
||
)
|
||
return decayed_lr
|
||
|
||
return lr_scheduler
|
||
|
||
def _get_current_lr(self) -> float:
|
||
"""获取当前学习率"""
|
||
return self.lr_scheduler(self.current_step)
|
||
|
||
def _update_learning_rate(self):
|
||
"""更新优化器中的学习率"""
|
||
current_lr = self._get_current_lr()
|
||
for param_group in self.optimizer.param_groups:
|
||
param_group["lr"] = current_lr
|
||
return current_lr
|
||
|
||
def train_step(
|
||
self, batch: Dict[str, torch.Tensor]
|
||
) -> Tuple[float, Dict[str, Any]]:
|
||
"""
|
||
执行单个训练步骤
|
||
|
||
Args:
|
||
batch: 包含输入数据的批次
|
||
|
||
Returns:
|
||
loss: 损失值
|
||
metrics: 训练指标字典
|
||
"""
|
||
self.model.train()
|
||
|
||
# 移动数据到设备
|
||
input_ids = batch["input_ids"].to(self.device)
|
||
token_type_ids = batch["token_type_ids"].to(self.device)
|
||
attention_mask = batch["attention_mask"].to(self.device)
|
||
history_slot_ids = batch["history_slot_ids"].to(self.device)
|
||
pinyin_ids = batch["pinyin_ids"].to(self.device)
|
||
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
|
||
|
||
# 混合精度训练
|
||
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
|
||
# 前向传播
|
||
logits = self.model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
# 计算损失
|
||
loss = self.criterion(logits, labels)
|
||
loss = loss / self.grad_accum_steps
|
||
|
||
# 反向传播
|
||
self.scaler.scale(loss).backward()
|
||
|
||
metrics = {
|
||
"loss": loss.item() * self.grad_accum_steps,
|
||
"lr": self._get_current_lr(),
|
||
}
|
||
|
||
# 计算准确率
|
||
with torch.no_grad():
|
||
preds = torch.argmax(logits, dim=-1)
|
||
correct = (preds == labels).sum().item()
|
||
total = labels.size(0)
|
||
metrics["accuracy"] = correct / total if total > 0 else 0.0
|
||
|
||
return loss.item() * self.grad_accum_steps, metrics
|
||
|
||
def evaluate(self) -> Dict[str, float]:
|
||
"""
|
||
在评估集上评估模型
|
||
|
||
Returns:
|
||
评估指标字典
|
||
"""
|
||
if self.eval_dataloader is None:
|
||
return {}
|
||
|
||
self.model.eval()
|
||
total_loss = 0.0
|
||
total_correct = 0
|
||
total_samples = 0
|
||
|
||
with torch.no_grad():
|
||
for batch in self.eval_dataloader:
|
||
# 移动数据到设备
|
||
input_ids = batch["input_ids"].to(self.device)
|
||
token_type_ids = batch["token_type_ids"].to(self.device)
|
||
attention_mask = batch["attention_mask"].to(self.device)
|
||
history_slot_ids = batch["history_slot_ids"].to(self.device)
|
||
pinyin_ids = batch["pinyin_ids"].to(self.device)
|
||
labels = batch["labels"].to(self.device).squeeze(-1)
|
||
|
||
# 前向传播
|
||
logits = self.model(
|
||
input_ids=input_ids,
|
||
token_type_ids=token_type_ids,
|
||
attention_mask=attention_mask,
|
||
pinyin_ids=pinyin_ids,
|
||
history_slot_ids=history_slot_ids,
|
||
)
|
||
|
||
# 计算损失
|
||
loss = self.criterion(logits, labels)
|
||
total_loss += loss.item() * labels.size(0)
|
||
|
||
# 计算准确率
|
||
preds = torch.argmax(logits, dim=-1)
|
||
correct = (preds == labels).sum().item()
|
||
total_correct += correct
|
||
total_samples += labels.size(0)
|
||
|
||
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
||
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
||
|
||
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
|
||
|
||
def save_checkpoint(
|
||
self, filename: str, is_best: bool = False, is_periodic: bool = False
|
||
):
|
||
"""
|
||
保存检查点
|
||
|
||
Args:
|
||
filename: 检查点文件名
|
||
is_best: 是否是最佳模型
|
||
is_periodic: 是否是定期保存的检查点(会覆盖之前的定期检查点)
|
||
"""
|
||
# 如果是定期保存,使用固定的文件名来覆盖之前的
|
||
if is_periodic:
|
||
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
||
else:
|
||
checkpoint_path = self.checkpoint_dir / filename
|
||
|
||
checkpoint = {
|
||
"step": self.current_step,
|
||
"epoch": self.current_epoch,
|
||
"model_state_dict": self.model.state_dict(),
|
||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||
"scaler_state_dict": self.scaler.state_dict(),
|
||
"best_eval_loss": self.best_eval_loss,
|
||
"config": {
|
||
"learning_rate": self.learning_rate,
|
||
"weight_decay": self.weight_decay,
|
||
"warmup_ratio": self.warmup_ratio,
|
||
"label_smoothing": self.label_smoothing,
|
||
"total_steps": self.total_steps,
|
||
},
|
||
}
|
||
|
||
torch.save(checkpoint, checkpoint_path)
|
||
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||
|
||
if is_best:
|
||
best_path = self.checkpoint_dir / "best_model.pt"
|
||
torch.save(checkpoint, best_path)
|
||
logger.info(f"Best model saved to {best_path}")
|
||
|
||
def load_checkpoint(
|
||
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
|
||
):
|
||
"""
|
||
加载检查点
|
||
|
||
Args:
|
||
checkpoint_path: 检查点文件路径
|
||
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
|
||
"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
|
||
if not reset_training_state:
|
||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||
|
||
self.current_step = checkpoint["step"]
|
||
self.current_epoch = checkpoint["epoch"]
|
||
self.best_eval_loss = checkpoint["best_eval_loss"]
|
||
|
||
logger.info(f"Checkpoint loaded from {checkpoint_path}")
|
||
logger.info(
|
||
f"Resuming from step {self.current_step}, epoch {self.current_epoch}"
|
||
)
|
||
else:
|
||
# 重置训练状态
|
||
self.current_step = 0
|
||
self.current_epoch = 0
|
||
self.best_eval_loss = float("inf")
|
||
|
||
logger.info(
|
||
f"Checkpoint loaded from {checkpoint_path} (training state reset)"
|
||
)
|
||
logger.info("Training state reset: starting from step 0, epoch 0")
|
||
|
||
def _log_to_tensorboard(self, metrics: Dict[str, float], step: int):
|
||
"""将指标记录到TensorBoard和JSON状态文件"""
|
||
if self.writer is not None:
|
||
for key, value in metrics.items():
|
||
self.writer.add_scalar(key, value, step)
|
||
|
||
# 同时记录到JSON状态文件
|
||
self._write_training_status(metrics, step)
|
||
|
||
def _load_existing_status_data(self) -> List[Dict]:
|
||
"""从文件加载已有的训练状态数据"""
|
||
try:
|
||
if self.status_file.exists():
|
||
with open(self.status_file, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
if isinstance(data, list):
|
||
logger.info(
|
||
f"Loaded {len(data)} existing training status records from {self.status_file}"
|
||
)
|
||
return data
|
||
else:
|
||
logger.warning(
|
||
f"Status file {self.status_file} does not contain a list, starting fresh"
|
||
)
|
||
return []
|
||
else:
|
||
logger.info(
|
||
f"Status file {self.status_file} does not exist, starting fresh"
|
||
)
|
||
return []
|
||
except json.JSONDecodeError:
|
||
logger.warning(
|
||
f"Status file {self.status_file} has invalid JSON format, starting fresh"
|
||
)
|
||
return []
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Failed to load existing status data from {self.status_file}: {e}"
|
||
)
|
||
return []
|
||
|
||
def _write_training_status(self, metrics: Dict[str, float], step: int):
|
||
"""将训练状态写入JSON文件"""
|
||
try:
|
||
# 创建状态记录
|
||
status_record = {
|
||
"step": step,
|
||
"epoch": self.current_epoch + 1,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
# 添加所有指标
|
||
for key, value in metrics.items():
|
||
status_record[key] = float(value)
|
||
|
||
# 检查是否已存在相同步数的记录(避免重复)
|
||
existing_indices = [
|
||
i
|
||
for i, record in enumerate(self.training_status_data)
|
||
if record.get("step") == step
|
||
]
|
||
if existing_indices:
|
||
# 替换现有记录
|
||
for idx in existing_indices:
|
||
self.training_status_data[idx] = status_record
|
||
else:
|
||
# 添加到内存缓存
|
||
self.training_status_data.append(status_record)
|
||
|
||
# 限制内存中的数据量,只保留最近1000条记录
|
||
if len(self.training_status_data) > 1000:
|
||
self.training_status_data = self.training_status_data[-1000:]
|
||
|
||
# 确保数据是列表格式
|
||
if not isinstance(self.training_status_data, list):
|
||
logger.warning(
|
||
f"training_status_data is not a list (type: {type(self.training_status_data).__name__}), converting to list"
|
||
)
|
||
self.training_status_data = (
|
||
[self.training_status_data] if self.training_status_data else []
|
||
)
|
||
|
||
# 写入文件
|
||
with open(self.status_file, "w", encoding="utf-8") as f:
|
||
json.dump(self.training_status_data, f, indent=2, ensure_ascii=False)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to write training status: {e}")
|
||
|
||
def _create_progress_bar(self) -> Progress:
|
||
"""创建Rich进度条"""
|
||
return Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||
TimeElapsedColumn(),
|
||
TimeRemainingColumn(),
|
||
console=self.console,
|
||
expand=True,
|
||
)
|
||
|
||
def _print_training_info(self):
|
||
"""打印训练信息"""
|
||
info_table = Table(
|
||
title="Training Configuration",
|
||
show_header=True,
|
||
header_style="bold magenta",
|
||
)
|
||
info_table.add_column("Parameter", style="cyan")
|
||
info_table.add_column("Value", style="green")
|
||
|
||
info_table.add_row("Device", str(self.device))
|
||
info_table.add_row("Total Steps", str(self.total_steps))
|
||
info_table.add_row("Warmup Steps", str(self.warmup_steps))
|
||
info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}")
|
||
info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}")
|
||
info_table.add_row("Weight Decay", str(self.weight_decay))
|
||
info_table.add_row("Label Smoothing", str(self.label_smoothing))
|
||
info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps))
|
||
info_table.add_row("Mixed Precision", str(self.mixed_precision))
|
||
|
||
self.console.print(info_table)
|
||
|
||
def train(
|
||
self, resume_from: Optional[str] = None, reset_training_state: bool = False
|
||
):
|
||
"""
|
||
主训练循环
|
||
|
||
Args:
|
||
resume_from: 从哪个检查点恢复训练(可选)
|
||
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
|
||
"""
|
||
# 如果提供了检查点,则恢复训练
|
||
if resume_from is not None:
|
||
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
|
||
|
||
# 打印训练信息
|
||
self._print_training_info()
|
||
|
||
# 初始化训练状态
|
||
global_step = self.current_step
|
||
accumulated_loss = 0.0
|
||
accumulated_accuracy = 0.0
|
||
accumulation_counter = 0
|
||
|
||
# 创建进度条
|
||
with self._create_progress_bar() as progress:
|
||
epoch_task = progress.add_task(
|
||
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
|
||
total=self.total_steps,
|
||
)
|
||
|
||
# 训练循环
|
||
for epoch in range(self.current_epoch, self.num_epochs):
|
||
self.current_epoch = epoch
|
||
progress.update(
|
||
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
|
||
)
|
||
|
||
for batch_idx, batch in enumerate(self.train_dataloader):
|
||
# 更新学习率
|
||
current_lr = self._update_learning_rate()
|
||
|
||
# 训练步骤
|
||
loss, metrics = self.train_step(batch)
|
||
|
||
# 累积指标
|
||
accumulated_loss += loss
|
||
accumulated_accuracy += metrics.get("accuracy", 0.0)
|
||
accumulation_counter += 1
|
||
|
||
# 梯度累积:每grad_accum_steps步更新一次参数
|
||
if (global_step + 1) % self.grad_accum_steps == 0:
|
||
# 梯度裁剪
|
||
self.scaler.unscale_(self.optimizer)
|
||
torch.nn.utils.clip_grad_norm_(
|
||
self.model.parameters(), self.clip_grad_norm
|
||
)
|
||
|
||
# 更新参数
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
self.optimizer.zero_grad()
|
||
|
||
# 更新进度条
|
||
progress.update(
|
||
epoch_task,
|
||
advance=1,
|
||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
|
||
f"Step {global_step}/{self.total_steps} | "
|
||
f"Loss: {loss:.4f} | "
|
||
f"LR: {current_lr:.2e}",
|
||
)
|
||
|
||
# 定期评估和记录
|
||
if (global_step + 1) % self.eval_frequency == 0:
|
||
# 计算平均指标
|
||
avg_loss = accumulated_loss / accumulation_counter
|
||
avg_accuracy = accumulated_accuracy / accumulation_counter
|
||
|
||
# 评估模型
|
||
eval_metrics = self.evaluate()
|
||
|
||
# 准备日志指标
|
||
log_metrics = {
|
||
"train/loss": avg_loss,
|
||
"train/accuracy": avg_accuracy,
|
||
"train/learning_rate": current_lr,
|
||
}
|
||
|
||
if eval_metrics:
|
||
log_metrics.update(
|
||
{
|
||
"eval/loss": eval_metrics["eval_loss"],
|
||
"eval/accuracy": eval_metrics["eval_accuracy"],
|
||
}
|
||
)
|
||
|
||
# 更新最佳模型
|
||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||
self.best_eval_loss = eval_metrics["eval_loss"]
|
||
# 只保存best_model,不创建额外的checkpoint文件
|
||
self.save_checkpoint("best_model.pt", is_best=True)
|
||
|
||
# 记录到TensorBoard
|
||
self._log_to_tensorboard(log_metrics, global_step)
|
||
|
||
# 打印日志
|
||
log_text = (
|
||
f"[Epoch {epoch + 1}/{self.num_epochs}] "
|
||
f"[Step {global_step}/{self.total_steps}] "
|
||
f"Train Loss: {avg_loss:.4f} | "
|
||
f"Train Acc: {avg_accuracy:.4f} | "
|
||
f"LR: {current_lr:.2e}"
|
||
)
|
||
|
||
if eval_metrics:
|
||
log_text += (
|
||
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
|
||
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||
)
|
||
|
||
progress.console.log(log_text)
|
||
|
||
# 重置累积指标
|
||
accumulated_loss = 0.0
|
||
accumulated_accuracy = 0.0
|
||
accumulation_counter = 0
|
||
|
||
# 定期保存检查点(覆盖之前的定期检查点)
|
||
if (global_step + 1) % self.save_frequency == 0:
|
||
self.save_checkpoint("latest_checkpoint.pt", is_periodic=True)
|
||
|
||
# 更新步数
|
||
global_step += 1
|
||
self.current_step = global_step
|
||
|
||
# 检查是否达到总步数
|
||
if global_step >= self.total_steps:
|
||
progress.update(epoch_task, completed=self.total_steps)
|
||
break
|
||
|
||
# 重置进度条
|
||
progress.reset(epoch_task)
|
||
|
||
# 每个epoch结束后保存检查点
|
||
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||
|
||
# 检查是否达到总步数
|
||
if global_step >= self.total_steps:
|
||
break
|
||
|
||
# 训练完成
|
||
logger.info("Training completed!")
|
||
|
||
# 保存最终模型
|
||
self.save_checkpoint("final_model.pt")
|
||
|
||
# 关闭TensorBoard写入器
|
||
if self.writer is not None:
|
||
self.writer.close()
|
||
|
||
|
||
def worker_init_fn(worker_id: int) -> None:
|
||
"""
|
||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||
|
||
Args:
|
||
worker_id: worker的ID
|
||
"""
|
||
worker_seed = torch.initial_seed() % (2**32)
|
||
np.random.seed(worker_seed)
|
||
random.seed(worker_seed)
|
||
|
||
|
||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""
|
||
自定义批处理函数,将多个样本组合成一个batch
|
||
|
||
Args:
|
||
batch: 样本列表,每个样本是一个字典
|
||
|
||
Returns:
|
||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||
"""
|
||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||
|
||
# 字符串字段保持为列表
|
||
prefixes = [item["prefix"] for item in batch]
|
||
suffixes = [item["suffix"] for item in batch]
|
||
pinyins = [item["pinyin"] for item in batch]
|
||
|
||
return {
|
||
"input_ids": input_ids,
|
||
"token_type_ids": token_type_ids,
|
||
"attention_mask": attention_mask,
|
||
"labels": labels,
|
||
"history_slot_ids": history_slot_ids,
|
||
"prefix": prefixes,
|
||
"suffix": suffixes,
|
||
"pinyin": pinyins,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
|
||
# Typer CLI应用
|
||
def create_dataloader(
|
||
dataset: PinyinInputDataset,
|
||
batch_size: int,
|
||
num_workers: int = 2,
|
||
pin_memory: bool = True,
|
||
shuffle: bool = False,
|
||
max_iter_length: Optional[int] = None,
|
||
) -> Any:
|
||
"""
|
||
创建数据加载器,优先使用DataLoader2,如果不可用则回退到DataLoader。
|
||
专门针对流式数据集优化。
|
||
|
||
Args:
|
||
dataset: PinyinInputDataset实例
|
||
batch_size: 批次大小
|
||
num_workers: worker数量(对于流式数据集建议为2)
|
||
pin_memory: 是否固定内存
|
||
shuffle: 是否打乱(流式数据集内部处理打乱)
|
||
max_iter_length: 最大迭代长度,用于计算总步数
|
||
|
||
Returns:
|
||
数据加载器实例
|
||
"""
|
||
if (
|
||
DATA_LOADER2_AVAILABLE
|
||
and DataLoader2 is not None
|
||
and MultiProcessingReadingService is not None
|
||
):
|
||
try:
|
||
# DataLoader2配置,针对流式数据集优化
|
||
reading_service = MultiProcessingReadingService(
|
||
num_workers=num_workers,
|
||
prefetch_factor=2, # 减少预取以避免内存问题
|
||
persistent_workers=True,
|
||
pin_memory=pin_memory,
|
||
worker_init_fn=worker_init_fn,
|
||
)
|
||
|
||
dataloader = DataLoader2(
|
||
dataset,
|
||
reading_service=reading_service,
|
||
batch_size=batch_size,
|
||
collate_fn=collate_fn,
|
||
shuffle=shuffle,
|
||
)
|
||
logger.info(f"✅ 使用DataLoader2创建数据加载器,worker数量: {num_workers}")
|
||
return dataloader
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ DataLoader2创建失败: {e},回退到标准DataLoader")
|
||
|
||
# 回退到标准DataLoader
|
||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
num_workers=num_workers,
|
||
pin_memory=pin_memory,
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=2, # 减少预取以避免内存问题
|
||
persistent_workers=True,
|
||
shuffle=shuffle,
|
||
)
|
||
return dataloader
|
||
|
||
|
||
app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False)
|
||
|
||
|
||
@app.command()
|
||
def train(
|
||
# 数据参数
|
||
train_data_path: str = typer.Option(
|
||
..., "--train-data-path", "-t", help="训练数据集路径"
|
||
),
|
||
eval_data_path: str = typer.Option(
|
||
..., "--eval-data-path", "-e", help="评估数据集路径"
|
||
),
|
||
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
|
||
# 模型参数
|
||
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
|
||
pinyin_vocab_size: int = typer.Option(
|
||
30, "--pinyin-vocab-size", help="拼音词汇表大小"
|
||
),
|
||
max_iter_length: int = typer.Option(
|
||
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
||
),
|
||
dim: int = typer.Option(512, "--dim", help="模型维度"),
|
||
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
|
||
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
|
||
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
|
||
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
|
||
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
|
||
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
|
||
# 训练参数
|
||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
||
min_learning_rate: float = typer.Option(
|
||
1e-9, "--min-learning-rate", help="最小学习率"
|
||
),
|
||
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||
label_smoothing: float = typer.Option(
|
||
0.15, "--label-smoothing", help="标签平滑参数"
|
||
),
|
||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
|
||
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
|
||
# 其他参数
|
||
mixed_precision: bool = typer.Option(
|
||
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
||
),
|
||
num_workers: int = typer.Option(
|
||
2, "--num-workers", help="数据加载worker数量(流式数据集建议为2)"
|
||
),
|
||
use_tensorboard: bool = typer.Option(
|
||
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
|
||
),
|
||
resume_from: Optional[str] = typer.Option(
|
||
None, "--resume-from", help="从检查点恢复训练"
|
||
),
|
||
reset_training_state: bool = typer.Option(
|
||
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
||
),
|
||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||
compile: bool = typer.Option(
|
||
False,
|
||
"--compile/--no-compile",
|
||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
||
),
|
||
):
|
||
"""
|
||
训练输入法模型
|
||
"""
|
||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||
|
||
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
|
||
if torch.cuda.is_available():
|
||
torch.set_float32_matmul_precision("high")
|
||
|
||
# 设置随机种子
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
console = Console()
|
||
|
||
# 打印配置信息
|
||
console.print(
|
||
Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan")
|
||
)
|
||
|
||
config_table = Table(show_header=True, header_style="bold magenta")
|
||
config_table.add_column("Category", style="cyan")
|
||
config_table.add_column("Parameter", style="green")
|
||
config_table.add_column("Value", style="yellow")
|
||
|
||
# 添加配置信息
|
||
config_table.add_row("数据", "训练数据路径", train_data_path)
|
||
config_table.add_row("数据", "评估数据路径", eval_data_path)
|
||
config_table.add_row("数据", "输出目录", output_dir)
|
||
config_table.add_row("数据", "批次大小", str(batch_size))
|
||
config_table.add_row("数据", "Worker数量", str(num_workers))
|
||
|
||
config_table.add_row("模型", "词汇表大小", str(vocab_size))
|
||
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
|
||
config_table.add_row("模型", "模型维度", str(dim))
|
||
config_table.add_row("模型", "槽位数量", str(num_slots))
|
||
config_table.add_row("模型", "Transformer层数", str(n_layers))
|
||
config_table.add_row("模型", "注意力头数", str(n_heads))
|
||
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||
config_table.add_row("模型", "编译优化", str(compile))
|
||
|
||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
|
||
config_table.add_row("训练", "权重衰减", str(weight_decay))
|
||
config_table.add_row("训练", "热身比例", str(warmup_ratio))
|
||
config_table.add_row("训练", "标签平滑", str(label_smoothing))
|
||
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
|
||
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||
|
||
console.print(config_table)
|
||
|
||
# 创建输出目录
|
||
output_path = Path(output_dir)
|
||
output_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 保存配置
|
||
config = {
|
||
"train_data_path": train_data_path,
|
||
"eval_data_path": eval_data_path,
|
||
"output_dir": output_dir,
|
||
"vocab_size": vocab_size,
|
||
"pinyin_vocab_size": pinyin_vocab_size,
|
||
"dim": dim,
|
||
"num_slots": num_slots,
|
||
"n_layers": n_layers,
|
||
"n_heads": n_heads,
|
||
"num_experts": num_experts,
|
||
"max_seq_len": max_seq_len,
|
||
"use_pinyin": use_pinyin,
|
||
"batch_size": batch_size,
|
||
"num_workers": num_workers,
|
||
"num_epochs": num_epochs,
|
||
"learning_rate": learning_rate,
|
||
"min_learning_rate": min_learning_rate,
|
||
"weight_decay": weight_decay,
|
||
"warmup_ratio": warmup_ratio,
|
||
"label_smoothing": label_smoothing,
|
||
"grad_accum_steps": grad_accum_steps,
|
||
"clip_grad_norm": clip_grad_norm,
|
||
"eval_frequency": eval_frequency,
|
||
"save_frequency": save_frequency,
|
||
"mixed_precision": mixed_precision,
|
||
"use_tensorboard": use_tensorboard,
|
||
"seed": seed,
|
||
"max_iter_length": max_iter_length,
|
||
"compile": compile,
|
||
}
|
||
|
||
config_file = output_path / "training_config.json"
|
||
with open(config_file, "w", encoding="utf-8") as f:
|
||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info(f"Configuration saved to {config_file}")
|
||
|
||
# 创建数据加载器
|
||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||
|
||
# 训练数据集
|
||
train_dataset = PinyinInputDataset(
|
||
data_path=train_data_path,
|
||
max_workers=-1, # 自动选择worker数量
|
||
max_iter_length=max_iter_length,
|
||
max_seq_length=max_seq_len,
|
||
text_field="text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size=5000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
# 训练数据加载器
|
||
# 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数
|
||
# 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理
|
||
train_dataloader = create_dataloader(
|
||
dataset=train_dataset,
|
||
batch_size=batch_size,
|
||
num_workers=num_workers,
|
||
pin_memory=torch.cuda.is_available(),
|
||
max_iter_length=max_iter_length,
|
||
)
|
||
|
||
# 评估数据集(使用相同的设置,但可以调整参数)
|
||
eval_dataset = PinyinInputDataset(
|
||
data_path=eval_data_path,
|
||
max_workers=-1,
|
||
max_iter_length=batch_size * 64, # 评估集较小
|
||
max_seq_length=max_seq_len,
|
||
text_field="text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size=1000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
eval_dataloader = create_dataloader(
|
||
dataset=eval_dataset,
|
||
batch_size=batch_size,
|
||
num_workers=1, # 评估使用较少的worker
|
||
pin_memory=torch.cuda.is_available(),
|
||
max_iter_length=batch_size * 64,
|
||
)
|
||
|
||
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
||
model = InputMethodEngine(
|
||
vocab_size=vocab_size,
|
||
pinyin_vocab_size=pinyin_vocab_size,
|
||
dim=dim,
|
||
num_slots=num_slots,
|
||
n_layers=n_layers,
|
||
n_heads=n_heads,
|
||
num_experts=num_experts,
|
||
max_seq_len=max_seq_len,
|
||
compile=compile,
|
||
)
|
||
|
||
console.print(
|
||
f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
||
)
|
||
|
||
# 创建训练器
|
||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||
trainer = Trainer(
|
||
model=model,
|
||
train_dataloader=train_dataloader,
|
||
eval_dataloader=eval_dataloader,
|
||
total_steps=int(max_iter_length / batch_size),
|
||
output_dir=output_dir,
|
||
num_epochs=num_epochs,
|
||
learning_rate=learning_rate,
|
||
min_learning_rate=min_learning_rate,
|
||
weight_decay=weight_decay,
|
||
warmup_ratio=warmup_ratio,
|
||
label_smoothing=label_smoothing,
|
||
grad_accum_steps=grad_accum_steps,
|
||
clip_grad_norm=clip_grad_norm,
|
||
eval_frequency=eval_frequency,
|
||
save_frequency=save_frequency,
|
||
mixed_precision=mixed_precision,
|
||
use_tensorboard=use_tensorboard,
|
||
status_file="training_status.json",
|
||
)
|
||
|
||
console.print("[green]✓ 训练器创建完成[/green]")
|
||
|
||
# 开始训练
|
||
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
try:
|
||
trainer.train(
|
||
resume_from=resume_from, reset_training_state=reset_training_state
|
||
)
|
||
except KeyboardInterrupt:
|
||
console.print("[bold green]训练被终止[/bold green]")
|
||
trainer.save_checkpoint("interrupted_model.pt")
|
||
|
||
console.print("[bold green]✓ 训练完成![/bold green]")
|
||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
console.print(f"模型和日志保存在: {output_dir}")
|
||
|
||
|
||
@app.command()
|
||
def evaluate(
|
||
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||
data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"),
|
||
batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"),
|
||
):
|
||
"""
|
||
评估训练好的模型
|
||
"""
|
||
console = Console()
|
||
console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]")
|
||
|
||
# 这里应该实现评估逻辑
|
||
# 1. 加载检查点
|
||
# 2. 创建数据加载器
|
||
# 3. 评估模型
|
||
|
||
console.print("[yellow]评估功能待实现[/yellow]")
|
||
|
||
|
||
@app.command()
|
||
def export(
|
||
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||
output_path: str = typer.Option(
|
||
"./exported_model.onnx", "--output", "-o", help="输出路径"
|
||
),
|
||
):
|
||
"""
|
||
导出模型为ONNX格式
|
||
"""
|
||
console = Console()
|
||
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
|
||
|
||
# 这里应该实现导出逻辑
|
||
# 1. 加载检查点
|
||
# 2. 导出为ONNX
|
||
|
||
console.print("[yellow]导出功能待实现[/yellow]")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app()
|