feat: 优化数据加载器配置并新增模型评估与预测功能
This commit is contained in:
parent
834872dc0b
commit
c3c6f69532
11
example.py
11
example.py
|
|
@ -27,10 +27,10 @@ if __name__ == "__main__":
|
|||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1024,
|
||||
num_workers=15,
|
||||
num_workers=1,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True if torch.cuda.is_available() else False,
|
||||
collate_fn=custom_collate,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
shuffle=False, # 数据集内部已实现打乱
|
||||
|
|
@ -51,11 +51,12 @@ if __name__ == "__main__":
|
|||
# 测试数据集
|
||||
try:
|
||||
logger.info("测试数据集")
|
||||
total = 3000
|
||||
for i, sample in tqdm(enumerate(dataloader), total=total):
|
||||
total = 20
|
||||
for i, sample in tqdm(enumerate(dataloader), total=20):
|
||||
if i >= total:
|
||||
break
|
||||
#print(f"Sample {i+1}: {sample['txt'][0:10]}")
|
||||
|
||||
print(f"Sample {i+1}: {sample['txt'][0:10]}")
|
||||
"""
|
||||
print(f"Sample {i+1}:")
|
||||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ requires-python = ">=3.13"
|
|||
dependencies = [
|
||||
"bokeh>=3.8.2",
|
||||
"datasets>=4.5.0",
|
||||
"ipykernel>=7.2.0",
|
||||
"ipython>=9.10.0",
|
||||
"loguru>=0.7.3",
|
||||
"modelscope>=1.34.0",
|
||||
"msgpack>=1.1.2",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt
|
||||
from suinput.query import QueryEngine
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 初始化查询引擎
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
data_dir="/home/songsenand/DataSet/data",
|
||||
query_engine=query_engine,
|
||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||
max_len=88,
|
||||
batch_query_size=300,
|
||||
shuffle=True,
|
||||
shuffle_buffer_size=4000,
|
||||
)
|
||||
logger.info("数据集初始化")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
num_workers=1,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True if torch.cuda.is_available() else False,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
shuffle=False, # 数据集内部已实现打乱
|
||||
)
|
||||
try:
|
||||
total = 5
|
||||
for i, sample in tqdm(enumerate(dataloader), total=5):
|
||||
if i >= total:
|
||||
break
|
||||
print(sample)
|
||||
# pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb"))
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -3,17 +3,19 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.amp as amp
|
||||
from transformers import AutoModel
|
||||
from modelscope import AutoModel
|
||||
import pickle
|
||||
from importlib.resources import files
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .monitor import TrainingMonitor
|
||||
|
||||
|
||||
# ---------------------------- 工具函数 ----------------------------
|
||||
def round_to_power_of_two(x):
|
||||
"""将数字向上取整为2的幂(此处固定返回1024)"""
|
||||
return 1024 # 根据您的说明固定 d_model = 1024
|
||||
EVAL_DATALOADER = (
|
||||
pickle.load(file.open('rb'))
|
||||
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
|
|
@ -98,6 +100,7 @@ class MoEModel(nn.Module):
|
|||
self.embedding = bert.embeddings
|
||||
self.bert_config = bert.config
|
||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||
self.device = None # 将在 to() 调用时设置
|
||||
|
||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
|
|
@ -136,6 +139,11 @@ class MoEModel(nn.Module):
|
|||
|
||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||
|
||||
def to(self, device):
|
||||
"""重写 to 方法,记录设备"""
|
||||
self.device = device
|
||||
return super().to(device)
|
||||
|
||||
def forward(self, input_ids, attention_mask, pg):
|
||||
"""
|
||||
input_ids : [batch, seq_len]
|
||||
|
|
@ -200,35 +208,234 @@ class MoEModel(nn.Module):
|
|||
return probs
|
||||
return logits
|
||||
|
||||
def predict(self, inputs, labels):
|
||||
pass
|
||||
def model_eval(self, eval_dataloader, criterion=None):
|
||||
"""
|
||||
在验证集上评估模型,返回准确率和平均损失。
|
||||
|
||||
def train(
|
||||
self,
|
||||
dataloader,
|
||||
monitor: TrainingMonitor,
|
||||
criterion = nn.CrossEntropyLoss(),
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
optimizer=None,
|
||||
sample_frequency=1000,
|
||||
):
|
||||
self.train()
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=6e-6)
|
||||
参数:
|
||||
eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id'
|
||||
criterion: 损失函数,默认为 CrossEntropyLoss()
|
||||
返回:
|
||||
accuracy: float, 准确率
|
||||
avg_loss: float, 平均损失
|
||||
"""
|
||||
if criterion is None:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
for i, sample in tqdm(enumerate(dataloader), total=1e7):
|
||||
optimizer.zero_grad()
|
||||
with amp.autocast():
|
||||
char_id = sample.pop("char_id").to(device)
|
||||
input_ids = sample.pop("input_ids").to(device)
|
||||
attention_mask = sample.pop("attention_mask").to(device)
|
||||
pg = sample.pop("pg").to(device)
|
||||
self.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in eval_dataloader:
|
||||
# 移动数据到模型设备
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
attention_mask = batch["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
# 前向传播
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
loss = criterion(logits, char_id)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss = criterion(logits, labels)
|
||||
total_loss += loss.item() * labels.size(0)
|
||||
|
||||
# 计算准确率
|
||||
preds = logits.argmax(dim=-1)
|
||||
correct += (preds == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
avg_loss = total_loss / total if total > 0 else 0.0
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
return accuracy, avg_loss
|
||||
|
||||
def predict(self, sample, debug=False):
|
||||
"""
|
||||
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||
|
||||
参数:
|
||||
sample : dict
|
||||
必须包含字段:
|
||||
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
|
||||
- 'attention_mask': 同上
|
||||
- 'pg' : [batch] 或标量
|
||||
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
|
||||
调试时(debug=True)必须包含字段:
|
||||
- 'txt' : 字符串列表(batch)或单个字符串
|
||||
- 'char' : 字符串列表(batch)或单个字符串
|
||||
- 'py' : 字符串列表(batch)或单个字符串
|
||||
debug : bool
|
||||
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。
|
||||
|
||||
返回:
|
||||
preds : torch.Tensor
|
||||
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# ------------------ 1. 提取并规范化输入 ------------------
|
||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||
input_ids = sample["input_ids"]
|
||||
attention_mask = sample["attention_mask"]
|
||||
pg = sample["pg"]
|
||||
has_batch_dim = input_ids.dim() > 1
|
||||
|
||||
if not has_batch_dim:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
attention_mask = attention_mask.unsqueeze(0)
|
||||
if pg.dim() == 0:
|
||||
pg = pg.unsqueeze(0).expand(input_ids.size(0))
|
||||
|
||||
# ------------------ 2. 移动设备 ------------------
|
||||
input_ids = input_ids.to(self.device)
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
pg = pg.to(self.device)
|
||||
|
||||
# ------------------ 3. 推理 ------------------
|
||||
with torch.no_grad():
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
||||
|
||||
# ------------------ 4. 调试打印(错误样本) ------------------
|
||||
if debug:
|
||||
# 检查必需字段
|
||||
required_keys = ["char_id", "txt", "char", "py"]
|
||||
missing = [k for k in required_keys if k not in sample]
|
||||
if missing:
|
||||
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
|
||||
|
||||
# 提取真实标签
|
||||
true_labels = sample["char_id"]
|
||||
if true_labels.dim() == 0:
|
||||
true_labels = true_labels.unsqueeze(0)
|
||||
|
||||
# 移动真实标签到相同设备
|
||||
true_labels = true_labels.to(self.device)
|
||||
|
||||
# 找出预测错误的索引
|
||||
incorrect_mask = preds != true_labels
|
||||
incorrect_indices = torch.where(incorrect_mask)[0]
|
||||
|
||||
if len(incorrect_indices) > 0:
|
||||
print("\n=== 预测错误样本 ===")
|
||||
# 获取调试字段(可能是列表或单个字符串)
|
||||
txts = sample["txt"]
|
||||
chars = sample["char"]
|
||||
pys = sample["py"]
|
||||
|
||||
# 统一转换为列表(如果输入是单个字符串)
|
||||
if isinstance(txts, str):
|
||||
txts = [txts]
|
||||
chars = [chars]
|
||||
pys = [pys]
|
||||
|
||||
for idx in incorrect_indices.cpu().numpy():
|
||||
print(f"样本索引 {idx}:")
|
||||
print(f" Text : {txts[idx]}")
|
||||
print(f" Char : {chars[idx]}")
|
||||
print(f" Pinyin: {pys[idx]}")
|
||||
print(
|
||||
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
|
||||
)
|
||||
print("===================\n")
|
||||
|
||||
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
|
||||
if not has_batch_dim:
|
||||
return preds.squeeze(0) # 返回标量
|
||||
return preds
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader,
|
||||
eval_dataloader=None,
|
||||
monitor: TrainingMonitor = None,
|
||||
criterion=nn.CrossEntropyLoss(),
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
num_epochs=1,
|
||||
eval_frequency=1000,
|
||||
grad_accum_steps=1,
|
||||
clip_grad_norm=1.0,
|
||||
mixed_precision=False,
|
||||
):
|
||||
"""
|
||||
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||
|
||||
参数:
|
||||
train_dataloader: DataLoader,训练数据
|
||||
eval_dataloader: DataLoader,验证数据(可选)
|
||||
monitor: TrainingMonitor 实例,用于实时绘图
|
||||
criterion: 损失函数
|
||||
optimizer: 优化器,默认 AdamW(lr=6e-6)
|
||||
scheduler: 学习率调度器
|
||||
num_epochs: 训练轮数
|
||||
eval_frequency: 评估间隔(步数)
|
||||
grad_accum_steps: 梯度累积步数
|
||||
clip_grad_norm: 梯度裁剪范数
|
||||
mixed_precision: 是否启用混合精度
|
||||
"""
|
||||
# 确保模型在正确的设备上
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
|
||||
# 切换到训练模式(调用父类方法)
|
||||
super().train()
|
||||
|
||||
# 默认优化器
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
|
||||
|
||||
# 混合精度缩放器
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
global_step = 0
|
||||
optimizer.zero_grad()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=1e6)
|
||||
):
|
||||
# 移动数据
|
||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
# 混合精度前向
|
||||
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
loss = criterion(logits, labels)
|
||||
loss = loss / grad_accum_steps # 梯度累积归一化
|
||||
|
||||
# 反向传播(缩放)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度累积:每 grad_accum_steps 步更新一次
|
||||
if (batch_idx + 1) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer) # 用于梯度裁剪
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# 周期性评估
|
||||
if (
|
||||
eval_dataloader is not None
|
||||
and global_step % eval_frequency == 0
|
||||
):
|
||||
acc, _ = self.model_eval(eval_dataloader, criterion)
|
||||
if monitor is not None:
|
||||
monitor.add_step(
|
||||
global_step,
|
||||
{"loss": loss.item() * grad_accum_steps, "acc": acc},
|
||||
)
|
||||
elif monitor is not None:
|
||||
# 仅记录训练损失
|
||||
monitor.add_step(
|
||||
global_step, {"loss": loss.item() * grad_accum_steps}
|
||||
)
|
||||
|
||||
|
||||
# ============================ 使用示例 ============================
|
||||
|
|
|
|||
Loading…
Reference in New Issue