From dc718cde5bd8205844f62a3501282daff3dacb81 Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 26 Feb 2026 00:58:05 +0800 Subject: [PATCH] =?UTF-8?q?fix(dataset):=20=E6=B7=BB=E5=8A=A0=20token=5Fty?= =?UTF-8?q?pe=5Fids=20=E5=88=B0=20collate=20=E5=87=BD=E6=95=B0=E7=9A=84=20?= =?UTF-8?q?hint=20=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 1 + src/trainer/model.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index a605f8c..43930cb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -620,6 +620,7 @@ def custom_collate(batch): "hint": { "input_ids": torch.cat([h["input_ids"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]), + "token_type_ids": torch.cat([h["token_type_ids"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), "pg": torch.cat([item["pg"] for item in batch]), diff --git a/src/trainer/model.py b/src/trainer/model.py index a435e7c..49fb424 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -522,7 +522,8 @@ class MoEModel(nn.Module): logger.info( f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" ) - logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") + else: + logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") batch_loss_sum = 0.0 if processed_batches >= stop_batch: break