fix(dataset): 添加 token_type_ids 到 collate 函数的 hint 字段

This commit is contained in:
songsenand 2026-02-26 00:58:05 +08:00
parent 7c90633ebc
commit dc718cde5b
2 changed files with 3 additions and 1 deletions

View File

@ -620,6 +620,7 @@ def custom_collate(batch):
"hint": { "hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]), "input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]),
"token_type_ids": torch.cat([h["token_type_ids"] for h in hints]),
}, },
"char_id": torch.cat([item["char_id"] for item in batch]), "char_id": torch.cat([item["char_id"] for item in batch]),
"pg": torch.cat([item["pg"] for item in batch]), "pg": torch.cat([item["pg"] for item in batch]),

View File

@ -522,6 +522,7 @@ class MoEModel(nn.Module):
logger.info( logger.info(
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
) )
else:
logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}") logger.info(f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}")
batch_loss_sum = 0.0 batch_loss_sum = 0.0
if processed_batches >= stop_batch: if processed_batches >= stop_batch: