SUimeModelTraner/check_weights.py

148 lines
5.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
快速检查模型权重加载情况的脚本
"""
from pathlib import Path
import numpy as np
import torch
def analyze_checkpoint(checkpoint_path):
"""分析checkpoint文件"""
print(f"🔍 分析checkpoint: {checkpoint_path}")
if not Path(checkpoint_path).exists():
print(f"❌ 文件不存在")
return
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"✅ 加载成功")
print(f" 类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f" 键名: {list(checkpoint.keys())}")
# 找到模型状态字典
state_dict = None
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
print(f" 🔍 使用'model_state_dict'")
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
print(f" 🔍 使用'state_dict'")
else:
# 可能是直接的状态字典
state_dict = checkpoint
print(f" 🔍 使用直接状态字典")
if state_dict:
print(f" 总权重数: {len(state_dict)}")
# 分析分类头权重
classifier_keys = []
for key in state_dict.keys():
if "classifier" in key:
classifier_keys.append(key)
if classifier_keys:
print(f" 📊 分类头相关权重:")
for key in classifier_keys:
weight = state_dict[key]
print(f" {key}: shape={weight.shape}")
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
print(f" 均值: {weight.mean():.6f}")
print(f" 标准差: {weight.std():.6f}")
# 检查权重是否接近随机初始化
if weight.std() < 0.01:
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
# 检查模型架构键名
print(f"\n 🔑 模型架构键名示例前20个:")
for i, key in enumerate(list(state_dict.keys())[:20]):
weight = state_dict[key]
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
# 检查是否有预期的组件
expected_components = [
"context_encoder",
"slot_memory",
"cross_attn",
"moe",
"classifier",
]
found_components = []
for comp in expected_components:
found = any(comp in key for key in state_dict.keys())
if found:
found_components.append(comp)
print(f"\n 📋 找到的模型组件: {found_components}")
missing = set(expected_components) - set(found_components)
if missing:
print(f" ❌ 缺失的组件: {missing}")
return state_dict
else:
print(f"❌ checkpoint不是字典类型")
except Exception as e:
print(f"❌ 加载失败: {e}")
import traceback
traceback.print_exc()
def check_weight_distribution(state_dict):
"""检查权重分布"""
print(f"\n📊 权重分布统计:")
weight_stats = []
for key, weight in state_dict.items():
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
stats = {
"key": key,
"shape": weight.shape,
"min": weight.min().item(),
"max": weight.max().item(),
"mean": weight.mean().item(),
"std": weight.std().item(),
"abs_mean": weight.abs().mean().item(),
}
weight_stats.append(stats)
# 打印前10个权重
for i, stats in enumerate(weight_stats[:10]):
print(f" {i + 1:2d}. {stats['key']:40}")
print(f" 形状: {stats['shape']}")
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
# 检查是否接近随机初始化
if stats["std"] < 0.01:
print(f" ⚠️ 警告: 标准差很小,可能未训练")
return weight_stats
def main():
import sys
if len(sys.argv) < 2:
print("使用方法: python check_weights.py <checkpoint_path>")
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
return
checkpoint_path = sys.argv[1]
state_dict = analyze_checkpoint(checkpoint_path)
if state_dict:
check_weight_distribution(state_dict)
if __name__ == "__main__":
main()