feat(checkpoint): 优化 checkpoint 保存策略,保留最后3个epoch并节省磁盘空间

This commit is contained in:
songsenand 2026-04-11 13:19:49 +08:00
parent 1cdef19153
commit 0fea985b45
7 changed files with 1149 additions and 185 deletions

187
CHECKPOINT_CHANGES.md Normal file
View File

@ -0,0 +1,187 @@
# Checkpoint 保存策略更新说明
## 更新概述
为了解决硬盘空间限制问题,修改了模型 checkpoint 的保存方式,采用**同名覆盖**的方式仅保留**最后 3 个 epoch**的 checkpoint。
## 主要变更
### 1. 保存策略
#### 保留的文件
- ✅ `best_model.pt` - 最佳模型(当 eval_loss 降低时保存)
- ✅ `latest_checkpoint.pt` - 定期保存的检查点(每 save_frequency 步覆盖保存)
- ✅ `epoch_checkpoint_1.pt` - 循环保存的最后 3 个 epoch 之一
- ✅ `epoch_checkpoint_2.pt` - 循环保存的最后 3 个 epoch 之一
- ✅ `epoch_checkpoint_3.pt` - 循环保存的最后 3 个 epoch 之一
- ✅ `epoch_checkpoints.json` - epoch checkpoint 元数据文件
#### 移除的文件
- ❌ `final_model.pt` - 与 epoch_last 重复,已移除
- ❌ `epoch_*.pt` - 不再为每个 epoch 创建独立文件
### 2. 循环覆盖机制
使用 3 个固定文件名循环覆盖:
```
Epoch 1 -> epoch_checkpoint_1.pt
Epoch 2 -> epoch_checkpoint_2.pt
Epoch 3 -> epoch_checkpoint_3.pt
Epoch 4 -> epoch_checkpoint_1.pt (覆盖 epoch 1)
Epoch 5 -> epoch_checkpoint_2.pt (覆盖 epoch 2)
...
```
### 3. JSON 元数据管理
`epoch_checkpoints.json` 文件记录:
```json
{
"checkpoints": [
{
"epoch": 8,
"file": "epoch_checkpoint_2.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_2.pt",
"saved_at": "2026-04-11T10:30:00",
"step": 800
},
{
"epoch": 9,
"file": "epoch_checkpoint_3.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_3.pt",
"saved_at": "2026-04-11T10:35:00",
"step": 900
},
{
"epoch": 10,
"file": "epoch_checkpoint_1.pt",
"path": "/path/to/checkpoints/epoch_checkpoint_1.pt",
"saved_at": "2026-04-11T10:40:00",
"step": 1000
}
],
"next_slot": 1,
"total_epochs_completed": 10
}
```
### 4. 自动恢复功能
训练时会自动从 JSON 文件中读取最新的 checkpoint 并恢复训练:
```python
# 训练时自动恢复
trainer.train(
resume_from=None, # 如果指定,优先使用指定的 checkpoint
reset_training_state=False,
auto_resume=True # 自动从最新的 epoch checkpoint 恢复
)
```
**恢复优先级:**
1. 如果指定了 `resume_from`,使用指定的 checkpoint
2. 否则,如果存在 epoch checkpoint 元数据,自动从最新的 epoch checkpoint 恢复
3. 否则,从头开始训练
## 使用方法
### 查看保存的 checkpoint
训练结束后,可以通过 JSON 文件查看保存的 checkpoint
```bash
# 查看元数据
cat ./output/checkpoints/epoch_checkpoints.json
# 或者使用 Python
python -c "
import json
with open('./output/checkpoints/epoch_checkpoints.json') as f:
data = json.load(f)
print('保存的 epochs:', [cp['epoch'] for cp in data['checkpoints']])
print('最新的 checkpoint:', max(data['checkpoints'], key=lambda x: x['epoch']))
"
```
### 手动加载特定 epoch 的 checkpoint
```python
import torch
from src.model.trainer import Trainer
# 加载元数据
import json
with open('./output/checkpoints/epoch_checkpoints.json') as f:
data = json.load(f)
# 按 epoch 排序
sorted_checkpoints = sorted(data['checkpoints'], key=lambda x: x['epoch'])
print('可用的 epochs:', [cp['epoch'] for cp in sorted_checkpoints])
# 加载特定 epoch 的 checkpoint
target_epoch = 9
checkpoint_info = next(cp for cp in sorted_checkpoints if cp['epoch'] == target_epoch)
checkpoint_path = checkpoint_info['path']
# 使用 trainer 加载
trainer.load_checkpoint(checkpoint_path)
```
### 禁用自动恢复
如果需要从头开始训练,可以禁用自动恢复:
```bash
# 在代码中设置
trainer.train(auto_resume=False)
```
## 优势
1. **节省磁盘空间** - 只保留最后 3 个 epoch不会随训练时间增长而占用更多空间
2. **自动管理** - 无需手动删除旧的 checkpoint
3. **顺序清晰** - 通过 JSON 文件可以清楚知道每个 checkpoint 对应的 epoch
4. **自动恢复** - 训练中断后可以从最近的 checkpoint 自动恢复
5. **保留重要 checkpoint** - best_model 和 latest_checkpoint 仍然保留
## 磁盘空间对比
### 修改前
```
假设训练 100 个 epoch每个 checkpoint 100MB:
- epoch_1.pt ~ epoch_100.pt: 100 * 100MB = 10GB
- best_model.pt: 100MB
- final_model.pt: 100MB
- latest_checkpoint.pt: 100MB
总计:~10.3GB
```
### 修改后
```
训练 100 个 epoch每个 checkpoint 100MB:
- epoch_checkpoint_1~3.pt: 3 * 100MB = 300MB
- best_model.pt: 100MB
- latest_checkpoint.pt: 100MB
- epoch_checkpoints.json: <1KB
总计:~400MB
节省空间:~9.9GB (96% 空间节省)
```
## 注意事项
1. **训练中断** - 如果训练在中途被中断,只会保留中断前的最后 3 个 epoch
2. **JSON 文件** - 不要手动修改 `epoch_checkpoints.json` 文件,否则可能导致恢复失败
3. **兼容性** - 旧的 `epoch_*.pt` 文件不会自动删除,如果需要可以手动清理
## 测试
运行测试脚本验证功能:
```bash
python test_epoch_checkpoint.py
```
## 相关文件
- `src/model/trainer.py` - 主要修改文件
- `test_epoch_checkpoint.py` - 功能测试脚本

View File

@ -1,147 +1,22 @@
#!/usr/bin/env python3
"""
快速检查模型权重加载情况的脚本
"""
from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import torch
from model.trainer import collate_fn, worker_init_fn
def analyze_checkpoint(checkpoint_path):
"""分析checkpoint文件"""
print(f"🔍 分析checkpoint: {checkpoint_path}")
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
if not Path(checkpoint_path).exists():
print(f"❌ 文件不存在")
return
dataloader = DataLoader(
data,
batch_size=1024,
num_workers=2,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
shuffle=False,
)
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"✅ 加载成功")
print(f" 类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f" 键名: {list(checkpoint.keys())}")
# 找到模型状态字典
state_dict = None
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
print(f" 🔍 使用'model_state_dict'")
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
print(f" 🔍 使用'state_dict'")
else:
# 可能是直接的状态字典
state_dict = checkpoint
print(f" 🔍 使用直接状态字典")
if state_dict:
print(f" 总权重数: {len(state_dict)}")
# 分析分类头权重
classifier_keys = []
for key in state_dict.keys():
if "classifier" in key:
classifier_keys.append(key)
if classifier_keys:
print(f" 📊 分类头相关权重:")
for key in classifier_keys:
weight = state_dict[key]
print(f" {key}: shape={weight.shape}")
print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]")
print(f" 均值: {weight.mean():.6f}")
print(f" 标准差: {weight.std():.6f}")
# 检查权重是否接近随机初始化
if weight.std() < 0.01:
print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练")
# 检查模型架构键名
print(f"\n 🔑 模型架构键名示例前20个:")
for i, key in enumerate(list(state_dict.keys())[:20]):
weight = state_dict[key]
print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}")
# 检查是否有预期的组件
expected_components = [
"context_encoder",
"slot_memory",
"cross_attn",
"moe",
"classifier",
]
found_components = []
for comp in expected_components:
found = any(comp in key for key in state_dict.keys())
if found:
found_components.append(comp)
print(f"\n 📋 找到的模型组件: {found_components}")
missing = set(expected_components) - set(found_components)
if missing:
print(f" ❌ 缺失的组件: {missing}")
return state_dict
else:
print(f"❌ checkpoint不是字典类型")
except Exception as e:
print(f"❌ 加载失败: {e}")
import traceback
traceback.print_exc()
def check_weight_distribution(state_dict):
"""检查权重分布"""
print(f"\n📊 权重分布统计:")
weight_stats = []
for key, weight in state_dict.items():
if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置
stats = {
"key": key,
"shape": weight.shape,
"min": weight.min().item(),
"max": weight.max().item(),
"mean": weight.mean().item(),
"std": weight.std().item(),
"abs_mean": weight.abs().mean().item(),
}
weight_stats.append(stats)
# 打印前10个权重
for i, stats in enumerate(weight_stats[:10]):
print(f" {i + 1:2d}. {stats['key']:40}")
print(f" 形状: {stats['shape']}")
print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]")
print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}")
# 检查是否接近随机初始化
if stats["std"] < 0.01:
print(f" ⚠️ 警告: 标准差很小,可能未训练")
return weight_stats
def main():
import sys
if len(sys.argv) < 2:
print("使用方法: python check_weights.py <checkpoint_path>")
print("示例: python check_weights.py ./output/checkpoints/best_model.pt")
return
checkpoint_path = sys.argv[1]
state_dict = analyze_checkpoint(checkpoint_path)
if state_dict:
check_weight_distribution(state_dict)
if __name__ == "__main__":
main()
for i in dataloader:
print((i['labels'] == 1).sum())
break

View File

@ -8,7 +8,7 @@
"id": 0,
"char": "",
"pinyin": "",
"count": 434748360
"count": 11067734826
},
"1": {
"id": 1,
@ -142549,4 +142549,4 @@
"compressed": false,
"pair_count": 20646
}
}
}

526
src/model/interactive_ui.py Normal file
View File

@ -0,0 +1,526 @@
"""
交互式训练配置界面
提供终端风格的交互式配置界面用于配置输入法模型训练参数
使用Rich库创建美观的终端界面
"""
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt
from rich.layout import Layout
from rich.columns import Columns
from rich.text import Text
from rich import box
class TrainingConfigUI:
"""
训练配置交互式界面
提供终端风格的交互式配置界面按必需程度分组参数
1. 必需参数必须提供
2. 推荐参数有合理默认值
3. 可选参数高级参数
4. 恢复参数训练恢复相关
"""
def __init__(self, console: Optional[Console] = None):
"""
初始化交互式界面
Args:
console: Rich控制台实例如果为None则创建新的
"""
self.console = console or Console()
self.config = {}
# 参数定义
self.param_definitions = self._load_param_definitions()
# 颜色主题
self.colors = {
"primary": "cyan",
"secondary": "green",
"accent": "yellow",
"warning": "red",
"success": "green",
"info": "blue",
}
def _load_param_definitions(self) -> Dict[str, List[Dict]]:
"""
加载参数定义
Returns:
按组分类的参数定义
"""
return {
"required": [
{
"name": "train_data_path",
"type": "path",
"prompt": "训练数据集路径",
"required": True,
"validate": self._validate_path,
},
{
"name": "eval_data_path",
"type": "path",
"prompt": "评估数据集路径",
"required": True,
"validate": self._validate_path,
},
],
"recommended": [
{
"name": "output_dir",
"type": "path",
"prompt": "输出目录",
"default": "./output",
"required": False,
},
{
"name": "batch_size",
"type": "int",
"prompt": "批次大小",
"default": 128,
"required": False,
"validate": lambda x: x > 0,
},
{
"name": "num_epochs",
"type": "int",
"prompt": "训练轮数",
"default": 10,
"required": False,
"validate": lambda x: x > 0,
},
{
"name": "learning_rate",
"type": "float",
"prompt": "学习率",
"default": 1e-5,
"required": False,
"validate": lambda x: 1e-9 <= x <= 1,
},
{
"name": "mixed_precision",
"type": "bool",
"prompt": "使用混合精度训练",
"default": True,
"required": False,
},
{
"name": "compile",
"type": "bool",
"prompt": "启用torch.compile优化",
"default": False,
"required": False,
},
{
"name": "use_tensorboard",
"type": "bool",
"prompt": "使用TensorBoard记录",
"default": True,
"required": False,
},
],
"optional": [
{
"name": "min_learning_rate",
"type": "float",
"prompt": "最小学习率",
"default": 1e-9,
"required": False,
"validate": lambda x: 1e-12 <= x <= 1,
},
{
"name": "warmup_ratio",
"type": "float",
"prompt": "热身步数比例",
"default": 0.1,
"required": False,
"validate": lambda x: 0 <= x <= 1,
},
{
"name": "eval_frequency",
"type": "int",
"prompt": "评估频率(步数)",
"default": 500,
"required": False,
"validate": lambda x: x > 0,
},
{
"name": "save_frequency",
"type": "int",
"prompt": "保存检查点频率(步数)",
"default": 1000,
"required": False,
"validate": lambda x: x > 0,
},
{
"name": "max_iter_length",
"type": "int",
"prompt": "数据集大小",
"default": 1024 * 1024 * 128,
"required": False,
"validate": lambda x: x > 0,
},
{
"name": "num_workers",
"type": "int",
"prompt": "数据加载worker数量",
"default": 2,
"required": False,
"validate": lambda x: x >= 0,
},
{
"name": "seed",
"type": "int",
"prompt": "随机种子",
"default": 42,
"required": False,
},
],
"recovery": [
{
"name": "auto_resume",
"type": "bool",
"prompt": "自动从最新checkpoint恢复",
"default": True,
"required": False,
},
{
"name": "resume_from",
"type": "path",
"prompt": "从指定checkpoint恢复可选",
"default": None,
"required": False,
"validate": self._validate_optional_path,
},
{
"name": "reset_training_state",
"type": "bool",
"prompt": "重置训练状态(只加载权重)",
"default": False,
"required": False,
},
],
}
def _validate_path(self, path: str) -> bool:
"""
验证路径
Args:
path: 路径字符串
Returns:
是否有效
"""
if not path:
return False
path_obj = Path(path)
return path_obj.exists()
def _validate_optional_path(self, path: Optional[str]) -> bool:
"""
验证可选路径可以为None
Args:
path: 路径字符串或None
Returns:
是否有效
"""
if path is None or path == "":
return True
return self._validate_path(path)
def show_welcome(self):
"""
显示欢迎界面
"""
self.console.clear()
welcome_text = Text()
welcome_text.append("🚀 ", style="bold yellow")
welcome_text.append("输入法模型训练系统", style="bold cyan")
welcome_text.append("\n")
welcome_text.append("=" * 50, style="dim")
welcome_text.append("\n\n")
welcome_text.append("欢迎使用交互式训练配置界面!\n", style="bold green")
welcome_text.append("请按照提示配置训练参数。\n", style="green")
welcome_text.append("\n")
welcome_text.append("按 [Enter] 使用默认值\n", style="dim")
welcome_text.append("按 [Ctrl+C] 退出\n", style="dim")
panel = Panel(
welcome_text,
title="🎯 开始配置",
border_style=self.colors["primary"],
padding=(1, 2),
expand=False,
)
self.console.print(panel)
self.console.print()
def ask_param_group(self, group_name: str, group_title: str) -> Dict[str, Any]:
"""
询问一组参数
Args:
group_name: 参数组名称
group_title: 显示标题
Returns:
该组的配置字典
"""
group_config = {}
params = self.param_definitions.get(group_name, [])
if not params:
return group_config
# 显示组标题
self.console.print()
self.console.print(f"[bold {self.colors['primary']}]{group_title}[/bold {self.colors['primary']}]")
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
for param in params:
name = param["name"]
prompt = param["prompt"]
param_type = param["type"]
default = param.get("default")
required = param.get("required", False)
validate_func = param.get("validate")
# 如果已经有值(从命令行传入),跳过
if name in self.config and self.config[name] is not None:
group_config[name] = self.config[name]
self.console.print(f" {prompt}: [green]{self.config[name]}[/green] (已提供)")
continue
# 询问参数
while True:
try:
if param_type == "bool":
value = Confirm.ask(
f" {prompt}",
default=default if default is not None else False,
show_default=True,
)
elif param_type == "int":
value = IntPrompt.ask(
f" {prompt}",
default=default if default is not None else 0,
show_default=True,
)
elif param_type == "float":
value = FloatPrompt.ask(
f" {prompt}",
default=default if default is not None else 0.0,
show_default=True,
)
else: # "path" or other string types
default_str = str(default) if default is not None else ""
value = Prompt.ask(
f" {prompt}",
default=default_str,
show_default=True,
)
# 处理空字符串
if value == "":
value = None
# 验证
if validate_func and value is not None:
if not validate_func(value):
self.console.print(f" [red]无效值,请重新输入[/red]")
continue
# 检查必需参数
if required and (value is None or value == ""):
self.console.print(f" [red]此参数为必需参数[/red]")
continue
group_config[name] = value
break
except KeyboardInterrupt:
self.console.print("\n[yellow]已取消[/yellow]")
raise
except Exception as e:
self.console.print(f" [red]输入错误: {e}[/red]")
if not required:
# 宽松处理:使用默认值
group_config[name] = default
self.console.print(f" [yellow]使用默认值: {default}[/yellow]")
break
return group_config
def show_config_summary(self, config: Dict[str, Any]):
"""
显示配置摘要
Args:
config: 完整的配置字典
"""
self.console.print()
self.console.print(f"[bold {self.colors['primary']]}📋 配置摘要[/bold {self.colors['primary']}]")
self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]")
# 创建表格
table = Table(show_header=True, header_style=f"bold {self.colors['accent']}", box=box.ROUNDED)
table.add_column("参数", style=self.colors["primary"])
table.add_column("", style=self.colors["success"])
table.add_column("类型", style=self.colors["info"])
# 按组添加参数
groups = [
("必需参数", "required"),
("推荐参数", "recommended"),
("可选参数", "optional"),
("恢复参数", "recovery"),
]
for group_title, group_name in groups:
params = self.param_definitions.get(group_name, [])
if params:
table.add_row(f"[bold]{group_title}[/bold]", "", "", style="bold")
for param in params:
name = param["name"]
if name in config:
value = config[name]
param_type = param["type"]
# 格式化值
if value is None:
value_str = "[dim]None[/dim]"
elif isinstance(value, bool):
value_str = "" if value else ""
else:
value_str = str(value)
table.add_row(f" {param['prompt']}", value_str, param_type)
self.console.print(table)
self.console.print()
def run(self, initial_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
运行完整的交互式配置流程
Args:
initial_config: 初始配置如从命令行传入的参数
Returns:
完整的配置字典
"""
# 初始化配置
self.config = initial_config or {}
try:
# 显示欢迎界面
self.show_welcome()
# 询问必需参数
self.console.print()
if Confirm.ask("[bold cyan]是否配置必需参数?[/bold cyan]", default=True):
required_config = self.ask_param_group("required", "📁 必需参数")
self.config.update(required_config)
# 检查必需参数是否已提供
required_params = self.param_definitions["required"]
for param in required_params:
if param["name"] not in self.config or self.config[param["name"]] is None:
self.console.print(f"[red]错误: 必需参数 '{param['prompt']}' 未提供[/red]")
self.console.print("[yellow]请重新配置必需参数[/yellow]")
required_config = self.ask_param_group("required", "📁 必需参数")
self.config.update(required_config)
# 询问推荐参数
self.console.print()
if Confirm.ask("[bold cyan]是否配置推荐参数?[/bold cyan]", default=True):
recommended_config = self.ask_param_group("recommended", "⚙️ 推荐参数")
self.config.update(recommended_config)
# 询问可选参数
self.console.print()
if Confirm.ask("[bold cyan]是否配置可选参数(高级)?[/bold cyan]", default=False):
optional_config = self.ask_param_group("optional", "🎛️ 可选参数")
self.config.update(optional_config)
# 询问恢复参数
self.console.print()
if Confirm.ask("[bold cyan]是否配置恢复参数?[/bold cyan]", default=True):
recovery_config = self.ask_param_group("recovery", "🔄 恢复参数")
self.config.update(recovery_config)
# 显示配置摘要
self.show_config_summary(self.config)
# 确认配置
self.console.print()
if not Confirm.ask("[bold green]是否使用此配置开始训练?[/bold green]", default=True):
self.console.print("[yellow]配置已取消[/yellow]")
return {}
return self.config
except KeyboardInterrupt:
self.console.print("\n[yellow]配置已取消[/yellow]")
return {}
except Exception as e:
self.console.print(f"[red]配置过程中出现错误: {e}[/red]")
return {}
def get_interactive_config(provided_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
获取交互式配置
Args:
provided_params: 已提供的参数如从命令行传入
Returns:
完整的配置字典
"""
console = Console()
ui = TrainingConfigUI(console)
# 过滤掉None值
if provided_params:
provided_params = {k: v for k, v in provided_params.items() if v is not None}
config = ui.run(provided_params)
return config
if __name__ == "__main__":
# 测试交互式界面
console = Console()
console.print("[bold]测试交互式配置界面[/bold]")
# 模拟已提供的参数
test_params = {
"train_data_path": "./data/train.txt",
"eval_data_path": "./data/eval.txt",
}
config = get_interactive_config(test_params)
if config:
console.print("\n[green]✓ 配置完成[/green]")
console.print(f"配置参数: {config}")
else:
console.print("\n[yellow]配置已取消[/yellow]")

View File

@ -166,7 +166,13 @@ class Trainer:
# 不加载历史数据,直接初始化为空列表以覆盖原有数据
self.training_status_data = []
# 初始化Rich控制台
# 初始化 epoch checkpoint 元数据
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
self.epoch_checkpoints = [] # 最多保留 3 个
self.next_epoch_slot = 0 # 下一个要覆盖的位置 (0-2)
self._load_epoch_metadata()
# 初始化 Rich 控制台
self.console = Console()
# 训练状态
@ -368,6 +374,122 @@ class Trainer:
else:
logger.info(f"Checkpoint saved to {checkpoint_path}")
def _load_epoch_metadata(self):
"""加载 epoch checkpoint 元数据"""
if self.epoch_metadata_file.exists():
try:
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
self.epoch_checkpoints = metadata.get("checkpoints", [])
self.next_epoch_slot = metadata.get("next_slot", 0)
logger.info(
f"Loaded epoch checkpoint metadata: {len(self.epoch_checkpoints)} checkpoints"
)
except Exception as e:
logger.warning(f"Failed to load epoch metadata: {e}")
self.epoch_checkpoints = []
self.next_epoch_slot = 0
else:
self.epoch_checkpoints = []
self.next_epoch_slot = 0
def _save_epoch_metadata(self):
"""保存 epoch checkpoint 元数据"""
metadata = {
"checkpoints": self.epoch_checkpoints,
"next_slot": self.next_epoch_slot,
"total_epochs_completed": max(
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
),
}
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
def save_epoch_checkpoint(self, epoch: int):
"""
保存 epoch checkpoint循环覆盖只保留最后 3
Args:
epoch: 当前 epoch 编号 1 开始
"""
# 确定文件名(循环使用 3 个固定文件名)
slot = self.next_epoch_slot
filename = f"epoch_checkpoint_{slot + 1}.pt"
checkpoint_path = self.checkpoint_dir / filename
# 保存 checkpoint
checkpoint = {
"step": self.current_step,
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"best_eval_loss": self.best_eval_loss,
"config": {
"learning_rate": self.learning_rate,
"weight_decay": self.weight_decay,
"warmup_ratio": self.warmup_ratio,
"label_smoothing": self.label_smoothing,
"total_steps": self.total_steps,
},
}
torch.save(checkpoint, checkpoint_path)
# 更新元数据
checkpoint_info = {
"epoch": epoch,
"file": filename,
"path": str(checkpoint_path),
"saved_at": datetime.now().isoformat(),
"step": self.current_step,
}
# 如果已经有 3 个,替换对应位置的;否则添加
if len(self.epoch_checkpoints) >= 3:
self.epoch_checkpoints[slot] = checkpoint_info
else:
self.epoch_checkpoints.append(checkpoint_info)
# 更新下一个槽位
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
# 保存元数据
self._save_epoch_metadata()
# 按 epoch 排序后获取最新的 epoch
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
latest_epoch = sorted_checkpoints[-1]["epoch"] if sorted_checkpoints else epoch
logger.info(
f"Epoch {epoch} checkpoint saved to {filename} "
f"(keeping last {len(self.epoch_checkpoints)} epochs: "
f"{[cp['epoch'] for cp in sorted_checkpoints]})"
)
def get_latest_epoch_checkpoint(self) -> Optional[Dict]:
"""
获取最新的 epoch checkpoint 信息
Returns:
最新的 checkpoint 信息字典如果没有则返回 None
"""
if not self.epoch_checkpoints:
return None
# 按 epoch 排序,返回最新的
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
return sorted_checkpoints[-1]
def get_epoch_checkpoints(self) -> List[Dict]:
"""
获取所有保存的 epoch checkpoint 信息 epoch 排序
Returns:
checkpoint 信息列表 epoch 升序排列
"""
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
def load_checkpoint(
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
):
@ -534,7 +656,10 @@ class Trainer:
self.console.print(info_table)
def train(
self, resume_from: Optional[str] = None, reset_training_state: bool = False
self,
resume_from: Optional[str] = None,
reset_training_state: bool = False,
auto_resume: bool = True,
):
"""
主训练循环
@ -542,10 +667,23 @@ class Trainer:
Args:
resume_from: 从哪个检查点恢复训练可选
reset_training_state: 是否重置训练状态只加载模型权重从头开始训练
auto_resume: 是否自动从最新的 epoch checkpoint 恢复如果存在
"""
# 如果提供了检查点,则恢复训练
# 如果提供了检查点,则优先使用提供的检查点恢复训练
if resume_from is not None:
self.load_checkpoint(resume_from, reset_training_state=reset_training_state)
elif auto_resume and self.epoch_checkpoints:
# 自动从最新的 epoch checkpoint 恢复
latest_checkpoint = self.get_latest_epoch_checkpoint()
if latest_checkpoint:
checkpoint_path = latest_checkpoint["path"]
logger.info(
f"Auto-resuming from latest epoch checkpoint: {checkpoint_path} "
f"(epoch {latest_checkpoint['epoch']})"
)
self.load_checkpoint(
checkpoint_path, reset_training_state=reset_training_state
)
# 打印训练信息
self._print_training_info()
@ -675,8 +813,8 @@ class Trainer:
# 进度条不重置,显示整体训练进度
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
self.save_epoch_checkpoint(epoch + 1)
# 检查是否达到总步数
if global_step >= self.total_steps:
@ -685,10 +823,18 @@ class Trainer:
# 训练完成
logger.info("Training completed!")
# 保存最终模型
self.save_checkpoint("final_model.pt")
# 显示保存的 epoch checkpoint 信息
if self.epoch_checkpoints:
sorted_checkpoints = self.get_epoch_checkpoints()
logger.info(
f"Saved epoch checkpoints: {[cp['epoch'] for cp in sorted_checkpoints]}"
)
logger.info(
f"Latest checkpoint: epoch {sorted_checkpoints[-1]['epoch']} "
f"({sorted_checkpoints[-1]['file']})"
)
# 关闭TensorBoard写入器
# 关闭 TensorBoard 写入器
if self.writer is not None:
self.writer.close()
@ -907,21 +1053,10 @@ def train(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
# 数据大小
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
@ -954,6 +1089,9 @@ def train(
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
),
auto_resume: bool = typer.Option(
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: bool = typer.Option(
False,
@ -975,6 +1113,17 @@ def train(
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 硬编码模型参数
vocab_size = 10019
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 20
max_seq_len = 128
use_pinyin = True # 始终使用拼音
console = Console()
# 打印配置信息
@ -1014,6 +1163,7 @@ def train(
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
# 创建输出目录
@ -1049,6 +1199,7 @@ def train(
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
}
@ -1149,10 +1300,12 @@ def train(
# 开始训练
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"开始时间{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
resume_from=resume_from,
reset_training_state=reset_training_state,
auto_resume=auto_resume,
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
@ -1223,20 +1376,10 @@ def expand_and_train(
"-m",
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
),
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
# 数据大小
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
@ -1269,6 +1412,9 @@ def expand_and_train(
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
),
auto_resume: bool = typer.Option(
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: bool = typer.Option(
False,
@ -1287,6 +1433,16 @@ def expand_and_train(
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 硬编码模型参数
vocab_size = 10019
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 20
max_seq_len = 128
use_pinyin = True # 始终使用拼音
console = Console()
# 打印配置信息
@ -1330,6 +1486,7 @@ def expand_and_train(
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
# 创建输出目录
@ -1367,6 +1524,7 @@ def expand_and_train(
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
}
@ -1482,10 +1640,12 @@ def expand_and_train(
# 开始训练
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"开始时间{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
resume_from=resume_from,
reset_training_state=reset_training_state,
auto_resume=auto_resume,
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
@ -1598,6 +1758,9 @@ def expand_finetune(
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态"
),
auto_resume: bool = typer.Option(
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: Optional[bool] = typer.Option(
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
@ -1697,6 +1860,7 @@ def expand_finetune(
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
output_path = Path(final_output_dir)
@ -1786,10 +1950,12 @@ def expand_finetune(
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"开始时间{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from, reset_training_state=reset_training_state
resume_from=resume_from,
reset_training_state=reset_training_state,
auto_resume=auto_resume,
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")

View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都"
part2 = "bxu"
part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为"
part2 = "weiruo"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
@ -56,7 +56,7 @@ if len_py < 24:
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
masked_labels = [649, 925, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = "可行|特别|伤害"
@ -83,7 +83,7 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu")
checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
input_ids = sample["input_ids"]

210
test_epoch_checkpoint.py Normal file
View File

@ -0,0 +1,210 @@
"""
测试 epoch checkpoint 循环保存功能
"""
import json
import tempfile
from pathlib import Path
from datetime import datetime
# 模拟 Trainer 类的 epoch checkpoint 管理逻辑
class EpochCheckpointManager:
def __init__(self, checkpoint_dir: Path):
self.checkpoint_dir = checkpoint_dir
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
self.epoch_checkpoints = []
self.next_epoch_slot = 0
def _load_epoch_metadata(self):
if self.epoch_metadata_file.exists():
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
self.epoch_checkpoints = metadata.get("checkpoints", [])
self.next_epoch_slot = metadata.get("next_slot", 0)
def _save_epoch_metadata(self):
metadata = {
"checkpoints": self.epoch_checkpoints,
"next_slot": self.next_epoch_slot,
"total_epochs_completed": max(
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
),
}
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
def save_epoch_checkpoint(self, epoch: int):
"""模拟保存 epoch checkpoint不实际保存模型文件"""
slot = self.next_epoch_slot
filename = f"epoch_checkpoint_{slot + 1}.pt"
checkpoint_path = self.checkpoint_dir / filename
checkpoint_info = {
"epoch": epoch,
"file": filename,
"path": str(checkpoint_path),
"saved_at": datetime.now().isoformat(),
"step": epoch * 100, # 模拟 step
}
if len(self.epoch_checkpoints) >= 3:
self.epoch_checkpoints[slot] = checkpoint_info
else:
self.epoch_checkpoints.append(checkpoint_info)
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
self._save_epoch_metadata()
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
print(
f"✓ Epoch {epoch:2d} saved -> {filename:20s} "
f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})"
)
def get_latest_epoch_checkpoint(self):
if not self.epoch_checkpoints:
return None
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
return sorted_checkpoints[-1]
def get_epoch_checkpoints(self):
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
def test_circular_save():
"""测试循环保存功能"""
print("=" * 70)
print("测试:循环保存最后 3 个 epoch checkpoint")
print("=" * 70)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_dir = Path(tmpdir)
manager = EpochCheckpointManager(checkpoint_dir)
# 模拟训练 10 个 epoch
print("\n模拟训练 10 个 epoch:")
print("-" * 70)
for epoch in range(1, 11):
manager.save_epoch_checkpoint(epoch)
# 验证最终状态
print("\n" + "=" * 70)
print("最终状态验证:")
print("=" * 70)
# 检查 JSON 文件
metadata_file = checkpoint_dir / "epoch_checkpoints.json"
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
print(f"\n✓ JSON 元数据文件:{metadata_file.name}")
print(f" next_slot: {metadata['next_slot']}")
print(f" total_epochs_completed: {metadata['total_epochs_completed']}")
print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}")
print(f" 期望3 个")
sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"])
print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
print(f" 期望:[8, 9, 10]")
# 验证顺序
latest = manager.get_latest_epoch_checkpoint()
print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})")
print(f" 期望epoch 10")
# 验证文件对应关系
print(f"\n✓ 文件对应关系:")
for cp in sorted_checkpoints:
print(f" epoch {cp['epoch']:2d} -> {cp['file']}")
# 验证
assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint"
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
"应该保留最后 3 个 epoch"
)
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
print("\n" + "=" * 70)
print("✓✓✓ 所有测试通过!")
print("=" * 70)
def test_auto_resume():
"""测试自动恢复功能"""
print("\n" + "=" * 70)
print("测试:从 JSON 自动恢复 checkpoint 顺序")
print("=" * 70)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_dir = Path(tmpdir)
# 创建 manager 并保存一些 checkpoint
manager1 = EpochCheckpointManager(checkpoint_dir)
print("\n第一轮训练(保存 epoch 1-5:")
print("-" * 70)
for epoch in range(1, 6):
manager1.save_epoch_checkpoint(epoch)
# 模拟新的训练会话(从 JSON 恢复)
print("\n\n第二轮训练(从 JSON 恢复):")
print("-" * 70)
manager2 = EpochCheckpointManager(checkpoint_dir)
manager2._load_epoch_metadata()
print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}")
print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}")
print(f"✓ next_slot: {manager2.next_epoch_slot}")
# 继续保存
print("\n继续保存 epoch 6-10:")
print("-" * 70)
for epoch in range(6, 11):
manager2.save_epoch_checkpoint(epoch)
# 验证最终状态
print("\n" + "=" * 70)
print("最终状态验证:")
print("=" * 70)
sorted_checkpoints = manager2.get_epoch_checkpoints()
print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
print(f" 期望:[8, 9, 10]")
latest = manager2.get_latest_epoch_checkpoint()
print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}")
print(f" 期望epoch 10")
# 验证
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
"应该保留最后 3 个 epoch"
)
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
print("\n" + "=" * 70)
print("✓✓✓ 自动恢复测试通过!")
print("=" * 70)
if __name__ == "__main__":
test_circular_save()
test_auto_resume()
print("\n" + "=" * 70)
print("所有测试完成!")
print("=" * 70)
print("\n功能总结:")
print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint")
print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据")
print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序")
print(" 4. ✓ 可以正确识别最新的 checkpoint")
print("\n保留的文件:")
print(" - best_model.pt (最佳模型)")
print(" - latest_checkpoint.pt (定期保存)")
print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt")
print(" - epoch_checkpoints.json (元数据)")
print("\n移除的文件:")
print(" - final_model.pt (与 epoch_last 重复)")
print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)")
print("=" * 70)