feat: 优化数据加载与训练逻辑,增加自定义学习率调度支持

This commit is contained in:
songsenand 2026-02-13 10:48:17 +08:00
parent 982d0521d5
commit 54ac5af876
2 changed files with 139 additions and 62 deletions

View File

@ -1,6 +1,9 @@
import json
import os
import random
from typing import Any, Dict, List, Tuple, Optional
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
@ -36,7 +39,7 @@ class PinyinInputDataset(IterableDataset):
batch_query_size: int = 1000,
# 打乱参数
shuffle: bool = True,
shuffle_buffer_size: int = 10000,
shuffle_buffer_size: int = 100,
# 削峰填谷参数
max_freq: int = 434748359, # "的"的频率
min_freq: int = 109, # "蓚"的频率
@ -44,6 +47,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
py_group_json_file: Optional[Dict[str, int]] = None,
):
"""
初始化数据集
@ -411,6 +415,7 @@ class PinyinInputDataset(IterableDataset):
if not char_info:
continue
logger.info(f"获取字符信息: {char_info}")
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
@ -441,7 +446,7 @@ class PinyinInputDataset(IterableDataset):
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
),
}
@ -481,7 +486,6 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
batch_samples = []
for item in self.dataset:
text = item.get(self.text_field, "")
if not text:
@ -527,18 +531,15 @@ class PinyinInputDataset(IterableDataset):
# 达到批量大小时处理
if len(char_pinyin_batch) >= self.batch_query_size:
batch_samples += self._process_batch(
batch_samples = self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
char_pinyin_batch = []
char_positions = []
if len(batch_samples) >= self.shuffle_buffer_size:
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
yield from self._shuffle_and_yield(batch_samples)
batch_samples = []
# 处理剩余的字符
if char_pinyin_batch:
batch_samples += self._process_batch(
batch_samples = self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
@ -581,7 +582,6 @@ def custom_collate_with_txt(batch):
"char": [item["char"] for item in batch],
"txt": [item["txt"] for item in batch],
"py": [item["py"] for item in batch],
"pg": torch.cat([item["pg"] for item in batch]),
}
return result
@ -602,7 +602,71 @@ def custom_collate(batch):
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"pg": torch.cat([item["pg"] for item in batch]),
"py": [item["py"] for item in batch],
# "py_group_id": [item["py"] for item in batch],
}
return result
# 使用示例
if __name__ == "__main__":
from query import QueryEngine
from tqdm import tqdm
# 初始化查询引擎
query_engine = QueryEngine()
query_engine.load()
# 创建数据集
dataset = PinyinInputDataset(
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
query_engine=query_engine,
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
max_len=88,
batch_query_size=300,
shuffle=True,
shuffle_buffer_size=4000,
)
logger.info("数据集初始化")
dataloader = DataLoader(
dataset,
batch_size=1024,
num_workers=15,
worker_init_fn=worker_init_fn,
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=custom_collate_with_txt,
prefetch_factor=8,
persistent_workers=True,
shuffle=False, # 数据集内部已实现打乱
)
"""import cProfile
def profile_func(dataloader):
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
return
cProfile.run('profile_func(dataloader)')
"""
# 测试数据集
try:
logger.info("测试数据集")
for i, sample in tqdm(enumerate(dataloader), total=3000):
if i >= 3000:
break
"""
print(f"Sample {i+1}:")
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
print(f" Pinyin: {sample['py']}")
print(f" Context length: {len(sample['txt'])}")
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
print()
"""
except StopIteration:
print("数据集为空")

View File

@ -1,23 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.amp as amp
from modelscope import AutoModel
import pickle
from importlib.resources import files
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel
from tqdm import tqdm
from .monitor import TrainingMonitor
EVAL_DATALOADER = (
pickle.load(file.open('rb'))
EVAL_DATALOADER = [
pickle.load(file.open("rb"))
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
)
]
# ---------------------------- 残差块 ----------------------------
@ -232,8 +230,8 @@ class MoEModel(nn.Module):
with torch.no_grad():
for batch in eval_dataloader:
# 移动数据到模型设备
input_ids = batch['hint']["input_ids"].to(self.device)
attention_mask = batch['hint']["attention_mask"].to(self.device)
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
@ -355,93 +353,108 @@ class MoEModel(nn.Module):
optimizer=None,
scheduler=None,
num_epochs=1,
eval_frequency=1000,
eval_frequency=500,
grad_accum_steps=1,
clip_grad_norm=1.0,
mixed_precision=False,
lr_schedule=None, # 新增:可选的自定义学习率调度函数
):
"""
训练模型支持混合精度梯度累积学习率调度实时监控
参数
train_dataloader: DataLoader训练数据
eval_dataloader: DataLoader验证数据可选
monitor: TrainingMonitor 实例用于实时绘图
criterion: 损失函数
optimizer: 优化器默认 AdamW(lr=6e-6)
scheduler: 学习率调度器
num_epochs: 训练轮数
eval_frequency: 评估间隔步数
grad_accum_steps: 梯度累积步数
clip_grad_norm: 梯度裁剪范数
mixed_precision: 是否启用混合精度
... 原有参数 ...
lr_schedule : callable, optional
自定义学习率调度函数接收参数 (processed_batches, optimizer)
可在内部直接修改 optimizer.param_groups 中的学习率
若为 None则启用内置的固定阈值调度前1000批 1e-4之后 6e-6
"""
# 确保模型在正确的设备上
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
# 切换到训练模式(调用父类方法)
# 切换到训练模式
super().train()
# 默认优化器
if optimizer is None:
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
created_optimizer = True
else:
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
# 混合精度缩放器
scaler = amp.GradScaler(enabled=mixed_precision)
global_step = 0
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
batch_loss_sum = 0.0
optimizer.zero_grad()
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(
tqdm(train_dataloader, total=1e6)
):
# 移动数据
input_ids = batch['hint']["input_ids"].to(self.device)
attention_mask = batch['hint']["attention_mask"].to(self.device)
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
# ---------- 更新 batch 计数器 ----------
processed_batches += 1
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
if created_optimizer and lr_schedule is None:
if processed_batches <= 1000:
new_lr = 1e-4
else:
new_lr = 6e-6
# 为所有参数组统一设置学习率
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
elif lr_schedule is not None:
# 调用用户自定义的调度函数
lr_schedule(processed_batches, optimizer)
# ---------- 移动数据 ----------
input_ids = batch["hint"]["input_ids"].to(self.device)
attention_mask = batch["hint"]["attention_mask"].to(self.device)
pg = batch["pg"].to(self.device)
labels = batch["char_id"].to(self.device)
# 混合精度前向
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
with amp.autocast(
device_type=self.device.type, enabled=mixed_precision
):
logits = self(input_ids, attention_mask, pg)
loss = criterion(logits, labels)
loss = loss / grad_accum_steps # 梯度累积归一化
loss = loss / grad_accum_steps
# 反向传播(缩放)
# 反向传播
scaler.scale(loss).backward()
# 梯度累积:每 grad_accum_steps 步更新一次
# 梯度累积
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.unscale_(optimizer) # 用于梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
global_step += 1
original_loss = loss.item() * grad_accum_steps
batch_loss_sum += original_loss
# 周期性评估
# 周期性评估(与原代码相同)
if (
eval_dataloader is not None
and global_step % eval_frequency == 0
):
avg_loss = batch_loss_sum / global_step
acc, _ = self.model_eval(eval_dataloader, criterion)
super().train()
if monitor is not None:
monitor.add_step(
global_step,
{"loss": loss.item() * grad_accum_steps, "acc": acc},
{"loss": avg_loss, "acc": acc},
)
logger.info({"loss": loss.item() * grad_accum_steps, "acc": acc})
elif monitor is not None:
# 仅记录训练损失
monitor.add_step(
global_step, {"loss": loss.item() * grad_accum_steps}
logger.info(
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc}"
)
logger.info({"loss": loss.item() * grad_accum_steps})
batch_loss_sum = 0.0
# ============================ 使用示例 ============================