211 lines
7.5 KiB
Python
211 lines
7.5 KiB
Python
"""
|
||
测试 epoch checkpoint 循环保存功能
|
||
"""
|
||
|
||
import json
|
||
import tempfile
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
|
||
# 模拟 Trainer 类的 epoch checkpoint 管理逻辑
|
||
class EpochCheckpointManager:
|
||
def __init__(self, checkpoint_dir: Path):
|
||
self.checkpoint_dir = checkpoint_dir
|
||
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
|
||
self.epoch_checkpoints = []
|
||
self.next_epoch_slot = 0
|
||
|
||
def _load_epoch_metadata(self):
|
||
if self.epoch_metadata_file.exists():
|
||
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
|
||
metadata = json.load(f)
|
||
self.epoch_checkpoints = metadata.get("checkpoints", [])
|
||
self.next_epoch_slot = metadata.get("next_slot", 0)
|
||
|
||
def _save_epoch_metadata(self):
|
||
metadata = {
|
||
"checkpoints": self.epoch_checkpoints,
|
||
"next_slot": self.next_epoch_slot,
|
||
"total_epochs_completed": max(
|
||
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
|
||
),
|
||
}
|
||
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
|
||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||
|
||
def save_epoch_checkpoint(self, epoch: int):
|
||
"""模拟保存 epoch checkpoint(不实际保存模型文件)"""
|
||
slot = self.next_epoch_slot
|
||
filename = f"epoch_checkpoint_{slot + 1}.pt"
|
||
checkpoint_path = self.checkpoint_dir / filename
|
||
|
||
checkpoint_info = {
|
||
"epoch": epoch,
|
||
"file": filename,
|
||
"path": str(checkpoint_path),
|
||
"saved_at": datetime.now().isoformat(),
|
||
"step": epoch * 100, # 模拟 step
|
||
}
|
||
|
||
if len(self.epoch_checkpoints) >= 3:
|
||
self.epoch_checkpoints[slot] = checkpoint_info
|
||
else:
|
||
self.epoch_checkpoints.append(checkpoint_info)
|
||
|
||
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
|
||
self._save_epoch_metadata()
|
||
|
||
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||
print(
|
||
f"✓ Epoch {epoch:2d} saved -> {filename:20s} "
|
||
f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})"
|
||
)
|
||
|
||
def get_latest_epoch_checkpoint(self):
|
||
if not self.epoch_checkpoints:
|
||
return None
|
||
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||
return sorted_checkpoints[-1]
|
||
|
||
def get_epoch_checkpoints(self):
|
||
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
|
||
|
||
|
||
def test_circular_save():
|
||
"""测试循环保存功能"""
|
||
print("=" * 70)
|
||
print("测试:循环保存最后 3 个 epoch checkpoint")
|
||
print("=" * 70)
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_dir = Path(tmpdir)
|
||
manager = EpochCheckpointManager(checkpoint_dir)
|
||
|
||
# 模拟训练 10 个 epoch
|
||
print("\n模拟训练 10 个 epoch:")
|
||
print("-" * 70)
|
||
for epoch in range(1, 11):
|
||
manager.save_epoch_checkpoint(epoch)
|
||
|
||
# 验证最终状态
|
||
print("\n" + "=" * 70)
|
||
print("最终状态验证:")
|
||
print("=" * 70)
|
||
|
||
# 检查 JSON 文件
|
||
metadata_file = checkpoint_dir / "epoch_checkpoints.json"
|
||
with open(metadata_file, "r", encoding="utf-8") as f:
|
||
metadata = json.load(f)
|
||
|
||
print(f"\n✓ JSON 元数据文件:{metadata_file.name}")
|
||
print(f" next_slot: {metadata['next_slot']}")
|
||
print(f" total_epochs_completed: {metadata['total_epochs_completed']}")
|
||
|
||
print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}")
|
||
print(f" 期望:3 个")
|
||
|
||
sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"])
|
||
print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
|
||
print(f" 期望:[8, 9, 10]")
|
||
|
||
# 验证顺序
|
||
latest = manager.get_latest_epoch_checkpoint()
|
||
print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})")
|
||
print(f" 期望:epoch 10")
|
||
|
||
# 验证文件对应关系
|
||
print(f"\n✓ 文件对应关系:")
|
||
for cp in sorted_checkpoints:
|
||
print(f" epoch {cp['epoch']:2d} -> {cp['file']}")
|
||
|
||
# 验证
|
||
assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint"
|
||
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
|
||
"应该保留最后 3 个 epoch"
|
||
)
|
||
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
|
||
|
||
print("\n" + "=" * 70)
|
||
print("✓✓✓ 所有测试通过!")
|
||
print("=" * 70)
|
||
|
||
|
||
def test_auto_resume():
|
||
"""测试自动恢复功能"""
|
||
print("\n" + "=" * 70)
|
||
print("测试:从 JSON 自动恢复 checkpoint 顺序")
|
||
print("=" * 70)
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_dir = Path(tmpdir)
|
||
|
||
# 创建 manager 并保存一些 checkpoint
|
||
manager1 = EpochCheckpointManager(checkpoint_dir)
|
||
print("\n第一轮训练(保存 epoch 1-5):")
|
||
print("-" * 70)
|
||
for epoch in range(1, 6):
|
||
manager1.save_epoch_checkpoint(epoch)
|
||
|
||
# 模拟新的训练会话(从 JSON 恢复)
|
||
print("\n\n第二轮训练(从 JSON 恢复):")
|
||
print("-" * 70)
|
||
manager2 = EpochCheckpointManager(checkpoint_dir)
|
||
manager2._load_epoch_metadata()
|
||
|
||
print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}")
|
||
print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}")
|
||
print(f"✓ next_slot: {manager2.next_epoch_slot}")
|
||
|
||
# 继续保存
|
||
print("\n继续保存 epoch 6-10:")
|
||
print("-" * 70)
|
||
for epoch in range(6, 11):
|
||
manager2.save_epoch_checkpoint(epoch)
|
||
|
||
# 验证最终状态
|
||
print("\n" + "=" * 70)
|
||
print("最终状态验证:")
|
||
print("=" * 70)
|
||
|
||
sorted_checkpoints = manager2.get_epoch_checkpoints()
|
||
print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
|
||
print(f" 期望:[8, 9, 10]")
|
||
|
||
latest = manager2.get_latest_epoch_checkpoint()
|
||
print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}")
|
||
print(f" 期望:epoch 10")
|
||
|
||
# 验证
|
||
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
|
||
"应该保留最后 3 个 epoch"
|
||
)
|
||
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
|
||
|
||
print("\n" + "=" * 70)
|
||
print("✓✓✓ 自动恢复测试通过!")
|
||
print("=" * 70)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_circular_save()
|
||
test_auto_resume()
|
||
|
||
print("\n" + "=" * 70)
|
||
print("所有测试完成!")
|
||
print("=" * 70)
|
||
print("\n功能总结:")
|
||
print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint")
|
||
print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据")
|
||
print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序")
|
||
print(" 4. ✓ 可以正确识别最新的 checkpoint")
|
||
print("\n保留的文件:")
|
||
print(" - best_model.pt (最佳模型)")
|
||
print(" - latest_checkpoint.pt (定期保存)")
|
||
print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt")
|
||
print(" - epoch_checkpoints.json (元数据)")
|
||
print("\n移除的文件:")
|
||
print(" - final_model.pt (与 epoch_last 重复)")
|
||
print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)")
|
||
print("=" * 70)
|