SUimeModelTraner/transfer_weights.py

233 lines
6.8 KiB
Python

#!/usr/bin/env python3
"""
参数迁移脚本:将旧模型权重迁移到新架构
使用方法:
python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt
功能:
1. 加载旧模型 checkpoint
2. 创建新模型架构
3. 直接迁移匹配的参数
4. 拆分迁移 k_proj/v_proj 到 text/pinyin 双分支
5. 保存迁移后的 checkpoint
6. 打印详细迁移报告
"""
import argparse
import sys
from pathlib import Path
import torch
sys.path.append("src")
from src.model.model import InputMethodEngine
def transfer_weights(
old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device
):
"""
迁移旧模型权重到新架构
Args:
old_checkpoint_path: 旧 checkpoint 路径
new_model: 新模型实例
device: 设备
Returns:
(directly_transferred, split_transferred, new_params, skipped) 迁移报告
"""
old_checkpoint = torch.load(old_checkpoint_path, map_location=device)
if "model_state_dict" in old_checkpoint:
old_state_dict = old_checkpoint["model_state_dict"]
else:
old_state_dict = old_checkpoint
new_state_dict = new_model.state_dict()
directly_transferred = []
split_transferred = []
new_params = []
skipped = []
for key in new_state_dict.keys():
if key in old_state_dict:
if new_state_dict[key].shape == old_state_dict[key].shape:
new_state_dict[key] = old_state_dict[key]
directly_transferred.append(key)
else:
skipped.append(
(
key,
"shape mismatch",
f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}",
)
)
elif key in [
"cross_attn.k_text_proj.weight",
"cross_attn.k_pinyin_proj.weight",
]:
if "cross_attn.k_proj.weight" in old_state_dict:
new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone()
split_transferred.append((key, "cross_attn.k_proj.weight"))
elif key in [
"cross_attn.v_text_proj.weight",
"cross_attn.v_pinyin_proj.weight",
]:
if "cross_attn.v_proj.weight" in old_state_dict:
new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone()
split_transferred.append((key, "cross_attn.v_proj.weight"))
else:
new_params.append(key)
new_model.load_state_dict(new_state_dict)
return directly_transferred, split_transferred, new_params, skipped
def print_report(
directly_transferred, split_transferred, new_params, skipped, output_path
):
"""打印迁移报告"""
total_params = len(directly_transferred) + len(split_transferred) + len(new_params)
coverage = (
(len(directly_transferred) + len(split_transferred)) / total_params * 100
if total_params > 0
else 100
)
print("\n" + "=" * 60)
print("📊 参数迁移报告")
print("=" * 60)
print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):")
categories = {}
for key in directly_transferred:
category = key.split(".")[0]
if category not in categories:
categories[category] = []
categories[category].append(key)
for cat, keys in sorted(categories.items()):
print(f" - {cat}.* ({len(keys)} 个参数)")
if split_transferred:
print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):")
for new_key, old_key in split_transferred:
print(f" - {new_key}{old_key}")
if new_params:
print(f"\n⚠️ 新增参数 ({len(new_params)} 层):")
for key in new_params[:10]:
print(f" - {key}")
if len(new_params) > 10:
print(f" ... 等共 {len(new_params)}")
if skipped:
print(f"\n❌ 跳过 ({len(skipped)} 层):")
for key, reason, detail in skipped[:5]:
print(f" - {key}: {reason} ({detail})")
if len(skipped) > 5:
print(f" ... 等共 {len(skipped)}")
print("\n" + "-" * 60)
print(
f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})"
)
print("=" * 60)
print(f"\n💾 已保存到: {output_path}")
def main():
parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构")
parser.add_argument(
"--old-checkpoint",
"-c",
type=str,
required=True,
help="旧模型 checkpoint 路径",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="./migrated_model.pt",
help="迁移后的 checkpoint 输出路径",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="设备 (cuda/cpu)",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"📦 加载旧模型: {args.old_checkpoint}")
print(f"🔧 使用设备: {device}")
old_path = Path(args.old_checkpoint)
if not old_path.exists():
print(f"❌ 文件不存在: {args.old_checkpoint}")
sys.exit(1)
print("\n🏗️ 创建新模型架构...")
new_model = InputMethodEngine(
vocab_size=10019,
pinyin_vocab_size=30,
dim=512,
num_slots=8,
n_layers=4,
n_heads=4,
num_experts=10,
max_seq_len=128,
compile=False,
)
new_model.to(device)
print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}")
print("\n🔄 开始迁移参数...")
directly_transferred, split_transferred, new_params, skipped = transfer_weights(
args.old_checkpoint, new_model, device
)
print("\n💾 保存迁移后的 checkpoint...")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
"step": 0,
"epoch": 0,
"model_state_dict": new_model.state_dict(),
"best_eval_loss": float("inf"),
"config": {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
},
"migration_source": args.old_checkpoint,
}
torch.save(checkpoint, output_path)
print_report(
directly_transferred, split_transferred, new_params, skipped, output_path
)
print("\n✅ 迁移完成!使用方法:")
print(
f" python -m model.trainer train --resume-from {args.output} --reset-training-state"
)
if __name__ == "__main__":
main()