SUimeModelTraner/src/model/trainer.py

912 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import math
import os
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import typer
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from torch import autocast
from torch.amp.grad_scaler import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .dataset import PinyinInputDataset
# 导入模型和数据
from .model import InputMethodEngine
class Trainer:
"""
输入法模型训练器
实现训练InputMethodEngine模型支持
- 预热+余弦退火学习率调度
- TensorBoard日志记录
- AdamW优化器weight_decay=0.1
- 混合精度训练
- CrossEntropyLoss损失函数支持weight和label_smoothing
- Rich终端美化输出
"""
def __init__(
self,
model: InputMethodEngine,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
total_steps: int,
output_dir: str = "./output",
num_epochs: int = 10,
learning_rate: float = 1e-4,
min_learning_rate: float = 1e-6,
weight_decay: float = 0.1,
warmup_ratio: float = 0.1,
label_smoothing: float = 0.15,
loss_weight: Optional[torch.Tensor] = None,
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
eval_frequency: int = 500,
save_frequency: int = 10000,
mixed_precision: bool = True,
device: Optional[torch.device] = None,
use_tensorboard: bool = True,
):
"""
初始化训练器
Args:
model: 要训练的InputMethodEngine模型
train_dataloader: 训练数据加载器
eval_dataloader: 评估数据加载器(可选)
output_dir: 输出目录,用于保存模型和日志
num_epochs: 训练轮数
total_steps: 总训练步数如果为None则根据epochs计算
learning_rate: 最大学习率(预热后)
min_learning_rate: 最小学习率(余弦退火后的最低值)
weight_decay: AdamW优化器的权重衰减
warmup_ratio: 热身步数占总步数的比例
label_smoothing: CrossEntropyLoss的标签平滑参数
loss_weight: CrossEntropyLoss的类别权重
grad_accum_steps: 梯度累积步数
clip_grad_norm: 梯度裁剪的最大范数
eval_frequency: 评估频率(步数)
save_frequency: 保存检查点频率(步数)
mixed_precision: 是否使用混合精度训练
device: 训练设备如果为None则自动选择
use_tensorboard: 是否使用TensorBoard记录
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = list([i for i in eval_dataloader])
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.min_learning_rate = min_learning_rate
self.weight_decay = weight_decay
self.warmup_ratio = warmup_ratio
self.label_smoothing = label_smoothing
self.loss_weight = loss_weight
self.grad_accum_steps = grad_accum_steps
self.clip_grad_norm = clip_grad_norm
self.eval_frequency = eval_frequency
self.save_frequency = save_frequency
self.mixed_precision = mixed_precision
self.use_tensorboard = use_tensorboard
# 设置设备
logger.info(f"GPU可用 {torch.cuda.is_available()}")
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
# 移动模型到设备
self.model.to(self.device)
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(exist_ok=True)
# 计算总步数
self.total_steps = total_steps
self.warmup_steps = int(self.total_steps * warmup_ratio)
# 初始化优化器
self.optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999),
eps=1e-8,
)
# 初始化损失函数
if loss_weight is not None:
self.criterion = nn.CrossEntropyLoss(
weight=loss_weight.to(self.device), label_smoothing=label_smoothing
)
else:
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 初始化混合精度训练器
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
self.scaler = GradScaler(device_type, enabled=mixed_precision)
# 初始化TensorBoard
if use_tensorboard:
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
else:
self.writer = None
# 初始化Rich控制台
self.console = Console()
# 训练状态
self.current_step = 0
self.current_epoch = 0
self.best_eval_loss = float("inf")
# 学习率调度函数
self.lr_scheduler = self._create_lr_scheduler()
logger.info(f"Trainer initialized with device: {self.device}")
logger.info(
f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}"
)
logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
def _create_lr_scheduler(self) -> Callable[[int], float]:
"""创建学习率调度函数(预热 + 余弦退火)"""
def lr_scheduler(step: int) -> float:
if step < self.warmup_steps:
# 线性预热
return self.learning_rate * (step / self.warmup_steps)
else:
# 余弦退火
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
decayed_lr = (
self.min_learning_rate
+ (self.learning_rate - self.min_learning_rate) * cosine_decay
)
return decayed_lr
return lr_scheduler
def _get_current_lr(self) -> float:
"""获取当前学习率"""
return self.lr_scheduler(self.current_step)
def _update_learning_rate(self):
"""更新优化器中的学习率"""
current_lr = self._get_current_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = current_lr
return current_lr
def train_step(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[float, Dict[str, Any]]:
"""
执行单个训练步骤
Args:
batch: 包含输入数据的批次
Returns:
loss: 损失值
metrics: 训练指标字典
"""
self.model.train()
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
pinyin_ids = batch["pinyin_ids"].to(self.device)
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
# 混合精度训练
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
loss = loss / self.grad_accum_steps
# 反向传播
self.scaler.scale(loss).backward()
metrics = {
"loss": loss.item() * self.grad_accum_steps,
"lr": self._get_current_lr(),
}
# 计算准确率
with torch.no_grad():
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total = labels.size(0)
metrics["accuracy"] = correct / total if total > 0 else 0.0
return loss.item() * self.grad_accum_steps, metrics
def evaluate(self) -> Dict[str, float]:
"""
在评估集上评估模型
Returns:
评估指标字典
"""
if self.eval_dataloader is None:
return {}
self.model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
with torch.no_grad():
for batch in self.eval_dataloader:
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
pinyin_ids = batch["pinyin_ids"].to(self.device)
labels = batch["labels"].to(self.device).squeeze(-1)
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total_correct += correct
total_samples += labels.size(0)
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
def save_checkpoint(self, filename: str, is_best: bool = False):
"""
保存检查点
Args:
filename: 检查点文件名
is_best: 是否是最佳模型
"""
checkpoint_path = self.checkpoint_dir / filename
checkpoint = {
"step": self.current_step,
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"best_eval_loss": self.best_eval_loss,
"config": {
"learning_rate": self.learning_rate,
"weight_decay": self.weight_decay,
"warmup_ratio": self.warmup_ratio,
"label_smoothing": self.label_smoothing,
"total_steps": self.total_steps,
},
}
torch.save(checkpoint, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
logger.info(f"Best model saved to {best_path}")
def load_checkpoint(self, checkpoint_path: Union[str, Path]):
"""
加载检查点
Args:
checkpoint_path: 检查点文件路径
"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.current_step = checkpoint["step"]
self.current_epoch = checkpoint["epoch"]
self.best_eval_loss = checkpoint["best_eval_loss"]
logger.info(f"Checkpoint loaded from {checkpoint_path}")
logger.info(
f"Resuming from step {self.current_step}, epoch {self.current_epoch}"
)
def _log_to_tensorboard(self, metrics: Dict[str, float], step: int):
"""将指标记录到TensorBoard"""
if self.writer is None:
return
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
def _create_progress_bar(self) -> Progress:
"""创建Rich进度条"""
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
def _print_training_info(self):
"""打印训练信息"""
info_table = Table(
title="Training Configuration",
show_header=True,
header_style="bold magenta",
)
info_table.add_column("Parameter", style="cyan")
info_table.add_column("Value", style="green")
info_table.add_row("Device", str(self.device))
info_table.add_row("Total Steps", str(self.total_steps))
info_table.add_row("Warmup Steps", str(self.warmup_steps))
info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}")
info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}")
info_table.add_row("Weight Decay", str(self.weight_decay))
info_table.add_row("Label Smoothing", str(self.label_smoothing))
info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps))
info_table.add_row("Mixed Precision", str(self.mixed_precision))
self.console.print(info_table)
def train(self, resume_from: Optional[str] = None):
"""
主训练循环
Args:
resume_from: 从哪个检查点恢复训练(可选)
"""
# 如果提供了检查点,则恢复训练
if resume_from is not None:
self.load_checkpoint(resume_from)
# 打印训练信息
self._print_training_info()
# 初始化训练状态
global_step = self.current_step
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
)
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
# 训练步骤
loss, metrics = self.train_step(batch)
# 累积指标
accumulated_loss += loss
accumulated_accuracy += metrics.get("accuracy", 0.0)
accumulation_counter += 1
# 梯度累积每grad_accum_steps步更新一次参数
if (global_step + 1) % self.grad_accum_steps == 0:
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_grad_norm
)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
progress.update(
epoch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
)
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
avg_loss = accumulated_loss / accumulation_counter
avg_accuracy = accumulated_accuracy / accumulation_counter
# 评估模型
eval_metrics = self.evaluate()
# 准备日志指标
log_metrics = {
"train/loss": avg_loss,
"train/accuracy": avg_accuracy,
"train/learning_rate": current_lr,
}
if eval_metrics:
log_metrics.update(
{
"eval/loss": eval_metrics["eval_loss"],
"eval/accuracy": eval_metrics["eval_accuracy"],
}
)
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
self.save_checkpoint(
f"step_{global_step + 1}.pt", is_best=True
)
# 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step)
# 打印日志
log_text = (
f"[Epoch {epoch + 1}/{self.num_epochs}] "
f"[Step {global_step}/{self.total_steps}] "
f"Train Loss: {avg_loss:.4f} | "
f"Train Acc: {avg_accuracy:.4f} | "
f"LR: {current_lr:.2e}"
)
if eval_metrics:
log_text += (
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 定期保存检查点
if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint(f"step_{global_step}.pt")
# 更新步数
global_step += 1
self.current_step = global_step
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 重置进度条
progress.reset(epoch_task)
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
# 检查是否达到总步数
if global_step >= self.total_steps:
break
# 训练完成
logger.info("Training completed!")
# 保存最终模型
self.save_checkpoint("final_model.pt")
# 关闭TensorBoard写入器
if self.writer is not None:
self.writer.close()
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
Args:
batch: 样本列表,每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
# Typer CLI应用
app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False)
@app.command()
def train(
# 数据参数
train_data_path: str = typer.Option(
..., "--train-data-path", "-t", help="训练数据集路径"
),
eval_data_path: str = typer.Option(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
):
"""
训练输入法模型
"""
torch.multiprocessing.set_sharing_strategy("file_system")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 打印配置信息
console.print(
Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan")
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
# 添加配置信息
config_table.add_row("数据", "训练数据路径", train_data_path)
config_table.add_row("数据", "评估数据路径", eval_data_path)
config_table.add_row("数据", "输出目录", output_dir)
config_table.add_row("数据", "批次大小", str(batch_size))
config_table.add_row("模型", "词汇表大小", str(vocab_size))
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
config_table.add_row("模型", "模型维度", str(dim))
config_table.add_row("模型", "槽位数量", str(num_slots))
config_table.add_row("模型", "Transformer层数", str(n_layers))
config_table.add_row("模型", "注意力头数", str(n_heads))
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio))
config_table.add_row("训练", "标签平滑", str(label_smoothing))
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存配置
config = {
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"use_pinyin": use_pinyin,
"batch_size": batch_size,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"label_smoothing": label_smoothing,
"grad_accum_steps": grad_accum_steps,
"clip_grad_norm": clip_grad_norm,
"eval_frequency": eval_frequency,
"save_frequency": save_frequency,
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"max_iter_length": max_iter_length,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
# 注意PinyinInputDataset是IterableDataset所以不能使用shuffle参数
# 多worker配置每个worker处理数据集的一个分片由dataset.__iter__中的shard处理
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 25),
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch适合大内存场景
persistent_workers=True, # 保持worker进程存活避免重建开销
)
# 评估数据集(使用相同的设置,但可以调整参数)
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=1024, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=1000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=batch_size,
num_workers=1,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch
persistent_workers=True, # 保持worker进程存活
)
console.print("[bold cyan]正在创建模型...[/bold cyan]")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
)
console.print(
f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
# 创建训练器
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length / batch_size),
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
)
console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
trainer.train(resume_from=resume_from)
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}")
@app.command()
def evaluate(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"),
batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"),
):
"""
评估训练好的模型
"""
console = Console()
console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]")
# 这里应该实现评估逻辑
# 1. 加载检查点
# 2. 创建数据加载器
# 3. 评估模型
console.print("[yellow]评估功能待实现[/yellow]")
@app.command()
def export(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
output_path: str = typer.Option(
"./exported_model.onnx", "--output", "-o", help="输出路径"
),
):
"""
导出模型为ONNX格式
"""
console = Console()
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
# 这里应该实现导出逻辑
# 1. 加载检查点
# 2. 导出为ONNX
console.print("[yellow]导出功能待实现[/yellow]")
if __name__ == "__main__":
app()