fix(dataset): 添加 token_type_ids 到 collate 函数的 hint 字段
This commit is contained in:
parent
7c90633ebc
commit
dc718cde5b
|
|
@ -620,6 +620,7 @@ def custom_collate(batch):
|
|||
"hint": {
|
||||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||
"token_type_ids": torch.cat([h["token_type_ids"] for h in hints]),
|
||||
},
|
||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||
"pg": torch.cat([item["pg"] for item in batch]),
|
||||
|
|
|
|||
|
|
@ -522,7 +522,8 @@ class MoEModel(nn.Module):
|
|||
logger.info(
|
||||
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||
)
|
||||
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
|
||||
else:
|
||||
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
|
||||
batch_loss_sum = 0.0
|
||||
if processed_batches >= stop_batch:
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in New Issue