#!/usr/bin/env python3 """ 快速检查模型权重加载情况的脚本 """ from pathlib import Path import numpy as np import torch def analyze_checkpoint(checkpoint_path): """分析checkpoint文件""" print(f"🔍 分析checkpoint: {checkpoint_path}") if not Path(checkpoint_path).exists(): print(f"❌ 文件不存在") return try: checkpoint = torch.load(checkpoint_path, map_location="cpu") print(f"✅ 加载成功") print(f" 类型: {type(checkpoint)}") if isinstance(checkpoint, dict): print(f" 键名: {list(checkpoint.keys())}") # 找到模型状态字典 state_dict = None if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] print(f" 🔍 使用'model_state_dict'键") elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] print(f" 🔍 使用'state_dict'键") else: # 可能是直接的状态字典 state_dict = checkpoint print(f" 🔍 使用直接状态字典") if state_dict: print(f" 总权重数: {len(state_dict)}") # 分析分类头权重 classifier_keys = [] for key in state_dict.keys(): if "classifier" in key: classifier_keys.append(key) if classifier_keys: print(f" 📊 分类头相关权重:") for key in classifier_keys: weight = state_dict[key] print(f" {key}: shape={weight.shape}") print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]") print(f" 均值: {weight.mean():.6f}") print(f" 标准差: {weight.std():.6f}") # 检查权重是否接近随机初始化 if weight.std() < 0.01: print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练") # 检查模型架构键名 print(f"\n 🔑 模型架构键名示例(前20个):") for i, key in enumerate(list(state_dict.keys())[:20]): weight = state_dict[key] print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}") # 检查是否有预期的组件 expected_components = [ "context_encoder", "slot_memory", "cross_attn", "moe", "classifier", ] found_components = [] for comp in expected_components: found = any(comp in key for key in state_dict.keys()) if found: found_components.append(comp) print(f"\n 📋 找到的模型组件: {found_components}") missing = set(expected_components) - set(found_components) if missing: print(f" ❌ 缺失的组件: {missing}") return state_dict else: print(f"❌ checkpoint不是字典类型") except Exception as e: print(f"❌ 加载失败: {e}") import traceback traceback.print_exc() def check_weight_distribution(state_dict): """检查权重分布""" print(f"\n📊 权重分布统计:") weight_stats = [] for key, weight in state_dict.items(): if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置 stats = { "key": key, "shape": weight.shape, "min": weight.min().item(), "max": weight.max().item(), "mean": weight.mean().item(), "std": weight.std().item(), "abs_mean": weight.abs().mean().item(), } weight_stats.append(stats) # 打印前10个权重 for i, stats in enumerate(weight_stats[:10]): print(f" {i + 1:2d}. {stats['key']:40}") print(f" 形状: {stats['shape']}") print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]") print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}") # 检查是否接近随机初始化 if stats["std"] < 0.01: print(f" ⚠️ 警告: 标准差很小,可能未训练") return weight_stats def main(): import sys if len(sys.argv) < 2: print("使用方法: python check_weights.py ") print("示例: python check_weights.py ./output/checkpoints/best_model.pt") return checkpoint_path = sys.argv[1] state_dict = analyze_checkpoint(checkpoint_path) if state_dict: check_weight_distribution(state_dict) if __name__ == "__main__": main()