""" 测试 epoch checkpoint 循环保存功能 """ import json import tempfile from pathlib import Path from datetime import datetime # 模拟 Trainer 类的 epoch checkpoint 管理逻辑 class EpochCheckpointManager: def __init__(self, checkpoint_dir: Path): self.checkpoint_dir = checkpoint_dir self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json" self.epoch_checkpoints = [] self.next_epoch_slot = 0 def _load_epoch_metadata(self): if self.epoch_metadata_file.exists(): with open(self.epoch_metadata_file, "r", encoding="utf-8") as f: metadata = json.load(f) self.epoch_checkpoints = metadata.get("checkpoints", []) self.next_epoch_slot = metadata.get("next_slot", 0) def _save_epoch_metadata(self): metadata = { "checkpoints": self.epoch_checkpoints, "next_slot": self.next_epoch_slot, "total_epochs_completed": max( (cp["epoch"] for cp in self.epoch_checkpoints), default=0 ), } with open(self.epoch_metadata_file, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2, ensure_ascii=False) def save_epoch_checkpoint(self, epoch: int): """模拟保存 epoch checkpoint(不实际保存模型文件)""" slot = self.next_epoch_slot filename = f"epoch_checkpoint_{slot + 1}.pt" checkpoint_path = self.checkpoint_dir / filename checkpoint_info = { "epoch": epoch, "file": filename, "path": str(checkpoint_path), "saved_at": datetime.now().isoformat(), "step": epoch * 100, # 模拟 step } if len(self.epoch_checkpoints) >= 3: self.epoch_checkpoints[slot] = checkpoint_info else: self.epoch_checkpoints.append(checkpoint_info) self.next_epoch_slot = (self.next_epoch_slot + 1) % 3 self._save_epoch_metadata() sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) print( f"✓ Epoch {epoch:2d} saved -> {filename:20s} " f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})" ) def get_latest_epoch_checkpoint(self): if not self.epoch_checkpoints: return None sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) return sorted_checkpoints[-1] def get_epoch_checkpoints(self): return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"]) def test_circular_save(): """测试循环保存功能""" print("=" * 70) print("测试:循环保存最后 3 个 epoch checkpoint") print("=" * 70) with tempfile.TemporaryDirectory() as tmpdir: checkpoint_dir = Path(tmpdir) manager = EpochCheckpointManager(checkpoint_dir) # 模拟训练 10 个 epoch print("\n模拟训练 10 个 epoch:") print("-" * 70) for epoch in range(1, 11): manager.save_epoch_checkpoint(epoch) # 验证最终状态 print("\n" + "=" * 70) print("最终状态验证:") print("=" * 70) # 检查 JSON 文件 metadata_file = checkpoint_dir / "epoch_checkpoints.json" with open(metadata_file, "r", encoding="utf-8") as f: metadata = json.load(f) print(f"\n✓ JSON 元数据文件:{metadata_file.name}") print(f" next_slot: {metadata['next_slot']}") print(f" total_epochs_completed: {metadata['total_epochs_completed']}") print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}") print(f" 期望:3 个") sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"]) print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}") print(f" 期望:[8, 9, 10]") # 验证顺序 latest = manager.get_latest_epoch_checkpoint() print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})") print(f" 期望:epoch 10") # 验证文件对应关系 print(f"\n✓ 文件对应关系:") for cp in sorted_checkpoints: print(f" epoch {cp['epoch']:2d} -> {cp['file']}") # 验证 assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint" assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], ( "应该保留最后 3 个 epoch" ) assert latest["epoch"] == 10, "最新的应该是 epoch 10" print("\n" + "=" * 70) print("✓✓✓ 所有测试通过!") print("=" * 70) def test_auto_resume(): """测试自动恢复功能""" print("\n" + "=" * 70) print("测试:从 JSON 自动恢复 checkpoint 顺序") print("=" * 70) with tempfile.TemporaryDirectory() as tmpdir: checkpoint_dir = Path(tmpdir) # 创建 manager 并保存一些 checkpoint manager1 = EpochCheckpointManager(checkpoint_dir) print("\n第一轮训练(保存 epoch 1-5):") print("-" * 70) for epoch in range(1, 6): manager1.save_epoch_checkpoint(epoch) # 模拟新的训练会话(从 JSON 恢复) print("\n\n第二轮训练(从 JSON 恢复):") print("-" * 70) manager2 = EpochCheckpointManager(checkpoint_dir) manager2._load_epoch_metadata() print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}") print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}") print(f"✓ next_slot: {manager2.next_epoch_slot}") # 继续保存 print("\n继续保存 epoch 6-10:") print("-" * 70) for epoch in range(6, 11): manager2.save_epoch_checkpoint(epoch) # 验证最终状态 print("\n" + "=" * 70) print("最终状态验证:") print("=" * 70) sorted_checkpoints = manager2.get_epoch_checkpoints() print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}") print(f" 期望:[8, 9, 10]") latest = manager2.get_latest_epoch_checkpoint() print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}") print(f" 期望:epoch 10") # 验证 assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], ( "应该保留最后 3 个 epoch" ) assert latest["epoch"] == 10, "最新的应该是 epoch 10" print("\n" + "=" * 70) print("✓✓✓ 自动恢复测试通过!") print("=" * 70) if __name__ == "__main__": test_circular_save() test_auto_resume() print("\n" + "=" * 70) print("所有测试完成!") print("=" * 70) print("\n功能总结:") print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint") print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据") print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序") print(" 4. ✓ 可以正确识别最新的 checkpoint") print("\n保留的文件:") print(" - best_model.pt (最佳模型)") print(" - latest_checkpoint.pt (定期保存)") print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt") print(" - epoch_checkpoints.json (元数据)") print("\n移除的文件:") print(" - final_model.pt (与 epoch_last 重复)") print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)") print("=" * 70)