feat(model): 添加 MoEModelWithNeck 类及注意力池化模块
This commit is contained in:
parent
4031a668da
commit
d88a68e421
|
|
@ -215,3 +215,6 @@ uv.lock
|
||||||
*.log
|
*.log
|
||||||
marimo/
|
marimo/
|
||||||
__marimo__/
|
__marimo__/
|
||||||
|
|
||||||
|
# ---> Others
|
||||||
|
/model/*
|
||||||
|
|
@ -5,6 +5,7 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"amd-quark>=0.1.0",
|
||||||
"bokeh>=3.8.2",
|
"bokeh>=3.8.2",
|
||||||
"datasets>=4.5.0",
|
"datasets>=4.5.0",
|
||||||
"ipykernel>=7.2.0",
|
"ipykernel>=7.2.0",
|
||||||
|
|
@ -13,6 +14,7 @@ dependencies = [
|
||||||
"modelscope>=1.34.0",
|
"modelscope>=1.34.0",
|
||||||
"msgpack>=1.1.2",
|
"msgpack>=1.1.2",
|
||||||
"numpy>=2.4.2",
|
"numpy>=2.4.2",
|
||||||
|
"onnxruntime>=1.24.2",
|
||||||
"pandas>=3.0.0",
|
"pandas>=3.0.0",
|
||||||
"pypinyin>=0.55.0",
|
"pypinyin>=0.55.0",
|
||||||
"requests>=2.32.5",
|
"requests>=2.32.5",
|
||||||
|
|
|
||||||
|
|
@ -437,11 +437,24 @@ class MoEModel(nn.Module):
|
||||||
warmup_ratio=0.1,
|
warmup_ratio=0.1,
|
||||||
label_smoothing=0.15,
|
label_smoothing=0.15,
|
||||||
lr=1e-4,
|
lr=1e-4,
|
||||||
|
lr_schedule=None,
|
||||||
):
|
):
|
||||||
|
def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps):
|
||||||
|
if _processed_batches < _warmup_steps:
|
||||||
|
current_lr = _lr * (_processed_batches / _warmup_steps)
|
||||||
|
else:
|
||||||
|
progress = (_processed_batches - _warmup_steps) / (
|
||||||
|
_stop_batch - _warmup_steps
|
||||||
|
)
|
||||||
|
current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
|
return current_lr
|
||||||
|
|
||||||
"""训练函数,调整了输入参数"""
|
"""训练函数,调整了输入参数"""
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
if lr_schedule is None:
|
||||||
|
lr_schedule = default_lr_schedule
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
|
@ -462,7 +475,6 @@ class MoEModel(nn.Module):
|
||||||
warmup_steps = int(total_steps * warmup_ratio)
|
warmup_steps = int(total_steps * warmup_ratio)
|
||||||
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}")
|
||||||
processed_batches = 0
|
processed_batches = 0
|
||||||
global_step = 0 # 初始化
|
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
try:
|
try:
|
||||||
|
|
@ -471,13 +483,9 @@ class MoEModel(nn.Module):
|
||||||
tqdm(train_dataloader, total=int(stop_batch))
|
tqdm(train_dataloader, total=int(stop_batch))
|
||||||
):
|
):
|
||||||
# LR Schedule
|
# LR Schedule
|
||||||
if processed_batches < warmup_steps:
|
current_lr = lr_schedule(
|
||||||
current_lr = lr * (processed_batches / warmup_steps)
|
lr, stop_batch, processed_batches, warmup_steps
|
||||||
else:
|
|
||||||
progress = (processed_batches - warmup_steps) / (
|
|
||||||
total_steps - warmup_steps
|
|
||||||
)
|
)
|
||||||
current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group["lr"] = current_lr
|
param_group["lr"] = current_lr
|
||||||
|
|
||||||
|
|
@ -509,7 +517,7 @@ class MoEModel(nn.Module):
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
batch_loss_sum += loss.item() * grad_accum_steps
|
batch_loss_sum += loss.item() * grad_accum_steps
|
||||||
if global_step % eval_frequency == 0:
|
if processed_batches % eval_frequency == 0:
|
||||||
if eval_dataloader:
|
if eval_dataloader:
|
||||||
self.eval()
|
self.eval()
|
||||||
acc, eval_loss = self.model_eval(
|
acc, eval_loss = self.model_eval(
|
||||||
|
|
@ -519,29 +527,30 @@ class MoEModel(nn.Module):
|
||||||
if monitor:
|
if monitor:
|
||||||
# 使用 eval_loss 作为监控指标
|
# 使用 eval_loss 作为监控指标
|
||||||
monitor.add_step(
|
monitor.add_step(
|
||||||
global_step,
|
processed_batches,
|
||||||
{
|
{
|
||||||
"train_loss": batch_loss_sum
|
"train_loss": batch_loss_sum
|
||||||
/ (
|
/ (
|
||||||
eval_frequency if global_step > 0 else 1
|
eval_frequency
|
||||||
|
if processed_batches > 0
|
||||||
|
else 1
|
||||||
),
|
),
|
||||||
"acc": acc,
|
"acc": acc,
|
||||||
"loss": eval_loss,
|
"loss": eval_loss,
|
||||||
"lr": current_lr
|
"lr": current_lr,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}"
|
f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}"
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
if processed_batches >= stop_batch:
|
if processed_batches >= stop_batch:
|
||||||
break
|
break
|
||||||
processed_batches += 1
|
processed_batches += 1
|
||||||
global_step += 1
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Training interrupted by user")
|
logger.info("Training interrupted by user")
|
||||||
|
|
||||||
|
|
@ -607,3 +616,122 @@ class MoEModel(nn.Module):
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
)
|
)
|
||||||
logger.info(f"ONNX model exported to {output_path}")
|
logger.info(f"ONNX model exported to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class MoEModelWithNeck(MoEModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrained_model_name="iic/nlp_structbert_backbone_lite_std",
|
||||||
|
num_classes=10018,
|
||||||
|
output_multiplier=2,
|
||||||
|
d_model=768,
|
||||||
|
num_resblocks=4,
|
||||||
|
num_domain_experts=12,
|
||||||
|
num_shared_experts=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_multiplier = output_multiplier
|
||||||
|
|
||||||
|
# 1. 加载预训练 Embedding (Lite: hidden_size=512)
|
||||||
|
logger.info(f"Loading backbone: {pretrained_model_name}")
|
||||||
|
bert = AutoModel.from_pretrained(pretrained_model_name)
|
||||||
|
self.embedding = bert.embeddings
|
||||||
|
self.bert_config = bert.config
|
||||||
|
self.hidden_size = self.bert_config.hidden_size # 512
|
||||||
|
self.device = None
|
||||||
|
|
||||||
|
# 2. Transformer Encoder
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=self.bert_config.intermediate_size,
|
||||||
|
dropout=self.bert_config.hidden_dropout_prob,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
||||||
|
|
||||||
|
# 3. 注意力池化(新增)
|
||||||
|
self.attn_pool = AttentionPooling(self.hidden_size)
|
||||||
|
|
||||||
|
self.neck = nn.ModuleList(
|
||||||
|
[ResidualBlock(self.hidden_size) for _ in range(num_resblocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 专家系统
|
||||||
|
total_experts = num_domain_experts + num_shared_experts
|
||||||
|
self.experts = nn.ModuleList()
|
||||||
|
for i in range(total_experts):
|
||||||
|
dropout = 0.3 if i < num_domain_experts else 0.4
|
||||||
|
self.experts.append(
|
||||||
|
Expert(
|
||||||
|
input_dim=self.hidden_size,
|
||||||
|
d_model=d_model,
|
||||||
|
num_resblocks=num_resblocks,
|
||||||
|
output_multiplier=self.output_multiplier,
|
||||||
|
dropout_prob=dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.expert_bias = nn.Embedding(
|
||||||
|
total_experts, self.hidden_size * self.output_multiplier
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 分类头
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.hidden_size * self.output_multiplier),
|
||||||
|
nn.Dropout(0.4),
|
||||||
|
nn.Linear(
|
||||||
|
self.hidden_size * self.output_multiplier,
|
||||||
|
self.hidden_size * self.output_multiplier * 2,
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask, token_type_ids, pg):
|
||||||
|
"""
|
||||||
|
新版 Forward 函数,不再需要 p_start,改用 token_type_ids。
|
||||||
|
Args:
|
||||||
|
input_ids: [B, L]
|
||||||
|
attention_mask: [B, L]
|
||||||
|
token_type_ids: [B, L] (0=文本, 1=拼音)
|
||||||
|
pg: [B] 拼音组 ID
|
||||||
|
"""
|
||||||
|
# ----- 1. Embeddings -----
|
||||||
|
# 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理,
|
||||||
|
# 但这里我们直接使用它的 embedding,并手动将 token_type_ids 的嵌入加到上面。
|
||||||
|
# 由于 bert.embeddings 通常包含 token_type_embeddings,我们可以利用它。
|
||||||
|
# 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids)
|
||||||
|
# 如果当前 embedding 不支持传入 token_type_ids,可以手动相加:
|
||||||
|
# embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids)
|
||||||
|
# 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数
|
||||||
|
embeddings = self.embedding(input_ids, token_type_ids=token_type_ids)
|
||||||
|
|
||||||
|
# ----- 2. Transformer Encoder -----
|
||||||
|
padding_mask = attention_mask == 0
|
||||||
|
encoded = self.encoder(
|
||||||
|
embeddings, src_key_padding_mask=padding_mask
|
||||||
|
) # [B, S, H]
|
||||||
|
|
||||||
|
# ----- 3. 注意力池化(代替原来的 Span Pooling)-----
|
||||||
|
# 使用 attention_mask 忽略 padding 位置
|
||||||
|
pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H]
|
||||||
|
|
||||||
|
for block in self.neck:
|
||||||
|
pooled = block(pooled)
|
||||||
|
|
||||||
|
# ----- 4. 专家路由(硬路由)-----
|
||||||
|
# 将所有专家的输出堆叠为 [batch, num_experts, hidden*multiplier]
|
||||||
|
expert_outputs = torch.stack([e(pooled) for e in self.experts], dim=1) # [B, E, D]
|
||||||
|
|
||||||
|
# pg: [B] -> 扩展为 [B, 1, D] 作为 gather 的索引
|
||||||
|
index = pg.view(-1, 1, 1).expand(-1, 1, expert_outputs.size(-1)) # [B, 1, D]
|
||||||
|
expert_out = torch.gather(expert_outputs, 1, index).squeeze(1) # [B, D]
|
||||||
|
|
||||||
|
# 加上专家偏置
|
||||||
|
bias = self.expert_bias(pg) # [B, D]
|
||||||
|
expert_out = expert_out + bias
|
||||||
|
|
||||||
|
# ----- 5. 分类头 -----
|
||||||
|
return self.classifier(expert_out)
|
||||||
Loading…
Reference in New Issue