410 lines
17 KiB
Python
410 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。
|
|
|
|
与 src/model/trainer.py 的 train 命令行为完全一致,额外增加:
|
|
- --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源)
|
|
- --freeze-context-encoder: 冻结 context_encoder 层(默认开启)
|
|
|
|
运行方式:
|
|
python scripts/finetune_slots.py \
|
|
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
|
|
--train-data-path /path/to/train_data \
|
|
--eval-data-path /path/to/eval_data \
|
|
--output-dir ./finetune_output \
|
|
--freeze-context-encoder
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import random
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from model.model import InputMethodEngine
|
|
from model.trainer import (
|
|
Trainer,
|
|
create_dataloader,
|
|
worker_init_fn,
|
|
)
|
|
from model.dataset import PinyinInputDataset
|
|
from model.preprocessed_dataset import (
|
|
PreProcessedDataset,
|
|
is_preprocessed_data,
|
|
)
|
|
|
|
from loguru import logger
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
|
|
# === 数据参数 ===
|
|
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
|
|
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
|
|
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
|
|
parser.add_argument("--max-iter-length", type=int,
|
|
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
|
|
|
|
# === 迁移学习参数 ===
|
|
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
|
|
help="预训练模型检查点路径")
|
|
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
|
|
help="冻结 context_encoder 层 (默认开启)")
|
|
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
|
|
action="store_false",
|
|
help="不冻结 context_encoder")
|
|
|
|
# === 训练参数 ===
|
|
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
|
|
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
|
|
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
|
|
help="学习率")
|
|
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
|
|
help="最小学习率")
|
|
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
|
|
parser.add_argument("--label-smoothing", type=float, default=0.1,
|
|
help="标签平滑参数")
|
|
parser.add_argument("--grad-accum-steps", type=int, default=1,
|
|
help="梯度累积步数")
|
|
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
|
|
help="梯度裁剪范数")
|
|
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
|
|
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
|
|
|
|
# === 其他参数 ===
|
|
parser.add_argument("--mixed-precision", action="store_true", default=True)
|
|
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
|
|
action="store_false", help="禁用混合精度")
|
|
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
|
|
parser.add_argument("--tensorboard", action="store_true", default=True)
|
|
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
|
|
help="禁用 TensorBoard")
|
|
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
|
parser.add_argument("--compile", action="store_true", default=False,
|
|
help="使用 torch.compile 优化")
|
|
parser.add_argument("--moe-mode", default="all",
|
|
choices=["all", "sparse", "sparse_allow_graph"],
|
|
help="MoE 计算策略")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# ================================================================
|
|
# 初始化
|
|
# ================================================================
|
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
|
if torch.cuda.is_available():
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
torch.manual_seed(args.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
console = Console()
|
|
output_path = Path(args.output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ================================================================
|
|
# 模型常量 (与 trainer.py 保持一致)
|
|
# ================================================================
|
|
vocab_size = 10019
|
|
pinyin_vocab_size = 30
|
|
dim = 512
|
|
num_slots = 8
|
|
n_layers = 4
|
|
n_heads = 4
|
|
num_experts = 10
|
|
max_seq_len = 128
|
|
|
|
# ================================================================
|
|
# 打印配置
|
|
# ================================================================
|
|
console.print(Panel.fit(
|
|
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
|
|
config_table = Table(show_header=True, header_style="bold magenta")
|
|
config_table.add_column("Category", style="cyan")
|
|
config_table.add_column("Parameter", style="green")
|
|
config_table.add_column("Value", style="yellow")
|
|
|
|
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
|
|
config_table.add_row("迁移学习", "冻结 context_encoder",
|
|
str(args.freeze_context_encoder))
|
|
config_table.add_row("数据", "训练数据路径", args.train_data_path)
|
|
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
|
|
config_table.add_row("数据", "输出目录", args.output_dir)
|
|
config_table.add_row("数据", "批次大小", str(args.batch_size))
|
|
config_table.add_row("数据", "Worker数量", str(args.num_workers))
|
|
config_table.add_row("模型", "MoE策略", args.moe_mode)
|
|
config_table.add_row("模型", "编译优化", str(args.compile))
|
|
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
|
|
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
|
|
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
|
|
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
|
|
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
|
|
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
|
|
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
|
|
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
|
|
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
|
|
|
|
# ================================================================
|
|
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
|
|
# ================================================================
|
|
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
|
|
|
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
|
|
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
|
|
|
|
if is_train_preprocessed:
|
|
train_dataset = PreProcessedDataset(args.train_data_path,
|
|
max_cache_shards=2)
|
|
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
|
shuffle_train = not pre_shuffled
|
|
if args.max_iter_length > 0:
|
|
capped_samples = min(len(train_dataset), args.max_iter_length)
|
|
else:
|
|
capped_samples = len(train_dataset)
|
|
total_steps = (capped_samples // args.batch_size) * args.num_epochs
|
|
train_num_workers = min(args.num_workers, 1)
|
|
logger.info(
|
|
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
|
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
|
f"workers={train_num_workers}, steps={total_steps:,}")
|
|
train_dataloader = create_dataloader(
|
|
dataset=train_dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=train_num_workers,
|
|
pin_memory=torch.cuda.is_available(),
|
|
shuffle=shuffle_train,
|
|
)
|
|
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
|
else:
|
|
train_dataset = PinyinInputDataset(
|
|
data_path=args.train_data_path,
|
|
max_workers=-1,
|
|
max_iter_length=args.max_iter_length,
|
|
max_seq_length=max_seq_len,
|
|
text_field="text",
|
|
py_style_weight=(9, 2, 1),
|
|
shuffle_buffer_size=2000000,
|
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
|
5: 15, 6: 10, 7: 5, 8: 2},
|
|
)
|
|
total_steps = int(args.max_iter_length *
|
|
args.num_epochs / args.batch_size)
|
|
train_dataloader = create_dataloader(
|
|
dataset=train_dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=torch.cuda.is_available(),
|
|
max_iter_length=args.max_iter_length,
|
|
)
|
|
config_table.add_row("数据", "训练数据类型", "流式数据")
|
|
|
|
if is_eval_preprocessed:
|
|
eval_dataset = PreProcessedDataset(args.eval_data_path,
|
|
max_cache_shards=1)
|
|
eval_dataloader = create_dataloader(
|
|
dataset=eval_dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=0,
|
|
pin_memory=torch.cuda.is_available(),
|
|
shuffle=False,
|
|
)
|
|
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
|
else:
|
|
eval_dataset = PinyinInputDataset(
|
|
data_path=args.eval_data_path,
|
|
max_workers=-1,
|
|
max_iter_length=args.batch_size * 64,
|
|
max_seq_length=max_seq_len,
|
|
text_field="text",
|
|
py_style_weight=(9, 2, 1),
|
|
shuffle_buffer_size=2000000,
|
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
|
5: 15, 6: 10, 7: 5, 8: 2},
|
|
)
|
|
eval_dataloader = create_dataloader(
|
|
dataset=eval_dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=2,
|
|
pin_memory=torch.cuda.is_available(),
|
|
max_iter_length=args.batch_size * 64,
|
|
)
|
|
config_table.add_row("数据", "评估数据类型", "流式数据")
|
|
|
|
config_table.add_row("数据", "总步数", str(total_steps))
|
|
console.print(config_table)
|
|
|
|
# ================================================================
|
|
# 创建模型并加载预训练权重
|
|
# ================================================================
|
|
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = InputMethodEngine(
|
|
vocab_size=vocab_size,
|
|
pinyin_vocab_size=pinyin_vocab_size,
|
|
dim=dim,
|
|
num_slots=num_slots,
|
|
n_layers=n_layers,
|
|
n_heads=n_heads,
|
|
num_experts=num_experts,
|
|
max_seq_len=max_seq_len,
|
|
compile=args.compile,
|
|
moe_mode=args.moe_mode,
|
|
)
|
|
model.to(device)
|
|
|
|
# 加载预训练权重
|
|
pretrained_path = Path(args.pretrained_checkpoint)
|
|
if not pretrained_path.exists():
|
|
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
|
|
sys.exit(1)
|
|
|
|
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
|
|
if "model_state_dict" in checkpoint:
|
|
pretrained_weights = checkpoint["model_state_dict"]
|
|
else:
|
|
pretrained_weights = checkpoint
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(
|
|
pretrained_weights, strict=False)
|
|
|
|
if missing_keys:
|
|
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
|
|
f"{missing_keys[:5]}...[/yellow]")
|
|
if unexpected_keys:
|
|
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
|
|
f"{unexpected_keys[:5]}...[/yellow]")
|
|
|
|
console.print(f"[green]✓ 预训练权重加载完成[/green]")
|
|
|
|
# ================================================================
|
|
# 冻结 context_encoder
|
|
# ================================================================
|
|
if args.freeze_context_encoder:
|
|
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
|
|
frozen_count = 0
|
|
trainable_count = 0
|
|
for name, param in model.named_parameters():
|
|
if name.startswith("context_encoder"):
|
|
param.requires_grad = False
|
|
frozen_count += param.numel()
|
|
else:
|
|
trainable_count += param.numel()
|
|
|
|
total_params = frozen_count + trainable_count
|
|
console.print(f"[green]✓ context_encoder 已冻结[/green]")
|
|
logger.info(
|
|
f"冻结参数: {frozen_count:,} / {total_params:,} "
|
|
f"({frozen_count / total_params * 100:.1f}%), "
|
|
f"可训练参数: {trainable_count:,} / {total_params:,} "
|
|
f"({trainable_count / total_params * 100:.1f}%)")
|
|
else:
|
|
logger.info("未冻结任何层,全模型参与训练")
|
|
|
|
# ================================================================
|
|
# 保存配置
|
|
# ================================================================
|
|
config = {
|
|
"pretrained_checkpoint": args.pretrained_checkpoint,
|
|
"freeze_context_encoder": args.freeze_context_encoder,
|
|
"train_data_path": args.train_data_path,
|
|
"eval_data_path": args.eval_data_path,
|
|
"output_dir": args.output_dir,
|
|
"batch_size": args.batch_size,
|
|
"num_epochs": args.num_epochs,
|
|
"learning_rate": args.learning_rate,
|
|
"min_learning_rate": args.min_learning_rate,
|
|
"weight_decay": args.weight_decay,
|
|
"warmup_ratio": args.warmup_ratio,
|
|
"label_smoothing": args.label_smoothing,
|
|
"grad_accum_steps": args.grad_accum_steps,
|
|
"clip_grad_norm": args.clip_grad_norm,
|
|
"eval_frequency": args.eval_frequency,
|
|
"save_frequency": args.save_frequency,
|
|
"mixed_precision": args.mixed_precision,
|
|
"num_workers": args.num_workers,
|
|
"use_tensorboard": args.tensorboard,
|
|
"seed": args.seed,
|
|
"compile": args.compile,
|
|
"moe_mode": args.moe_mode,
|
|
"total_steps": total_steps,
|
|
"vocab_size": vocab_size,
|
|
"pinyin_vocab_size": pinyin_vocab_size,
|
|
"dim": dim,
|
|
"num_slots": num_slots,
|
|
"n_layers": n_layers,
|
|
"n_heads": n_heads,
|
|
"num_experts": num_experts,
|
|
"max_seq_len": max_seq_len,
|
|
"is_train_preprocessed": is_train_preprocessed,
|
|
"is_eval_preprocessed": is_eval_preprocessed,
|
|
}
|
|
config_file = output_path / "training_config.json"
|
|
with open(config_file, "w", encoding="utf-8") as f:
|
|
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
logger.info(f"Configuration saved to {config_file}")
|
|
|
|
# ================================================================
|
|
# 创建 Trainer 并开始训练
|
|
# ================================================================
|
|
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=eval_dataloader,
|
|
total_steps=total_steps,
|
|
output_dir=args.output_dir,
|
|
num_epochs=args.num_epochs,
|
|
learning_rate=args.learning_rate,
|
|
min_learning_rate=args.min_learning_rate,
|
|
weight_decay=args.weight_decay,
|
|
warmup_ratio=args.warmup_ratio,
|
|
label_smoothing=args.label_smoothing,
|
|
grad_accum_steps=args.grad_accum_steps,
|
|
clip_grad_norm=args.clip_grad_norm,
|
|
eval_frequency=args.eval_frequency,
|
|
save_frequency=args.save_frequency,
|
|
mixed_precision=args.mixed_precision,
|
|
device=device,
|
|
use_tensorboard=args.tensorboard,
|
|
status_file="training_status.json",
|
|
)
|
|
console.print("[green]✓ 训练器创建完成[/green]")
|
|
|
|
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
|
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
try:
|
|
trainer.train(
|
|
resume_from=None,
|
|
reset_training_state=False,
|
|
auto_resume=False, # 迁移学习从头开始,不自动恢复
|
|
)
|
|
except KeyboardInterrupt:
|
|
console.print("[bold green]训练被终止[/bold green]")
|
|
trainer.save_checkpoint("interrupted_model.pt")
|
|
|
|
console.print("[bold green]✓ 训练完成![/bold green]")
|
|
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
console.print(f"模型和日志保存在: {args.output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|