fix(trainer): 优化检查点保存逻辑避免重复写入和进度条重置问题
This commit is contained in:
parent
569eeb4d12
commit
919d0972e2
|
|
@ -355,12 +355,18 @@ class Trainer:
|
|||
}
|
||||
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||||
|
||||
if is_best:
|
||||
best_path = self.checkpoint_dir / "best_model.pt"
|
||||
# 如果已经保存到best_model.pt,则不再重复保存
|
||||
if checkpoint_path != best_path:
|
||||
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||||
torch.save(checkpoint, best_path)
|
||||
logger.info(f"Best model saved to {best_path}")
|
||||
else:
|
||||
logger.info(f"Best model saved to {checkpoint_path}")
|
||||
else:
|
||||
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path: Union[str, Path], reset_training_state: bool = False
|
||||
|
|
@ -667,8 +673,7 @@ class Trainer:
|
|||
progress.update(epoch_task, completed=self.total_steps)
|
||||
break
|
||||
|
||||
# 重置进度条
|
||||
progress.reset(epoch_task)
|
||||
# 进度条不重置,显示整体训练进度
|
||||
|
||||
# 每个epoch结束后保存检查点
|
||||
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||
|
|
@ -1797,3 +1802,4 @@ def expand_finetune(
|
|||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
|
|
|
|||
32
test.py
32
test.py
|
|
@ -2,6 +2,7 @@ import sys
|
|||
|
||||
sys.path.append("src")
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
|
@ -95,7 +96,9 @@ for k, v in sample.items():
|
|||
if isinstance(v, str):
|
||||
print(f"{k}: {v}")
|
||||
|
||||
start = time.time()
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
|
@ -125,32 +128,3 @@ for i in range(20):
|
|||
print(
|
||||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试 history_slot_ids 全零情况")
|
||||
print("=" * 60)
|
||||
|
||||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0)
|
||||
sample["history_slot_ids"] = history_slot_ids
|
||||
|
||||
res_zero = model(
|
||||
input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids
|
||||
)
|
||||
probs_zero = F.softmax(res_zero, dim=-1)
|
||||
|
||||
print(f"\n📊 概率分布分析 (全零历史):")
|
||||
print(f" 形状: {probs_zero.shape}")
|
||||
print(f" 总概率和: {probs_zero.sum().item():.6f}")
|
||||
print(f" 最大概率: {probs_zero.max().item():.6f}")
|
||||
print(f" 最小概率: {probs_zero.min().item():.6f}")
|
||||
print(f" 平均概率: {probs_zero.mean().item():.6f}")
|
||||
|
||||
top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20)
|
||||
print(f"\n🏆 Top-20预测 (全零历史):")
|
||||
for i in range(20):
|
||||
idx = top_indices_zero[0, i].item()
|
||||
prob = top_probs_zero[0, i].item()
|
||||
print(
|
||||
f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue