From d88a68e421f9f233d842e2b919a3672bffb2691d Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 28 Feb 2026 09:42:13 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=20MoEModelWit?= =?UTF-8?q?hNeck=20=E7=B1=BB=E5=8F=8A=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=B1=A0?= =?UTF-8?q?=E5=8C=96=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- pyproject.toml | 2 + src/trainer/model.py | 158 +++++++++++++++++++++++++++++++++++++++---- 3 files changed, 149 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 6b6ce0a..52c0719 100644 --- a/.gitignore +++ b/.gitignore @@ -214,4 +214,7 @@ uv.lock *.log marimo/ -__marimo__/ \ No newline at end of file +__marimo__/ + +# ---> Others +/model/* \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0883e1d..476ed50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "amd-quark>=0.1.0", "bokeh>=3.8.2", "datasets>=4.5.0", "ipykernel>=7.2.0", @@ -13,6 +14,7 @@ dependencies = [ "modelscope>=1.34.0", "msgpack>=1.1.2", "numpy>=2.4.2", + "onnxruntime>=1.24.2", "pandas>=3.0.0", "pypinyin>=0.55.0", "requests>=2.32.5", diff --git a/src/trainer/model.py b/src/trainer/model.py index 7cd5f3c..b6e72a9 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -437,11 +437,24 @@ class MoEModel(nn.Module): warmup_ratio=0.1, label_smoothing=0.15, lr=1e-4, + lr_schedule=None, ): + def default_lr_schedule(_lr, _processed_batches, _stop_batch, _warmup_steps): + if _processed_batches < _warmup_steps: + current_lr = _lr * (_processed_batches / _warmup_steps) + else: + progress = (_processed_batches - _warmup_steps) / ( + _stop_batch - _warmup_steps + ) + current_lr = _lr * (0.5 * (1.0 + math.cos(math.pi * progress))) + return current_lr + """训练函数,调整了输入参数""" if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) + if lr_schedule is None: + lr_schedule = default_lr_schedule self.train() @@ -462,7 +475,6 @@ class MoEModel(nn.Module): warmup_steps = int(total_steps * warmup_ratio) logger.info(f"Training Start: Steps={total_steps}, Warmup={warmup_steps}") processed_batches = 0 - global_step = 0 # 初始化 batch_loss_sum = 0.0 optimizer.zero_grad() try: @@ -471,13 +483,9 @@ class MoEModel(nn.Module): tqdm(train_dataloader, total=int(stop_batch)) ): # LR Schedule - if processed_batches < warmup_steps: - current_lr = lr * (processed_batches / warmup_steps) - else: - progress = (processed_batches - warmup_steps) / ( - total_steps - warmup_steps - ) - current_lr = lr * (0.5 * (1.0 + math.cos(math.pi * progress))) + current_lr = lr_schedule( + lr, stop_batch, processed_batches, warmup_steps + ) for param_group in optimizer.param_groups: param_group["lr"] = current_lr @@ -509,7 +517,7 @@ class MoEModel(nn.Module): scaler.update() optimizer.zero_grad() batch_loss_sum += loss.item() * grad_accum_steps - if global_step % eval_frequency == 0: + if processed_batches % eval_frequency == 0: if eval_dataloader: self.eval() acc, eval_loss = self.model_eval( @@ -519,29 +527,30 @@ class MoEModel(nn.Module): if monitor: # 使用 eval_loss 作为监控指标 monitor.add_step( - global_step, + processed_batches, { "train_loss": batch_loss_sum / ( - eval_frequency if global_step > 0 else 1 + eval_frequency + if processed_batches > 0 + else 1 ), "acc": acc, "loss": eval_loss, - "lr": current_lr + "lr": current_lr, }, ) logger.info( - f"step: {global_step}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" + f"step: {processed_batches}, eval_loss: {eval_loss:.4f}, acc: {acc:.4f}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" ) else: logger.info( - f"step: {global_step}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if global_step > 0 else 1):.4f}, current_lr: {current_lr}" + f"step: {processed_batches}, 'batch_loss_sum': {batch_loss_sum / (eval_frequency if processed_batches > 0 else 1):.4f}, current_lr: {current_lr}" ) batch_loss_sum = 0.0 if processed_batches >= stop_batch: break processed_batches += 1 - global_step += 1 except KeyboardInterrupt: logger.info("Training interrupted by user") @@ -607,3 +616,122 @@ class MoEModel(nn.Module): do_constant_folding=True, ) logger.info(f"ONNX model exported to {output_path}") + + +class MoEModelWithNeck(MoEModel): + def __init__( + self, + pretrained_model_name="iic/nlp_structbert_backbone_lite_std", + num_classes=10018, + output_multiplier=2, + d_model=768, + num_resblocks=4, + num_domain_experts=12, + num_shared_experts=1, + ): + super().__init__() + self.output_multiplier = output_multiplier + + # 1. 加载预训练 Embedding (Lite: hidden_size=512) + logger.info(f"Loading backbone: {pretrained_model_name}") + bert = AutoModel.from_pretrained(pretrained_model_name) + self.embedding = bert.embeddings + self.bert_config = bert.config + self.hidden_size = self.bert_config.hidden_size # 512 + self.device = None + + # 2. Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hidden_size, + nhead=8, + dim_feedforward=self.bert_config.intermediate_size, + dropout=self.bert_config.hidden_dropout_prob, + activation="gelu", + batch_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) + + # 3. 注意力池化(新增) + self.attn_pool = AttentionPooling(self.hidden_size) + + self.neck = nn.ModuleList( + [ResidualBlock(self.hidden_size) for _ in range(num_resblocks)] + ) + + # 4. 专家系统 + total_experts = num_domain_experts + num_shared_experts + self.experts = nn.ModuleList() + for i in range(total_experts): + dropout = 0.3 if i < num_domain_experts else 0.4 + self.experts.append( + Expert( + input_dim=self.hidden_size, + d_model=d_model, + num_resblocks=num_resblocks, + output_multiplier=self.output_multiplier, + dropout_prob=dropout, + ) + ) + + self.expert_bias = nn.Embedding( + total_experts, self.hidden_size * self.output_multiplier + ) + + # 5. 分类头 + self.classifier = nn.Sequential( + nn.LayerNorm(self.hidden_size * self.output_multiplier), + nn.Dropout(0.4), + nn.Linear( + self.hidden_size * self.output_multiplier, + self.hidden_size * self.output_multiplier * 2, + ), + nn.GELU(), + nn.Linear(self.hidden_size * self.output_multiplier * 2, num_classes), + ) + + def forward(self, input_ids, attention_mask, token_type_ids, pg): + """ + 新版 Forward 函数,不再需要 p_start,改用 token_type_ids。 + Args: + input_ids: [B, L] + attention_mask: [B, L] + token_type_ids: [B, L] (0=文本, 1=拼音) + pg: [B] 拼音组 ID + """ + # ----- 1. Embeddings ----- + # 注意:预训练的 embedding 层本身可能已经包含了 token_type_ids 的处理, + # 但这里我们直接使用它的 embedding,并手动将 token_type_ids 的嵌入加到上面。 + # 由于 bert.embeddings 通常包含 token_type_embeddings,我们可以利用它。 + # 但为简化,我们直接使用 bert.embeddings(input_ids, token_type_ids=token_type_ids) + # 如果当前 embedding 不支持传入 token_type_ids,可以手动相加: + # embeddings = self.embedding(input_ids) + self.embedding.token_type_embeddings(token_type_ids) + # 这里采用更通用的方式:假设 self.embedding 有 token_type_ids 参数 + embeddings = self.embedding(input_ids, token_type_ids=token_type_ids) + + # ----- 2. Transformer Encoder ----- + padding_mask = attention_mask == 0 + encoded = self.encoder( + embeddings, src_key_padding_mask=padding_mask + ) # [B, S, H] + + # ----- 3. 注意力池化(代替原来的 Span Pooling)----- + # 使用 attention_mask 忽略 padding 位置 + pooled = self.attn_pool(encoded, attention_mask, token_type_ids) # [B, H] + + for block in self.neck: + pooled = block(pooled) + + # ----- 4. 专家路由(硬路由)----- + # 将所有专家的输出堆叠为 [batch, num_experts, hidden*multiplier] + expert_outputs = torch.stack([e(pooled) for e in self.experts], dim=1) # [B, E, D] + + # pg: [B] -> 扩展为 [B, 1, D] 作为 gather 的索引 + index = pg.view(-1, 1, 1).expand(-1, 1, expert_outputs.size(-1)) # [B, 1, D] + expert_out = torch.gather(expert_outputs, 1, index).squeeze(1) # [B, D] + + # 加上专家偏置 + bias = self.expert_bias(pg) # [B, D] + expert_out = expert_out + bias + + # ----- 5. 分类头 ----- + return self.classifier(expert_out) \ No newline at end of file