feat(checkpoint): 优化 checkpoint 保存策略,保留最后3个epoch并节省磁盘空间
This commit is contained in:
parent
1cdef19153
commit
0fea985b45
|
|
@ -0,0 +1,187 @@
|
||||||
|
# Checkpoint 保存策略更新说明
|
||||||
|
|
||||||
|
## 更新概述
|
||||||
|
|
||||||
|
为了解决硬盘空间限制问题,修改了模型 checkpoint 的保存方式,采用**同名覆盖**的方式仅保留**最后 3 个 epoch**的 checkpoint。
|
||||||
|
|
||||||
|
## 主要变更
|
||||||
|
|
||||||
|
### 1. 保存策略
|
||||||
|
|
||||||
|
#### 保留的文件
|
||||||
|
- ✅ `best_model.pt` - 最佳模型(当 eval_loss 降低时保存)
|
||||||
|
- ✅ `latest_checkpoint.pt` - 定期保存的检查点(每 save_frequency 步覆盖保存)
|
||||||
|
- ✅ `epoch_checkpoint_1.pt` - 循环保存的最后 3 个 epoch 之一
|
||||||
|
- ✅ `epoch_checkpoint_2.pt` - 循环保存的最后 3 个 epoch 之一
|
||||||
|
- ✅ `epoch_checkpoint_3.pt` - 循环保存的最后 3 个 epoch 之一
|
||||||
|
- ✅ `epoch_checkpoints.json` - epoch checkpoint 元数据文件
|
||||||
|
|
||||||
|
#### 移除的文件
|
||||||
|
- ❌ `final_model.pt` - 与 epoch_last 重复,已移除
|
||||||
|
- ❌ `epoch_*.pt` - 不再为每个 epoch 创建独立文件
|
||||||
|
|
||||||
|
### 2. 循环覆盖机制
|
||||||
|
|
||||||
|
使用 3 个固定文件名循环覆盖:
|
||||||
|
```
|
||||||
|
Epoch 1 -> epoch_checkpoint_1.pt
|
||||||
|
Epoch 2 -> epoch_checkpoint_2.pt
|
||||||
|
Epoch 3 -> epoch_checkpoint_3.pt
|
||||||
|
Epoch 4 -> epoch_checkpoint_1.pt (覆盖 epoch 1)
|
||||||
|
Epoch 5 -> epoch_checkpoint_2.pt (覆盖 epoch 2)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. JSON 元数据管理
|
||||||
|
|
||||||
|
`epoch_checkpoints.json` 文件记录:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"checkpoints": [
|
||||||
|
{
|
||||||
|
"epoch": 8,
|
||||||
|
"file": "epoch_checkpoint_2.pt",
|
||||||
|
"path": "/path/to/checkpoints/epoch_checkpoint_2.pt",
|
||||||
|
"saved_at": "2026-04-11T10:30:00",
|
||||||
|
"step": 800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 9,
|
||||||
|
"file": "epoch_checkpoint_3.pt",
|
||||||
|
"path": "/path/to/checkpoints/epoch_checkpoint_3.pt",
|
||||||
|
"saved_at": "2026-04-11T10:35:00",
|
||||||
|
"step": 900
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 10,
|
||||||
|
"file": "epoch_checkpoint_1.pt",
|
||||||
|
"path": "/path/to/checkpoints/epoch_checkpoint_1.pt",
|
||||||
|
"saved_at": "2026-04-11T10:40:00",
|
||||||
|
"step": 1000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"next_slot": 1,
|
||||||
|
"total_epochs_completed": 10
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 自动恢复功能
|
||||||
|
|
||||||
|
训练时会自动从 JSON 文件中读取最新的 checkpoint 并恢复训练:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 训练时自动恢复
|
||||||
|
trainer.train(
|
||||||
|
resume_from=None, # 如果指定,优先使用指定的 checkpoint
|
||||||
|
reset_training_state=False,
|
||||||
|
auto_resume=True # 自动从最新的 epoch checkpoint 恢复
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**恢复优先级:**
|
||||||
|
1. 如果指定了 `resume_from`,使用指定的 checkpoint
|
||||||
|
2. 否则,如果存在 epoch checkpoint 元数据,自动从最新的 epoch checkpoint 恢复
|
||||||
|
3. 否则,从头开始训练
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 查看保存的 checkpoint
|
||||||
|
|
||||||
|
训练结束后,可以通过 JSON 文件查看保存的 checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看元数据
|
||||||
|
cat ./output/checkpoints/epoch_checkpoints.json
|
||||||
|
|
||||||
|
# 或者使用 Python
|
||||||
|
python -c "
|
||||||
|
import json
|
||||||
|
with open('./output/checkpoints/epoch_checkpoints.json') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
print('保存的 epochs:', [cp['epoch'] for cp in data['checkpoints']])
|
||||||
|
print('最新的 checkpoint:', max(data['checkpoints'], key=lambda x: x['epoch']))
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 手动加载特定 epoch 的 checkpoint
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from src.model.trainer import Trainer
|
||||||
|
|
||||||
|
# 加载元数据
|
||||||
|
import json
|
||||||
|
with open('./output/checkpoints/epoch_checkpoints.json') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# 按 epoch 排序
|
||||||
|
sorted_checkpoints = sorted(data['checkpoints'], key=lambda x: x['epoch'])
|
||||||
|
print('可用的 epochs:', [cp['epoch'] for cp in sorted_checkpoints])
|
||||||
|
|
||||||
|
# 加载特定 epoch 的 checkpoint
|
||||||
|
target_epoch = 9
|
||||||
|
checkpoint_info = next(cp for cp in sorted_checkpoints if cp['epoch'] == target_epoch)
|
||||||
|
checkpoint_path = checkpoint_info['path']
|
||||||
|
|
||||||
|
# 使用 trainer 加载
|
||||||
|
trainer.load_checkpoint(checkpoint_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 禁用自动恢复
|
||||||
|
|
||||||
|
如果需要从头开始训练,可以禁用自动恢复:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在代码中设置
|
||||||
|
trainer.train(auto_resume=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 优势
|
||||||
|
|
||||||
|
1. **节省磁盘空间** - 只保留最后 3 个 epoch,不会随训练时间增长而占用更多空间
|
||||||
|
2. **自动管理** - 无需手动删除旧的 checkpoint
|
||||||
|
3. **顺序清晰** - 通过 JSON 文件可以清楚知道每个 checkpoint 对应的 epoch
|
||||||
|
4. **自动恢复** - 训练中断后可以从最近的 checkpoint 自动恢复
|
||||||
|
5. **保留重要 checkpoint** - best_model 和 latest_checkpoint 仍然保留
|
||||||
|
|
||||||
|
## 磁盘空间对比
|
||||||
|
|
||||||
|
### 修改前
|
||||||
|
```
|
||||||
|
假设训练 100 个 epoch,每个 checkpoint 100MB:
|
||||||
|
- epoch_1.pt ~ epoch_100.pt: 100 * 100MB = 10GB
|
||||||
|
- best_model.pt: 100MB
|
||||||
|
- final_model.pt: 100MB
|
||||||
|
- latest_checkpoint.pt: 100MB
|
||||||
|
总计:~10.3GB
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修改后
|
||||||
|
```
|
||||||
|
训练 100 个 epoch,每个 checkpoint 100MB:
|
||||||
|
- epoch_checkpoint_1~3.pt: 3 * 100MB = 300MB
|
||||||
|
- best_model.pt: 100MB
|
||||||
|
- latest_checkpoint.pt: 100MB
|
||||||
|
- epoch_checkpoints.json: <1KB
|
||||||
|
总计:~400MB
|
||||||
|
|
||||||
|
节省空间:~9.9GB (96% 空间节省)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **训练中断** - 如果训练在中途被中断,只会保留中断前的最后 3 个 epoch
|
||||||
|
2. **JSON 文件** - 不要手动修改 `epoch_checkpoints.json` 文件,否则可能导致恢复失败
|
||||||
|
3. **兼容性** - 旧的 `epoch_*.pt` 文件不会自动删除,如果需要可以手动清理
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
运行测试脚本验证功能:
|
||||||
|
```bash
|
||||||
|
python test_epoch_checkpoint.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `src/model/trainer.py` - 主要修改文件
|
||||||
|
- `test_epoch_checkpoint.py` - 功能测试脚本
|
||||||
159
check_weights.py
159
check_weights.py
|
|
@ -1,147 +1,22 @@
|
||||||
#!/usr/bin/env python3
|
from model.dataset import PinyinInputDataset
|
||||||
"""
|
from torch.utils.data import DataLoader
|
||||||
快速检查模型权重加载情况的脚本
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
from model.trainer import collate_fn, worker_init_fn
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_checkpoint(checkpoint_path):
|
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||||
"""分析checkpoint文件"""
|
|
||||||
print(f"🔍 分析checkpoint: {checkpoint_path}")
|
|
||||||
|
|
||||||
if not Path(checkpoint_path).exists():
|
dataloader = DataLoader(
|
||||||
print(f"❌ 文件不存在")
|
data,
|
||||||
return
|
batch_size=1024,
|
||||||
|
num_workers=2,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=2, # 减少预取以避免内存问题
|
||||||
|
persistent_workers=True,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
for i in dataloader:
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
print((i['labels'] == 1).sum())
|
||||||
print(f"✅ 加载成功")
|
break
|
||||||
print(f" 类型: {type(checkpoint)}")
|
|
||||||
|
|
||||||
if isinstance(checkpoint, dict):
|
|
||||||
print(f" 键名: {list(checkpoint.keys())}")
|
|
||||||
|
|
||||||
# 找到模型状态字典
|
|
||||||
state_dict = None
|
|
||||||
if "model_state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["model_state_dict"]
|
|
||||||
print(f" 🔍 使用'model_state_dict'键")
|
|
||||||
elif "state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["state_dict"]
|
|
||||||
print(f" 🔍 使用'state_dict'键")
|
|
||||||
else:
|
|
||||||
# 可能是直接的状态字典
|
|
||||||
state_dict = checkpoint
|
|
||||||
print(f" 🔍 使用直接状态字典")
|
|
||||||
|
|
||||||
if state_dict:
|
|
||||||
print(f" 总权重数: {len(state_dict)}")
|
|
||||||
|
|
||||||
# 分析分类头权重
|
|
||||||
classifier_keys = []
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if "classifier" in key:
|
|
||||||
classifier_keys.append(key)
|
|
||||||
|
|
||||||
if classifier_keys:
|
|
||||||
print(f" 📊 分类头相关权重:")
|
|
||||||
for key in classifier_keys:
|
|
||||||
weight = state_dict[key]
|
|
||||||
print(f" {key}: shape={weight.shape}")
|
|
||||||
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
|
|
||||||
print(f" 均值: {weight.mean():.6f}")
|
|
||||||
print(f" 标准差: {weight.std():.6f}")
|
|
||||||
|
|
||||||
# 检查权重是否接近随机初始化
|
|
||||||
if weight.std() < 0.01:
|
|
||||||
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
|
|
||||||
|
|
||||||
# 检查模型架构键名
|
|
||||||
print(f"\n 🔑 模型架构键名示例(前20个):")
|
|
||||||
for i, key in enumerate(list(state_dict.keys())[:20]):
|
|
||||||
weight = state_dict[key]
|
|
||||||
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
|
|
||||||
|
|
||||||
# 检查是否有预期的组件
|
|
||||||
expected_components = [
|
|
||||||
"context_encoder",
|
|
||||||
"slot_memory",
|
|
||||||
"cross_attn",
|
|
||||||
"moe",
|
|
||||||
"classifier",
|
|
||||||
]
|
|
||||||
found_components = []
|
|
||||||
for comp in expected_components:
|
|
||||||
found = any(comp in key for key in state_dict.keys())
|
|
||||||
if found:
|
|
||||||
found_components.append(comp)
|
|
||||||
|
|
||||||
print(f"\n 📋 找到的模型组件: {found_components}")
|
|
||||||
missing = set(expected_components) - set(found_components)
|
|
||||||
if missing:
|
|
||||||
print(f" ❌ 缺失的组件: {missing}")
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
else:
|
|
||||||
print(f"❌ checkpoint不是字典类型")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 加载失败: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
def check_weight_distribution(state_dict):
|
|
||||||
"""检查权重分布"""
|
|
||||||
print(f"\n📊 权重分布统计:")
|
|
||||||
|
|
||||||
weight_stats = []
|
|
||||||
for key, weight in state_dict.items():
|
|
||||||
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
|
|
||||||
stats = {
|
|
||||||
"key": key,
|
|
||||||
"shape": weight.shape,
|
|
||||||
"min": weight.min().item(),
|
|
||||||
"max": weight.max().item(),
|
|
||||||
"mean": weight.mean().item(),
|
|
||||||
"std": weight.std().item(),
|
|
||||||
"abs_mean": weight.abs().mean().item(),
|
|
||||||
}
|
|
||||||
weight_stats.append(stats)
|
|
||||||
|
|
||||||
# 打印前10个权重
|
|
||||||
for i, stats in enumerate(weight_stats[:10]):
|
|
||||||
print(f" {i + 1:2d}. {stats['key']:40}")
|
|
||||||
print(f" 形状: {stats['shape']}")
|
|
||||||
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
|
|
||||||
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
|
|
||||||
|
|
||||||
# 检查是否接近随机初始化
|
|
||||||
if stats["std"] < 0.01:
|
|
||||||
print(f" ⚠️ 警告: 标准差很小,可能未训练")
|
|
||||||
|
|
||||||
return weight_stats
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("使用方法: python check_weights.py <checkpoint_path>")
|
|
||||||
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
|
|
||||||
return
|
|
||||||
|
|
||||||
checkpoint_path = sys.argv[1]
|
|
||||||
state_dict = analyze_checkpoint(checkpoint_path)
|
|
||||||
|
|
||||||
if state_dict:
|
|
||||||
check_weight_distribution(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"char": "",
|
"char": "",
|
||||||
"pinyin": "",
|
"pinyin": "",
|
||||||
"count": 434748360
|
"count": 11067734826
|
||||||
},
|
},
|
||||||
"1": {
|
"1": {
|
||||||
"id": 1,
|
"id": 1,
|
||||||
|
|
@ -142549,4 +142549,4 @@
|
||||||
"compressed": false,
|
"compressed": false,
|
||||||
"pair_count": 20646
|
"pair_count": 20646
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,526 @@
|
||||||
|
"""
|
||||||
|
交互式训练配置界面
|
||||||
|
|
||||||
|
提供终端风格的交互式配置界面,用于配置输入法模型训练参数。
|
||||||
|
使用Rich库创建美观的终端界面。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.columns import Columns
|
||||||
|
from rich.text import Text
|
||||||
|
from rich import box
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfigUI:
|
||||||
|
"""
|
||||||
|
训练配置交互式界面
|
||||||
|
|
||||||
|
提供终端风格的交互式配置界面,按必需程度分组参数:
|
||||||
|
1. 必需参数(必须提供)
|
||||||
|
2. 推荐参数(有合理默认值)
|
||||||
|
3. 可选参数(高级参数)
|
||||||
|
4. 恢复参数(训练恢复相关)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, console: Optional[Console] = None):
|
||||||
|
"""
|
||||||
|
初始化交互式界面
|
||||||
|
|
||||||
|
Args:
|
||||||
|
console: Rich控制台实例,如果为None则创建新的
|
||||||
|
"""
|
||||||
|
self.console = console or Console()
|
||||||
|
self.config = {}
|
||||||
|
|
||||||
|
# 参数定义
|
||||||
|
self.param_definitions = self._load_param_definitions()
|
||||||
|
|
||||||
|
# 颜色主题
|
||||||
|
self.colors = {
|
||||||
|
"primary": "cyan",
|
||||||
|
"secondary": "green",
|
||||||
|
"accent": "yellow",
|
||||||
|
"warning": "red",
|
||||||
|
"success": "green",
|
||||||
|
"info": "blue",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_param_definitions(self) -> Dict[str, List[Dict]]:
|
||||||
|
"""
|
||||||
|
加载参数定义
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按组分类的参数定义
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"required": [
|
||||||
|
{
|
||||||
|
"name": "train_data_path",
|
||||||
|
"type": "path",
|
||||||
|
"prompt": "训练数据集路径",
|
||||||
|
"required": True,
|
||||||
|
"validate": self._validate_path,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "eval_data_path",
|
||||||
|
"type": "path",
|
||||||
|
"prompt": "评估数据集路径",
|
||||||
|
"required": True,
|
||||||
|
"validate": self._validate_path,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"recommended": [
|
||||||
|
{
|
||||||
|
"name": "output_dir",
|
||||||
|
"type": "path",
|
||||||
|
"prompt": "输出目录",
|
||||||
|
"default": "./output",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "batch_size",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "批次大小",
|
||||||
|
"default": 128,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x > 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "num_epochs",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "训练轮数",
|
||||||
|
"default": 10,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x > 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "learning_rate",
|
||||||
|
"type": "float",
|
||||||
|
"prompt": "学习率",
|
||||||
|
"default": 1e-5,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: 1e-9 <= x <= 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mixed_precision",
|
||||||
|
"type": "bool",
|
||||||
|
"prompt": "使用混合精度训练",
|
||||||
|
"default": True,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "compile",
|
||||||
|
"type": "bool",
|
||||||
|
"prompt": "启用torch.compile优化",
|
||||||
|
"default": False,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "use_tensorboard",
|
||||||
|
"type": "bool",
|
||||||
|
"prompt": "使用TensorBoard记录",
|
||||||
|
"default": True,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"optional": [
|
||||||
|
{
|
||||||
|
"name": "min_learning_rate",
|
||||||
|
"type": "float",
|
||||||
|
"prompt": "最小学习率",
|
||||||
|
"default": 1e-9,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: 1e-12 <= x <= 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "warmup_ratio",
|
||||||
|
"type": "float",
|
||||||
|
"prompt": "热身步数比例",
|
||||||
|
"default": 0.1,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: 0 <= x <= 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "eval_frequency",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "评估频率(步数)",
|
||||||
|
"default": 500,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x > 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "save_frequency",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "保存检查点频率(步数)",
|
||||||
|
"default": 1000,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x > 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "max_iter_length",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "数据集大小",
|
||||||
|
"default": 1024 * 1024 * 128,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x > 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "num_workers",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "数据加载worker数量",
|
||||||
|
"default": 2,
|
||||||
|
"required": False,
|
||||||
|
"validate": lambda x: x >= 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "seed",
|
||||||
|
"type": "int",
|
||||||
|
"prompt": "随机种子",
|
||||||
|
"default": 42,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"recovery": [
|
||||||
|
{
|
||||||
|
"name": "auto_resume",
|
||||||
|
"type": "bool",
|
||||||
|
"prompt": "自动从最新checkpoint恢复",
|
||||||
|
"default": True,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "resume_from",
|
||||||
|
"type": "path",
|
||||||
|
"prompt": "从指定checkpoint恢复(可选)",
|
||||||
|
"default": None,
|
||||||
|
"required": False,
|
||||||
|
"validate": self._validate_optional_path,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "reset_training_state",
|
||||||
|
"type": "bool",
|
||||||
|
"prompt": "重置训练状态(只加载权重)",
|
||||||
|
"default": False,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _validate_path(self, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 路径字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否有效
|
||||||
|
"""
|
||||||
|
if not path:
|
||||||
|
return False
|
||||||
|
path_obj = Path(path)
|
||||||
|
return path_obj.exists()
|
||||||
|
|
||||||
|
def _validate_optional_path(self, path: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
验证可选路径(可以为None)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 路径字符串或None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否有效
|
||||||
|
"""
|
||||||
|
if path is None or path == "":
|
||||||
|
return True
|
||||||
|
return self._validate_path(path)
|
||||||
|
|
||||||
|
def show_welcome(self):
|
||||||
|
"""
|
||||||
|
显示欢迎界面
|
||||||
|
"""
|
||||||
|
self.console.clear()
|
||||||
|
|
||||||
|
welcome_text = Text()
|
||||||
|
welcome_text.append("🚀 ", style="bold yellow")
|
||||||
|
welcome_text.append("输入法模型训练系统", style="bold cyan")
|
||||||
|
welcome_text.append("\n")
|
||||||
|
welcome_text.append("=" * 50, style="dim")
|
||||||
|
welcome_text.append("\n\n")
|
||||||
|
welcome_text.append("欢迎使用交互式训练配置界面!\n", style="bold green")
|
||||||
|
welcome_text.append("请按照提示配置训练参数。\n", style="green")
|
||||||
|
welcome_text.append("\n")
|
||||||
|
welcome_text.append("按 [Enter] 使用默认值\n", style="dim")
|
||||||
|
welcome_text.append("按 [Ctrl+C] 退出\n", style="dim")
|
||||||
|
|
||||||
|
panel = Panel(
|
||||||
|
welcome_text,
|
||||||
|
title="🎯 开始配置",
|
||||||
|
border_style=self.colors["primary"],
|
||||||
|
padding=(1, 2),
|
||||||
|
expand=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.console.print(panel)
|
||||||
|
self.console.print()
|
||||||
|
|
||||||
|
def ask_param_group(self, group_name: str, group_title: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
询问一组参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_name: 参数组名称
|
||||||
|
group_title: 显示标题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该组的配置字典
|
||||||
|
"""
|
||||||
|
group_config = {}
|
||||||
|
params = self.param_definitions.get(group_name, [])
|
||||||
|
|
||||||
|
if not params:
|
||||||
|
return group_config
|
||||||
|
|
||||||
|
# 显示组标题
|
||||||
|
self.console.print()
|
||||||
|
self.console.print(f"[bold {self.colors['primary']}]{group_title}[/bold {self.colors['primary']}]")
|
||||||
|
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
name = param["name"]
|
||||||
|
prompt = param["prompt"]
|
||||||
|
param_type = param["type"]
|
||||||
|
default = param.get("default")
|
||||||
|
required = param.get("required", False)
|
||||||
|
validate_func = param.get("validate")
|
||||||
|
|
||||||
|
# 如果已经有值(从命令行传入),跳过
|
||||||
|
if name in self.config and self.config[name] is not None:
|
||||||
|
group_config[name] = self.config[name]
|
||||||
|
self.console.print(f" {prompt}: [green]{self.config[name]}[/green] (已提供)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 询问参数
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if param_type == "bool":
|
||||||
|
value = Confirm.ask(
|
||||||
|
f" {prompt}",
|
||||||
|
default=default if default is not None else False,
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
elif param_type == "int":
|
||||||
|
value = IntPrompt.ask(
|
||||||
|
f" {prompt}",
|
||||||
|
default=default if default is not None else 0,
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
elif param_type == "float":
|
||||||
|
value = FloatPrompt.ask(
|
||||||
|
f" {prompt}",
|
||||||
|
default=default if default is not None else 0.0,
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
else: # "path" or other string types
|
||||||
|
default_str = str(default) if default is not None else ""
|
||||||
|
value = Prompt.ask(
|
||||||
|
f" {prompt}",
|
||||||
|
default=default_str,
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
# 处理空字符串
|
||||||
|
if value == "":
|
||||||
|
value = None
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
if validate_func and value is not None:
|
||||||
|
if not validate_func(value):
|
||||||
|
self.console.print(f" [red]无效值,请重新输入[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查必需参数
|
||||||
|
if required and (value is None or value == ""):
|
||||||
|
self.console.print(f" [red]此参数为必需参数[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
group_config[name] = value
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.console.print("\n[yellow]已取消[/yellow]")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f" [red]输入错误: {e}[/red]")
|
||||||
|
if not required:
|
||||||
|
# 宽松处理:使用默认值
|
||||||
|
group_config[name] = default
|
||||||
|
self.console.print(f" [yellow]使用默认值: {default}[/yellow]")
|
||||||
|
break
|
||||||
|
|
||||||
|
return group_config
|
||||||
|
|
||||||
|
def show_config_summary(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
显示配置摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 完整的配置字典
|
||||||
|
"""
|
||||||
|
self.console.print()
|
||||||
|
self.console.print(f"[bold {self.colors['primary']]}📋 配置摘要[/bold {self.colors['primary']}]")
|
||||||
|
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
|
||||||
|
|
||||||
|
# 创建表格
|
||||||
|
table = Table(show_header=True, header_style=f"bold {self.colors['accent']}", box=box.ROUNDED)
|
||||||
|
table.add_column("参数", style=self.colors["primary"])
|
||||||
|
table.add_column("值", style=self.colors["success"])
|
||||||
|
table.add_column("类型", style=self.colors["info"])
|
||||||
|
|
||||||
|
# 按组添加参数
|
||||||
|
groups = [
|
||||||
|
("必需参数", "required"),
|
||||||
|
("推荐参数", "recommended"),
|
||||||
|
("可选参数", "optional"),
|
||||||
|
("恢复参数", "recovery"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for group_title, group_name in groups:
|
||||||
|
params = self.param_definitions.get(group_name, [])
|
||||||
|
if params:
|
||||||
|
table.add_row(f"[bold]{group_title}[/bold]", "", "", style="bold")
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
name = param["name"]
|
||||||
|
if name in config:
|
||||||
|
value = config[name]
|
||||||
|
param_type = param["type"]
|
||||||
|
|
||||||
|
# 格式化值
|
||||||
|
if value is None:
|
||||||
|
value_str = "[dim]None[/dim]"
|
||||||
|
elif isinstance(value, bool):
|
||||||
|
value_str = "是" if value else "否"
|
||||||
|
else:
|
||||||
|
value_str = str(value)
|
||||||
|
|
||||||
|
table.add_row(f" {param['prompt']}", value_str, param_type)
|
||||||
|
|
||||||
|
self.console.print(table)
|
||||||
|
self.console.print()
|
||||||
|
|
||||||
|
def run(self, initial_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
运行完整的交互式配置流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_config: 初始配置(如从命令行传入的参数)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的配置字典
|
||||||
|
"""
|
||||||
|
# 初始化配置
|
||||||
|
self.config = initial_config or {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 显示欢迎界面
|
||||||
|
self.show_welcome()
|
||||||
|
|
||||||
|
# 询问必需参数
|
||||||
|
self.console.print()
|
||||||
|
if Confirm.ask("[bold cyan]是否配置必需参数?[/bold cyan]", default=True):
|
||||||
|
required_config = self.ask_param_group("required", "📁 必需参数")
|
||||||
|
self.config.update(required_config)
|
||||||
|
|
||||||
|
# 检查必需参数是否已提供
|
||||||
|
required_params = self.param_definitions["required"]
|
||||||
|
for param in required_params:
|
||||||
|
if param["name"] not in self.config or self.config[param["name"]] is None:
|
||||||
|
self.console.print(f"[red]错误: 必需参数 '{param['prompt']}' 未提供[/red]")
|
||||||
|
self.console.print("[yellow]请重新配置必需参数[/yellow]")
|
||||||
|
required_config = self.ask_param_group("required", "📁 必需参数")
|
||||||
|
self.config.update(required_config)
|
||||||
|
|
||||||
|
# 询问推荐参数
|
||||||
|
self.console.print()
|
||||||
|
if Confirm.ask("[bold cyan]是否配置推荐参数?[/bold cyan]", default=True):
|
||||||
|
recommended_config = self.ask_param_group("recommended", "⚙️ 推荐参数")
|
||||||
|
self.config.update(recommended_config)
|
||||||
|
|
||||||
|
# 询问可选参数
|
||||||
|
self.console.print()
|
||||||
|
if Confirm.ask("[bold cyan]是否配置可选参数(高级)?[/bold cyan]", default=False):
|
||||||
|
optional_config = self.ask_param_group("optional", "🎛️ 可选参数")
|
||||||
|
self.config.update(optional_config)
|
||||||
|
|
||||||
|
# 询问恢复参数
|
||||||
|
self.console.print()
|
||||||
|
if Confirm.ask("[bold cyan]是否配置恢复参数?[/bold cyan]", default=True):
|
||||||
|
recovery_config = self.ask_param_group("recovery", "🔄 恢复参数")
|
||||||
|
self.config.update(recovery_config)
|
||||||
|
|
||||||
|
# 显示配置摘要
|
||||||
|
self.show_config_summary(self.config)
|
||||||
|
|
||||||
|
# 确认配置
|
||||||
|
self.console.print()
|
||||||
|
if not Confirm.ask("[bold green]是否使用此配置开始训练?[/bold green]", default=True):
|
||||||
|
self.console.print("[yellow]配置已取消[/yellow]")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.console.print("\n[yellow]配置已取消[/yellow]")
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
self.console.print(f"[red]配置过程中出现错误: {e}[/red]")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_interactive_config(provided_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取交互式配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provided_params: 已提供的参数(如从命令行传入)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的配置字典
|
||||||
|
"""
|
||||||
|
console = Console()
|
||||||
|
ui = TrainingConfigUI(console)
|
||||||
|
|
||||||
|
# 过滤掉None值
|
||||||
|
if provided_params:
|
||||||
|
provided_params = {k: v for k, v in provided_params.items() if v is not None}
|
||||||
|
|
||||||
|
config = ui.run(provided_params)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试交互式界面
|
||||||
|
console = Console()
|
||||||
|
console.print("[bold]测试交互式配置界面[/bold]")
|
||||||
|
|
||||||
|
# 模拟已提供的参数
|
||||||
|
test_params = {
|
||||||
|
"train_data_path": "./data/train.txt",
|
||||||
|
"eval_data_path": "./data/eval.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = get_interactive_config(test_params)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
console.print("\n[green]✓ 配置完成[/green]")
|
||||||
|
console.print(f"配置参数: {config}")
|
||||||
|
else:
|
||||||
|
console.print("\n[yellow]配置已取消[/yellow]")
|
||||||
|
|
@ -166,7 +166,13 @@ class Trainer:
|
||||||
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
|
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
|
||||||
self.training_status_data = []
|
self.training_status_data = []
|
||||||
|
|
||||||
# 初始化Rich控制台
|
# 初始化 epoch checkpoint 元数据
|
||||||
|
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
|
||||||
|
self.epoch_checkpoints = [] # 最多保留 3 个
|
||||||
|
self.next_epoch_slot = 0 # 下一个要覆盖的位置 (0-2)
|
||||||
|
self._load_epoch_metadata()
|
||||||
|
|
||||||
|
# 初始化 Rich 控制台
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
|
||||||
# 训练状态
|
# 训练状态
|
||||||
|
|
@ -368,6 +374,122 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||||||
|
|
||||||
|
def _load_epoch_metadata(self):
|
||||||
|
"""加载 epoch checkpoint 元数据"""
|
||||||
|
if self.epoch_metadata_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.epoch_checkpoints = metadata.get("checkpoints", [])
|
||||||
|
self.next_epoch_slot = metadata.get("next_slot", 0)
|
||||||
|
logger.info(
|
||||||
|
f"Loaded epoch checkpoint metadata: {len(self.epoch_checkpoints)} checkpoints"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load epoch metadata: {e}")
|
||||||
|
self.epoch_checkpoints = []
|
||||||
|
self.next_epoch_slot = 0
|
||||||
|
else:
|
||||||
|
self.epoch_checkpoints = []
|
||||||
|
self.next_epoch_slot = 0
|
||||||
|
|
||||||
|
def _save_epoch_metadata(self):
|
||||||
|
"""保存 epoch checkpoint 元数据"""
|
||||||
|
metadata = {
|
||||||
|
"checkpoints": self.epoch_checkpoints,
|
||||||
|
"next_slot": self.next_epoch_slot,
|
||||||
|
"total_epochs_completed": max(
|
||||||
|
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
def save_epoch_checkpoint(self, epoch: int):
|
||||||
|
"""
|
||||||
|
保存 epoch checkpoint(循环覆盖,只保留最后 3 个)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch: 当前 epoch 编号(从 1 开始)
|
||||||
|
"""
|
||||||
|
# 确定文件名(循环使用 3 个固定文件名)
|
||||||
|
slot = self.next_epoch_slot
|
||||||
|
filename = f"epoch_checkpoint_{slot + 1}.pt"
|
||||||
|
checkpoint_path = self.checkpoint_dir / filename
|
||||||
|
|
||||||
|
# 保存 checkpoint
|
||||||
|
checkpoint = {
|
||||||
|
"step": self.current_step,
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"scaler_state_dict": self.scaler.state_dict(),
|
||||||
|
"best_eval_loss": self.best_eval_loss,
|
||||||
|
"config": {
|
||||||
|
"learning_rate": self.learning_rate,
|
||||||
|
"weight_decay": self.weight_decay,
|
||||||
|
"warmup_ratio": self.warmup_ratio,
|
||||||
|
"label_smoothing": self.label_smoothing,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# 更新元数据
|
||||||
|
checkpoint_info = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"file": filename,
|
||||||
|
"path": str(checkpoint_path),
|
||||||
|
"saved_at": datetime.now().isoformat(),
|
||||||
|
"step": self.current_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果已经有 3 个,替换对应位置的;否则添加
|
||||||
|
if len(self.epoch_checkpoints) >= 3:
|
||||||
|
self.epoch_checkpoints[slot] = checkpoint_info
|
||||||
|
else:
|
||||||
|
self.epoch_checkpoints.append(checkpoint_info)
|
||||||
|
|
||||||
|
# 更新下一个槽位
|
||||||
|
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
self._save_epoch_metadata()
|
||||||
|
|
||||||
|
# 按 epoch 排序后获取最新的 epoch
|
||||||
|
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
latest_epoch = sorted_checkpoints[-1]["epoch"] if sorted_checkpoints else epoch
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch} checkpoint saved to {filename} "
|
||||||
|
f"(keeping last {len(self.epoch_checkpoints)} epochs: "
|
||||||
|
f"{[cp['epoch'] for cp in sorted_checkpoints]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_latest_epoch_checkpoint(self) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
获取最新的 epoch checkpoint 信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最新的 checkpoint 信息字典,如果没有则返回 None
|
||||||
|
"""
|
||||||
|
if not self.epoch_checkpoints:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按 epoch 排序,返回最新的
|
||||||
|
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
return sorted_checkpoints[-1]
|
||||||
|
|
||||||
|
def get_epoch_checkpoints(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
获取所有保存的 epoch checkpoint 信息(按 epoch 排序)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
checkpoint 信息列表,按 epoch 升序排列
|
||||||
|
"""
|
||||||
|
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
|
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
|
||||||
):
|
):
|
||||||
|
|
@ -534,7 +656,10 @@ class Trainer:
|
||||||
self.console.print(info_table)
|
self.console.print(info_table)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, resume_from: Optional[str] = None, reset_training_state: bool = False
|
self,
|
||||||
|
resume_from: Optional[str] = None,
|
||||||
|
reset_training_state: bool = False,
|
||||||
|
auto_resume: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
主训练循环
|
主训练循环
|
||||||
|
|
@ -542,10 +667,23 @@ class Trainer:
|
||||||
Args:
|
Args:
|
||||||
resume_from: 从哪个检查点恢复训练(可选)
|
resume_from: 从哪个检查点恢复训练(可选)
|
||||||
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
|
reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练)
|
||||||
|
auto_resume: 是否自动从最新的 epoch checkpoint 恢复(如果存在)
|
||||||
"""
|
"""
|
||||||
# 如果提供了检查点,则恢复训练
|
# 如果提供了检查点,则优先使用提供的检查点恢复训练
|
||||||
if resume_from is not None:
|
if resume_from is not None:
|
||||||
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
|
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
|
||||||
|
elif auto_resume and self.epoch_checkpoints:
|
||||||
|
# 自动从最新的 epoch checkpoint 恢复
|
||||||
|
latest_checkpoint = self.get_latest_epoch_checkpoint()
|
||||||
|
if latest_checkpoint:
|
||||||
|
checkpoint_path = latest_checkpoint["path"]
|
||||||
|
logger.info(
|
||||||
|
f"Auto-resuming from latest epoch checkpoint: {checkpoint_path} "
|
||||||
|
f"(epoch {latest_checkpoint['epoch']})"
|
||||||
|
)
|
||||||
|
self.load_checkpoint(
|
||||||
|
checkpoint_path, reset_training_state=reset_training_state
|
||||||
|
)
|
||||||
|
|
||||||
# 打印训练信息
|
# 打印训练信息
|
||||||
self._print_training_info()
|
self._print_training_info()
|
||||||
|
|
@ -675,8 +813,8 @@ class Trainer:
|
||||||
|
|
||||||
# 进度条不重置,显示整体训练进度
|
# 进度条不重置,显示整体训练进度
|
||||||
|
|
||||||
# 每个epoch结束后保存检查点
|
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
|
||||||
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
|
self.save_epoch_checkpoint(epoch + 1)
|
||||||
|
|
||||||
# 检查是否达到总步数
|
# 检查是否达到总步数
|
||||||
if global_step >= self.total_steps:
|
if global_step >= self.total_steps:
|
||||||
|
|
@ -685,10 +823,18 @@ class Trainer:
|
||||||
# 训练完成
|
# 训练完成
|
||||||
logger.info("Training completed!")
|
logger.info("Training completed!")
|
||||||
|
|
||||||
# 保存最终模型
|
# 显示保存的 epoch checkpoint 信息
|
||||||
self.save_checkpoint("final_model.pt")
|
if self.epoch_checkpoints:
|
||||||
|
sorted_checkpoints = self.get_epoch_checkpoints()
|
||||||
|
logger.info(
|
||||||
|
f"Saved epoch checkpoints: {[cp['epoch'] for cp in sorted_checkpoints]}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Latest checkpoint: epoch {sorted_checkpoints[-1]['epoch']} "
|
||||||
|
f"({sorted_checkpoints[-1]['file']})"
|
||||||
|
)
|
||||||
|
|
||||||
# 关闭TensorBoard写入器
|
# 关闭 TensorBoard 写入器
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
|
|
@ -907,21 +1053,10 @@ def train(
|
||||||
..., "--eval-data-path", "-e", help="评估数据集路径"
|
..., "--eval-data-path", "-e", help="评估数据集路径"
|
||||||
),
|
),
|
||||||
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
|
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
|
||||||
# 模型参数
|
# 数据大小
|
||||||
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
|
|
||||||
pinyin_vocab_size: int = typer.Option(
|
|
||||||
30, "--pinyin-vocab-size", help="拼音词汇表大小"
|
|
||||||
),
|
|
||||||
max_iter_length: int = typer.Option(
|
max_iter_length: int = typer.Option(
|
||||||
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
||||||
),
|
),
|
||||||
dim: int = typer.Option(512, "--dim", help="模型维度"),
|
|
||||||
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
|
|
||||||
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
|
|
||||||
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
|
|
||||||
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
|
|
||||||
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
|
|
||||||
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||||
|
|
@ -954,6 +1089,9 @@ def train(
|
||||||
reset_training_state: bool = typer.Option(
|
reset_training_state: bool = typer.Option(
|
||||||
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
||||||
),
|
),
|
||||||
|
auto_resume: bool = typer.Option(
|
||||||
|
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
|
||||||
|
),
|
||||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||||||
compile: bool = typer.Option(
|
compile: bool = typer.Option(
|
||||||
False,
|
False,
|
||||||
|
|
@ -975,6 +1113,17 @@ def train(
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# 硬编码模型参数
|
||||||
|
vocab_size = 10019
|
||||||
|
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
|
||||||
|
dim = 512
|
||||||
|
num_slots = 8
|
||||||
|
n_layers = 4
|
||||||
|
n_heads = 4
|
||||||
|
num_experts = 20
|
||||||
|
max_seq_len = 128
|
||||||
|
use_pinyin = True # 始终使用拼音
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# 打印配置信息
|
# 打印配置信息
|
||||||
|
|
@ -1014,6 +1163,7 @@ def train(
|
||||||
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||||
|
|
||||||
|
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||||
console.print(config_table)
|
console.print(config_table)
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
|
|
@ -1049,6 +1199,7 @@ def train(
|
||||||
"mixed_precision": mixed_precision,
|
"mixed_precision": mixed_precision,
|
||||||
"use_tensorboard": use_tensorboard,
|
"use_tensorboard": use_tensorboard,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
|
"auto_resume": auto_resume,
|
||||||
"max_iter_length": max_iter_length,
|
"max_iter_length": max_iter_length,
|
||||||
"compile": compile,
|
"compile": compile,
|
||||||
}
|
}
|
||||||
|
|
@ -1149,10 +1300,12 @@ def train(
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
||||||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
try:
|
try:
|
||||||
trainer.train(
|
trainer.train(
|
||||||
resume_from=resume_from, reset_training_state=reset_training_state
|
resume_from=resume_from,
|
||||||
|
reset_training_state=reset_training_state,
|
||||||
|
auto_resume=auto_resume,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("[bold green]训练被终止[/bold green]")
|
console.print("[bold green]训练被终止[/bold green]")
|
||||||
|
|
@ -1223,20 +1376,10 @@ def expand_and_train(
|
||||||
"-m",
|
"-m",
|
||||||
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
|
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
|
||||||
),
|
),
|
||||||
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
|
# 数据大小
|
||||||
pinyin_vocab_size: int = typer.Option(
|
|
||||||
30, "--pinyin-vocab-size", help="拼音词汇表大小"
|
|
||||||
),
|
|
||||||
max_iter_length: int = typer.Option(
|
max_iter_length: int = typer.Option(
|
||||||
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
||||||
),
|
),
|
||||||
dim: int = typer.Option(512, "--dim", help="模型维度"),
|
|
||||||
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
|
|
||||||
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
|
|
||||||
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
|
|
||||||
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
|
|
||||||
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
|
|
||||||
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||||
|
|
@ -1269,6 +1412,9 @@ def expand_and_train(
|
||||||
reset_training_state: bool = typer.Option(
|
reset_training_state: bool = typer.Option(
|
||||||
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
||||||
),
|
),
|
||||||
|
auto_resume: bool = typer.Option(
|
||||||
|
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
|
||||||
|
),
|
||||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||||||
compile: bool = typer.Option(
|
compile: bool = typer.Option(
|
||||||
False,
|
False,
|
||||||
|
|
@ -1287,6 +1433,16 @@ def expand_and_train(
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# 硬编码模型参数
|
||||||
|
vocab_size = 10019
|
||||||
|
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
|
||||||
|
dim = 512
|
||||||
|
num_slots = 8
|
||||||
|
n_layers = 4
|
||||||
|
n_heads = 4
|
||||||
|
num_experts = 20
|
||||||
|
max_seq_len = 128
|
||||||
|
use_pinyin = True # 始终使用拼音
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# 打印配置信息
|
# 打印配置信息
|
||||||
|
|
@ -1330,6 +1486,7 @@ def expand_and_train(
|
||||||
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||||
|
|
||||||
|
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||||
console.print(config_table)
|
console.print(config_table)
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
|
|
@ -1367,6 +1524,7 @@ def expand_and_train(
|
||||||
"mixed_precision": mixed_precision,
|
"mixed_precision": mixed_precision,
|
||||||
"use_tensorboard": use_tensorboard,
|
"use_tensorboard": use_tensorboard,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
|
"auto_resume": auto_resume,
|
||||||
"max_iter_length": max_iter_length,
|
"max_iter_length": max_iter_length,
|
||||||
"compile": compile,
|
"compile": compile,
|
||||||
}
|
}
|
||||||
|
|
@ -1482,10 +1640,12 @@ def expand_and_train(
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
|
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
|
||||||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
try:
|
try:
|
||||||
trainer.train(
|
trainer.train(
|
||||||
resume_from=resume_from, reset_training_state=reset_training_state
|
resume_from=resume_from,
|
||||||
|
reset_training_state=reset_training_state,
|
||||||
|
auto_resume=auto_resume,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("[bold green]训练被终止[/bold green]")
|
console.print("[bold green]训练被终止[/bold green]")
|
||||||
|
|
@ -1598,6 +1758,9 @@ def expand_finetune(
|
||||||
reset_training_state: bool = typer.Option(
|
reset_training_state: bool = typer.Option(
|
||||||
False, "--reset-training-state", help="重置训练状态"
|
False, "--reset-training-state", help="重置训练状态"
|
||||||
),
|
),
|
||||||
|
auto_resume: bool = typer.Option(
|
||||||
|
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
|
||||||
|
),
|
||||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||||||
compile: Optional[bool] = typer.Option(
|
compile: Optional[bool] = typer.Option(
|
||||||
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
|
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
|
||||||
|
|
@ -1697,6 +1860,7 @@ def expand_finetune(
|
||||||
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
|
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||||
|
|
||||||
|
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||||
console.print(config_table)
|
console.print(config_table)
|
||||||
|
|
||||||
output_path = Path(final_output_dir)
|
output_path = Path(final_output_dir)
|
||||||
|
|
@ -1786,10 +1950,12 @@ def expand_finetune(
|
||||||
console.print("[green]✓ 训练器创建完成[/green]")
|
console.print("[green]✓ 训练器创建完成[/green]")
|
||||||
|
|
||||||
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
|
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
|
||||||
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
try:
|
try:
|
||||||
trainer.train(
|
trainer.train(
|
||||||
resume_from=resume_from, reset_training_state=reset_training_state
|
resume_from=resume_from,
|
||||||
|
reset_training_state=reset_training_state,
|
||||||
|
auto_resume=auto_resume,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("[bold green]训练被终止[/bold green]")
|
console.print("[bold green]训练被终止[/bold green]")
|
||||||
|
|
|
||||||
8
test.py
8
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都"
|
part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为"
|
||||||
part2 = "bxu"
|
part2 = "weiruo"
|
||||||
pinyin_ids = text_to_pinyin_ids(part2)
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
len_py = len(pinyin_ids)
|
len_py = len(pinyin_ids)
|
||||||
if len_py < 24:
|
if len_py < 24:
|
||||||
|
|
@ -56,7 +56,7 @@ if len_py < 24:
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
masked_labels = [649, 925, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = ""
|
part3 = ""
|
||||||
part4 = "可行|特别|伤害"
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ sample = {
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""
|
||||||
|
测试 epoch checkpoint 循环保存功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
# 模拟 Trainer 类的 epoch checkpoint 管理逻辑
|
||||||
|
class EpochCheckpointManager:
|
||||||
|
def __init__(self, checkpoint_dir: Path):
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
|
||||||
|
self.epoch_checkpoints = []
|
||||||
|
self.next_epoch_slot = 0
|
||||||
|
|
||||||
|
def _load_epoch_metadata(self):
|
||||||
|
if self.epoch_metadata_file.exists():
|
||||||
|
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.epoch_checkpoints = metadata.get("checkpoints", [])
|
||||||
|
self.next_epoch_slot = metadata.get("next_slot", 0)
|
||||||
|
|
||||||
|
def _save_epoch_metadata(self):
|
||||||
|
metadata = {
|
||||||
|
"checkpoints": self.epoch_checkpoints,
|
||||||
|
"next_slot": self.next_epoch_slot,
|
||||||
|
"total_epochs_completed": max(
|
||||||
|
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
def save_epoch_checkpoint(self, epoch: int):
|
||||||
|
"""模拟保存 epoch checkpoint(不实际保存模型文件)"""
|
||||||
|
slot = self.next_epoch_slot
|
||||||
|
filename = f"epoch_checkpoint_{slot + 1}.pt"
|
||||||
|
checkpoint_path = self.checkpoint_dir / filename
|
||||||
|
|
||||||
|
checkpoint_info = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"file": filename,
|
||||||
|
"path": str(checkpoint_path),
|
||||||
|
"saved_at": datetime.now().isoformat(),
|
||||||
|
"step": epoch * 100, # 模拟 step
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(self.epoch_checkpoints) >= 3:
|
||||||
|
self.epoch_checkpoints[slot] = checkpoint_info
|
||||||
|
else:
|
||||||
|
self.epoch_checkpoints.append(checkpoint_info)
|
||||||
|
|
||||||
|
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
|
||||||
|
self._save_epoch_metadata()
|
||||||
|
|
||||||
|
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
print(
|
||||||
|
f"✓ Epoch {epoch:2d} saved -> {filename:20s} "
|
||||||
|
f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_latest_epoch_checkpoint(self):
|
||||||
|
if not self.epoch_checkpoints:
|
||||||
|
return None
|
||||||
|
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
return sorted_checkpoints[-1]
|
||||||
|
|
||||||
|
def get_epoch_checkpoints(self):
|
||||||
|
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_circular_save():
|
||||||
|
"""测试循环保存功能"""
|
||||||
|
print("=" * 70)
|
||||||
|
print("测试:循环保存最后 3 个 epoch checkpoint")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint_dir = Path(tmpdir)
|
||||||
|
manager = EpochCheckpointManager(checkpoint_dir)
|
||||||
|
|
||||||
|
# 模拟训练 10 个 epoch
|
||||||
|
print("\n模拟训练 10 个 epoch:")
|
||||||
|
print("-" * 70)
|
||||||
|
for epoch in range(1, 11):
|
||||||
|
manager.save_epoch_checkpoint(epoch)
|
||||||
|
|
||||||
|
# 验证最终状态
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("最终状态验证:")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# 检查 JSON 文件
|
||||||
|
metadata_file = checkpoint_dir / "epoch_checkpoints.json"
|
||||||
|
with open(metadata_file, "r", encoding="utf-8") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
print(f"\n✓ JSON 元数据文件:{metadata_file.name}")
|
||||||
|
print(f" next_slot: {metadata['next_slot']}")
|
||||||
|
print(f" total_epochs_completed: {metadata['total_epochs_completed']}")
|
||||||
|
|
||||||
|
print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}")
|
||||||
|
print(f" 期望:3 个")
|
||||||
|
|
||||||
|
sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"])
|
||||||
|
print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
|
||||||
|
print(f" 期望:[8, 9, 10]")
|
||||||
|
|
||||||
|
# 验证顺序
|
||||||
|
latest = manager.get_latest_epoch_checkpoint()
|
||||||
|
print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})")
|
||||||
|
print(f" 期望:epoch 10")
|
||||||
|
|
||||||
|
# 验证文件对应关系
|
||||||
|
print(f"\n✓ 文件对应关系:")
|
||||||
|
for cp in sorted_checkpoints:
|
||||||
|
print(f" epoch {cp['epoch']:2d} -> {cp['file']}")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint"
|
||||||
|
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
|
||||||
|
"应该保留最后 3 个 epoch"
|
||||||
|
)
|
||||||
|
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("✓✓✓ 所有测试通过!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_resume():
|
||||||
|
"""测试自动恢复功能"""
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("测试:从 JSON 自动恢复 checkpoint 顺序")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint_dir = Path(tmpdir)
|
||||||
|
|
||||||
|
# 创建 manager 并保存一些 checkpoint
|
||||||
|
manager1 = EpochCheckpointManager(checkpoint_dir)
|
||||||
|
print("\n第一轮训练(保存 epoch 1-5):")
|
||||||
|
print("-" * 70)
|
||||||
|
for epoch in range(1, 6):
|
||||||
|
manager1.save_epoch_checkpoint(epoch)
|
||||||
|
|
||||||
|
# 模拟新的训练会话(从 JSON 恢复)
|
||||||
|
print("\n\n第二轮训练(从 JSON 恢复):")
|
||||||
|
print("-" * 70)
|
||||||
|
manager2 = EpochCheckpointManager(checkpoint_dir)
|
||||||
|
manager2._load_epoch_metadata()
|
||||||
|
|
||||||
|
print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}")
|
||||||
|
print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}")
|
||||||
|
print(f"✓ next_slot: {manager2.next_epoch_slot}")
|
||||||
|
|
||||||
|
# 继续保存
|
||||||
|
print("\n继续保存 epoch 6-10:")
|
||||||
|
print("-" * 70)
|
||||||
|
for epoch in range(6, 11):
|
||||||
|
manager2.save_epoch_checkpoint(epoch)
|
||||||
|
|
||||||
|
# 验证最终状态
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("最终状态验证:")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
sorted_checkpoints = manager2.get_epoch_checkpoints()
|
||||||
|
print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
|
||||||
|
print(f" 期望:[8, 9, 10]")
|
||||||
|
|
||||||
|
latest = manager2.get_latest_epoch_checkpoint()
|
||||||
|
print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}")
|
||||||
|
print(f" 期望:epoch 10")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
|
||||||
|
"应该保留最后 3 个 epoch"
|
||||||
|
)
|
||||||
|
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("✓✓✓ 自动恢复测试通过!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_circular_save()
|
||||||
|
test_auto_resume()
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("所有测试完成!")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\n功能总结:")
|
||||||
|
print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint")
|
||||||
|
print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据")
|
||||||
|
print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序")
|
||||||
|
print(" 4. ✓ 可以正确识别最新的 checkpoint")
|
||||||
|
print("\n保留的文件:")
|
||||||
|
print(" - best_model.pt (最佳模型)")
|
||||||
|
print(" - latest_checkpoint.pt (定期保存)")
|
||||||
|
print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt")
|
||||||
|
print(" - epoch_checkpoints.json (元数据)")
|
||||||
|
print("\n移除的文件:")
|
||||||
|
print(" - final_model.pt (与 epoch_last 重复)")
|
||||||
|
print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)")
|
||||||
|
print("=" * 70)
|
||||||
Loading…
Reference in New Issue