feat: 优化数据加载器配置并新增模型评估与预测功能

This commit is contained in:
songsenand 2026-02-13 00:57:21 +08:00
parent 834872dc0b
commit c3c6f69532
9 changed files with 292 additions and 34 deletions

View File

@ -27,10 +27,10 @@ if __name__ == "__main__":
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=15,
num_workers=1,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
@ -51,11 +51,12 @@ if __name__ == "__main__":
# 测试数据集
try:
logger.info("测试数据集")
total = 3000
for i, sample in tqdm(enumerate(dataloader), total=total):
total = 20
for i, sample in tqdm(enumerate(dataloader), total=20):
if i >= total:
break
#print(f"Sample {i+1}: {sample['txt'][0:10]}")
print(f"Sample {i+1}: {sample['txt'][0:10]}")
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")

View File

@ -7,6 +7,8 @@ requires-python = ">=3.13"
dependencies = [
"bokeh>=3.8.2",
"datasets>=4.5.0",
"ipykernel>=7.2.0",
"ipython>=9.10.0",
"loguru>=0.7.3",
"modelscope>=1.34.0",
"msgpack>=1.1.2",

View File

@ -0,0 +1,48 @@
from tqdm import tqdm
from loguru import logger
import torch
from torch.utils.data import DataLoader
import pickle
from pathlib import Path
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt
from suinput.query import QueryEngine
# 使用示例
if __name__ == "__main__":
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/DataSet/data",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=1,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
try:
total = 5
for i, sample in tqdm(enumerate(dataloader), total=5):
if i >= total:
break
print(sample)
# pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb"))
except StopIteration:
print("数据集为空")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -3,17 +3,19 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.amp as amp
from transformers import AutoModel
from modelscope import AutoModel
import pickle
from importlib.resources import files
from tqdm import tqdm
from .monitor import TrainingMonitor
# ---------------------------- 工具函数 ----------------------------
def round_to_power_of_two(x):
"""将数字向上取整为2的幂此处固定返回1024"""
return 1024 # 根据您的说明固定 d_model = 1024
EVAL_DATALOADER = (
pickle.load(file.open('rb'))
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
)
# ---------------------------- 残差块 ----------------------------
@ -98,6 +100,7 @@ class MoEModel(nn.Module):
self.embedding = bert.embeddings
self.bert_config = bert.config
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
self.device = None # 将在 to() 调用时设置
# 2. 4 层标准 Transformer Encoder从 config 读取参数)
encoder_layer = nn.TransformerEncoderLayer(
@ -136,6 +139,11 @@ class MoEModel(nn.Module):
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
def to(self, device):
"""重写 to 方法,记录设备"""
self.device = device
return super().to(device)
def forward(self, input_ids, attention_mask, pg):
"""
input_ids : [batch, seq_len]
@ -200,35 +208,234 @@ class MoEModel(nn.Module):
return probs
return logits
def predict(self, inputs, labels):
pass
def model_eval(self, eval_dataloader, criterion=None):
"""
在验证集上评估模型返回准确率和平均损失
def train(
self,
dataloader,
monitor: TrainingMonitor,
criterion = nn.CrossEntropyLoss(),
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
optimizer=None,
sample_frequency=1000,
):
self.train()
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=6e-6)
参数
eval_dataloader: DataLoader提供 'input_ids', 'attention_mask', 'pg', 'char_id'
criterion: 损失函数默认为 CrossEntropyLoss()
返回
accuracy: float, 准确率
avg_loss: float, 平均损失
"""
if criterion is None:
criterion = nn.CrossEntropyLoss()
for i, sample in tqdm(enumerate(dataloader), total=1e7):
optimizer.zero_grad()
with amp.autocast():
char_id = sample.pop("char_id").to(device)
input_ids = sample.pop("input_ids").to(device)
attention_mask = sample.pop("attention_mask").to(device)
pg = sample.pop("pg").to(device)
self.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 前向传播
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, char_id)
loss.backward()
optimizer.step()
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def predict(self, sample, debug=False):
"""
基于 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数
sample : dict
必须包含字段
- 'input_ids' : [batch, seq_len] [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量真实标签 debug=True 时必须提供
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool
是否打印预测错误的样本信息若为 True sample 缺少 char_id/txt/char/py抛出 ValueError
返回
preds : torch.Tensor
[batch] 预测类别标签若输入为单样本且无 batch 维度则返回标量
"""
self.eval()
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
pg = sample["pg"]
has_batch_dim = input_ids.dim() > 1
if not has_batch_dim:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
if pg.dim() == 0:
pg = pg.unsqueeze(0).expand(input_ids.size(0))
# ------------------ 2. 移动设备 ------------------
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
pg = pg.to(self.device)
# ------------------ 3. 推理 ------------------
with torch.no_grad():
logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds
def fit(
self,
train_dataloader,
eval_dataloader=None,
monitor: TrainingMonitor = None,
criterion=nn.CrossEntropyLoss(),
optimizer=None,
scheduler=None,
num_epochs=1,
eval_frequency=1000,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
):
"""
训练模型支持混合精度梯度累积学习率调度实时监控
参数
train_dataloader: DataLoader训练数据
eval_dataloader: DataLoader验证数据可选
monitor: TrainingMonitor 实例用于实时绘图
criterion: 损失函数
optimizer: 优化器默认 AdamW(lr=6e-6)
scheduler: 学习率调度器
num_epochs: 训练轮数
eval_frequency: 评估间隔步数
grad_accum_steps: 梯度累积步数
clip_grad_norm: 梯度裁剪范数
mixed_precision: 是否启用混合精度
"""
# 确保模型在正确的设备上
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式(调用父类方法)
super().train()
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=1e6)
):
# 移动数据
input_ids = batch['hint']["input_ids"].to(self.device)
attention_mask = batch['hint']["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps # 梯度累积归一化
# 反向传播(缩放)
scaler.scale(loss).backward()
# 梯度累积:每 grad_accum_steps 步更新一次
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer) # 用于梯度裁剪
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1
# 周期性评估
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
acc, _ = self.model_eval(eval_dataloader, criterion)
if monitor is not None:
monitor.add_step(
global_step,
{"loss": loss.item() * grad_accum_steps, "acc": acc},
)
elif monitor is not None:
# 仅记录训练损失
monitor.add_step(
global_step, {"loss": loss.item() * grad_accum_steps}
)
# ============================ 使用示例 ============================