#!/usr/bin/env python3 """ 临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。 与 src/model/trainer.py 的 train 命令行为完全一致,额外增加: - --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源) - --freeze-context-encoder: 冻结 context_encoder 层(默认开启) 运行方式: python scripts/finetune_slots.py \ --pretrained-checkpoint ./output/checkpoints/best_model.pt \ --train-data-path /path/to/train_data \ --eval-data-path /path/to/eval_data \ --output-dir ./finetune_output \ --freeze-context-encoder """ import argparse import json import os import random import sys from datetime import datetime from pathlib import Path import numpy as np import torch sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from model.model import InputMethodEngine from model.trainer import ( Trainer, create_dataloader, worker_init_fn, ) from model.dataset import PinyinInputDataset from model.preprocessed_dataset import ( PreProcessedDataset, is_preprocessed_data, ) from loguru import logger from rich.console import Console from rich.panel import Panel from rich.table import Table def main(): parser = argparse.ArgumentParser( description="迁移学习训练:加载预训练模型,冻结指定层后重新训练", formatter_class=argparse.RawDescriptionHelpFormatter, ) # === 数据参数 === parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径") parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径") parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录") parser.add_argument("--max-iter-length", type=int, default=1024 * 1024 * 128, help="每个 epoch 最大样本数") # === 迁移学习参数 === parser.add_argument("--pretrained-checkpoint", "-c", required=True, help="预训练模型检查点路径") parser.add_argument("--freeze-context-encoder", action="store_true", default=True, help="冻结 context_encoder 层 (默认开启)") parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder", action="store_false", help="不冻结 context_encoder") # === 训练参数 === parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小") parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数") parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4, help="学习率") parser.add_argument("--min-learning-rate", type=float, default=1e-9, help="最小学习率") parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减") parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例") parser.add_argument("--label-smoothing", type=float, default=0.1, help="标签平滑参数") parser.add_argument("--grad-accum-steps", type=int, default=1, help="梯度累积步数") parser.add_argument("--clip-grad-norm", type=float, default=1.0, help="梯度裁剪范数") parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率") parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率") # === 其他参数 === parser.add_argument("--mixed-precision", action="store_true", default=True) parser.add_argument("--no-mixed-precision", dest="mixed_precision", action="store_false", help="禁用混合精度") parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数") parser.add_argument("--tensorboard", action="store_true", default=True) parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false", help="禁用 TensorBoard") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--compile", action="store_true", default=False, help="使用 torch.compile 优化") parser.add_argument("--moe-mode", default="all", choices=["all", "sparse", "sparse_allow_graph"], help="MoE 计算策略") args = parser.parse_args() # ================================================================ # 初始化 # ================================================================ torch.multiprocessing.set_sharing_strategy("file_system") if torch.cuda.is_available(): torch.set_float32_matmul_precision("high") torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) console = Console() output_path = Path(args.output_dir) output_path.mkdir(parents=True, exist_ok=True) # ================================================================ # 模型常量 (与 trainer.py 保持一致) # ================================================================ vocab_size = 10019 pinyin_vocab_size = 30 dim = 512 num_slots = 8 n_layers = 4 n_heads = 4 num_experts = 10 max_seq_len = 128 # ================================================================ # 打印配置 # ================================================================ console.print(Panel.fit( "[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan")) config_table = Table(show_header=True, header_style="bold magenta") config_table.add_column("Category", style="cyan") config_table.add_column("Parameter", style="green") config_table.add_column("Value", style="yellow") config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint) config_table.add_row("迁移学习", "冻结 context_encoder", str(args.freeze_context_encoder)) config_table.add_row("数据", "训练数据路径", args.train_data_path) config_table.add_row("数据", "评估数据路径", args.eval_data_path) config_table.add_row("数据", "输出目录", args.output_dir) config_table.add_row("数据", "批次大小", str(args.batch_size)) config_table.add_row("数据", "Worker数量", str(args.num_workers)) config_table.add_row("模型", "MoE策略", args.moe_mode) config_table.add_row("模型", "编译优化", str(args.compile)) config_table.add_row("训练", "训练轮数", str(args.num_epochs)) config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}") config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}") config_table.add_row("训练", "权重衰减", str(args.weight_decay)) config_table.add_row("训练", "热身比例", str(args.warmup_ratio)) config_table.add_row("训练", "标签平滑", str(args.label_smoothing)) config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps)) config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm)) config_table.add_row("训练", "混合精度", str(args.mixed_precision)) # ================================================================ # 创建数据加载器 (逻辑与 trainer.py CLI 完全一致) # ================================================================ console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") is_train_preprocessed = is_preprocessed_data(args.train_data_path) is_eval_preprocessed = is_preprocessed_data(args.eval_data_path) if is_train_preprocessed: train_dataset = PreProcessedDataset(args.train_data_path, max_cache_shards=2) pre_shuffled = train_dataset.metadata.get("pre_shuffled", False) shuffle_train = not pre_shuffled if args.max_iter_length > 0: capped_samples = min(len(train_dataset), args.max_iter_length) else: capped_samples = len(train_dataset) total_steps = (capped_samples // args.batch_size) * args.num_epochs train_num_workers = min(args.num_workers, 1) logger.info( f"Preprocessed dataset: {len(train_dataset):,} samples, " f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, " f"workers={train_num_workers}, steps={total_steps:,}") train_dataloader = create_dataloader( dataset=train_dataset, batch_size=args.batch_size, num_workers=train_num_workers, pin_memory=torch.cuda.is_available(), shuffle=shuffle_train, ) config_table.add_row("数据", "训练数据类型", "预处理数据") else: train_dataset = PinyinInputDataset( data_path=args.train_data_path, max_workers=-1, max_iter_length=args.max_iter_length, max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=2000000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) total_steps = int(args.max_iter_length * args.num_epochs / args.batch_size) train_dataloader = create_dataloader( dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=torch.cuda.is_available(), max_iter_length=args.max_iter_length, ) config_table.add_row("数据", "训练数据类型", "流式数据") if is_eval_preprocessed: eval_dataset = PreProcessedDataset(args.eval_data_path, max_cache_shards=1) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=args.batch_size, num_workers=0, pin_memory=torch.cuda.is_available(), shuffle=False, ) config_table.add_row("数据", "评估数据类型", "预处理数据") else: eval_dataset = PinyinInputDataset( data_path=args.eval_data_path, max_workers=-1, max_iter_length=args.batch_size * 64, max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), shuffle_buffer_size=2000000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=args.batch_size, num_workers=2, pin_memory=torch.cuda.is_available(), max_iter_length=args.batch_size * 64, ) config_table.add_row("数据", "评估数据类型", "流式数据") config_table.add_row("数据", "总步数", str(total_steps)) console.print(config_table) # ================================================================ # 创建模型并加载预训练权重 # ================================================================ console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = InputMethodEngine( vocab_size=vocab_size, pinyin_vocab_size=pinyin_vocab_size, dim=dim, num_slots=num_slots, n_layers=n_layers, n_heads=n_heads, num_experts=num_experts, max_seq_len=max_seq_len, compile=args.compile, moe_mode=args.moe_mode, ) model.to(device) # 加载预训练权重 pretrained_path = Path(args.pretrained_checkpoint) if not pretrained_path.exists(): console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]") sys.exit(1) checkpoint = torch.load(args.pretrained_checkpoint, map_location=device) if "model_state_dict" in checkpoint: pretrained_weights = checkpoint["model_state_dict"] else: pretrained_weights = checkpoint missing_keys, unexpected_keys = model.load_state_dict( pretrained_weights, strict=False) if missing_keys: console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): " f"{missing_keys[:5]}...[/yellow]") if unexpected_keys: console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): " f"{unexpected_keys[:5]}...[/yellow]") console.print(f"[green]✓ 预训练权重加载完成[/green]") # ================================================================ # 冻结 context_encoder # ================================================================ if args.freeze_context_encoder: console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]") frozen_count = 0 trainable_count = 0 for name, param in model.named_parameters(): if name.startswith("context_encoder"): param.requires_grad = False frozen_count += param.numel() else: trainable_count += param.numel() total_params = frozen_count + trainable_count console.print(f"[green]✓ context_encoder 已冻结[/green]") logger.info( f"冻结参数: {frozen_count:,} / {total_params:,} " f"({frozen_count / total_params * 100:.1f}%), " f"可训练参数: {trainable_count:,} / {total_params:,} " f"({trainable_count / total_params * 100:.1f}%)") else: logger.info("未冻结任何层,全模型参与训练") # ================================================================ # 保存配置 # ================================================================ config = { "pretrained_checkpoint": args.pretrained_checkpoint, "freeze_context_encoder": args.freeze_context_encoder, "train_data_path": args.train_data_path, "eval_data_path": args.eval_data_path, "output_dir": args.output_dir, "batch_size": args.batch_size, "num_epochs": args.num_epochs, "learning_rate": args.learning_rate, "min_learning_rate": args.min_learning_rate, "weight_decay": args.weight_decay, "warmup_ratio": args.warmup_ratio, "label_smoothing": args.label_smoothing, "grad_accum_steps": args.grad_accum_steps, "clip_grad_norm": args.clip_grad_norm, "eval_frequency": args.eval_frequency, "save_frequency": args.save_frequency, "mixed_precision": args.mixed_precision, "num_workers": args.num_workers, "use_tensorboard": args.tensorboard, "seed": args.seed, "compile": args.compile, "moe_mode": args.moe_mode, "total_steps": total_steps, "vocab_size": vocab_size, "pinyin_vocab_size": pinyin_vocab_size, "dim": dim, "num_slots": num_slots, "n_layers": n_layers, "n_heads": n_heads, "num_experts": num_experts, "max_seq_len": max_seq_len, "is_train_preprocessed": is_train_preprocessed, "is_eval_preprocessed": is_eval_preprocessed, } config_file = output_path / "training_config.json" with open(config_file, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info(f"Configuration saved to {config_file}") # ================================================================ # 创建 Trainer 并开始训练 # ================================================================ console.print("[bold cyan]正在创建训练器...[/bold cyan]") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, total_steps=total_steps, output_dir=args.output_dir, num_epochs=args.num_epochs, learning_rate=args.learning_rate, min_learning_rate=args.min_learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, label_smoothing=args.label_smoothing, grad_accum_steps=args.grad_accum_steps, clip_grad_norm=args.clip_grad_norm, eval_frequency=args.eval_frequency, save_frequency=args.save_frequency, mixed_precision=args.mixed_precision, device=device, use_tensorboard=args.tensorboard, status_file="training_status.json", ) console.print("[green]✓ 训练器创建完成[/green]") console.print("\n[bold cyan]开始训练...[/bold cyan]") console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( resume_from=None, reset_training_state=False, auto_resume=False, # 迁移学习从头开始,不自动恢复 ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") trainer.save_checkpoint("interrupted_model.pt") console.print("[bold green]✓ 训练完成![/bold green]") console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") console.print(f"模型和日志保存在: {args.output_dir}") if __name__ == "__main__": main()