SUimeModelTraner/CHECKPOINT_CHANGES.md

188 lines
5.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Checkpoint 保存策略更新说明
## 更新概述
为了解决硬盘空间限制问题,修改了模型 checkpoint 的保存方式,采用**同名覆盖**的方式仅保留**最后 3 个 epoch**的 checkpoint。
## 主要变更
### 1. 保存策略
#### 保留的文件
-`best_model.pt` - 最佳模型(当 eval_loss 降低时保存)
-`latest_checkpoint.pt` - 定期保存的检查点(每 save_frequency 步覆盖保存)
-`epoch_checkpoint_1.pt` - 循环保存的最后 3 个 epoch 之一
-`epoch_checkpoint_2.pt` - 循环保存的最后 3 个 epoch 之一
-`epoch_checkpoint_3.pt` - 循环保存的最后 3 个 epoch 之一
-`epoch_checkpoints.json` - epoch checkpoint 元数据文件
#### 移除的文件
-`final_model.pt` - 与 epoch_last 重复,已移除
-`epoch_*.pt` - 不再为每个 epoch 创建独立文件
### 2. 循环覆盖机制
使用 3 个固定文件名循环覆盖:
```
Epoch 1 -> epoch_checkpoint_1.pt
Epoch 2 -> epoch_checkpoint_2.pt
Epoch 3 -> epoch_checkpoint_3.pt
Epoch 4 -> epoch_checkpoint_1.pt (覆盖 epoch 1)
Epoch 5 -> epoch_checkpoint_2.pt (覆盖 epoch 2)
...
```
### 3. JSON 元数据管理
`epoch_checkpoints.json` 文件记录:
```json
{
"checkpoints": [
{
"epoch": 8,
"file": "epoch_checkpoint_2.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_2.pt",
"saved_at": "2026-04-11T10:30:00",
"step": 800
},
{
"epoch": 9,
"file": "epoch_checkpoint_3.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_3.pt",
"saved_at": "2026-04-11T10:35:00",
"step": 900
},
{
"epoch": 10,
"file": "epoch_checkpoint_1.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_1.pt",
"saved_at": "2026-04-11T10:40:00",
"step": 1000
}
],
"next_slot": 1,
"total_epochs_completed": 10
}
```
### 4. 自动恢复功能
训练时会自动从 JSON 文件中读取最新的 checkpoint 并恢复训练:
```python
# 训练时自动恢复
trainer.train(
resume_from=None, # 如果指定,优先使用指定的 checkpoint
reset_training_state=False,
auto_resume=True # 自动从最新的 epoch checkpoint 恢复
)
```
**恢复优先级:**
1. 如果指定了 `resume_from`,使用指定的 checkpoint
2. 否则,如果存在 epoch checkpoint 元数据,自动从最新的 epoch checkpoint 恢复
3. 否则,从头开始训练
## 使用方法
### 查看保存的 checkpoint
训练结束后,可以通过 JSON 文件查看保存的 checkpoint
```bash
# 查看元数据
cat ./output/checkpoints/epoch_checkpoints.json
# 或者使用 Python
python -c "
import json
with open('./output/checkpoints/epoch_checkpoints.json') as f:
data = json.load(f)
print('保存的 epochs:', [cp['epoch'] for cp in data['checkpoints']])
print('最新的 checkpoint:', max(data['checkpoints'], key=lambda x: x['epoch']))
"
```
### 手动加载特定 epoch 的 checkpoint
```python
import torch
from src.model.trainer import Trainer
# 加载元数据
import json
with open('./output/checkpoints/epoch_checkpoints.json') as f:
data = json.load(f)
# 按 epoch 排序
sorted_checkpoints = sorted(data['checkpoints'], key=lambda x: x['epoch'])
print('可用的 epochs:', [cp['epoch'] for cp in sorted_checkpoints])
# 加载特定 epoch 的 checkpoint
target_epoch = 9
checkpoint_info = next(cp for cp in sorted_checkpoints if cp['epoch'] == target_epoch)
checkpoint_path = checkpoint_info['path']
# 使用 trainer 加载
trainer.load_checkpoint(checkpoint_path)
```
### 禁用自动恢复
如果需要从头开始训练,可以禁用自动恢复:
```bash
# 在代码中设置
trainer.train(auto_resume=False)
```
## 优势
1. **节省磁盘空间** - 只保留最后 3 个 epoch不会随训练时间增长而占用更多空间
2. **自动管理** - 无需手动删除旧的 checkpoint
3. **顺序清晰** - 通过 JSON 文件可以清楚知道每个 checkpoint 对应的 epoch
4. **自动恢复** - 训练中断后可以从最近的 checkpoint 自动恢复
5. **保留重要 checkpoint** - best_model 和 latest_checkpoint 仍然保留
## 磁盘空间对比
### 修改前
```
假设训练 100 个 epoch每个 checkpoint 100MB:
- epoch_1.pt ~ epoch_100.pt: 100 * 100MB = 10GB
- best_model.pt: 100MB
- final_model.pt: 100MB
- latest_checkpoint.pt: 100MB
总计:~10.3GB
```
### 修改后
```
训练 100 个 epoch每个 checkpoint 100MB:
- epoch_checkpoint_1~3.pt: 3 * 100MB = 300MB
- best_model.pt: 100MB
- latest_checkpoint.pt: 100MB
- epoch_checkpoints.json: <1KB
总计:~400MB
节省空间:~9.9GB (96% 空间节省)
```
## 注意事项
1. **训练中断** - 如果训练在中途被中断,只会保留中断前的最后 3 个 epoch
2. **JSON 文件** - 不要手动修改 `epoch_checkpoints.json` 文件,否则可能导致恢复失败
3. **兼容性** - 旧的 `epoch_*.pt` 文件不会自动删除,如果需要可以手动清理
## 测试
运行测试脚本验证功能:
```bash
python test_epoch_checkpoint.py
```
## 相关文件
- `src/model/trainer.py` - 主要修改文件
- `test_epoch_checkpoint.py` - 功能测试脚本