148 lines
5.0 KiB
Python
Executable File
148 lines
5.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
快速检查模型权重加载情况的脚本
|
||
"""
|
||
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
|
||
|
||
def analyze_checkpoint(checkpoint_path):
|
||
"""分析checkpoint文件"""
|
||
print(f"🔍 分析checkpoint: {checkpoint_path}")
|
||
|
||
if not Path(checkpoint_path).exists():
|
||
print(f"❌ 文件不存在")
|
||
return
|
||
|
||
try:
|
||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||
print(f"✅ 加载成功")
|
||
print(f" 类型: {type(checkpoint)}")
|
||
|
||
if isinstance(checkpoint, dict):
|
||
print(f" 键名: {list(checkpoint.keys())}")
|
||
|
||
# 找到模型状态字典
|
||
state_dict = None
|
||
if "model_state_dict" in checkpoint:
|
||
state_dict = checkpoint["model_state_dict"]
|
||
print(f" 🔍 使用'model_state_dict'键")
|
||
elif "state_dict" in checkpoint:
|
||
state_dict = checkpoint["state_dict"]
|
||
print(f" 🔍 使用'state_dict'键")
|
||
else:
|
||
# 可能是直接的状态字典
|
||
state_dict = checkpoint
|
||
print(f" 🔍 使用直接状态字典")
|
||
|
||
if state_dict:
|
||
print(f" 总权重数: {len(state_dict)}")
|
||
|
||
# 分析分类头权重
|
||
classifier_keys = []
|
||
for key in state_dict.keys():
|
||
if "classifier" in key:
|
||
classifier_keys.append(key)
|
||
|
||
if classifier_keys:
|
||
print(f" 📊 分类头相关权重:")
|
||
for key in classifier_keys:
|
||
weight = state_dict[key]
|
||
print(f" {key}: shape={weight.shape}")
|
||
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
|
||
print(f" 均值: {weight.mean():.6f}")
|
||
print(f" 标准差: {weight.std():.6f}")
|
||
|
||
# 检查权重是否接近随机初始化
|
||
if weight.std() < 0.01:
|
||
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
|
||
|
||
# 检查模型架构键名
|
||
print(f"\n 🔑 模型架构键名示例(前20个):")
|
||
for i, key in enumerate(list(state_dict.keys())[:20]):
|
||
weight = state_dict[key]
|
||
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
|
||
|
||
# 检查是否有预期的组件
|
||
expected_components = [
|
||
"context_encoder",
|
||
"slot_memory",
|
||
"cross_attn",
|
||
"moe",
|
||
"classifier",
|
||
]
|
||
found_components = []
|
||
for comp in expected_components:
|
||
found = any(comp in key for key in state_dict.keys())
|
||
if found:
|
||
found_components.append(comp)
|
||
|
||
print(f"\n 📋 找到的模型组件: {found_components}")
|
||
missing = set(expected_components) - set(found_components)
|
||
if missing:
|
||
print(f" ❌ 缺失的组件: {missing}")
|
||
|
||
return state_dict
|
||
else:
|
||
print(f"❌ checkpoint不是字典类型")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 加载失败: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
|
||
|
||
def check_weight_distribution(state_dict):
|
||
"""检查权重分布"""
|
||
print(f"\n📊 权重分布统计:")
|
||
|
||
weight_stats = []
|
||
for key, weight in state_dict.items():
|
||
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
|
||
stats = {
|
||
"key": key,
|
||
"shape": weight.shape,
|
||
"min": weight.min().item(),
|
||
"max": weight.max().item(),
|
||
"mean": weight.mean().item(),
|
||
"std": weight.std().item(),
|
||
"abs_mean": weight.abs().mean().item(),
|
||
}
|
||
weight_stats.append(stats)
|
||
|
||
# 打印前10个权重
|
||
for i, stats in enumerate(weight_stats[:10]):
|
||
print(f" {i + 1:2d}. {stats['key']:40}")
|
||
print(f" 形状: {stats['shape']}")
|
||
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
|
||
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
|
||
|
||
# 检查是否接近随机初始化
|
||
if stats["std"] < 0.01:
|
||
print(f" ⚠️ 警告: 标准差很小,可能未训练")
|
||
|
||
return weight_stats
|
||
|
||
|
||
def main():
|
||
import sys
|
||
|
||
if len(sys.argv) < 2:
|
||
print("使用方法: python check_weights.py <checkpoint_path>")
|
||
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
|
||
return
|
||
|
||
checkpoint_path = sys.argv[1]
|
||
state_dict = analyze_checkpoint(checkpoint_path)
|
||
|
||
if state_dict:
|
||
check_weight_distribution(state_dict)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|