From 54ac5af876a94da0f1b463c78f761d4033d0697a Mon Sep 17 00:00:00 2001 From: songsenand Date: Fri, 13 Feb 2026 10:48:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E4=B8=8E=E8=AE=AD=E7=BB=83=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0=E8=87=AA=E5=AE=9A=E4=B9=89=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E7=8E=87=E8=B0=83=E5=BA=A6=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/suinput/dataset.py | 92 ++++++++++++++++++++++++++++------ src/trainer/model.py | 109 +++++++++++++++++++++++------------------ 2 files changed, 139 insertions(+), 62 deletions(-) diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 0b86f4e..a71f8fb 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -1,6 +1,9 @@ +import json import os import random -from typing import Any, Dict, List, Tuple, Optional +from importlib.resources import files +from pathlib import Path +from typing import Any, Dict, List, Tuple import numpy as np import torch @@ -36,7 +39,7 @@ class PinyinInputDataset(IterableDataset): batch_query_size: int = 1000, # 打乱参数 shuffle: bool = True, - shuffle_buffer_size: int = 10000, + shuffle_buffer_size: int = 100, # 削峰填谷参数 max_freq: int = 434748359, # "的"的频率 min_freq: int = 109, # "蓚"的频率 @@ -44,6 +47,7 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 + py_group_json_file: Optional[Dict[str, int]] = None, ): """ 初始化数据集 @@ -411,6 +415,7 @@ class PinyinInputDataset(IterableDataset): if not char_info: continue + logger.info(f"获取字符信息: {char_info}") # 削峰填谷调整 adjust_factor = self.adjust_frequency(char_info["freq"]) if adjust_factor <= 0: @@ -441,7 +446,7 @@ class PinyinInputDataset(IterableDataset): "char": char, "freq": char_info["freq"], "pg": torch.tensor( - [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] + self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 ), } @@ -480,8 +485,7 @@ class PinyinInputDataset(IterableDataset): seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - - batch_samples = [] + for item in self.dataset: text = item.get(self.text_field, "") if not text: @@ -527,21 +531,18 @@ class PinyinInputDataset(IterableDataset): # 达到批量大小时处理 if len(char_pinyin_batch) >= self.batch_query_size: - batch_samples += self._process_batch( + batch_samples = self._process_batch( char_pinyin_batch, char_positions, text ) + yield from self._shuffle_and_yield(batch_samples) char_pinyin_batch = [] char_positions = [] - if len(batch_samples) >= self.shuffle_buffer_size: - # logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}") - yield from self._shuffle_and_yield(batch_samples) - batch_samples = [] # 处理剩余的字符 if char_pinyin_batch: - batch_samples += self._process_batch( + batch_samples = self._process_batch( char_pinyin_batch, char_positions, text ) - yield from self._shuffle_and_yield(batch_samples) + yield from self._shuffle_and_yield(batch_samples) def __len__(self): """ @@ -581,7 +582,6 @@ def custom_collate_with_txt(batch): "char": [item["char"] for item in batch], "txt": [item["txt"] for item in batch], "py": [item["py"] for item in batch], - "pg": torch.cat([item["pg"] for item in batch]), } return result @@ -602,7 +602,71 @@ def custom_collate(batch): "attention_mask": torch.cat([h["attention_mask"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), - "pg": torch.cat([item["pg"] for item in batch]), + "py": [item["py"] for item in batch], + # "py_group_id": [item["py"] for item in batch], } return result + + +# 使用示例 +if __name__ == "__main__": + from query import QueryEngine + from tqdm import tqdm + + # 初始化查询引擎 + query_engine = QueryEngine() + query_engine.load() + + # 创建数据集 + dataset = PinyinInputDataset( + data_dir="/home/songsenand/Data/corpus/CCI-Data/", + query_engine=query_engine, + tokenizer_name="iic/nlp_structbert_backbone_tiny_std", + max_len=88, + batch_query_size=300, + shuffle=True, + shuffle_buffer_size=4000, + ) + + logger.info("数据集初始化") + dataloader = DataLoader( + dataset, + batch_size=1024, + num_workers=15, + worker_init_fn=worker_init_fn, + pin_memory=True if torch.cuda.is_available() else False, + collate_fn=custom_collate_with_txt, + prefetch_factor=8, + persistent_workers=True, + shuffle=False, # 数据集内部已实现打乱 + ) + + """import cProfile + + def profile_func(dataloader): + for i, sample in tqdm(enumerate(dataloader), total=3000): + if i >= 3000: + break + return + + + cProfile.run('profile_func(dataloader)') + + """ + # 测试数据集 + try: + logger.info("测试数据集") + for i, sample in tqdm(enumerate(dataloader), total=3000): + if i >= 3000: + break + """ + print(f"Sample {i+1}:") + print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}") + print(f" Pinyin: {sample['py']}") + print(f" Context length: {len(sample['txt'])}") + print(f" Hint shape: {sample['hint']['input_ids'].shape}") + print() + """ + except StopIteration: + print("数据集为空") diff --git a/src/trainer/model.py b/src/trainer/model.py index 0852548..e78c4ec 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -1,23 +1,21 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torch.amp as amp -from modelscope import AutoModel import pickle from importlib.resources import files +import torch +import torch.amp as amp +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim from loguru import logger - +from modelscope import AutoModel from tqdm import tqdm from .monitor import TrainingMonitor - -EVAL_DATALOADER = ( - pickle.load(file.open('rb')) +EVAL_DATALOADER = [ + pickle.load(file.open("rb")) for file in (files(__package__) / "eval_dataset").glob("*.pkl") -) +] # ---------------------------- 残差块 ---------------------------- @@ -232,8 +230,8 @@ class MoEModel(nn.Module): with torch.no_grad(): for batch in eval_dataloader: # 移动数据到模型设备 - input_ids = batch['hint']["input_ids"].to(self.device) - attention_mask = batch['hint']["attention_mask"].to(self.device) + input_ids = batch["hint"]["input_ids"].to(self.device) + attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) @@ -355,93 +353,108 @@ class MoEModel(nn.Module): optimizer=None, scheduler=None, num_epochs=1, - eval_frequency=1000, + eval_frequency=500, grad_accum_steps=1, clip_grad_norm=1.0, mixed_precision=False, + lr_schedule=None, # 新增:可选的自定义学习率调度函数 ): """ 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 参数: - train_dataloader: DataLoader,训练数据 - eval_dataloader: DataLoader,验证数据(可选) - monitor: TrainingMonitor 实例,用于实时绘图 - criterion: 损失函数 - optimizer: 优化器,默认 AdamW(lr=6e-6) - scheduler: 学习率调度器 - num_epochs: 训练轮数 - eval_frequency: 评估间隔(步数) - grad_accum_steps: 梯度累积步数 - clip_grad_norm: 梯度裁剪范数 - mixed_precision: 是否启用混合精度 + ... 原有参数 ... + lr_schedule : callable, optional + 自定义学习率调度函数,接收参数 (processed_batches, optimizer), + 可在内部直接修改 optimizer.param_groups 中的学习率。 + 若为 None,则启用内置的固定阈值调度(前1000批 1e-4,之后 6e-6)。 """ # 确保模型在正确的设备上 if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) - # 切换到训练模式(调用父类方法) + # 切换到训练模式 super().train() # 默认优化器 if optimizer is None: - optimizer = optim.AdamW(self.parameters(), lr=1e-4) + optimizer = optim.AdamW(self.parameters(), lr=1e-4) # 初始学习率 1e-4 + created_optimizer = True + else: + created_optimizer = False # 用户传入优化器,不自动覆盖学习率 # 混合精度缩放器 scaler = amp.GradScaler(enabled=mixed_precision) global_step = 0 + processed_batches = 0 # 新增:实际处理的 batch 数量计数器 + batch_loss_sum = 0.0 optimizer.zero_grad() for epoch in range(num_epochs): - for batch_idx, batch in enumerate( - tqdm(train_dataloader, total=1e6) - ): - # 移动数据 - input_ids = batch['hint']["input_ids"].to(self.device) - attention_mask = batch['hint']["attention_mask"].to(self.device) + for batch_idx, batch in enumerate(tqdm(train_dataloader, total=1e6)): + # ---------- 更新 batch 计数器 ---------- + processed_batches += 1 + + # ---------- 学习率调度(仅当使用默认优化器且未传入自定义调度函数时)---------- + if created_optimizer and lr_schedule is None: + if processed_batches <= 1000: + new_lr = 1e-4 + else: + new_lr = 6e-6 + # 为所有参数组统一设置学习率 + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + elif lr_schedule is not None: + # 调用用户自定义的调度函数 + lr_schedule(processed_batches, optimizer) + + # ---------- 移动数据 ---------- + input_ids = batch["hint"]["input_ids"].to(self.device) + attention_mask = batch["hint"]["attention_mask"].to(self.device) pg = batch["pg"].to(self.device) labels = batch["char_id"].to(self.device) # 混合精度前向 - with amp.autocast(device_type=self.device.type,enabled=mixed_precision): + with amp.autocast( + device_type=self.device.type, enabled=mixed_precision + ): logits = self(input_ids, attention_mask, pg) loss = criterion(logits, labels) - loss = loss / grad_accum_steps # 梯度累积归一化 + loss = loss / grad_accum_steps - # 反向传播(缩放) + # 反向传播 scaler.scale(loss).backward() - # 梯度累积:每 grad_accum_steps 步更新一次 + # 梯度累积 if (batch_idx + 1) % grad_accum_steps == 0: - scaler.unscale_(optimizer) # 用于梯度裁剪 + scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() global_step += 1 + original_loss = loss.item() * grad_accum_steps + batch_loss_sum += original_loss - # 周期性评估 + # 周期性评估(与原代码相同) if ( eval_dataloader is not None and global_step % eval_frequency == 0 ): + avg_loss = batch_loss_sum / global_step acc, _ = self.model_eval(eval_dataloader, criterion) super().train() if monitor is not None: monitor.add_step( global_step, - {"loss": loss.item() * grad_accum_steps, "acc": acc}, + {"loss": avg_loss, "acc": acc}, ) - logger.info({"loss": loss.item() * grad_accum_steps, "acc": acc}) - - elif monitor is not None: - # 仅记录训练损失 - monitor.add_step( - global_step, {"loss": loss.item() * grad_accum_steps} - ) - logger.info({"loss": loss.item() * grad_accum_steps}) + logger.info( + f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc}" + ) + batch_loss_sum = 0.0 # ============================ 使用示例 ============================