fix(trainer): 优化检查点保存逻辑避免重复写入和进度条重置问题

This commit is contained in:
songsenand 2026-04-09 22:46:05 +08:00
parent 569eeb4d12
commit 919d0972e2
2 changed files with 14 additions and 34 deletions

View File

@ -355,12 +355,18 @@ class Trainer:
}
torch.save(checkpoint, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
# 如果已经保存到best_model.pt则不再重复保存
if checkpoint_path != best_path:
logger.info(f"Checkpoint saved to {checkpoint_path}")
torch.save(checkpoint, best_path)
logger.info(f"Best model saved to {best_path}")
else:
logger.info(f"Best model saved to {checkpoint_path}")
else:
logger.info(f"Checkpoint saved to {checkpoint_path}")
def load_checkpoint(
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
@ -667,8 +673,7 @@ class Trainer:
progress.update(epoch_task, completed=self.total_steps)
break
# 重置进度条
progress.reset(epoch_task)
# 进度条不重置,显示整体训练进度
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
@ -1797,3 +1802,4 @@ def expand_finetune(
if __name__ == "__main__":
app()

32
test.py
View File

@ -2,6 +2,7 @@ import sys
sys.path.append("src")
import time
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
@ -95,7 +96,9 @@ for k, v in sample.items():
if isinstance(v, str):
print(f"{k}: {v}")
start = time.time()
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)
@ -125,32 +128,3 @@ for i in range(20):
print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)
print("\n" + "=" * 60)
print("测试 history_slot_ids 全零情况")
print("=" * 60)
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
sample["history_slot_ids"] = history_slot_ids
res_zero = model(
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
)
probs_zero = F.softmax(res_zero, dim=-1)
print(f"\n📊 概率分布分析 (全零历史):")
print(f" 形状: {probs_zero.shape}")
print(f" 总概率和: {probs_zero.sum().item():.6f}")
print(f" 最大概率: {probs_zero.max().item():.6f}")
print(f" 最小概率: {probs_zero.min().item():.6f}")
print(f" 平均概率: {probs_zero.mean().item():.6f}")
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
print(f"\n🏆 Top-20预测 (全零历史):")
for i in range(20):
idx = top_indices_zero[0, i].item()
prob = top_probs_zero[0, i].item()
print(
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
)