重构代码结构并优化注释格式

This commit is contained in:
songsenand 2026-02-22 12:16:22 +08:00
parent 3bb44f1d73
commit 5857c90be7
1 changed files with 12 additions and 12 deletions

View File

@ -10,7 +10,7 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from loguru import logger from loguru import logger
from modelscope import AutoModel, AutoTokenizer from modelscope import AutoModel, AutoTokenizer
from tqdm import tqdm from tqdm.autonotebook import tqdm
from .monitor import TrainingMonitor from .monitor import TrainingMonitor
from suinput.dataset import PG from suinput.dataset import PG
@ -121,14 +121,13 @@ class MoEModel(nn.Module):
) )
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
self.shared_resblocks = nn.ModuleList( # self.shared_resblocks = nn.ModuleList(
[ResidualBlock(self.hidden_size, 0.1) for _ in range(6)] # [ResidualBlock(self.hidden_size, 0.1) for _ in range(4)]
) # )
self.pooler = nn.AdaptiveAvgPool1d(1) self.pooler = nn.AdaptiveAvgPool1d(1)
# self.linear = nn.Linear(self.hidden_size, self.hidden_size) # self.linear = nn.Linear(self.hidden_size, self.hidden_size)
# 3. 专家层8个领域专家 + 1个共享专家 # 3. 专家层8个领域专家 + 1个共享专家
total_experts = num_domain_experts + num_shared_experts total_experts = num_domain_experts + num_shared_experts
self.experts = nn.ModuleList() self.experts = nn.ModuleList()
@ -140,12 +139,11 @@ class MoEModel(nn.Module):
input_dim=self.hidden_size, input_dim=self.hidden_size,
d_model=d_model, d_model=d_model,
num_resblocks=num_resblocks, num_resblocks=num_resblocks,
output_multiplier=self.output_multiplier, # 输出维度 = 2 * hidden_size output_multiplier=self.output_multiplier,
dropout_prob=dropout_prob, dropout_prob=dropout_prob,
) )
self.experts.append(expert) self.experts.append(expert)
self.expert_bias = nn.Embedding( self.expert_bias = nn.Embedding(
total_experts, self.output_multiplier * self.hidden_size total_experts, self.output_multiplier * self.hidden_size
) )
@ -195,8 +193,8 @@ class MoEModel(nn.Module):
) # [B, S, H] ) # [B, S, H]
# ----- 3. 池化量 ----- # ----- 3. 池化量 -----
for block in self.shared_resblocks: # for block in self.shared_resblocks:
encoded = block(encoded) # encoded = block(encoded)
pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1) pooled = self.pooler(encoded.transpose(1, 2)).squeeze(-1)
# pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2] # pooled = self.pooler(encoded.transpose(1, 2)) # [B, H, 2]
# pooled = pooled.flatten(1) # [B, H*2] # pooled = pooled.flatten(1) # [B, H*2]
@ -321,11 +319,11 @@ class MoEModel(nn.Module):
return_tensors="pt", return_tensors="pt",
) )
sample = {} sample = {}
sample['hint'] = { sample["hint"] = {
"input_ids": hint["input_ids"], "input_ids": hint["input_ids"],
"attention_mask": hint["attention_mask"], "attention_mask": hint["attention_mask"],
} }
sample['pg'] = torch.tensor([PG[py[0]]]) sample["pg"] = torch.tensor([PG[py[0]]])
return sample return sample
def predict(self, text, py, tokenizer=None): def predict(self, text, py, tokenizer=None):
@ -500,7 +498,7 @@ class MoEModel(nn.Module):
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}" f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
) )
batch_loss_sum = 0.0 batch_loss_sum = 0.0
if processed_batches >= stop_batch: if processed_batches + 1 >= stop_batch:
break break
global_step += 1 global_step += 1
@ -534,3 +532,5 @@ class MoEModel(nn.Module):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if name in freeze_layers: if name in freeze_layers:
param.requires_grad = False param.requires_grad = False