diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index a605f8c..43930cb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -620,6 +620,7 @@ def custom_collate(batch): "hint": { "input_ids": torch.cat([h["input_ids"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]), + "token_type_ids": torch.cat([h["token_type_ids"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), "pg": torch.cat([item["pg"] for item in batch]), diff --git a/src/trainer/model.py b/src/trainer/model.py index a435e7c..49fb424 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -522,7 +522,8 @@ class MoEModel(nn.Module): logger.info( f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" ) - logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") + else: + logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") batch_loss_sum = 0.0 if processed_batches >= stop_batch: break