feat: 优化数据加载器配置并新增模型评估与预测功能
This commit is contained in:
parent
834872dc0b
commit
c3c6f69532
11
example.py
11
example.py
|
|
@ -27,10 +27,10 @@ if __name__ == "__main__":
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1024,
|
batch_size=1024,
|
||||||
num_workers=15,
|
num_workers=1,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
pin_memory=True if torch.cuda.is_available() else False,
|
pin_memory=True if torch.cuda.is_available() else False,
|
||||||
collate_fn=custom_collate,
|
collate_fn=custom_collate_with_txt,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
shuffle=False, # 数据集内部已实现打乱
|
shuffle=False, # 数据集内部已实现打乱
|
||||||
|
|
@ -51,11 +51,12 @@ if __name__ == "__main__":
|
||||||
# 测试数据集
|
# 测试数据集
|
||||||
try:
|
try:
|
||||||
logger.info("测试数据集")
|
logger.info("测试数据集")
|
||||||
total = 3000
|
total = 20
|
||||||
for i, sample in tqdm(enumerate(dataloader), total=total):
|
for i, sample in tqdm(enumerate(dataloader), total=20):
|
||||||
if i >= total:
|
if i >= total:
|
||||||
break
|
break
|
||||||
#print(f"Sample {i+1}: {sample['txt'][0:10]}")
|
|
||||||
|
print(f"Sample {i+1}: {sample['txt'][0:10]}")
|
||||||
"""
|
"""
|
||||||
print(f"Sample {i+1}:")
|
print(f"Sample {i+1}:")
|
||||||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bokeh>=3.8.2",
|
"bokeh>=3.8.2",
|
||||||
"datasets>=4.5.0",
|
"datasets>=4.5.0",
|
||||||
|
"ipykernel>=7.2.0",
|
||||||
|
"ipython>=9.10.0",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
"modelscope>=1.34.0",
|
"modelscope>=1.34.0",
|
||||||
"msgpack>=1.1.2",
|
"msgpack>=1.1.2",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
from loguru import logger
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt
|
||||||
|
from suinput.query import QueryEngine
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 初始化查询引擎
|
||||||
|
query_engine = QueryEngine()
|
||||||
|
query_engine.load()
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
dataset = PinyinInputDataset(
|
||||||
|
data_dir="/home/songsenand/DataSet/data",
|
||||||
|
query_engine=query_engine,
|
||||||
|
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
|
max_len=88,
|
||||||
|
batch_query_size=300,
|
||||||
|
shuffle=True,
|
||||||
|
shuffle_buffer_size=4000,
|
||||||
|
)
|
||||||
|
logger.info("数据集初始化")
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=1,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
pin_memory=True if torch.cuda.is_available() else False,
|
||||||
|
collate_fn=custom_collate_with_txt,
|
||||||
|
prefetch_factor=8,
|
||||||
|
persistent_workers=True,
|
||||||
|
shuffle=False, # 数据集内部已实现打乱
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
total = 5
|
||||||
|
for i, sample in tqdm(enumerate(dataloader), total=5):
|
||||||
|
if i >= total:
|
||||||
|
break
|
||||||
|
print(sample)
|
||||||
|
# pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb"))
|
||||||
|
except StopIteration:
|
||||||
|
print("数据集为空")
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -3,17 +3,19 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.amp as amp
|
import torch.amp as amp
|
||||||
from transformers import AutoModel
|
from modelscope import AutoModel
|
||||||
|
import pickle
|
||||||
|
from importlib.resources import files
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .monitor import TrainingMonitor
|
from .monitor import TrainingMonitor
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 工具函数 ----------------------------
|
EVAL_DATALOADER = (
|
||||||
def round_to_power_of_two(x):
|
pickle.load(file.open('rb'))
|
||||||
"""将数字向上取整为2的幂(此处固定返回1024)"""
|
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
||||||
return 1024 # 根据您的说明固定 d_model = 1024
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
|
@ -98,6 +100,7 @@ class MoEModel(nn.Module):
|
||||||
self.embedding = bert.embeddings
|
self.embedding = bert.embeddings
|
||||||
self.bert_config = bert.config
|
self.bert_config = bert.config
|
||||||
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度
|
||||||
|
self.device = None # 将在 to() 调用时设置
|
||||||
|
|
||||||
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
# 2. 4 层标准 Transformer Encoder(从 config 读取参数)
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
|
@ -136,6 +139,11 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
# 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""重写 to 方法,记录设备"""
|
||||||
|
self.device = device
|
||||||
|
return super().to(device)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, pg):
|
def forward(self, input_ids, attention_mask, pg):
|
||||||
"""
|
"""
|
||||||
input_ids : [batch, seq_len]
|
input_ids : [batch, seq_len]
|
||||||
|
|
@ -200,35 +208,234 @@ class MoEModel(nn.Module):
|
||||||
return probs
|
return probs
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def predict(self, inputs, labels):
|
def model_eval(self, eval_dataloader, criterion=None):
|
||||||
pass
|
"""
|
||||||
|
在验证集上评估模型,返回准确率和平均损失。
|
||||||
|
|
||||||
def train(
|
参数:
|
||||||
self,
|
eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id'
|
||||||
dataloader,
|
criterion: 损失函数,默认为 CrossEntropyLoss()
|
||||||
monitor: TrainingMonitor,
|
返回:
|
||||||
criterion = nn.CrossEntropyLoss(),
|
accuracy: float, 准确率
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
avg_loss: float, 平均损失
|
||||||
optimizer=None,
|
"""
|
||||||
sample_frequency=1000,
|
if criterion is None:
|
||||||
):
|
criterion = nn.CrossEntropyLoss()
|
||||||
self.train()
|
|
||||||
if optimizer is None:
|
|
||||||
optimizer = optim.AdamW(self.parameters(), lr=6e-6)
|
|
||||||
|
|
||||||
for i, sample in tqdm(enumerate(dataloader), total=1e7):
|
self.eval()
|
||||||
optimizer.zero_grad()
|
total_loss = 0.0
|
||||||
with amp.autocast():
|
correct = 0
|
||||||
char_id = sample.pop("char_id").to(device)
|
total = 0
|
||||||
input_ids = sample.pop("input_ids").to(device)
|
|
||||||
attention_mask = sample.pop("attention_mask").to(device)
|
with torch.no_grad():
|
||||||
pg = sample.pop("pg").to(device)
|
for batch in eval_dataloader:
|
||||||
|
# 移动数据到模型设备
|
||||||
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["attention_mask"].to(self.device)
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
logits = self(input_ids, attention_mask, pg)
|
logits = self(input_ids, attention_mask, pg)
|
||||||
loss = criterion(logits, char_id)
|
loss = criterion(logits, labels)
|
||||||
loss.backward()
|
total_loss += loss.item() * labels.size(0)
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
preds = logits.argmax(dim=-1)
|
||||||
|
correct += (preds == labels).sum().item()
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total if total > 0 else 0.0
|
||||||
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
|
return accuracy, avg_loss
|
||||||
|
|
||||||
|
def predict(self, sample, debug=False):
|
||||||
|
"""
|
||||||
|
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
sample : dict
|
||||||
|
必须包含字段:
|
||||||
|
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
|
||||||
|
- 'attention_mask': 同上
|
||||||
|
- 'pg' : [batch] 或标量
|
||||||
|
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
|
||||||
|
调试时(debug=True)必须包含字段:
|
||||||
|
- 'txt' : 字符串列表(batch)或单个字符串
|
||||||
|
- 'char' : 字符串列表(batch)或单个字符串
|
||||||
|
- 'py' : 字符串列表(batch)或单个字符串
|
||||||
|
debug : bool
|
||||||
|
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
preds : torch.Tensor
|
||||||
|
[batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量)
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# ------------------ 1. 提取并规范化输入 ------------------
|
||||||
|
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
attention_mask = sample["attention_mask"]
|
||||||
|
pg = sample["pg"]
|
||||||
|
has_batch_dim = input_ids.dim() > 1
|
||||||
|
|
||||||
|
if not has_batch_dim:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
attention_mask = attention_mask.unsqueeze(0)
|
||||||
|
if pg.dim() == 0:
|
||||||
|
pg = pg.unsqueeze(0).expand(input_ids.size(0))
|
||||||
|
|
||||||
|
# ------------------ 2. 移动设备 ------------------
|
||||||
|
input_ids = input_ids.to(self.device)
|
||||||
|
attention_mask = attention_mask.to(self.device)
|
||||||
|
pg = pg.to(self.device)
|
||||||
|
|
||||||
|
# ------------------ 3. 推理 ------------------
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self(input_ids, attention_mask, pg)
|
||||||
|
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
||||||
|
|
||||||
|
# ------------------ 4. 调试打印(错误样本) ------------------
|
||||||
|
if debug:
|
||||||
|
# 检查必需字段
|
||||||
|
required_keys = ["char_id", "txt", "char", "py"]
|
||||||
|
missing = [k for k in required_keys if k not in sample]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
|
||||||
|
|
||||||
|
# 提取真实标签
|
||||||
|
true_labels = sample["char_id"]
|
||||||
|
if true_labels.dim() == 0:
|
||||||
|
true_labels = true_labels.unsqueeze(0)
|
||||||
|
|
||||||
|
# 移动真实标签到相同设备
|
||||||
|
true_labels = true_labels.to(self.device)
|
||||||
|
|
||||||
|
# 找出预测错误的索引
|
||||||
|
incorrect_mask = preds != true_labels
|
||||||
|
incorrect_indices = torch.where(incorrect_mask)[0]
|
||||||
|
|
||||||
|
if len(incorrect_indices) > 0:
|
||||||
|
print("\n=== 预测错误样本 ===")
|
||||||
|
# 获取调试字段(可能是列表或单个字符串)
|
||||||
|
txts = sample["txt"]
|
||||||
|
chars = sample["char"]
|
||||||
|
pys = sample["py"]
|
||||||
|
|
||||||
|
# 统一转换为列表(如果输入是单个字符串)
|
||||||
|
if isinstance(txts, str):
|
||||||
|
txts = [txts]
|
||||||
|
chars = [chars]
|
||||||
|
pys = [pys]
|
||||||
|
|
||||||
|
for idx in incorrect_indices.cpu().numpy():
|
||||||
|
print(f"样本索引 {idx}:")
|
||||||
|
print(f" Text : {txts[idx]}")
|
||||||
|
print(f" Char : {chars[idx]}")
|
||||||
|
print(f" Pinyin: {pys[idx]}")
|
||||||
|
print(
|
||||||
|
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
|
||||||
|
)
|
||||||
|
print("===================\n")
|
||||||
|
|
||||||
|
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
|
||||||
|
if not has_batch_dim:
|
||||||
|
return preds.squeeze(0) # 返回标量
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
monitor: TrainingMonitor = None,
|
||||||
|
criterion=nn.CrossEntropyLoss(),
|
||||||
|
optimizer=None,
|
||||||
|
scheduler=None,
|
||||||
|
num_epochs=1,
|
||||||
|
eval_frequency=1000,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
clip_grad_norm=1.0,
|
||||||
|
mixed_precision=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
train_dataloader: DataLoader,训练数据
|
||||||
|
eval_dataloader: DataLoader,验证数据(可选)
|
||||||
|
monitor: TrainingMonitor 实例,用于实时绘图
|
||||||
|
criterion: 损失函数
|
||||||
|
optimizer: 优化器,默认 AdamW(lr=6e-6)
|
||||||
|
scheduler: 学习率调度器
|
||||||
|
num_epochs: 训练轮数
|
||||||
|
eval_frequency: 评估间隔(步数)
|
||||||
|
grad_accum_steps: 梯度累积步数
|
||||||
|
clip_grad_norm: 梯度裁剪范数
|
||||||
|
mixed_precision: 是否启用混合精度
|
||||||
|
"""
|
||||||
|
# 确保模型在正确的设备上
|
||||||
|
if self.device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
# 切换到训练模式(调用父类方法)
|
||||||
|
super().train()
|
||||||
|
|
||||||
|
# 默认优化器
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
# 混合精度缩放器
|
||||||
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
for batch_idx, batch in enumerate(
|
||||||
|
tqdm(train_dataloader, total=1e6)
|
||||||
|
):
|
||||||
|
# 移动数据
|
||||||
|
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||||
|
pg = batch["pg"].to(self.device)
|
||||||
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
# 混合精度前向
|
||||||
|
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
|
||||||
|
logits = self(input_ids, attention_mask, pg)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
loss = loss / grad_accum_steps # 梯度累积归一化
|
||||||
|
|
||||||
|
# 反向传播(缩放)
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# 梯度累积:每 grad_accum_steps 步更新一次
|
||||||
|
if (batch_idx + 1) % grad_accum_steps == 0:
|
||||||
|
scaler.unscale_(optimizer) # 用于梯度裁剪
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# 周期性评估
|
||||||
|
if (
|
||||||
|
eval_dataloader is not None
|
||||||
|
and global_step % eval_frequency == 0
|
||||||
|
):
|
||||||
|
acc, _ = self.model_eval(eval_dataloader, criterion)
|
||||||
|
if monitor is not None:
|
||||||
|
monitor.add_step(
|
||||||
|
global_step,
|
||||||
|
{"loss": loss.item() * grad_accum_steps, "acc": acc},
|
||||||
|
)
|
||||||
|
elif monitor is not None:
|
||||||
|
# 仅记录训练损失
|
||||||
|
monitor.add_step(
|
||||||
|
global_step, {"loss": loss.item() * grad_accum_steps}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================ 使用示例 ============================
|
# ============================ 使用示例 ============================
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue