#!/usr/bin/env python3 """ 参数迁移脚本:将旧模型权重迁移到新架构 使用方法: python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt 功能: 1. 加载旧模型 checkpoint 2. 创建新模型架构 3. 直接迁移匹配的参数 4. 拆分迁移 k_proj/v_proj 到 text/pinyin 双分支 5. 保存迁移后的 checkpoint 6. 打印详细迁移报告 """ import argparse import sys from pathlib import Path import torch sys.path.append("src") from src.model.model import InputMethodEngine def transfer_weights( old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device ): """ 迁移旧模型权重到新架构 Args: old_checkpoint_path: 旧 checkpoint 路径 new_model: 新模型实例 device: 设备 Returns: (directly_transferred, split_transferred, new_params, skipped) 迁移报告 """ old_checkpoint = torch.load(old_checkpoint_path, map_location=device) if "model_state_dict" in old_checkpoint: old_state_dict = old_checkpoint["model_state_dict"] else: old_state_dict = old_checkpoint new_state_dict = new_model.state_dict() directly_transferred = [] split_transferred = [] new_params = [] skipped = [] for key in new_state_dict.keys(): if key in old_state_dict: if new_state_dict[key].shape == old_state_dict[key].shape: new_state_dict[key] = old_state_dict[key] directly_transferred.append(key) else: skipped.append( ( key, "shape mismatch", f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}", ) ) elif key in [ "cross_attn.k_text_proj.weight", "cross_attn.k_pinyin_proj.weight", ]: if "cross_attn.k_proj.weight" in old_state_dict: new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone() split_transferred.append((key, "cross_attn.k_proj.weight")) elif key in [ "cross_attn.v_text_proj.weight", "cross_attn.v_pinyin_proj.weight", ]: if "cross_attn.v_proj.weight" in old_state_dict: new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone() split_transferred.append((key, "cross_attn.v_proj.weight")) else: new_params.append(key) new_model.load_state_dict(new_state_dict) return directly_transferred, split_transferred, new_params, skipped def print_report( directly_transferred, split_transferred, new_params, skipped, output_path ): """打印迁移报告""" total_params = len(directly_transferred) + len(split_transferred) + len(new_params) coverage = ( (len(directly_transferred) + len(split_transferred)) / total_params * 100 if total_params > 0 else 100 ) print("\n" + "=" * 60) print("📊 参数迁移报告") print("=" * 60) print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):") categories = {} for key in directly_transferred: category = key.split(".")[0] if category not in categories: categories[category] = [] categories[category].append(key) for cat, keys in sorted(categories.items()): print(f" - {cat}.* ({len(keys)} 个参数)") if split_transferred: print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):") for new_key, old_key in split_transferred: print(f" - {new_key} ← {old_key}") if new_params: print(f"\n⚠️ 新增参数 ({len(new_params)} 层):") for key in new_params[:10]: print(f" - {key}") if len(new_params) > 10: print(f" ... 等共 {len(new_params)} 个") if skipped: print(f"\n❌ 跳过 ({len(skipped)} 层):") for key, reason, detail in skipped[:5]: print(f" - {key}: {reason} ({detail})") if len(skipped) > 5: print(f" ... 等共 {len(skipped)} 个") print("\n" + "-" * 60) print( f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})" ) print("=" * 60) print(f"\n💾 已保存到: {output_path}") def main(): parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构") parser.add_argument( "--old-checkpoint", "-c", type=str, required=True, help="旧模型 checkpoint 路径", ) parser.add_argument( "--output", "-o", type=str, default="./migrated_model.pt", help="迁移后的 checkpoint 输出路径", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="设备 (cuda/cpu)", ) args = parser.parse_args() device = torch.device(args.device) print(f"📦 加载旧模型: {args.old_checkpoint}") print(f"🔧 使用设备: {device}") old_path = Path(args.old_checkpoint) if not old_path.exists(): print(f"❌ 文件不存在: {args.old_checkpoint}") sys.exit(1) print("\n🏗️ 创建新模型架构...") new_model = InputMethodEngine( vocab_size=10019, pinyin_vocab_size=30, dim=512, num_slots=8, n_layers=4, n_heads=4, num_experts=10, max_seq_len=128, compile=False, ) new_model.to(device) print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}") print("\n🔄 开始迁移参数...") directly_transferred, split_transferred, new_params, skipped = transfer_weights( args.old_checkpoint, new_model, device ) print("\n💾 保存迁移后的 checkpoint...") output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) checkpoint = { "step": 0, "epoch": 0, "model_state_dict": new_model.state_dict(), "best_eval_loss": float("inf"), "config": { "vocab_size": 10019, "pinyin_vocab_size": 30, "dim": 512, "num_slots": 8, "n_layers": 4, "n_heads": 4, "num_experts": 10, "max_seq_len": 128, }, "migration_source": args.old_checkpoint, } torch.save(checkpoint, output_path) print_report( directly_transferred, split_transferred, new_params, skipped, output_path ) print("\n✅ 迁移完成!使用方法:") print( f" python -m model.trainer train --resume-from {args.output} --reset-training-state" ) if __name__ == "__main__": main()