Compare commits

..

No commits in common. "d2d65c7efabe49573339a3790839311ab2b4f392" and "7eb00c6207b7fab71b8d4ce2c6715085a59134a3" have entirely different histories.

4 changed files with 63 additions and 755 deletions

View File

@ -1,6 +1,6 @@
import os import os
import random import random
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Tuple, Optional
import numpy as np import numpy as np
import torch import torch
@ -97,28 +97,28 @@ class PinyinInputDataset(IterableDataset):
# 加载拼音分组 # 加载拼音分组
self.pg_groups = { self.pg_groups = {
"y": 0, "y": 0,
"z": 1, "k": 0,
"j": 2, "e": 0,
"l": 3, "l": 1,
"s": 4, "w": 1,
"x": 5, "f": 1,
"c": 6, "q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7, "h": 7,
"d": 8,
"b": 9,
"q": 10,
"g": 11,
"t": 12,
"m": 13,
"p": 14,
"w": 15,
"f": 16,
"k": 17,
"n": 18,
"r": 19,
"a": 19,
"e": 18,
"o": 17,
} }
def get_next_chinese_chars( def get_next_chinese_chars(
@ -440,7 +440,9 @@ class PinyinInputDataset(IterableDataset):
"char_id": torch.tensor([char_info["id"]]), "char_id": torch.tensor([char_info["id"]]),
"char": char, "char": char,
"freq": char_info["freq"], "freq": char_info["freq"],
"pg": torch.tensor([self.pg_groups[char_info.pinyin[0]]]), "pg": torch.tensor(
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
),
} }
# 根据调整因子重复样本 # 根据调整因子重复样本

View File

@ -1,14 +1,14 @@
from tqdm import tqdm
from loguru import logger
import torch
from torch.utils.data import DataLoader
import pickle import pickle
from pathlib import Path from pathlib import Path
import torch from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from suinput.dataset import PinyinInputDataset, custom_collate_with_txt, worker_init_fn
from suinput.query import QueryEngine from suinput.query import QueryEngine
# 使用示例 # 使用示例
if __name__ == "__main__": if __name__ == "__main__":
# 初始化查询引擎 # 初始化查询引擎
@ -42,13 +42,7 @@ if __name__ == "__main__":
for i, sample in tqdm(enumerate(dataloader), total=5): for i, sample in tqdm(enumerate(dataloader), total=5):
if i >= total: if i >= total:
break break
# print(sample) print(sample)
pickle.dump( # pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb"))
sample,
open(
f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl",
"wb",
),
)
except StopIteration: except StopIteration:
print("数据集为空") print("数据集为空")

View File

@ -1,7 +1,5 @@
import pickle import pickle
from importlib.resources import files from importlib.resources import files
from pathlib import Path
from typing import Optional, Union
import torch import torch
import torch.amp as amp import torch.amp as amp
@ -14,46 +12,10 @@ from tqdm import tqdm
from .monitor import TrainingMonitor from .monitor import TrainingMonitor
EVAL_DATALOADER = [
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): pickle.load(file.open("rb"))
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] for file in (files(__package__) / "eval_dataset").glob("*.pkl")
]
def round_to_power_of_two(x):
if x < 1:
return 0
n = x.bit_length()
n = min(max(7, n), 9)
lower = 1 << (n) # 小于等于x的最大2的幂次
upper = lower << 1 # 大于x的最小2的幂次
if x - lower < upper - x:
return lower
else:
return upper
EXPORT_HIDE_DIM = {
0: 1024,
1: 1024,
2: 1024,
3: 512,
4: 512,
5: 512,
6: 512,
7: 512,
8: 512,
9: 512,
10: 512,
11: 512,
12: 512,
13: 512,
14: 512,
15: 512,
16: 512,
17: 512,
18: 512,
19: 256,
}
# ---------------------------- 残差块 ---------------------------- # ---------------------------- 残差块 ----------------------------
@ -130,8 +92,8 @@ class MoEModel(nn.Module):
output_multiplier=2, output_multiplier=2,
d_model=768, d_model=768,
num_resblocks=4, num_resblocks=4,
num_domain_experts=20, num_domain_experts=8,
experts_dim=EXPORT_HIDE_DIM, num_shared_experts=1,
): ):
super().__init__() super().__init__()
self.output_multiplier = output_multiplier self.output_multiplier = output_multiplier
@ -142,35 +104,38 @@ class MoEModel(nn.Module):
self.bert_config = bert.config self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置 self.device = None # 将在 to() 调用时设置
self.experts_dim = experts_dim
# 2. 4 层标准 Transformer Encoder从 config 读取参数) # 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size, d_model=self.hidden_size,
nhead=8, nhead=self.bert_config.num_attention_heads,
dim_feedforward=self.bert_config.intermediate_size, dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob, dropout=self.bert_config.hidden_dropout_prob,
activation="gelu", activation="gelu",
batch_first=True, batch_first=True,
norm_first=True, # Pre-LN与预训练一致
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1) self.pooler = nn.AdaptiveAvgPool1d(1)
self.total_experts = 20 # 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList() self.experts = nn.ModuleList()
for i in range(self.total_experts): for i in range(total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
dropout_prob = 0.1 if i < num_domain_experts else 0.2
expert = Expert( expert = Expert(
input_dim=self.hidden_size, input_dim=self.hidden_size,
d_model=self.experts_dim[i], d_model=d_model,
num_resblocks=num_resblocks, num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=0.1, dropout_prob=dropout_prob,
) )
self.experts.append(expert) self.experts.append(expert)
self.expert_bias = nn.Embedding( self.expert_bias = nn.Embedding(
self.total_experts, self.output_multiplier * self.hidden_size total_experts, self.output_multiplier * self.hidden_size
) )
# 4. 分类头 # 4. 分类头
@ -181,6 +146,11 @@ class MoEModel(nn.Module):
self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size,
), ),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear( nn.Linear(
self.output_multiplier * self.hidden_size, self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2, self.output_multiplier * self.hidden_size * 2,
@ -253,54 +223,11 @@ class MoEModel(nn.Module):
expert_out = self.experts[7](pooled) + self.expert_bias( expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device) torch.tensor(7, device=pooled.device)
) )
elif group_id == 8: # group_id == 8 else: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias( expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device) torch.tensor(8, device=pooled.device)
) )
elif group_id == 9: # group_id == 9
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10: # group_id == 10
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11: # group_id == 11
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
elif group_id == 12: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
elif group_id == 13: # group_id == 13
expert_out = self.experts[13](pooled) + self.expert_bias(
torch.tensor(13, device=pooled.device)
)
elif group_id == 14: # group_id == 14
expert_out = self.experts[14](pooled) + self.expert_bias(
torch.tensor(14, device=pooled.device)
)
elif group_id == 15: # group_id == 15
expert_out = self.experts[15](pooled) + self.expert_bias(
torch.tensor(15, device=pooled.device)
)
elif group_id == 16: # group_id == 16
expert_out = self.experts[16](pooled) + self.expert_bias(
torch.tensor(16, device=pooled.device)
)
elif group_id == 17: # group_id == 17
expert_out = self.experts[17](pooled) + self.expert_bias(
torch.tensor(17, device=pooled.device)
)
elif group_id == 18: # group_id == 18
expert_out = self.experts[18](pooled) + self.expert_bias(
torch.tensor(18, device=pooled.device)
)
else: # group_id == 19
expert_out = self.experts[19](pooled) + self.expert_bias(
torch.tensor(19, device=pooled.device)
)
else: else:
batch_size = pooled.size(0) batch_size = pooled.size(0)
# 并行计算所有专家输出 # 并行计算所有专家输出
@ -460,7 +387,7 @@ class MoEModel(nn.Module):
self, self,
train_dataloader, train_dataloader,
eval_dataloader=None, eval_dataloader=None,
monitor: Optional[TrainingMonitor] = None, monitor: TrainingMonitor = None,
criterion=nn.CrossEntropyLoss(), criterion=nn.CrossEntropyLoss(),
optimizer=None, optimizer=None,
num_epochs=1, num_epochs=1,
@ -475,31 +402,11 @@ class MoEModel(nn.Module):
训练模型支持混合精度梯度累积学习率调度实时监控 训练模型支持混合精度梯度累积学习率调度实时监控
参数 参数
train_dataloader: DataLoader ... 原有参数 ...
训练数据加载器
eval_dataloader: DataLoader, optional
评估数据加载器
monitor: TrainingMonitor, optional
训练监控器
criterion: nn.Module, optional
损失函数
optimizer: optim.Optimizer, optional
优化器
num_epochs: int, optional
训练轮数
eval_frequency: int, optional
评估频率
grad_accum_steps: int, optional
梯度累积步数
clip_grad_norm: float, optional
梯度裁剪范数
mixed_precision: bool, optional
是否使用混合精度
lr: float, optional
初始学习率
lr_schedule : callable, optional lr_schedule : callable, optional
自定义学习率调度函数接收参数 (processed_batches, optimizer) 自定义学习率调度函数接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率 可在内部直接修改 optimizer.param_groups 中的学习率
若为 None则启用内置的固定阈值调度前1000批 1e-4之后 6e-6
""" """
# 确保模型在正确的设备上 # 确保模型在正确的设备上
if self.device is None: if self.device is None:
@ -564,7 +471,7 @@ class MoEModel(nn.Module):
and global_step % eval_frequency == 0 and global_step % eval_frequency == 0
): ):
avg_loss = batch_loss_sum / eval_frequency avg_loss = batch_loss_sum / eval_frequency
acc, eval_loss = self.model_eval(eval_dataloader, criterion) acc, _ = self.model_eval(eval_dataloader, criterion)
super().train() super().train()
if monitor is not None: if monitor is not None:
monitor.add_step( monitor.add_step(
@ -572,7 +479,7 @@ class MoEModel(nn.Module):
{"loss": avg_loss, "acc": acc}, {"loss": avg_loss, "acc": acc},
) )
logger.info( logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}"
) )
batch_loss_sum = 0.0 batch_loss_sum = 0.0

View File

@ -1,595 +0,0 @@
import pickle
from importlib.resources import files
from pathlib import Path
from typing import Optional, Union
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel
from tqdm import tqdm
from .monitor import TrainingMonitor
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.ln1 = nn.LayerNorm(dim)
self.linear2 = nn.Linear(dim, dim)
self.ln2 = nn.LayerNorm(dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
residual = x
x = self.relu(self.linear1(x))
x = self.ln1(x)
x = self.linear2(x)
x = self.ln2(x)
x = self.dropout(x) # 残差前加 Dropout符合原描述
x = x + residual
return self.relu(x)
# ---------------------------- 专家网络 ----------------------------
class Expert(nn.Module):
def __init__(
self,
input_dim,
d_model=1024,
num_resblocks=4,
output_multiplier=2,
dropout_prob=0.1,
):
"""
input_dim : BERT 输出的 hidden_size 312/768
d_model : 专家内部维度固定 1024
output_multiplier : 输出维度 = input_dim * output_multiplier
dropout_prob : 残差块内 Dropout
"""
super().__init__()
self.input_dim = input_dim
self.d_model = d_model
self.output_dim = input_dim * output_multiplier
# 输入映射input_dim -> d_model
self.linear_in = nn.Linear(input_dim, d_model)
# 残差堆叠
self.res_blocks = nn.ModuleList(
[ResidualBlock(d_model, dropout_prob) for _ in range(num_resblocks)]
)
# 输出映射d_model -> output_dim
self.output = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(inplace=True),
nn.Linear(d_model, self.output_dim),
)
def forward(self, x):
x = self.linear_in(x)
for block in self.res_blocks:
x = block(x)
return self.output(x)
# ---------------------------- 主模型MoE + 硬路由)------------------------
class MoEModel(nn.Module):
def __init__(
self,
pretrained_model_name="iic/nlp_structbert_backbone_tiny_std",
num_classes=10018,
output_multiplier=2,
d_model=768,
num_resblocks=4,
num_domain_experts=23,
):
super().__init__()
self.output_multiplier = output_multiplier
# 1. 加载预训练 BERT仅保留 embeddings
bert = AutoModel.from_pretrained(pretrained_model_name)
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
self.linear = nn.Linear(256, d_model)
# 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=self.bert_config.num_attention_heads,
dim_feedforward=self.bert_config.intermediate_size,
dropout=self.bert_config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True, # Pre-LN与预训练一致
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.pooler = nn.AdaptiveAvgPool1d(1)
self.total_experts = 23
self.experts = nn.ModuleList()
for i in range(self.total_experts):
# 领域专家 dropout=0.1,共享专家 dropout=0.2(您指定的更强正则)
expert = Expert(
input_dim=self.hidden_size * 2,
d_model=d_model,
num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size
dropout_prob=0.1,
)
self.experts.append(expert)
self.expert_bias = nn.Embedding(
self.total_experts, self.output_multiplier * self.hidden_size
)
# 4. 分类头
self.classifier = nn.Sequential(
nn.LayerNorm(self.output_multiplier * self.hidden_size),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size,
),
nn.ReLU(inplace=True),
nn.Linear(
self.output_multiplier * self.hidden_size,
self.output_multiplier * self.hidden_size * 2,
),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(self.output_multiplier * self.hidden_size * 2, num_classes),
)
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg):
"""
input_ids : [batch, seq_len]
attention_mask: [batch, seq_len] (1 为有效0 padding)
pg : group_id训练时为 [batch] LongTensor推理导出时为标量 Tensor
"""
# ----- 1. Embeddings -----
embeddings = self.embedding(input_ids) # [B, S, H]
# ----- 2. Transformer Encoder -----
# padding mask: True 表示忽略该位置
padding_mask = attention_mask == 0
encoded = self.encoder(
embeddings, src_key_padding_mask=padding_mask
) # [B, S, H]
# ----- 3. 池化量 -----
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# ----- 4. 专家路由(硬路由)-----
if torch.jit.is_tracing():
# ------------------ ONNX 导出模式条件分支batch=1------------------
# 此时 pg 为标量 Tensor转换为 Python int
group_id = pg.item() if torch.is_tensor(pg) else pg
if group_id == 0:
expert_out = self.experts[0](pooled) + self.expert_bias(
torch.tensor(0, device=pooled.device)
)
elif group_id == 1:
expert_out = self.experts[1](pooled) + self.expert_bias(
torch.tensor(1, device=pooled.device)
)
elif group_id == 2:
expert_out = self.experts[2](pooled) + self.expert_bias(
torch.tensor(2, device=pooled.device)
)
elif group_id == 3:
expert_out = self.experts[3](pooled) + self.expert_bias(
torch.tensor(3, device=pooled.device)
)
elif group_id == 4:
expert_out = self.experts[4](pooled) + self.expert_bias(
torch.tensor(4, device=pooled.device)
)
elif group_id == 5:
expert_out = self.experts[5](pooled) + self.expert_bias(
torch.tensor(5, device=pooled.device)
)
elif group_id == 6:
expert_out = self.experts[6](pooled) + self.expert_bias(
torch.tensor(6, device=pooled.device)
)
elif group_id == 7:
expert_out = self.experts[7](pooled) + self.expert_bias(
torch.tensor(7, device=pooled.device)
)
elif group_id == 8: # group_id == 8
expert_out = self.experts[8](pooled) + self.expert_bias(
torch.tensor(8, device=pooled.device)
)
elif group_id == 9: # group_id == 9
expert_out = self.experts[9](pooled) + self.expert_bias(
torch.tensor(9, device=pooled.device)
)
elif group_id == 10: # group_id == 10
expert_out = self.experts[10](pooled) + self.expert_bias(
torch.tensor(10, device=pooled.device)
)
elif group_id == 11: # group_id == 11
expert_out = self.experts[11](pooled) + self.expert_bias(
torch.tensor(11, device=pooled.device)
)
elif group_id == 12: # group_id == 12
expert_out = self.experts[12](pooled) + self.expert_bias(
torch.tensor(12, device=pooled.device)
)
elif group_id == 13: # group_id == 13
expert_out = self.experts[13](pooled) + self.expert_bias(
torch.tensor(13, device=pooled.device)
)
elif group_id == 14: # group_id == 14
expert_out = self.experts[14](pooled) + self.expert_bias(
torch.tensor(14, device=pooled.device)
)
elif group_id == 15: # group_id == 15
expert_out = self.experts[15](pooled) + self.expert_bias(
torch.tensor(15, device=pooled.device)
)
elif group_id == 16: # group_id == 16
expert_out = self.experts[16](pooled) + self.expert_bias(
torch.tensor(16, device=pooled.device)
)
elif group_id == 17: # group_id == 17
expert_out = self.experts[17](pooled) + self.expert_bias(
torch.tensor(17, device=pooled.device)
)
elif group_id == 18: # group_id == 18
expert_out = self.experts[18](pooled) + self.expert_bias(
torch.tensor(18, device=pooled.device)
)
elif group_id == 19: # group_id == 19
expert_out = self.experts[19](pooled) + self.expert_bias(
torch.tensor(19, device=pooled.device)
)
elif group_id == 20: # group_id == 20
expert_out = self.experts[20](pooled) + self.expert_bias(
torch.tensor(20, device=pooled.device)
)
elif group_id == 21: # group_id == 21
expert_out = self.experts[21](pooled) + self.expert_bias(
torch.tensor(21, device=pooled.device)
)
elif group_id == 22: # group_id == 22
expert_out = self.experts[22](pooled) + self.expert_bias(
torch.tensor(22, device=pooled.device)
)
else:
batch_size = pooled.size(0)
# 并行计算所有专家输出
expert_outputs = torch.stack(
[e(pooled) for e in self.experts], dim=0
) # [E, B, D]
# 根据 pg 索引专家输出
expert_out = expert_outputs[pg, torch.arange(batch_size)] # [B, D]
# 添加专家偏置
bias = self.expert_bias(pg) # [B, D]
expert_out = expert_out + bias
# ----- 5. 分类头 -----
logits = self.classifier(expert_out) # [batch, num_classes]
if not self.training: # 推理时加 Softmax
probs = torch.softmax(logits, dim=-1)
return probs
return logits
def model_eval(self, eval_dataloader, criterion=None):
"""
在验证集上评估模型返回准确率和平均损失
参数
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数默认为 CrossEntropyLoss()
返回
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.CrossEntropyLoss()
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def predict(self, sample, debug=False):
"""
基于 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数
sample : dict
必须包含字段
- 'input_ids' : [batch, seq_len] [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量真实标签 debug=True 时必须提供
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool
是否打印预测错误的样本信息若为 True sample 缺少 char_id/txt/char/py抛出 ValueError
返回
preds : torch.Tensor
[batch] 预测类别标签若输入为单样本且无 batch 维度则返回标量
"""
self.eval()
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0))
# ------------------ 2. 移动设备 ------------------
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
with torch.no_grad():
logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor: Optional[TrainingMonitor] = None,
criterion=nn.CrossEntropyLoss(),
optimizer=None,
num_epochs=1,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
lr=1e-4,
lr_schedule=None, # 新增:可选的自定义学习率调度函数
):
"""
训练模型支持混合精度梯度累积学习率调度实时监控
参数
train_dataloader: DataLoader
训练数据加载器
eval_dataloader: DataLoader, optional
评估数据加载器
monitor: TrainingMonitor, optional
训练监控器
criterion: nn.Module, optional
损失函数
optimizer: optim.Optimizer, optional
优化器
num_epochs: int, optional
训练轮数
eval_frequency: int, optional
评估频率
grad_accum_steps: int, optional
梯度累积步数
clip_grad_norm: float, optional
梯度裁剪范数
mixed_precision: bool, optional
是否使用混合精度
lr: float, optional
初始学习率
lr_schedule : callable, optional
自定义学习率调度函数接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率
"""
# 确保模型在正确的设备上
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式
super().train()
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=lr) # 初始学习率 1e-4
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
# ---------- 更新 batch 计数器 ----------
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps
# 反向传播
scaler.scale(loss).backward()
# 梯度累积
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1
original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估(与原代码相同)
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / eval_frequency
acc, eval_loss = self.model_eval(eval_dataloader, criterion)
super().train()
if monitor is not None:
monitor.add_step(
global_step,
{"loss": avg_loss, "acc": acc},
)
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
)
batch_loss_sum = 0.0
# ============================ 使用示例 ============================
if __name__ == "__main__":
# 1. 初始化模型
model = MoEModel()
model.eval()
# 2. 构造 dummy 输入batch=1用于导出 ONNX
dummy_input_ids = torch.randint(0, 100, (1, 64)) # [1, 64]
dummy_attention_mask = torch.ones_like(dummy_input_ids) # [1, 64]
dummy_pg = torch.tensor(3, dtype=torch.long) # 标量 group_id
# 3. 导出 ONNX使用条件分支仅计算一个专家
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_pg),
"moe_cpu.onnx",
input_names=["input_ids", "attention_mask", "pg"],
output_names=["logits"],
dynamic_axes={ # 固定 batch=1可不设 dynamic_axes
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
},
opset_version=12,
do_constant_folding=True,
)
print("ONNX 导出成功!")
# 4. 测试训练模式batch=4
model.train()
batch_input_ids = torch.randint(0, 100, (4, 64))
batch_attention_mask = torch.ones_like(batch_input_ids)
batch_pg = torch.tensor([0, 3, 8, 1], dtype=torch.long) # 不同 group
logits = model(batch_input_ids, batch_attention_mask, batch_pg)
print("训练模式输出形状:", logits.shape) # [4, 10018]