SUimeModelTraner/test_epoch_checkpoint.py

211 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 epoch checkpoint 循环保存功能
"""
import json
import tempfile
from pathlib import Path
from datetime import datetime
# 模拟 Trainer 类的 epoch checkpoint 管理逻辑
class EpochCheckpointManager:
def __init__(self, checkpoint_dir: Path):
self.checkpoint_dir = checkpoint_dir
self.epoch_metadata_file = self.checkpoint_dir / "epoch_checkpoints.json"
self.epoch_checkpoints = []
self.next_epoch_slot = 0
def _load_epoch_metadata(self):
if self.epoch_metadata_file.exists():
with open(self.epoch_metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
self.epoch_checkpoints = metadata.get("checkpoints", [])
self.next_epoch_slot = metadata.get("next_slot", 0)
def _save_epoch_metadata(self):
metadata = {
"checkpoints": self.epoch_checkpoints,
"next_slot": self.next_epoch_slot,
"total_epochs_completed": max(
(cp["epoch"] for cp in self.epoch_checkpoints), default=0
),
}
with open(self.epoch_metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
def save_epoch_checkpoint(self, epoch: int):
"""模拟保存 epoch checkpoint不实际保存模型文件"""
slot = self.next_epoch_slot
filename = f"epoch_checkpoint_{slot + 1}.pt"
checkpoint_path = self.checkpoint_dir / filename
checkpoint_info = {
"epoch": epoch,
"file": filename,
"path": str(checkpoint_path),
"saved_at": datetime.now().isoformat(),
"step": epoch * 100, # 模拟 step
}
if len(self.epoch_checkpoints) >= 3:
self.epoch_checkpoints[slot] = checkpoint_info
else:
self.epoch_checkpoints.append(checkpoint_info)
self.next_epoch_slot = (self.next_epoch_slot + 1) % 3
self._save_epoch_metadata()
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
print(
f"✓ Epoch {epoch:2d} saved -> {filename:20s} "
f"(keeping epochs: {[cp['epoch'] for cp in sorted_checkpoints]})"
)
def get_latest_epoch_checkpoint(self):
if not self.epoch_checkpoints:
return None
sorted_checkpoints = sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
return sorted_checkpoints[-1]
def get_epoch_checkpoints(self):
return sorted(self.epoch_checkpoints, key=lambda x: x["epoch"])
def test_circular_save():
"""测试循环保存功能"""
print("=" * 70)
print("测试:循环保存最后 3 个 epoch checkpoint")
print("=" * 70)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_dir = Path(tmpdir)
manager = EpochCheckpointManager(checkpoint_dir)
# 模拟训练 10 个 epoch
print("\n模拟训练 10 个 epoch:")
print("-" * 70)
for epoch in range(1, 11):
manager.save_epoch_checkpoint(epoch)
# 验证最终状态
print("\n" + "=" * 70)
print("最终状态验证:")
print("=" * 70)
# 检查 JSON 文件
metadata_file = checkpoint_dir / "epoch_checkpoints.json"
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
print(f"\n✓ JSON 元数据文件:{metadata_file.name}")
print(f" next_slot: {metadata['next_slot']}")
print(f" total_epochs_completed: {metadata['total_epochs_completed']}")
print(f"\n✓ 保存的 checkpoint 数量:{len(metadata['checkpoints'])}")
print(f" 期望3 个")
sorted_checkpoints = sorted(metadata["checkpoints"], key=lambda x: x["epoch"])
print(f"\n✓ 保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
print(f" 期望:[8, 9, 10]")
# 验证顺序
latest = manager.get_latest_epoch_checkpoint()
print(f"\n✓ 最新的 checkpoint: epoch {latest['epoch']} ({latest['file']})")
print(f" 期望epoch 10")
# 验证文件对应关系
print(f"\n✓ 文件对应关系:")
for cp in sorted_checkpoints:
print(f" epoch {cp['epoch']:2d} -> {cp['file']}")
# 验证
assert len(metadata["checkpoints"]) == 3, "应该只保留 3 个 checkpoint"
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
"应该保留最后 3 个 epoch"
)
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
print("\n" + "=" * 70)
print("✓✓✓ 所有测试通过!")
print("=" * 70)
def test_auto_resume():
"""测试自动恢复功能"""
print("\n" + "=" * 70)
print("测试:从 JSON 自动恢复 checkpoint 顺序")
print("=" * 70)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_dir = Path(tmpdir)
# 创建 manager 并保存一些 checkpoint
manager1 = EpochCheckpointManager(checkpoint_dir)
print("\n第一轮训练(保存 epoch 1-5:")
print("-" * 70)
for epoch in range(1, 6):
manager1.save_epoch_checkpoint(epoch)
# 模拟新的训练会话(从 JSON 恢复)
print("\n\n第二轮训练(从 JSON 恢复):")
print("-" * 70)
manager2 = EpochCheckpointManager(checkpoint_dir)
manager2._load_epoch_metadata()
print(f"\n✓ 恢复的 checkpoint 数量:{len(manager2.epoch_checkpoints)}")
print(f"✓ 恢复的 epochs: {[cp['epoch'] for cp in manager2.epoch_checkpoints]}")
print(f"✓ next_slot: {manager2.next_epoch_slot}")
# 继续保存
print("\n继续保存 epoch 6-10:")
print("-" * 70)
for epoch in range(6, 11):
manager2.save_epoch_checkpoint(epoch)
# 验证最终状态
print("\n" + "=" * 70)
print("最终状态验证:")
print("=" * 70)
sorted_checkpoints = manager2.get_epoch_checkpoints()
print(f"\n✓ 最终保存的 epochs: {[cp['epoch'] for cp in sorted_checkpoints]}")
print(f" 期望:[8, 9, 10]")
latest = manager2.get_latest_epoch_checkpoint()
print(f"✓ 最新的 checkpoint: epoch {latest['epoch']}")
print(f" 期望epoch 10")
# 验证
assert [cp["epoch"] for cp in sorted_checkpoints] == [8, 9, 10], (
"应该保留最后 3 个 epoch"
)
assert latest["epoch"] == 10, "最新的应该是 epoch 10"
print("\n" + "=" * 70)
print("✓✓✓ 自动恢复测试通过!")
print("=" * 70)
if __name__ == "__main__":
test_circular_save()
test_auto_resume()
print("\n" + "=" * 70)
print("所有测试完成!")
print("=" * 70)
print("\n功能总结:")
print(" 1. ✓ 循环覆盖保存,只保留最后 3 个 epoch checkpoint")
print(" 2. ✓ 使用 JSON 文件记录 checkpoint 元数据")
print(" 3. ✓ 支持从 JSON 自动恢复 checkpoint 顺序")
print(" 4. ✓ 可以正确识别最新的 checkpoint")
print("\n保留的文件:")
print(" - best_model.pt (最佳模型)")
print(" - latest_checkpoint.pt (定期保存)")
print(" - epoch_checkpoint_1.pt, epoch_checkpoint_2.pt, epoch_checkpoint_3.pt")
print(" - epoch_checkpoints.json (元数据)")
print("\n移除的文件:")
print(" - final_model.pt (与 epoch_last 重复)")
print(" - epoch_*.pt (不再为每个 epoch 创建独立文件)")
print("=" * 70)