feat: 优化数据加载与训练逻辑,增加自定义学习率调度支持
This commit is contained in:
parent
982d0521d5
commit
54ac5af876
|
|
@ -1,6 +1,9 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple, Optional
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -36,7 +39,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
batch_query_size: int = 1000,
|
batch_query_size: int = 1000,
|
||||||
# 打乱参数
|
# 打乱参数
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
shuffle_buffer_size: int = 10000,
|
shuffle_buffer_size: int = 100,
|
||||||
# 削峰填谷参数
|
# 削峰填谷参数
|
||||||
max_freq: int = 434748359, # "的"的频率
|
max_freq: int = 434748359, # "的"的频率
|
||||||
min_freq: int = 109, # "蓚"的频率
|
min_freq: int = 109, # "蓚"的频率
|
||||||
|
|
@ -44,6 +47,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
|
py_group_json_file: Optional[Dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -411,6 +415,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not char_info:
|
if not char_info:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logger.info(f"获取字符信息: {char_info}")
|
||||||
# 削峰填谷调整
|
# 削峰填谷调整
|
||||||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||||||
if adjust_factor <= 0:
|
if adjust_factor <= 0:
|
||||||
|
|
@ -441,7 +446,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor(
|
||||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -481,7 +486,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
random.seed(seed % (2**32))
|
random.seed(seed % (2**32))
|
||||||
np.random.seed(seed % (2**32))
|
np.random.seed(seed % (2**32))
|
||||||
|
|
||||||
batch_samples = []
|
|
||||||
for item in self.dataset:
|
for item in self.dataset:
|
||||||
text = item.get(self.text_field, "")
|
text = item.get(self.text_field, "")
|
||||||
if not text:
|
if not text:
|
||||||
|
|
@ -527,21 +531,18 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 达到批量大小时处理
|
# 达到批量大小时处理
|
||||||
if len(char_pinyin_batch) >= self.batch_query_size:
|
if len(char_pinyin_batch) >= self.batch_query_size:
|
||||||
batch_samples += self._process_batch(
|
batch_samples = self._process_batch(
|
||||||
char_pinyin_batch, char_positions, text
|
char_pinyin_batch, char_positions, text
|
||||||
)
|
)
|
||||||
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
char_pinyin_batch = []
|
char_pinyin_batch = []
|
||||||
char_positions = []
|
char_positions = []
|
||||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
|
||||||
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
|
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
|
||||||
batch_samples = []
|
|
||||||
# 处理剩余的字符
|
# 处理剩余的字符
|
||||||
if char_pinyin_batch:
|
if char_pinyin_batch:
|
||||||
batch_samples += self._process_batch(
|
batch_samples = self._process_batch(
|
||||||
char_pinyin_batch, char_positions, text
|
char_pinyin_batch, char_positions, text
|
||||||
)
|
)
|
||||||
yield from self._shuffle_and_yield(batch_samples)
|
yield from self._shuffle_and_yield(batch_samples)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -581,7 +582,6 @@ def custom_collate_with_txt(batch):
|
||||||
"char": [item["char"] for item in batch],
|
"char": [item["char"] for item in batch],
|
||||||
"txt": [item["txt"] for item in batch],
|
"txt": [item["txt"] for item in batch],
|
||||||
"py": [item["py"] for item in batch],
|
"py": [item["py"] for item in batch],
|
||||||
"pg": torch.cat([item["pg"] for item in batch]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -602,7 +602,71 @@ def custom_collate(batch):
|
||||||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||||||
},
|
},
|
||||||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||||||
"pg": torch.cat([item["pg"] for item in batch]),
|
"py": [item["py"] for item in batch],
|
||||||
|
# "py_group_id": [item["py"] for item in batch],
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from query import QueryEngine
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 初始化查询引擎
|
||||||
|
query_engine = QueryEngine()
|
||||||
|
query_engine.load()
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
dataset = PinyinInputDataset(
|
||||||
|
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
||||||
|
query_engine=query_engine,
|
||||||
|
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||||||
|
max_len=88,
|
||||||
|
batch_query_size=300,
|
||||||
|
shuffle=True,
|
||||||
|
shuffle_buffer_size=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("数据集初始化")
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1024,
|
||||||
|
num_workers=15,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
pin_memory=True if torch.cuda.is_available() else False,
|
||||||
|
collate_fn=custom_collate_with_txt,
|
||||||
|
prefetch_factor=8,
|
||||||
|
persistent_workers=True,
|
||||||
|
shuffle=False, # 数据集内部已实现打乱
|
||||||
|
)
|
||||||
|
|
||||||
|
"""import cProfile
|
||||||
|
|
||||||
|
def profile_func(dataloader):
|
||||||
|
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||||
|
if i >= 3000:
|
||||||
|
break
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
cProfile.run('profile_func(dataloader)')
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 测试数据集
|
||||||
|
try:
|
||||||
|
logger.info("测试数据集")
|
||||||
|
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||||||
|
if i >= 3000:
|
||||||
|
break
|
||||||
|
"""
|
||||||
|
print(f"Sample {i+1}:")
|
||||||
|
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||||||
|
print(f" Pinyin: {sample['py']}")
|
||||||
|
print(f" Context length: {len(sample['txt'])}")
|
||||||
|
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
||||||
|
print()
|
||||||
|
"""
|
||||||
|
except StopIteration:
|
||||||
|
print("数据集为空")
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,21 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.amp as amp
|
|
||||||
from modelscope import AutoModel
|
|
||||||
import pickle
|
import pickle
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from modelscope import AutoModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .monitor import TrainingMonitor
|
from .monitor import TrainingMonitor
|
||||||
|
|
||||||
|
EVAL_DATALOADER = [
|
||||||
EVAL_DATALOADER = (
|
pickle.load(file.open("rb"))
|
||||||
pickle.load(file.open('rb'))
|
|
||||||
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
for file in (files(__package__) / "eval_dataset").glob("*.pkl")
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
|
|
@ -232,8 +230,8 @@ class MoEModel(nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in eval_dataloader:
|
for batch in eval_dataloader:
|
||||||
# 移动数据到模型设备
|
# 移动数据到模型设备
|
||||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
pg = batch["pg"].to(self.device)
|
pg = batch["pg"].to(self.device)
|
||||||
labels = batch["char_id"].to(self.device)
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
|
|
@ -355,93 +353,108 @@ class MoEModel(nn.Module):
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
eval_frequency=1000,
|
eval_frequency=500,
|
||||||
grad_accum_steps=1,
|
grad_accum_steps=1,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
|
lr_schedule=None, # 新增:可选的自定义学习率调度函数
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
训练模型,支持混合精度、梯度累积、学习率调度、实时监控。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
train_dataloader: DataLoader,训练数据
|
... 原有参数 ...
|
||||||
eval_dataloader: DataLoader,验证数据(可选)
|
lr_schedule : callable, optional
|
||||||
monitor: TrainingMonitor 实例,用于实时绘图
|
自定义学习率调度函数,接收参数 (processed_batches, optimizer),
|
||||||
criterion: 损失函数
|
可在内部直接修改 optimizer.param_groups 中的学习率。
|
||||||
optimizer: 优化器,默认 AdamW(lr=6e-6)
|
若为 None,则启用内置的固定阈值调度(前1000批 1e-4,之后 6e-6)。
|
||||||
scheduler: 学习率调度器
|
|
||||||
num_epochs: 训练轮数
|
|
||||||
eval_frequency: 评估间隔(步数)
|
|
||||||
grad_accum_steps: 梯度累积步数
|
|
||||||
clip_grad_norm: 梯度裁剪范数
|
|
||||||
mixed_precision: 是否启用混合精度
|
|
||||||
"""
|
"""
|
||||||
# 确保模型在正确的设备上
|
# 确保模型在正确的设备上
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
# 切换到训练模式(调用父类方法)
|
# 切换到训练模式
|
||||||
super().train()
|
super().train()
|
||||||
|
|
||||||
# 默认优化器
|
# 默认优化器
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = optim.AdamW(self.parameters(), lr=1e-4)
|
optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4
|
||||||
|
created_optimizer = True
|
||||||
|
else:
|
||||||
|
created_optimizer = False # 用户传入优化器,不自动覆盖学习率
|
||||||
|
|
||||||
# 混合精度缩放器
|
# 混合精度缩放器
|
||||||
scaler = amp.GradScaler(enabled=mixed_precision)
|
scaler = amp.GradScaler(enabled=mixed_precision)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
processed_batches = 0 # 新增:实际处理的 batch 数量计数器
|
||||||
|
batch_loss_sum = 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for batch_idx, batch in enumerate(
|
for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)):
|
||||||
tqdm(train_dataloader, total=1e6)
|
# ---------- 更新 batch 计数器 ----------
|
||||||
):
|
processed_batches += 1
|
||||||
# 移动数据
|
|
||||||
input_ids = batch['hint']["input_ids"].to(self.device)
|
# ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)----------
|
||||||
attention_mask = batch['hint']["attention_mask"].to(self.device)
|
if created_optimizer and lr_schedule is None:
|
||||||
|
if processed_batches <= 1000:
|
||||||
|
new_lr = 1e-4
|
||||||
|
else:
|
||||||
|
new_lr = 6e-6
|
||||||
|
# 为所有参数组统一设置学习率
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group["lr"] = new_lr
|
||||||
|
elif lr_schedule is not None:
|
||||||
|
# 调用用户自定义的调度函数
|
||||||
|
lr_schedule(processed_batches, optimizer)
|
||||||
|
|
||||||
|
# ---------- 移动数据 ----------
|
||||||
|
input_ids = batch["hint"]["input_ids"].to(self.device)
|
||||||
|
attention_mask = batch["hint"]["attention_mask"].to(self.device)
|
||||||
pg = batch["pg"].to(self.device)
|
pg = batch["pg"].to(self.device)
|
||||||
labels = batch["char_id"].to(self.device)
|
labels = batch["char_id"].to(self.device)
|
||||||
|
|
||||||
# 混合精度前向
|
# 混合精度前向
|
||||||
with amp.autocast(device_type=self.device.type,enabled=mixed_precision):
|
with amp.autocast(
|
||||||
|
device_type=self.device.type, enabled=mixed_precision
|
||||||
|
):
|
||||||
logits = self(input_ids, attention_mask, pg)
|
logits = self(input_ids, attention_mask, pg)
|
||||||
loss = criterion(logits, labels)
|
loss = criterion(logits, labels)
|
||||||
loss = loss / grad_accum_steps # 梯度累积归一化
|
loss = loss / grad_accum_steps
|
||||||
|
|
||||||
# 反向传播(缩放)
|
# 反向传播
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# 梯度累积:每 grad_accum_steps 步更新一次
|
# 梯度累积
|
||||||
if (batch_idx + 1) % grad_accum_steps == 0:
|
if (batch_idx + 1) % grad_accum_steps == 0:
|
||||||
scaler.unscale_(optimizer) # 用于梯度裁剪
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
original_loss = loss.item() * grad_accum_steps
|
||||||
|
batch_loss_sum += original_loss
|
||||||
|
|
||||||
# 周期性评估
|
# 周期性评估(与原代码相同)
|
||||||
if (
|
if (
|
||||||
eval_dataloader is not None
|
eval_dataloader is not None
|
||||||
and global_step % eval_frequency == 0
|
and global_step % eval_frequency == 0
|
||||||
):
|
):
|
||||||
|
avg_loss = batch_loss_sum / global_step
|
||||||
acc, _ = self.model_eval(eval_dataloader, criterion)
|
acc, _ = self.model_eval(eval_dataloader, criterion)
|
||||||
super().train()
|
super().train()
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
monitor.add_step(
|
monitor.add_step(
|
||||||
global_step,
|
global_step,
|
||||||
{"loss": loss.item() * grad_accum_steps, "acc": acc},
|
{"loss": avg_loss, "acc": acc},
|
||||||
)
|
)
|
||||||
logger.info({"loss": loss.item() * grad_accum_steps, "acc": acc})
|
logger.info(
|
||||||
|
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc}"
|
||||||
elif monitor is not None:
|
)
|
||||||
# 仅记录训练损失
|
batch_loss_sum = 0.0
|
||||||
monitor.add_step(
|
|
||||||
global_step, {"loss": loss.item() * grad_accum_steps}
|
|
||||||
)
|
|
||||||
logger.info({"loss": loss.item() * grad_accum_steps})
|
|
||||||
|
|
||||||
|
|
||||||
# ============================ 使用示例 ============================
|
# ============================ 使用示例 ============================
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue