233 lines
6.8 KiB
Python
233 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
参数迁移脚本:将旧模型权重迁移到新架构
|
|
|
|
使用方法:
|
|
python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt
|
|
|
|
功能:
|
|
1. 加载旧模型 checkpoint
|
|
2. 创建新模型架构
|
|
3. 直接迁移匹配的参数
|
|
4. 拆分迁移 k_proj/v_proj 到 text/pinyin 双分支
|
|
5. 保存迁移后的 checkpoint
|
|
6. 打印详细迁移报告
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
sys.path.append("src")
|
|
from src.model.model import InputMethodEngine
|
|
|
|
|
|
def transfer_weights(
|
|
old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device
|
|
):
|
|
"""
|
|
迁移旧模型权重到新架构
|
|
|
|
Args:
|
|
old_checkpoint_path: 旧 checkpoint 路径
|
|
new_model: 新模型实例
|
|
device: 设备
|
|
|
|
Returns:
|
|
(directly_transferred, split_transferred, new_params, skipped) 迁移报告
|
|
"""
|
|
old_checkpoint = torch.load(old_checkpoint_path, map_location=device)
|
|
|
|
if "model_state_dict" in old_checkpoint:
|
|
old_state_dict = old_checkpoint["model_state_dict"]
|
|
else:
|
|
old_state_dict = old_checkpoint
|
|
|
|
new_state_dict = new_model.state_dict()
|
|
|
|
directly_transferred = []
|
|
split_transferred = []
|
|
new_params = []
|
|
skipped = []
|
|
|
|
for key in new_state_dict.keys():
|
|
if key in old_state_dict:
|
|
if new_state_dict[key].shape == old_state_dict[key].shape:
|
|
new_state_dict[key] = old_state_dict[key]
|
|
directly_transferred.append(key)
|
|
else:
|
|
skipped.append(
|
|
(
|
|
key,
|
|
"shape mismatch",
|
|
f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}",
|
|
)
|
|
)
|
|
elif key in [
|
|
"cross_attn.k_text_proj.weight",
|
|
"cross_attn.k_pinyin_proj.weight",
|
|
]:
|
|
if "cross_attn.k_proj.weight" in old_state_dict:
|
|
new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone()
|
|
split_transferred.append((key, "cross_attn.k_proj.weight"))
|
|
elif key in [
|
|
"cross_attn.v_text_proj.weight",
|
|
"cross_attn.v_pinyin_proj.weight",
|
|
]:
|
|
if "cross_attn.v_proj.weight" in old_state_dict:
|
|
new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone()
|
|
split_transferred.append((key, "cross_attn.v_proj.weight"))
|
|
else:
|
|
new_params.append(key)
|
|
|
|
new_model.load_state_dict(new_state_dict)
|
|
|
|
return directly_transferred, split_transferred, new_params, skipped
|
|
|
|
|
|
def print_report(
|
|
directly_transferred, split_transferred, new_params, skipped, output_path
|
|
):
|
|
"""打印迁移报告"""
|
|
total_params = len(directly_transferred) + len(split_transferred) + len(new_params)
|
|
coverage = (
|
|
(len(directly_transferred) + len(split_transferred)) / total_params * 100
|
|
if total_params > 0
|
|
else 100
|
|
)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("📊 参数迁移报告")
|
|
print("=" * 60)
|
|
|
|
print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):")
|
|
categories = {}
|
|
for key in directly_transferred:
|
|
category = key.split(".")[0]
|
|
if category not in categories:
|
|
categories[category] = []
|
|
categories[category].append(key)
|
|
for cat, keys in sorted(categories.items()):
|
|
print(f" - {cat}.* ({len(keys)} 个参数)")
|
|
|
|
if split_transferred:
|
|
print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):")
|
|
for new_key, old_key in split_transferred:
|
|
print(f" - {new_key} ← {old_key}")
|
|
|
|
if new_params:
|
|
print(f"\n⚠️ 新增参数 ({len(new_params)} 层):")
|
|
for key in new_params[:10]:
|
|
print(f" - {key}")
|
|
if len(new_params) > 10:
|
|
print(f" ... 等共 {len(new_params)} 个")
|
|
|
|
if skipped:
|
|
print(f"\n❌ 跳过 ({len(skipped)} 层):")
|
|
for key, reason, detail in skipped[:5]:
|
|
print(f" - {key}: {reason} ({detail})")
|
|
if len(skipped) > 5:
|
|
print(f" ... 等共 {len(skipped)} 个")
|
|
|
|
print("\n" + "-" * 60)
|
|
print(
|
|
f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})"
|
|
)
|
|
print("=" * 60)
|
|
print(f"\n💾 已保存到: {output_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构")
|
|
parser.add_argument(
|
|
"--old-checkpoint",
|
|
"-c",
|
|
type=str,
|
|
required=True,
|
|
help="旧模型 checkpoint 路径",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
"-o",
|
|
type=str,
|
|
default="./migrated_model.pt",
|
|
help="迁移后的 checkpoint 输出路径",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="设备 (cuda/cpu)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
device = torch.device(args.device)
|
|
|
|
print(f"📦 加载旧模型: {args.old_checkpoint}")
|
|
print(f"🔧 使用设备: {device}")
|
|
|
|
old_path = Path(args.old_checkpoint)
|
|
if not old_path.exists():
|
|
print(f"❌ 文件不存在: {args.old_checkpoint}")
|
|
sys.exit(1)
|
|
|
|
print("\n🏗️ 创建新模型架构...")
|
|
new_model = InputMethodEngine(
|
|
vocab_size=10019,
|
|
pinyin_vocab_size=30,
|
|
dim=512,
|
|
num_slots=8,
|
|
n_layers=4,
|
|
n_heads=4,
|
|
num_experts=10,
|
|
max_seq_len=128,
|
|
compile=False,
|
|
)
|
|
new_model.to(device)
|
|
|
|
print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}")
|
|
|
|
print("\n🔄 开始迁移参数...")
|
|
directly_transferred, split_transferred, new_params, skipped = transfer_weights(
|
|
args.old_checkpoint, new_model, device
|
|
)
|
|
|
|
print("\n💾 保存迁移后的 checkpoint...")
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
"step": 0,
|
|
"epoch": 0,
|
|
"model_state_dict": new_model.state_dict(),
|
|
"best_eval_loss": float("inf"),
|
|
"config": {
|
|
"vocab_size": 10019,
|
|
"pinyin_vocab_size": 30,
|
|
"dim": 512,
|
|
"num_slots": 8,
|
|
"n_layers": 4,
|
|
"n_heads": 4,
|
|
"num_experts": 10,
|
|
"max_seq_len": 128,
|
|
},
|
|
"migration_source": args.old_checkpoint,
|
|
}
|
|
|
|
torch.save(checkpoint, output_path)
|
|
|
|
print_report(
|
|
directly_transferred, split_transferred, new_params, skipped, output_path
|
|
)
|
|
|
|
print("\n✅ 迁移完成!使用方法:")
|
|
print(
|
|
f" python -m model.trainer train --resume-from {args.output} --reset-training-state"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|