SUimeModelTraner/scripts/finetune_slots.py

410 lines
17 KiB
Python

#!/usr/bin/env python3
"""
临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。
与 src/model/trainer.py 的 train 命令行为完全一致,额外增加:
- --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源)
- --freeze-context-encoder: 冻结 context_encoder 层(默认开启)
运行方式:
python scripts/finetune_slots.py \
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
--train-data-path /path/to/train_data \
--eval-data-path /path/to/eval_data \
--output-dir ./finetune_output \
--freeze-context-encoder
"""
import argparse
import json
import os
import random
import sys
from datetime import datetime
from pathlib import Path
import numpy as np
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.model import InputMethodEngine
from model.trainer import (
Trainer,
create_dataloader,
worker_init_fn,
)
from model.dataset import PinyinInputDataset
from model.preprocessed_dataset import (
PreProcessedDataset,
is_preprocessed_data,
)
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
def main():
parser = argparse.ArgumentParser(
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# === 数据参数 ===
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
parser.add_argument("--max-iter-length", type=int,
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
# === 迁移学习参数 ===
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
help="预训练模型检查点路径")
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
help="冻结 context_encoder 层 (默认开启)")
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
action="store_false",
help="不冻结 context_encoder")
# === 训练参数 ===
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
help="学习率")
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
help="最小学习率")
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
parser.add_argument("--label-smoothing", type=float, default=0.1,
help="标签平滑参数")
parser.add_argument("--grad-accum-steps", type=int, default=1,
help="梯度累积步数")
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
help="梯度裁剪范数")
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
# === 其他参数 ===
parser.add_argument("--mixed-precision", action="store_true", default=True)
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
action="store_false", help="禁用混合精度")
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
parser.add_argument("--tensorboard", action="store_true", default=True)
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
help="禁用 TensorBoard")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--compile", action="store_true", default=False,
help="使用 torch.compile 优化")
parser.add_argument("--moe-mode", default="all",
choices=["all", "sparse", "sparse_allow_graph"],
help="MoE 计算策略")
args = parser.parse_args()
# ================================================================
# 初始化
# ================================================================
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
console = Console()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# ================================================================
# 模型常量 (与 trainer.py 保持一致)
# ================================================================
vocab_size = 10019
pinyin_vocab_size = 30
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 10
max_seq_len = 128
# ================================================================
# 打印配置
# ================================================================
console.print(Panel.fit(
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
config_table.add_row("迁移学习", "冻结 context_encoder",
str(args.freeze_context_encoder))
config_table.add_row("数据", "训练数据路径", args.train_data_path)
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
config_table.add_row("数据", "输出目录", args.output_dir)
config_table.add_row("数据", "批次大小", str(args.batch_size))
config_table.add_row("数据", "Worker数量", str(args.num_workers))
config_table.add_row("模型", "MoE策略", args.moe_mode)
config_table.add_row("模型", "编译优化", str(args.compile))
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
# ================================================================
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
# ================================================================
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(args.train_data_path,
max_cache_shards=2)
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
shuffle_train = not pre_shuffled
if args.max_iter_length > 0:
capped_samples = min(len(train_dataset), args.max_iter_length)
else:
capped_samples = len(train_dataset)
total_steps = (capped_samples // args.batch_size) * args.num_epochs
train_num_workers = min(args.num_workers, 1)
logger.info(
f"Preprocessed dataset: {len(train_dataset):,} samples, "
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
f"workers={train_num_workers}, steps={total_steps:,}")
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
else:
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=-1,
max_iter_length=args.max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
total_steps = int(args.max_iter_length *
args.num_epochs / args.batch_size)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.max_iter_length,
)
config_table.add_row("数据", "训练数据类型", "流式数据")
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(args.eval_data_path,
max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=0,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)
config_table.add_row("数据", "评估数据类型", "预处理数据")
else:
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=-1,
max_iter_length=args.batch_size * 64,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.batch_size * 64,
)
config_table.add_row("数据", "评估数据类型", "流式数据")
config_table.add_row("数据", "总步数", str(total_steps))
console.print(config_table)
# ================================================================
# 创建模型并加载预训练权重
# ================================================================
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=args.compile,
moe_mode=args.moe_mode,
)
model.to(device)
# 加载预训练权重
pretrained_path = Path(args.pretrained_checkpoint)
if not pretrained_path.exists():
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
sys.exit(1)
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
if "model_state_dict" in checkpoint:
pretrained_weights = checkpoint["model_state_dict"]
else:
pretrained_weights = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(
pretrained_weights, strict=False)
if missing_keys:
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
f"{missing_keys[:5]}...[/yellow]")
if unexpected_keys:
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
f"{unexpected_keys[:5]}...[/yellow]")
console.print(f"[green]✓ 预训练权重加载完成[/green]")
# ================================================================
# 冻结 context_encoder
# ================================================================
if args.freeze_context_encoder:
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
frozen_count = 0
trainable_count = 0
for name, param in model.named_parameters():
if name.startswith("context_encoder"):
param.requires_grad = False
frozen_count += param.numel()
else:
trainable_count += param.numel()
total_params = frozen_count + trainable_count
console.print(f"[green]✓ context_encoder 已冻结[/green]")
logger.info(
f"冻结参数: {frozen_count:,} / {total_params:,} "
f"({frozen_count / total_params * 100:.1f}%), "
f"可训练参数: {trainable_count:,} / {total_params:,} "
f"({trainable_count / total_params * 100:.1f}%)")
else:
logger.info("未冻结任何层,全模型参与训练")
# ================================================================
# 保存配置
# ================================================================
config = {
"pretrained_checkpoint": args.pretrained_checkpoint,
"freeze_context_encoder": args.freeze_context_encoder,
"train_data_path": args.train_data_path,
"eval_data_path": args.eval_data_path,
"output_dir": args.output_dir,
"batch_size": args.batch_size,
"num_epochs": args.num_epochs,
"learning_rate": args.learning_rate,
"min_learning_rate": args.min_learning_rate,
"weight_decay": args.weight_decay,
"warmup_ratio": args.warmup_ratio,
"label_smoothing": args.label_smoothing,
"grad_accum_steps": args.grad_accum_steps,
"clip_grad_norm": args.clip_grad_norm,
"eval_frequency": args.eval_frequency,
"save_frequency": args.save_frequency,
"mixed_precision": args.mixed_precision,
"num_workers": args.num_workers,
"use_tensorboard": args.tensorboard,
"seed": args.seed,
"compile": args.compile,
"moe_mode": args.moe_mode,
"total_steps": total_steps,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# ================================================================
# 创建 Trainer 并开始训练
# ================================================================
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
min_learning_rate=args.min_learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
label_smoothing=args.label_smoothing,
grad_accum_steps=args.grad_accum_steps,
clip_grad_norm=args.clip_grad_norm,
eval_frequency=args.eval_frequency,
save_frequency=args.save_frequency,
mixed_precision=args.mixed_precision,
device=device,
use_tensorboard=args.tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=None,
reset_training_state=False,
auto_resume=False, # 迁移学习从头开始,不自动恢复
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {args.output_dir}")
if __name__ == "__main__":
main()