diff --git a/CHECKPOINT_CHANGES.md b/CHECKPOINT_CHANGES.md new file mode 100644 index 0000000..f73c209 --- /dev/null +++ b/CHECKPOINT_CHANGES.md @@ -0,0 +1,187 @@ +# Checkpoint 保存策略更新说明 + +## 更新概述 + +为了解决硬盘空间限制问题,修改了模型 checkpoint 的保存方式,采用**同名覆盖**的方式仅保留**最后 3 个 epoch**的 checkpoint。 + +## 主要变更 + +### 1. 保存策略 + +#### 保留的文件 +- ✅ `best_model.pt` - 最佳模型(当 eval_loss 降低时保存) +- ✅ `latest_checkpoint.pt` - 定期保存的检查点(每 save_frequency 步覆盖保存) +- ✅ `epoch_checkpoint_1.pt` - 循环保存的最后 3 个 epoch 之一 +- ✅ `epoch_checkpoint_2.pt` - 循环保存的最后 3 个 epoch 之一 +- ✅ `epoch_checkpoint_3.pt` - 循环保存的最后 3 个 epoch 之一 +- ✅ `epoch_checkpoints.json` - epoch checkpoint 元数据文件 + +#### 移除的文件 +- ❌ `final_model.pt` - 与 epoch_last 重复,已移除 +- ❌ `epoch_*.pt` - 不再为每个 epoch 创建独立文件 + +### 2. 循环覆盖机制 + +使用 3 个固定文件名循环覆盖: +``` +Epoch 1 -> epoch_checkpoint_1.pt +Epoch 2 -> epoch_checkpoint_2.pt +Epoch 3 -> epoch_checkpoint_3.pt +Epoch 4 -> epoch_checkpoint_1.pt (覆盖 epoch 1) +Epoch 5 -> epoch_checkpoint_2.pt (覆盖 epoch 2) +... +``` + +### 3. JSON 元数据管理 + +`epoch_checkpoints.json` 文件记录: +```json +{ + "checkpoints": [ + { + "epoch": 8, + "file": "epoch_checkpoint_2.pt", + "path": "/path/to/checkpoints/epoch_checkpoint_2.pt", + "saved_at": "2026-04-11T10:30:00", + "step": 800 + }, + { + "epoch": 9, + "file": "epoch_checkpoint_3.pt", + "path": "/path/to/checkpoints/epoch_checkpoint_3.pt", + "saved_at": "2026-04-11T10:35:00", + "step": 900 + }, + { + "epoch": 10, + "file": "epoch_checkpoint_1.pt", + "path": "/path/to/checkpoints/epoch_checkpoint_1.pt", + "saved_at": "2026-04-11T10:40:00", + "step": 1000 + } + ], + "next_slot": 1, + "total_epochs_completed": 10 +} +``` + +### 4. 自动恢复功能 + +训练时会自动从 JSON 文件中读取最新的 checkpoint 并恢复训练: + +```python +# 训练时自动恢复 +trainer.train( + resume_from=None, # 如果指定,优先使用指定的 checkpoint + reset_training_state=False, + auto_resume=True # 自动从最新的 epoch checkpoint 恢复 +) +``` + +**恢复优先级:** +1. 如果指定了 `resume_from`,使用指定的 checkpoint +2. 否则,如果存在 epoch checkpoint 元数据,自动从最新的 epoch checkpoint 恢复 +3. 否则,从头开始训练 + +## 使用方法 + +### 查看保存的 checkpoint + +训练结束后,可以通过 JSON 文件查看保存的 checkpoint: + +```bash +# 查看元数据 +cat ./output/checkpoints/epoch_checkpoints.json + +# 或者使用 Python +python -c " +import json +with open('./output/checkpoints/epoch_checkpoints.json') as f: + data = json.load(f) + print('保存的 epochs:', [cp['epoch'] for cp in data['checkpoints']]) + print('最新的 checkpoint:', max(data['checkpoints'], key=lambda x: x['epoch'])) +" +``` + +### 手动加载特定 epoch 的 checkpoint + +```python +import torch +from src.model.trainer import Trainer + +# 加载元数据 +import json +with open('./output/checkpoints/epoch_checkpoints.json') as f: + data = json.load(f) + +# 按 epoch 排序 +sorted_checkpoints = sorted(data['checkpoints'], key=lambda x: x['epoch']) +print('可用的 epochs:', [cp['epoch'] for cp in sorted_checkpoints]) + +# 加载特定 epoch 的 checkpoint +target_epoch = 9 +checkpoint_info = next(cp for cp in sorted_checkpoints if cp['epoch'] == target_epoch) +checkpoint_path = checkpoint_info['path'] + +# 使用 trainer 加载 +trainer.load_checkpoint(checkpoint_path) +``` + +### 禁用自动恢复 + +如果需要从头开始训练,可以禁用自动恢复: + +```bash +# 在代码中设置 +trainer.train(auto_resume=False) +``` + +## 优势 + +1. **节省磁盘空间** - 只保留最后 3 个 epoch,不会随训练时间增长而占用更多空间 +2. **自动管理** - 无需手动删除旧的 checkpoint +3. **顺序清晰** - 通过 JSON 文件可以清楚知道每个 checkpoint 对应的 epoch +4. **自动恢复** - 训练中断后可以从最近的 checkpoint 自动恢复 +5. **保留重要 checkpoint** - best_model 和 latest_checkpoint 仍然保留 + +## 磁盘空间对比 + +### 修改前 +``` +假设训练 100 个 epoch,每个 checkpoint 100MB: +- epoch_1.pt ~ epoch_100.pt: 100 * 100MB = 10GB +- best_model.pt: 100MB +- final_model.pt: 100MB +- latest_checkpoint.pt: 100MB +总计:~10.3GB +``` + +### 修改后 +``` +训练 100 个 epoch,每个 checkpoint 100MB: +- epoch_checkpoint_1~3.pt: 3 * 100MB = 300MB +- best_model.pt: 100MB +- latest_checkpoint.pt: 100MB +- epoch_checkpoints.json: <1KB +总计:~400MB + +节省空间:~9.9GB (96% 空间节省) +``` + +## 注意事项 + +1. **训练中断** - 如果训练在中途被中断,只会保留中断前的最后 3 个 epoch +2. **JSON 文件** - 不要手动修改 `epoch_checkpoints.json` 文件,否则可能导致恢复失败 +3. **兼容性** - 旧的 `epoch_*.pt` 文件不会自动删除,如果需要可以手动清理 + +## 测试 + +运行测试脚本验证功能: +```bash +python test_epoch_checkpoint.py +``` + +## 相关文件 + +- `src/model/trainer.py` - 主要修改文件 +- `test_epoch_checkpoint.py` - 功能测试脚本 diff --git a/check_weights.py b/check_weights.py index f69d629..d77e015 100755 --- a/check_weights.py +++ b/check_weights.py @@ -1,147 +1,22 @@ -#!/usr/bin/env python3 -""" -快速检查模型权重加载情况的脚本 -""" +from model.dataset import PinyinInputDataset +from torch.utils.data import DataLoader -from pathlib import Path - -import numpy as np -import torch +from model.trainer import collate_fn, worker_init_fn -def analyze_checkpoint(checkpoint_path): - """分析checkpoint文件""" - print(f"🔍 分析checkpoint: {checkpoint_path}") +data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/') - if not Path(checkpoint_path).exists(): - print(f"❌ 文件不存在") - return +dataloader = DataLoader( + data, + batch_size=1024, + num_workers=2, + worker_init_fn=worker_init_fn, + collate_fn=collate_fn, + prefetch_factor=2, # 减少预取以避免内存问题 + persistent_workers=True, + shuffle=False, +) - try: - checkpoint = torch.load(checkpoint_path, map_location="cpu") - print(f"✅ 加载成功") - print(f" 类型: {type(checkpoint)}") - - if isinstance(checkpoint, dict): - print(f" 键名: {list(checkpoint.keys())}") - - # 找到模型状态字典 - state_dict = None - if "model_state_dict" in checkpoint: - state_dict = checkpoint["model_state_dict"] - print(f" 🔍 使用'model_state_dict'键") - elif "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - print(f" 🔍 使用'state_dict'键") - else: - # 可能是直接的状态字典 - state_dict = checkpoint - print(f" 🔍 使用直接状态字典") - - if state_dict: - print(f" 总权重数: {len(state_dict)}") - - # 分析分类头权重 - classifier_keys = [] - for key in state_dict.keys(): - if "classifier" in key: - classifier_keys.append(key) - - if classifier_keys: - print(f" 📊 分类头相关权重:") - for key in classifier_keys: - weight = state_dict[key] - print(f" {key}: shape={weight.shape}") - print(f" 范围: [{weight.min():.6f}, {weight.max():.6f}]") - print(f" 均值: {weight.mean():.6f}") - print(f" 标准差: {weight.std():.6f}") - - # 检查权重是否接近随机初始化 - if weight.std() < 0.01: - print(f" ⚠️ 警告: 权重标准差很小,可能未正确训练") - - # 检查模型架构键名 - print(f"\n 🔑 模型架构键名示例(前20个):") - for i, key in enumerate(list(state_dict.keys())[:20]): - weight = state_dict[key] - print(f" {i + 1:2d}. {key:40} shape={str(weight.shape):15}") - - # 检查是否有预期的组件 - expected_components = [ - "context_encoder", - "slot_memory", - "cross_attn", - "moe", - "classifier", - ] - found_components = [] - for comp in expected_components: - found = any(comp in key for key in state_dict.keys()) - if found: - found_components.append(comp) - - print(f"\n 📋 找到的模型组件: {found_components}") - missing = set(expected_components) - set(found_components) - if missing: - print(f" ❌ 缺失的组件: {missing}") - - return state_dict - else: - print(f"❌ checkpoint不是字典类型") - - except Exception as e: - print(f"❌ 加载失败: {e}") - import traceback - - traceback.print_exc() - - -def check_weight_distribution(state_dict): - """检查权重分布""" - print(f"\n📊 权重分布统计:") - - weight_stats = [] - for key, weight in state_dict.items(): - if "weight" in key and len(weight.shape) >= 2: # 只检查权重矩阵,不包括偏置 - stats = { - "key": key, - "shape": weight.shape, - "min": weight.min().item(), - "max": weight.max().item(), - "mean": weight.mean().item(), - "std": weight.std().item(), - "abs_mean": weight.abs().mean().item(), - } - weight_stats.append(stats) - - # 打印前10个权重 - for i, stats in enumerate(weight_stats[:10]): - print(f" {i + 1:2d}. {stats['key']:40}") - print(f" 形状: {stats['shape']}") - print(f" 范围: [{stats['min']:.6f}, {stats['max']:.6f}]") - print(f" 均值: {stats['mean']:.6f} ± {stats['std']:.6f}") - - # 检查是否接近随机初始化 - if stats["std"] < 0.01: - print(f" ⚠️ 警告: 标准差很小,可能未训练") - - return weight_stats - - -def main(): - import sys - - if len(sys.argv) < 2: - print("使用方法: python check_weights.py ") - print("示例: python check_weights.py ./output/checkpoints/best_model.pt") - return - - checkpoint_path = sys.argv[1] - state_dict = analyze_checkpoint(checkpoint_path) - - if state_dict: - check_weight_distribution(state_dict) - - -if __name__ == "__main__": - main() +for i in dataloader: + print((i['labels'] == 1).sum()) + break diff --git a/src/model/assets/pinyin_char_statistics.json b/src/model/assets/pinyin_char_statistics.json index 30d55f2..66ad104 100644 --- a/src/model/assets/pinyin_char_statistics.json +++ b/src/model/assets/pinyin_char_statistics.json @@ -8,7 +8,7 @@ "id": 0, "char": "", "pinyin": "", - "count": 434748360 + "count": 11067734826 }, "1": { "id": 1, @@ -142549,4 +142549,4 @@ "compressed": false, "pair_count": 20646 } -} \ No newline at end of file +} diff --git a/src/model/interactive_ui.py b/src/model/interactive_ui.py new file mode 100644 index 0000000..2b66b95 --- /dev/null +++ b/src/model/interactive_ui.py @@ -0,0 +1,526 @@ +""" +交互式训练配置界面 + +提供终端风格的交互式配置界面,用于配置输入法模型训练参数。 +使用Rich库创建美观的终端界面。 +""" + +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from rich.console import Console +from rich.panel import Panel +from rich.table import Table +from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt +from rich.layout import Layout +from rich.columns import Columns +from rich.text import Text +from rich import box + + +class TrainingConfigUI: + """ + 训练配置交互式界面 + + 提供终端风格的交互式配置界面,按必需程度分组参数: + 1. 必需参数(必须提供) + 2. 推荐参数(有合理默认值) + 3. 可选参数(高级参数) + 4. 恢复参数(训练恢复相关) + """ + + def __init__(self, console: Optional[Console] = None): + """ + 初始化交互式界面 + + Args: + console: Rich控制台实例,如果为None则创建新的 + """ + self.console = console or Console() + self.config = {} + + # 参数定义 + self.param_definitions = self._load_param_definitions() + + # 颜色主题 + self.colors = { + "primary": "cyan", + "secondary": "green", + "accent": "yellow", + "warning": "red", + "success": "green", + "info": "blue", + } + + def _load_param_definitions(self) -> Dict[str, List[Dict]]: + """ + 加载参数定义 + + Returns: + 按组分类的参数定义 + """ + return { + "required": [ + { + "name": "train_data_path", + "type": "path", + "prompt": "训练数据集路径", + "required": True, + "validate": self._validate_path, + }, + { + "name": "eval_data_path", + "type": "path", + "prompt": "评估数据集路径", + "required": True, + "validate": self._validate_path, + }, + ], + "recommended": [ + { + "name": "output_dir", + "type": "path", + "prompt": "输出目录", + "default": "./output", + "required": False, + }, + { + "name": "batch_size", + "type": "int", + "prompt": "批次大小", + "default": 128, + "required": False, + "validate": lambda x: x > 0, + }, + { + "name": "num_epochs", + "type": "int", + "prompt": "训练轮数", + "default": 10, + "required": False, + "validate": lambda x: x > 0, + }, + { + "name": "learning_rate", + "type": "float", + "prompt": "学习率", + "default": 1e-5, + "required": False, + "validate": lambda x: 1e-9 <= x <= 1, + }, + { + "name": "mixed_precision", + "type": "bool", + "prompt": "使用混合精度训练", + "default": True, + "required": False, + }, + { + "name": "compile", + "type": "bool", + "prompt": "启用torch.compile优化", + "default": False, + "required": False, + }, + { + "name": "use_tensorboard", + "type": "bool", + "prompt": "使用TensorBoard记录", + "default": True, + "required": False, + }, + ], + "optional": [ + { + "name": "min_learning_rate", + "type": "float", + "prompt": "最小学习率", + "default": 1e-9, + "required": False, + "validate": lambda x: 1e-12 <= x <= 1, + }, + { + "name": "warmup_ratio", + "type": "float", + "prompt": "热身步数比例", + "default": 0.1, + "required": False, + "validate": lambda x: 0 <= x <= 1, + }, + { + "name": "eval_frequency", + "type": "int", + "prompt": "评估频率(步数)", + "default": 500, + "required": False, + "validate": lambda x: x > 0, + }, + { + "name": "save_frequency", + "type": "int", + "prompt": "保存检查点频率(步数)", + "default": 1000, + "required": False, + "validate": lambda x: x > 0, + }, + { + "name": "max_iter_length", + "type": "int", + "prompt": "数据集大小", + "default": 1024 * 1024 * 128, + "required": False, + "validate": lambda x: x > 0, + }, + { + "name": "num_workers", + "type": "int", + "prompt": "数据加载worker数量", + "default": 2, + "required": False, + "validate": lambda x: x >= 0, + }, + { + "name": "seed", + "type": "int", + "prompt": "随机种子", + "default": 42, + "required": False, + }, + ], + "recovery": [ + { + "name": "auto_resume", + "type": "bool", + "prompt": "自动从最新checkpoint恢复", + "default": True, + "required": False, + }, + { + "name": "resume_from", + "type": "path", + "prompt": "从指定checkpoint恢复(可选)", + "default": None, + "required": False, + "validate": self._validate_optional_path, + }, + { + "name": "reset_training_state", + "type": "bool", + "prompt": "重置训练状态(只加载权重)", + "default": False, + "required": False, + }, + ], + } + + def _validate_path(self, path: str) -> bool: + """ + 验证路径 + + Args: + path: 路径字符串 + + Returns: + 是否有效 + """ + if not path: + return False + path_obj = Path(path) + return path_obj.exists() + + def _validate_optional_path(self, path: Optional[str]) -> bool: + """ + 验证可选路径(可以为None) + + Args: + path: 路径字符串或None + + Returns: + 是否有效 + """ + if path is None or path == "": + return True + return self._validate_path(path) + + def show_welcome(self): + """ + 显示欢迎界面 + """ + self.console.clear() + + welcome_text = Text() + welcome_text.append("🚀 ", style="bold yellow") + welcome_text.append("输入法模型训练系统", style="bold cyan") + welcome_text.append("\n") + welcome_text.append("=" * 50, style="dim") + welcome_text.append("\n\n") + welcome_text.append("欢迎使用交互式训练配置界面!\n", style="bold green") + welcome_text.append("请按照提示配置训练参数。\n", style="green") + welcome_text.append("\n") + welcome_text.append("按 [Enter] 使用默认值\n", style="dim") + welcome_text.append("按 [Ctrl+C] 退出\n", style="dim") + + panel = Panel( + welcome_text, + title="🎯 开始配置", + border_style=self.colors["primary"], + padding=(1, 2), + expand=False, + ) + + self.console.print(panel) + self.console.print() + + def ask_param_group(self, group_name: str, group_title: str) -> Dict[str, Any]: + """ + 询问一组参数 + + Args: + group_name: 参数组名称 + group_title: 显示标题 + + Returns: + 该组的配置字典 + """ + group_config = {} + params = self.param_definitions.get(group_name, []) + + if not params: + return group_config + + # 显示组标题 + self.console.print() + self.console.print(f"[bold {self.colors['primary']}]{group_title}[/bold {self.colors['primary']}]") + self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]") + + for param in params: + name = param["name"] + prompt = param["prompt"] + param_type = param["type"] + default = param.get("default") + required = param.get("required", False) + validate_func = param.get("validate") + + # 如果已经有值(从命令行传入),跳过 + if name in self.config and self.config[name] is not None: + group_config[name] = self.config[name] + self.console.print(f" {prompt}: [green]{self.config[name]}[/green] (已提供)") + continue + + # 询问参数 + while True: + try: + if param_type == "bool": + value = Confirm.ask( + f" {prompt}", + default=default if default is not None else False, + show_default=True, + ) + elif param_type == "int": + value = IntPrompt.ask( + f" {prompt}", + default=default if default is not None else 0, + show_default=True, + ) + elif param_type == "float": + value = FloatPrompt.ask( + f" {prompt}", + default=default if default is not None else 0.0, + show_default=True, + ) + else: # "path" or other string types + default_str = str(default) if default is not None else "" + value = Prompt.ask( + f" {prompt}", + default=default_str, + show_default=True, + ) + # 处理空字符串 + if value == "": + value = None + + # 验证 + if validate_func and value is not None: + if not validate_func(value): + self.console.print(f" [red]无效值,请重新输入[/red]") + continue + + # 检查必需参数 + if required and (value is None or value == ""): + self.console.print(f" [red]此参数为必需参数[/red]") + continue + + group_config[name] = value + break + + except KeyboardInterrupt: + self.console.print("\n[yellow]已取消[/yellow]") + raise + except Exception as e: + self.console.print(f" [red]输入错误: {e}[/red]") + if not required: + # 宽松处理:使用默认值 + group_config[name] = default + self.console.print(f" [yellow]使用默认值: {default}[/yellow]") + break + + return group_config + + def show_config_summary(self, config: Dict[str, Any]): + """ + 显示配置摘要 + + Args: + config: 完整的配置字典 + """ + self.console.print() + self.console.print(f"[bold {self.colors['primary']]}📋 配置摘要[/bold {self.colors['primary']}]") + self.console.print(f"[{self.colors['secondary']}]{'=' * 40}[/{self.colors['secondary']}]") + + # 创建表格 + table = Table(show_header=True, header_style=f"bold {self.colors['accent']}", box=box.ROUNDED) + table.add_column("参数", style=self.colors["primary"]) + table.add_column("值", style=self.colors["success"]) + table.add_column("类型", style=self.colors["info"]) + + # 按组添加参数 + groups = [ + ("必需参数", "required"), + ("推荐参数", "recommended"), + ("可选参数", "optional"), + ("恢复参数", "recovery"), + ] + + for group_title, group_name in groups: + params = self.param_definitions.get(group_name, []) + if params: + table.add_row(f"[bold]{group_title}[/bold]", "", "", style="bold") + + for param in params: + name = param["name"] + if name in config: + value = config[name] + param_type = param["type"] + + # 格式化值 + if value is None: + value_str = "[dim]None[/dim]" + elif isinstance(value, bool): + value_str = "是" if value else "否" + else: + value_str = str(value) + + table.add_row(f" {param['prompt']}", value_str, param_type) + + self.console.print(table) + self.console.print() + + def run(self, initial_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + 运行完整的交互式配置流程 + + Args: + initial_config: 初始配置(如从命令行传入的参数) + + Returns: + 完整的配置字典 + """ + # 初始化配置 + self.config = initial_config or {} + + try: + # 显示欢迎界面 + self.show_welcome() + + # 询问必需参数 + self.console.print() + if Confirm.ask("[bold cyan]是否配置必需参数?[/bold cyan]", default=True): + required_config = self.ask_param_group("required", "📁 必需参数") + self.config.update(required_config) + + # 检查必需参数是否已提供 + required_params = self.param_definitions["required"] + for param in required_params: + if param["name"] not in self.config or self.config[param["name"]] is None: + self.console.print(f"[red]错误: 必需参数 '{param['prompt']}' 未提供[/red]") + self.console.print("[yellow]请重新配置必需参数[/yellow]") + required_config = self.ask_param_group("required", "📁 必需参数") + self.config.update(required_config) + + # 询问推荐参数 + self.console.print() + if Confirm.ask("[bold cyan]是否配置推荐参数?[/bold cyan]", default=True): + recommended_config = self.ask_param_group("recommended", "⚙️ 推荐参数") + self.config.update(recommended_config) + + # 询问可选参数 + self.console.print() + if Confirm.ask("[bold cyan]是否配置可选参数(高级)?[/bold cyan]", default=False): + optional_config = self.ask_param_group("optional", "🎛️ 可选参数") + self.config.update(optional_config) + + # 询问恢复参数 + self.console.print() + if Confirm.ask("[bold cyan]是否配置恢复参数?[/bold cyan]", default=True): + recovery_config = self.ask_param_group("recovery", "🔄 恢复参数") + self.config.update(recovery_config) + + # 显示配置摘要 + self.show_config_summary(self.config) + + # 确认配置 + self.console.print() + if not Confirm.ask("[bold green]是否使用此配置开始训练?[/bold green]", default=True): + self.console.print("[yellow]配置已取消[/yellow]") + return {} + + return self.config + + except KeyboardInterrupt: + self.console.print("\n[yellow]配置已取消[/yellow]") + return {} + except Exception as e: + self.console.print(f"[red]配置过程中出现错误: {e}[/red]") + return {} + + +def get_interactive_config(provided_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + 获取交互式配置 + + Args: + provided_params: 已提供的参数(如从命令行传入) + + Returns: + 完整的配置字典 + """ + console = Console() + ui = TrainingConfigUI(console) + + # 过滤掉None值 + if provided_params: + provided_params = {k: v for k, v in provided_params.items() if v is not None} + + config = ui.run(provided_params) + return config + + +if __name__ == "__main__": + # 测试交互式界面 + console = Console() + console.print("[bold]测试交互式配置界面[/bold]") + + # 模拟已提供的参数 + test_params = { + "train_data_path": "./data/train.txt", + "eval_data_path": "./data/eval.txt", + } + + config = get_interactive_config(test_params) + + if config: + console.print("\n[green]✓ 配置完成[/green]") + console.print(f"配置参数: {config}") + else: + console.print("\n[yellow]配置已取消[/yellow]") \ No newline at end of file diff --git a/src/model/trainer.py b/src/model/trainer.py index 4eafa52..0e497be 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -166,7 +166,13 @@ class Trainer: # 不加载历史数据,直接初始化为空列表以覆盖原有数据 self.training_status_data = [] - # 初始化Rich控制台 + # 初始化 epoch checkpoint 元数据 + self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json" + self.epoch_checkpoints = [] # 最多保留 3 个 + self.next_epoch_slot = 0 # 下一个要覆盖的位置 (0-2) + self._load_epoch_metadata() + + # 初始化 Rich 控制台 self.console = Console() # 训练状态 @@ -368,6 +374,122 @@ class Trainer: else: logger.info(f"Checkpoint saved to {checkpoint_path}") + def _load_epoch_metadata(self): + """加载 epoch checkpoint 元数据""" + if self.epoch_metadata_file.exists(): + try: + with open(self.epoch_metadata_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + self.epoch_checkpoints = metadata.get("checkpoints", []) + self.next_epoch_slot = metadata.get("next_slot", 0) + logger.info( + f"Loaded epoch checkpoint metadata: {len(self.epoch_checkpoints)} checkpoints" + ) + except Exception as e: + logger.warning(f"Failed to load epoch metadata: {e}") + self.epoch_checkpoints = [] + self.next_epoch_slot = 0 + else: + self.epoch_checkpoints = [] + self.next_epoch_slot = 0 + + def _save_epoch_metadata(self): + """保存 epoch checkpoint 元数据""" + metadata = { + "checkpoints": self.epoch_checkpoints, + "next_slot": self.next_epoch_slot, + "total_epochs_completed": max( + (cp["epoch"] for cp in self.epoch_checkpoints), default=0 + ), + } + with open(self.epoch_metadata_file, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + def save_epoch_checkpoint(self, epoch: int): + """ + 保存 epoch checkpoint(循环覆盖,只保留最后 3 个) + + Args: + epoch: 当前 epoch 编号(从 1 开始) + """ + # 确定文件名(循环使用 3 个固定文件名) + slot = self.next_epoch_slot + filename = f"epoch_checkpoint_{slot + 1}.pt" + checkpoint_path = self.checkpoint_dir / filename + + # 保存 checkpoint + checkpoint = { + "step": self.current_step, + "epoch": epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "scaler_state_dict": self.scaler.state_dict(), + "best_eval_loss": self.best_eval_loss, + "config": { + "learning_rate": self.learning_rate, + "weight_decay": self.weight_decay, + "warmup_ratio": self.warmup_ratio, + "label_smoothing": self.label_smoothing, + "total_steps": self.total_steps, + }, + } + + torch.save(checkpoint, checkpoint_path) + + # 更新元数据 + checkpoint_info = { + "epoch": epoch, + "file": filename, + "path": str(checkpoint_path), + "saved_at": datetime.now().isoformat(), + "step": self.current_step, + } + + # 如果已经有 3 个,替换对应位置的;否则添加 + if len(self.epoch_checkpoints) >= 3: + self.epoch_checkpoints[slot] = checkpoint_info + else: + self.epoch_checkpoints.append(checkpoint_info) + + # 更新下一个槽位 + self.next_epoch_slot = (self.next_epoch_slot + 1) % 3 + + # 保存元数据 + self._save_epoch_metadata() + + # 按 epoch 排序后获取最新的 epoch + sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + latest_epoch = sorted_checkpoints[-1]["epoch"] if sorted_checkpoints else epoch + + logger.info( + f"Epoch {epoch} checkpoint saved to {filename} " + f"(keeping last {len(self.epoch_checkpoints)} epochs: " + f"{[cp['epoch'] for cp in sorted_checkpoints]})" + ) + + def get_latest_epoch_checkpoint(self) -> Optional[Dict]: + """ + 获取最新的 epoch checkpoint 信息 + + Returns: + 最新的 checkpoint 信息字典,如果没有则返回 None + """ + if not self.epoch_checkpoints: + return None + + # 按 epoch 排序,返回最新的 + sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + return sorted_checkpoints[-1] + + def get_epoch_checkpoints(self) -> List[Dict]: + """ + 获取所有保存的 epoch checkpoint 信息(按 epoch 排序) + + Returns: + checkpoint 信息列表,按 epoch 升序排列 + """ + return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + def load_checkpoint( self, checkpoint_path: Union[str, Path], reset_training_state: bool = False ): @@ -534,7 +656,10 @@ class Trainer: self.console.print(info_table) def train( - self, resume_from: Optional[str] = None, reset_training_state: bool = False + self, + resume_from: Optional[str] = None, + reset_training_state: bool = False, + auto_resume: bool = True, ): """ 主训练循环 @@ -542,10 +667,23 @@ class Trainer: Args: resume_from: 从哪个检查点恢复训练(可选) reset_training_state: 是否重置训练状态(只加载模型权重,从头开始训练) + auto_resume: 是否自动从最新的 epoch checkpoint 恢复(如果存在) """ - # 如果提供了检查点,则恢复训练 + # 如果提供了检查点,则优先使用提供的检查点恢复训练 if resume_from is not None: self.load_checkpoint(resume_from, reset_training_state=reset_training_state) + elif auto_resume and self.epoch_checkpoints: + # 自动从最新的 epoch checkpoint 恢复 + latest_checkpoint = self.get_latest_epoch_checkpoint() + if latest_checkpoint: + checkpoint_path = latest_checkpoint["path"] + logger.info( + f"Auto-resuming from latest epoch checkpoint: {checkpoint_path} " + f"(epoch {latest_checkpoint['epoch']})" + ) + self.load_checkpoint( + checkpoint_path, reset_training_state=reset_training_state + ) # 打印训练信息 self._print_training_info() @@ -675,8 +813,8 @@ class Trainer: # 进度条不重置,显示整体训练进度 - # 每个epoch结束后保存检查点 - self.save_checkpoint(f"epoch_{epoch + 1}.pt") + # 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个) + self.save_epoch_checkpoint(epoch + 1) # 检查是否达到总步数 if global_step >= self.total_steps: @@ -685,10 +823,18 @@ class Trainer: # 训练完成 logger.info("Training completed!") - # 保存最终模型 - self.save_checkpoint("final_model.pt") + # 显示保存的 epoch checkpoint 信息 + if self.epoch_checkpoints: + sorted_checkpoints = self.get_epoch_checkpoints() + logger.info( + f"Saved epoch checkpoints: {[cp['epoch'] for cp in sorted_checkpoints]}" + ) + logger.info( + f"Latest checkpoint: epoch {sorted_checkpoints[-1]['epoch']} " + f"({sorted_checkpoints[-1]['file']})" + ) - # 关闭TensorBoard写入器 + # 关闭 TensorBoard 写入器 if self.writer is not None: self.writer.close() @@ -907,21 +1053,10 @@ def train( ..., "--eval-data-path", "-e", help="评估数据集路径" ), output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"), - # 模型参数 - vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), - pinyin_vocab_size: int = typer.Option( - 30, "--pinyin-vocab-size", help="拼音词汇表大小" - ), + # 数据大小 max_iter_length: int = typer.Option( 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" ), - dim: int = typer.Option(512, "--dim", help="模型维度"), - num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"), - n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"), - n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"), - num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"), - max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"), - use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), @@ -954,6 +1089,9 @@ def train( reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练" ), + auto_resume: bool = typer.Option( + True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练" + ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: bool = typer.Option( False, @@ -975,6 +1113,17 @@ def train( if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + # 硬编码模型参数 + vocab_size = 10019 + pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射 + dim = 512 + num_slots = 8 + n_layers = 4 + n_heads = 4 + num_experts = 20 + max_seq_len = 128 + use_pinyin = True # 始终使用拼音 + console = Console() # 打印配置信息 @@ -1014,6 +1163,7 @@ def train( config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) + config_table.add_row("其他", "自动恢复", str(auto_resume)) console.print(config_table) # 创建输出目录 @@ -1049,6 +1199,7 @@ def train( "mixed_precision": mixed_precision, "use_tensorboard": use_tensorboard, "seed": seed, + "auto_resume": auto_resume, "max_iter_length": max_iter_length, "compile": compile, } @@ -1149,10 +1300,12 @@ def train( # 开始训练 console.print("\n[bold cyan]开始训练...[/bold cyan]") - console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( - resume_from=resume_from, reset_training_state=reset_training_state + resume_from=resume_from, + reset_training_state=reset_training_state, + auto_resume=auto_resume, ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") @@ -1223,20 +1376,10 @@ def expand_and_train( "-m", help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类", ), - vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"), - pinyin_vocab_size: int = typer.Option( - 30, "--pinyin-vocab-size", help="拼音词汇表大小" - ), + # 数据大小 max_iter_length: int = typer.Option( 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" ), - dim: int = typer.Option(512, "--dim", help="模型维度"), - num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"), - n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"), - n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"), - num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"), - max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"), - use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"), # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), @@ -1269,6 +1412,9 @@ def expand_and_train( reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练" ), + auto_resume: bool = typer.Option( + True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练" + ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: bool = typer.Option( False, @@ -1287,6 +1433,16 @@ def expand_and_train( if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + # 硬编码模型参数 + vocab_size = 10019 + pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射 + dim = 512 + num_slots = 8 + n_layers = 4 + n_heads = 4 + num_experts = 20 + max_seq_len = 128 + use_pinyin = True # 始终使用拼音 console = Console() # 打印配置信息 @@ -1330,6 +1486,7 @@ def expand_and_train( config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) + config_table.add_row("其他", "自动恢复", str(auto_resume)) console.print(config_table) # 创建输出目录 @@ -1367,6 +1524,7 @@ def expand_and_train( "mixed_precision": mixed_precision, "use_tensorboard": use_tensorboard, "seed": seed, + "auto_resume": auto_resume, "max_iter_length": max_iter_length, "compile": compile, } @@ -1482,10 +1640,12 @@ def expand_and_train( # 开始训练 console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]") - console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( - resume_from=resume_from, reset_training_state=reset_training_state + resume_from=resume_from, + reset_training_state=reset_training_state, + auto_resume=auto_resume, ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") @@ -1598,6 +1758,9 @@ def expand_finetune( reset_training_state: bool = typer.Option( False, "--reset-training-state", help="重置训练状态" ), + auto_resume: bool = typer.Option( + True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练" + ), seed: int = typer.Option(42, "--seed", help="随机种子"), compile: Optional[bool] = typer.Option( None, "--compile/--no-compile", help="是否开启 torch.compile 优化" @@ -1697,6 +1860,7 @@ def expand_finetune( config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm)) config_table.add_row("训练", "混合精度", str(mixed_precision)) + config_table.add_row("其他", "自动恢复", str(auto_resume)) console.print(config_table) output_path = Path(final_output_dir) @@ -1786,10 +1950,12 @@ def expand_finetune( console.print("[green]✓ 训练器创建完成[/green]") console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]") - console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: trainer.train( - resume_from=resume_from, reset_training_state=reset_training_state + resume_from=resume_from, + reset_training_state=reset_training_state, + auto_resume=auto_resume, ) except KeyboardInterrupt: console.print("[bold green]训练被终止[/bold green]") diff --git a/test.py b/test.py index 9416ac0..db5aa0d 100644 --- a/test.py +++ b/test.py @@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]: return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] -part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都" -part2 = "bxu" +part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为" +part2 = "weiruo" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: @@ -56,7 +56,7 @@ if len_py < 24: else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) -masked_labels = [0, 0, 0, 0, 0, 0, 0, 0] +masked_labels = [649, 925, 0, 0, 0, 0, 0, 0] part3 = "" part4 = "可行|特别|伤害" @@ -83,7 +83,7 @@ sample = { model = InputMethodEngine(pinyin_vocab_size=30, compile=False) -checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu") +checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) input_ids = sample["input_ids"] diff --git a/test_epoch_checkpoint.py b/test_epoch_checkpoint.py new file mode 100644 index 0000000..0bc8938 --- /dev/null +++ b/test_epoch_checkpoint.py @@ -0,0 +1,210 @@ +""" +测试 epoch checkpoint 循环保存功能 +""" + +import json +import tempfile +from pathlib import Path +from datetime import datetime + + +# 模拟 Trainer 类的 epoch checkpoint 管理逻辑 +class EpochCheckpointManager: + def __init__(self, checkpoint_dir: Path): + self.checkpoint_dir = checkpoint_dir + self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json" + self.epoch_checkpoints = [] + self.next_epoch_slot = 0 + + def _load_epoch_metadata(self): + if self.epoch_metadata_file.exists(): + with open(self.epoch_metadata_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + self.epoch_checkpoints = metadata.get("checkpoints", []) + self.next_epoch_slot = metadata.get("next_slot", 0) + + def _save_epoch_metadata(self): + metadata = { + "checkpoints": self.epoch_checkpoints, + "next_slot": self.next_epoch_slot, + "total_epochs_completed": max( + (cp["epoch"] for cp in self.epoch_checkpoints), default=0 + ), + } + with open(self.epoch_metadata_file, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + def save_epoch_checkpoint(self, epoch: int): + """模拟保存 epoch checkpoint(不实际保存模型文件)""" + slot = self.next_epoch_slot + filename = f"epoch_checkpoint_{slot + 1}.pt" + checkpoint_path = self.checkpoint_dir / filename + + checkpoint_info = { + "epoch": epoch, + "file": filename, + "path": str(checkpoint_path), + "saved_at": datetime.now().isoformat(), + "step": epoch * 100, # 模拟 step + } + + if len(self.epoch_checkpoints) >= 3: + self.epoch_checkpoints[slot] = checkpoint_info + else: + self.epoch_checkpoints.append(checkpoint_info) + + self.next_epoch_slot = (self.next_epoch_slot + 1) % 3 + self._save_epoch_metadata() + + sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + print( + f"✓ Epoch {epoch:2d} saved -> {filename:20s} " + f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})" + ) + + def get_latest_epoch_checkpoint(self): + if not self.epoch_checkpoints: + return None + sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + return sorted_checkpoints[-1] + + def get_epoch_checkpoints(self): + return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) + + +def test_circular_save(): + """测试循环保存功能""" + print("=" * 70) + print("测试:循环保存最后 3 个 epoch checkpoint") + print("=" * 70) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = Path(tmpdir) + manager = EpochCheckpointManager(checkpoint_dir) + + # 模拟训练 10 个 epoch + print("\n模拟训练 10 个 epoch:") + print("-" * 70) + for epoch in range(1, 11): + manager.save_epoch_checkpoint(epoch) + + # 验证最终状态 + print("\n" + "=" * 70) + print("最终状态验证:") + print("=" * 70) + + # 检查 JSON 文件 + metadata_file = checkpoint_dir / "epoch_checkpoints.json" + with open(metadata_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + + print(f"\n✓ JSON 元数据文件:{metadata_file.name}") + print(f" next_slot: {metadata['next_slot']}") + print(f" total_epochs_completed: {metadata['total_epochs_completed']}") + + print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}") + print(f" 期望:3 个") + + sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"]) + print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}") + print(f" 期望:[8, 9, 10]") + + # 验证顺序 + latest = manager.get_latest_epoch_checkpoint() + print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})") + print(f" 期望:epoch 10") + + # 验证文件对应关系 + print(f"\n✓ 文件对应关系:") + for cp in sorted_checkpoints: + print(f" epoch {cp['epoch']:2d} -> {cp['file']}") + + # 验证 + assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint" + assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], ( + "应该保留最后 3 个 epoch" + ) + assert latest["epoch"] == 10, "最新的应该是 epoch 10" + + print("\n" + "=" * 70) + print("✓✓✓ 所有测试通过!") + print("=" * 70) + + +def test_auto_resume(): + """测试自动恢复功能""" + print("\n" + "=" * 70) + print("测试:从 JSON 自动恢复 checkpoint 顺序") + print("=" * 70) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_dir = Path(tmpdir) + + # 创建 manager 并保存一些 checkpoint + manager1 = EpochCheckpointManager(checkpoint_dir) + print("\n第一轮训练(保存 epoch 1-5):") + print("-" * 70) + for epoch in range(1, 6): + manager1.save_epoch_checkpoint(epoch) + + # 模拟新的训练会话(从 JSON 恢复) + print("\n\n第二轮训练(从 JSON 恢复):") + print("-" * 70) + manager2 = EpochCheckpointManager(checkpoint_dir) + manager2._load_epoch_metadata() + + print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}") + print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}") + print(f"✓ next_slot: {manager2.next_epoch_slot}") + + # 继续保存 + print("\n继续保存 epoch 6-10:") + print("-" * 70) + for epoch in range(6, 11): + manager2.save_epoch_checkpoint(epoch) + + # 验证最终状态 + print("\n" + "=" * 70) + print("最终状态验证:") + print("=" * 70) + + sorted_checkpoints = manager2.get_epoch_checkpoints() + print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}") + print(f" 期望:[8, 9, 10]") + + latest = manager2.get_latest_epoch_checkpoint() + print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}") + print(f" 期望:epoch 10") + + # 验证 + assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], ( + "应该保留最后 3 个 epoch" + ) + assert latest["epoch"] == 10, "最新的应该是 epoch 10" + + print("\n" + "=" * 70) + print("✓✓✓ 自动恢复测试通过!") + print("=" * 70) + + +if __name__ == "__main__": + test_circular_save() + test_auto_resume() + + print("\n" + "=" * 70) + print("所有测试完成!") + print("=" * 70) + print("\n功能总结:") + print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint") + print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据") + print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序") + print(" 4. ✓ 可以正确识别最新的 checkpoint") + print("\n保留的文件:") + print(" - best_model.pt (最佳模型)") + print(" - latest_checkpoint.pt (定期保存)") + print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt") + print(" - epoch_checkpoints.json (元数据)") + print("\n移除的文件:") + print(" - final_model.pt (与 epoch_last 重复)") + print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)") + print("=" * 70)