diff --git a/src/model/trainer.py b/src/model/trainer.py index 3c9a4f5..c3e2b46 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -355,12 +355,18 @@ class Trainer: } torch.save(checkpoint, checkpoint_path) - logger.info(f"Checkpoint saved to {checkpoint_path}") if is_best: best_path = self.checkpoint_dir / "best_model.pt" - torch.save(checkpoint, best_path) - logger.info(f"Best model saved to {best_path}") + # 如果已经保存到best_model.pt,则不再重复保存 + if checkpoint_path != best_path: + logger.info(f"Checkpoint saved to {checkpoint_path}") + torch.save(checkpoint, best_path) + logger.info(f"Best model saved to {best_path}") + else: + logger.info(f"Best model saved to {checkpoint_path}") + else: + logger.info(f"Checkpoint saved to {checkpoint_path}") def load_checkpoint( self, checkpoint_path: Union[str, Path], reset_training_state: bool = False @@ -667,8 +673,7 @@ class Trainer: progress.update(epoch_task, completed=self.total_steps) break - # 重置进度条 - progress.reset(epoch_task) + # 进度条不重置,显示整体训练进度 # 每个epoch结束后保存检查点 self.save_checkpoint(f"epoch_{epoch + 1}.pt") @@ -1797,3 +1802,4 @@ def expand_finetune( if __name__ == "__main__": app() + diff --git a/test.py b/test.py index a1aebfc..b2ec26f 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ import sys sys.path.append("src") +import time import torch from torch.utils.data import DataLoader from tqdm import tqdm @@ -95,7 +96,9 @@ for k, v in sample.items(): if isinstance(v, str): print(f"{k}: {v}") +start = time.time() res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) +print(f'计算时长: {(time.time() - start) * 1000:4f}ms') sort_res = sorted( [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True ) @@ -125,32 +128,3 @@ for i in range(20): print( f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}" ) - -print("\n" + "=" * 60) -print("测试 history_slot_ids 全零情况") -print("=" * 60) - -masked_labels = [0, 0, 0, 0, 0, 0, 0, 0] -history_slot_ids = torch.tensor(masked_labels, dtype=torch.long).unsqueeze(0) -sample["history_slot_ids"] = history_slot_ids - -res_zero = model( - input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids -) -probs_zero = F.softmax(res_zero, dim=-1) - -print(f"\n📊 概率分布分析 (全零历史):") -print(f" 形状: {probs_zero.shape}") -print(f" 总概率和: {probs_zero.sum().item():.6f}") -print(f" 最大概率: {probs_zero.max().item():.6f}") -print(f" 最小概率: {probs_zero.min().item():.6f}") -print(f" 平均概率: {probs_zero.mean().item():.6f}") - -top_probs_zero, top_indices_zero = torch.topk(probs_zero, k=20) -print(f"\n🏆 Top-20预测 (全零历史):") -for i in range(20): - idx = top_indices_zero[0, i].item() - prob = top_probs_zero[0, i].item() - print( - f" {i + 1:2d}. ID:{idx}\t字符: {query_engine.query_by_id(idx).char}\t概率: {prob:.6f}" - )