188 lines
5.1 KiB
Markdown
188 lines
5.1 KiB
Markdown
# Checkpoint 保存策略更新说明
|
||
|
||
## 更新概述
|
||
|
||
为了解决硬盘空间限制问题,修改了模型 checkpoint 的保存方式,采用**同名覆盖**的方式仅保留**最后 3 个 epoch**的 checkpoint。
|
||
|
||
## 主要变更
|
||
|
||
### 1. 保存策略
|
||
|
||
#### 保留的文件
|
||
- ✅ `best_model.pt` - 最佳模型(当 eval_loss 降低时保存)
|
||
- ✅ `latest_checkpoint.pt` - 定期保存的检查点(每 save_frequency 步覆盖保存)
|
||
- ✅ `epoch_checkpoint_1.pt` - 循环保存的最后 3 个 epoch 之一
|
||
- ✅ `epoch_checkpoint_2.pt` - 循环保存的最后 3 个 epoch 之一
|
||
- ✅ `epoch_checkpoint_3.pt` - 循环保存的最后 3 个 epoch 之一
|
||
- ✅ `epoch_checkpoints.json` - epoch checkpoint 元数据文件
|
||
|
||
#### 移除的文件
|
||
- ❌ `final_model.pt` - 与 epoch_last 重复,已移除
|
||
- ❌ `epoch_*.pt` - 不再为每个 epoch 创建独立文件
|
||
|
||
### 2. 循环覆盖机制
|
||
|
||
使用 3 个固定文件名循环覆盖:
|
||
```
|
||
Epoch 1 -> epoch_checkpoint_1.pt
|
||
Epoch 2 -> epoch_checkpoint_2.pt
|
||
Epoch 3 -> epoch_checkpoint_3.pt
|
||
Epoch 4 -> epoch_checkpoint_1.pt (覆盖 epoch 1)
|
||
Epoch 5 -> epoch_checkpoint_2.pt (覆盖 epoch 2)
|
||
...
|
||
```
|
||
|
||
### 3. JSON 元数据管理
|
||
|
||
`epoch_checkpoints.json` 文件记录:
|
||
```json
|
||
{
|
||
"checkpoints": [
|
||
{
|
||
"epoch": 8,
|
||
"file": "epoch_checkpoint_2.pt",
|
||
"path": "/path/to/checkpoints/epoch_checkpoint_2.pt",
|
||
"saved_at": "2026-04-11T10:30:00",
|
||
"step": 800
|
||
},
|
||
{
|
||
"epoch": 9,
|
||
"file": "epoch_checkpoint_3.pt",
|
||
"path": "/path/to/checkpoints/epoch_checkpoint_3.pt",
|
||
"saved_at": "2026-04-11T10:35:00",
|
||
"step": 900
|
||
},
|
||
{
|
||
"epoch": 10,
|
||
"file": "epoch_checkpoint_1.pt",
|
||
"path": "/path/to/checkpoints/epoch_checkpoint_1.pt",
|
||
"saved_at": "2026-04-11T10:40:00",
|
||
"step": 1000
|
||
}
|
||
],
|
||
"next_slot": 1,
|
||
"total_epochs_completed": 10
|
||
}
|
||
```
|
||
|
||
### 4. 自动恢复功能
|
||
|
||
训练时会自动从 JSON 文件中读取最新的 checkpoint 并恢复训练:
|
||
|
||
```python
|
||
# 训练时自动恢复
|
||
trainer.train(
|
||
resume_from=None, # 如果指定,优先使用指定的 checkpoint
|
||
reset_training_state=False,
|
||
auto_resume=True # 自动从最新的 epoch checkpoint 恢复
|
||
)
|
||
```
|
||
|
||
**恢复优先级:**
|
||
1. 如果指定了 `resume_from`,使用指定的 checkpoint
|
||
2. 否则,如果存在 epoch checkpoint 元数据,自动从最新的 epoch checkpoint 恢复
|
||
3. 否则,从头开始训练
|
||
|
||
## 使用方法
|
||
|
||
### 查看保存的 checkpoint
|
||
|
||
训练结束后,可以通过 JSON 文件查看保存的 checkpoint:
|
||
|
||
```bash
|
||
# 查看元数据
|
||
cat ./output/checkpoints/epoch_checkpoints.json
|
||
|
||
# 或者使用 Python
|
||
python -c "
|
||
import json
|
||
with open('./output/checkpoints/epoch_checkpoints.json') as f:
|
||
data = json.load(f)
|
||
print('保存的 epochs:', [cp['epoch'] for cp in data['checkpoints']])
|
||
print('最新的 checkpoint:', max(data['checkpoints'], key=lambda x: x['epoch']))
|
||
"
|
||
```
|
||
|
||
### 手动加载特定 epoch 的 checkpoint
|
||
|
||
```python
|
||
import torch
|
||
from src.model.trainer import Trainer
|
||
|
||
# 加载元数据
|
||
import json
|
||
with open('./output/checkpoints/epoch_checkpoints.json') as f:
|
||
data = json.load(f)
|
||
|
||
# 按 epoch 排序
|
||
sorted_checkpoints = sorted(data['checkpoints'], key=lambda x: x['epoch'])
|
||
print('可用的 epochs:', [cp['epoch'] for cp in sorted_checkpoints])
|
||
|
||
# 加载特定 epoch 的 checkpoint
|
||
target_epoch = 9
|
||
checkpoint_info = next(cp for cp in sorted_checkpoints if cp['epoch'] == target_epoch)
|
||
checkpoint_path = checkpoint_info['path']
|
||
|
||
# 使用 trainer 加载
|
||
trainer.load_checkpoint(checkpoint_path)
|
||
```
|
||
|
||
### 禁用自动恢复
|
||
|
||
如果需要从头开始训练,可以禁用自动恢复:
|
||
|
||
```bash
|
||
# 在代码中设置
|
||
trainer.train(auto_resume=False)
|
||
```
|
||
|
||
## 优势
|
||
|
||
1. **节省磁盘空间** - 只保留最后 3 个 epoch,不会随训练时间增长而占用更多空间
|
||
2. **自动管理** - 无需手动删除旧的 checkpoint
|
||
3. **顺序清晰** - 通过 JSON 文件可以清楚知道每个 checkpoint 对应的 epoch
|
||
4. **自动恢复** - 训练中断后可以从最近的 checkpoint 自动恢复
|
||
5. **保留重要 checkpoint** - best_model 和 latest_checkpoint 仍然保留
|
||
|
||
## 磁盘空间对比
|
||
|
||
### 修改前
|
||
```
|
||
假设训练 100 个 epoch,每个 checkpoint 100MB:
|
||
- epoch_1.pt ~ epoch_100.pt: 100 * 100MB = 10GB
|
||
- best_model.pt: 100MB
|
||
- final_model.pt: 100MB
|
||
- latest_checkpoint.pt: 100MB
|
||
总计:~10.3GB
|
||
```
|
||
|
||
### 修改后
|
||
```
|
||
训练 100 个 epoch,每个 checkpoint 100MB:
|
||
- epoch_checkpoint_1~3.pt: 3 * 100MB = 300MB
|
||
- best_model.pt: 100MB
|
||
- latest_checkpoint.pt: 100MB
|
||
- epoch_checkpoints.json: <1KB
|
||
总计:~400MB
|
||
|
||
节省空间:~9.9GB (96% 空间节省)
|
||
```
|
||
|
||
## 注意事项
|
||
|
||
1. **训练中断** - 如果训练在中途被中断,只会保留中断前的最后 3 个 epoch
|
||
2. **JSON 文件** - 不要手动修改 `epoch_checkpoints.json` 文件,否则可能导致恢复失败
|
||
3. **兼容性** - 旧的 `epoch_*.pt` 文件不会自动删除,如果需要可以手动清理
|
||
|
||
## 测试
|
||
|
||
运行测试脚本验证功能:
|
||
```bash
|
||
python test_epoch_checkpoint.py
|
||
```
|
||
|
||
## 相关文件
|
||
|
||
- `src/model/trainer.py` - 主要修改文件
|
||
- `test_epoch_checkpoint.py` - 功能测试脚本
|