feat: 优化数据加载与训练逻辑,增加自定义学习率调度支持
This commit is contained in:
parent
982d0521d5
commit
54ac5af876
|
|
@ -1,6 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -36,7 +39,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
batch_query_size: int = 1000,
|
||||
# 打乱参数
|
||||
shuffle: bool = True,
|
||||
shuffle_buffer_size: int = 10000,
|
||||
shuffle_buffer_size: int = 100,
|
||||
# 削峰填谷参数
|
||||
max_freq: int = 434748359, # "的"的频率
|
||||
min_freq: int = 109, # "蓚"的频率
|
||||
|
|
@ -44,6 +47,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -411,6 +415,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not char_info:
|
||||
continue
|
||||
|
||||
logger.info(f"获取字符信息: {char_info}")
|
||||
# 削峰填谷调整
|
||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||
if adjust_factor <= 0:
|
||||
|
|
@ -441,7 +446,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
"char": char,
|
||||
"freq": char_info["freq"],
|
||||
"pg": torch.tensor(
|
||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
||||
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -480,8 +485,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
seed = base_seed + worker_id
|
||||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
batch_samples = []
|
||||
|
||||
for item in self.dataset:
|
||||
text = item.get(self.text_field, "")
|
||||
if not text:
|
||||
|
|
@ -527,21 +531,18 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 达到批量大小时处理
|
||||
if len(char_pinyin_batch) >= self.batch_query_size:
|
||||
batch_samples += self._process_batch(
|
||||
batch_samples = self._process_batch(
|
||||
char_pinyin_batch, char_positions, text
|
||||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
char_pinyin_batch = []
|
||||
char_positions = []
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
batch_samples = []
|
||||
# 处理剩余的字符
|
||||
if char_pinyin_batch:
|
||||
batch_samples += self._process_batch(
|
||||
batch_samples = self._process_batch(
|
||||
char_pinyin_batch, char_positions, text
|
||||
)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
yield from self._shuffle_and_yield(batch_samples)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
|
|
@ -581,7 +582,6 @@ def custom_collate_with_txt(batch):
|
|||
"char": [item["char"] for item in batch],
|
||||
"txt": [item["txt"] for item in batch],
|
||||
"py": [item["py"] for item in batch],
|
||||
"pg": torch.cat([item["pg"] for item in batch]),
|
||||
}
|
||||
|
||||
return result
|
||||
|
|
@ -602,7 +602,71 @@ def custom_collate(batch):
|
|||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||
},
|
||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||
"pg": torch.cat([item["pg"] for item in batch]),
|
||||
"py": [item["py"] for item in batch],
|
||||
# "py_group_id": [item["py"] for item in batch],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
from query import QueryEngine
|
||||
from tqdm import tqdm
|
||||
|
||||
# 初始化查询引擎
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
# 创建数据集
|
||||
dataset = PinyinInputDataset(
|
||||
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
||||
query_engine=query_engine,
|
||||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||
max_len=88,
|
||||
batch_query_size=300,
|
||||
shuffle=True,
|
||||
shuffle_buffer_size=4000,
|
||||
)
|
||||
|
||||
logger.info("数据集初始化")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1024,
|
||||
num_workers=15,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True if torch.cuda.is_available() else False,
|
||||
collate_fn=custom_collate_with_txt,
|
||||
prefetch_factor=8,
|
||||
persistent_workers=True,
|
||||
shuffle=False, # 数据集内部已实现打乱
|
||||
)
|
||||
|
||||
"""import cProfile
|
||||
|
||||
def profile_func(dataloader):
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
return
|
||||
|
||||
|
||||
cProfile.run('profile_func(dataloader)')
|
||||
|
||||
"""
|
||||
# 测试数据集
|
||||
try:
|
||||
logger.info("测试数据集")
|
||||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||
if i >= 3000:
|
||||
break
|
||||
"""
|
||||
print(f"Sample {i+1}:")
|
||||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||
print(f" Pinyin: {sample['py']}")
|
||||
print(f" Context length: {len(sample['txt'])}")
|
||||
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
||||
print()
|
||||
"""
|
||||
except StopIteration:
|
||||
print("数据集为空")
|
||||
|
|
|
|||
|
|
@ -1,23 +1,21 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.amp as amp
|
||||
from modelscope import AutoModel
|
||||
import pickle
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torch.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
|
||||
from modelscope import AutoModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .monitor import TrainingMonitor
|
||||
|
||||
|
||||
EVAL_DATALOADER = (
|
||||
pickle.load(file.open('rb'))
|
||||
EVAL_DATALOADER = [
|
||||
pickle.load(file.open("rb"))
|
||||
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
|
|
@ -232,8 +230,8 @@ class MoEModel(nn.Module):
|
|||
with torch.no_grad():
|
||||
for batch in eval_dataloader:
|
||||
# 移动数据到模型设备
|
||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
|
|
@ -355,93 +353,108 @@ class MoEModel(nn.Module):
|
|||
optimizer=None,
|
||||
scheduler=None,
|
||||
num_epochs=1,
|
||||
eval_frequency=1000,
|
||||
eval_frequency=500,
|
||||
grad_accum_steps=1,
|
||||
clip_grad_norm=1.0,
|
||||
mixed_precision=False,
|
||||
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||
):
|
||||
"""
|
||||
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||
|
||||
参数:
|
||||
train_dataloader: DataLoader,训练数据
|
||||
eval_dataloader: DataLoader,验证数据(可选)
|
||||
monitor: TrainingMonitor 实例,用于实时绘图
|
||||
criterion: 损失函数
|
||||
optimizer: 优化器,默认 AdamW(lr=6e-6)
|
||||
scheduler: 学习率调度器
|
||||
num_epochs: 训练轮数
|
||||
eval_frequency: 评估间隔(步数)
|
||||
grad_accum_steps: 梯度累积步数
|
||||
clip_grad_norm: 梯度裁剪范数
|
||||
mixed_precision: 是否启用混合精度
|
||||
... 原有参数 ...
|
||||
lr_schedule : callable, optional
|
||||
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
||||
可在内部直接修改 optimizer.param_groups 中的学习率。
|
||||
若为 None,则启用内置的固定阈值调度(前1000批 1e-4,之后 6e-6)。
|
||||
"""
|
||||
# 确保模型在正确的设备上
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
|
||||
# 切换到训练模式(调用父类方法)
|
||||
# 切换到训练模式
|
||||
super().train()
|
||||
|
||||
# 默认优化器
|
||||
if optimizer is None:
|
||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
|
||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
|
||||
created_optimizer = True
|
||||
else:
|
||||
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
|
||||
|
||||
# 混合精度缩放器
|
||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||
|
||||
global_step = 0
|
||||
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
|
||||
batch_loss_sum = 0.0
|
||||
optimizer.zero_grad()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(train_dataloader, total=1e6)
|
||||
):
|
||||
# 移动数据
|
||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
||||
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
|
||||
# ---------- 更新 batch 计数器 ----------
|
||||
processed_batches += 1
|
||||
|
||||
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||
if created_optimizer and lr_schedule is None:
|
||||
if processed_batches <= 1000:
|
||||
new_lr = 1e-4
|
||||
else:
|
||||
new_lr = 6e-6
|
||||
# 为所有参数组统一设置学习率
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
elif lr_schedule is not None:
|
||||
# 调用用户自定义的调度函数
|
||||
lr_schedule(processed_batches, optimizer)
|
||||
|
||||
# ---------- 移动数据 ----------
|
||||
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||
pg = batch["pg"].to(self.device)
|
||||
labels = batch["char_id"].to(self.device)
|
||||
|
||||
# 混合精度前向
|
||||
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
|
||||
with amp.autocast(
|
||||
device_type=self.device.type, enabled=mixed_precision
|
||||
):
|
||||
logits = self(input_ids, attention_mask, pg)
|
||||
loss = criterion(logits, labels)
|
||||
loss = loss / grad_accum_steps # 梯度累积归一化
|
||||
loss = loss / grad_accum_steps
|
||||
|
||||
# 反向传播(缩放)
|
||||
# 反向传播
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 梯度累积:每 grad_accum_steps 步更新一次
|
||||
# 梯度累积
|
||||
if (batch_idx + 1) % grad_accum_steps == 0:
|
||||
scaler.unscale_(optimizer) # 用于梯度裁剪
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
original_loss = loss.item() * grad_accum_steps
|
||||
batch_loss_sum += original_loss
|
||||
|
||||
# 周期性评估
|
||||
# 周期性评估(与原代码相同)
|
||||
if (
|
||||
eval_dataloader is not None
|
||||
and global_step % eval_frequency == 0
|
||||
):
|
||||
avg_loss = batch_loss_sum / global_step
|
||||
acc, _ = self.model_eval(eval_dataloader, criterion)
|
||||
super().train()
|
||||
if monitor is not None:
|
||||
monitor.add_step(
|
||||
global_step,
|
||||
{"loss": loss.item() * grad_accum_steps, "acc": acc},
|
||||
{"loss": avg_loss, "acc": acc},
|
||||
)
|
||||
logger.info({"loss": loss.item() * grad_accum_steps, "acc": acc})
|
||||
|
||||
elif monitor is not None:
|
||||
# 仅记录训练损失
|
||||
monitor.add_step(
|
||||
global_step, {"loss": loss.item() * grad_accum_steps}
|
||||
)
|
||||
logger.info({"loss": loss.item() * grad_accum_steps})
|
||||
logger.info(
|
||||
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc}"
|
||||
)
|
||||
batch_loss_sum = 0.0
|
||||
|
||||
|
||||
# ============================ 使用示例 ============================
|
||||
|
|
|
|||
Loading…
Reference in New Issue