diff --git a/example.py b/example.py index 3481fb8..ef93d69 100644 --- a/example.py +++ b/example.py @@ -27,10 +27,10 @@ if __name__ == "__main__": dataloader = DataLoader( dataset, batch_size=1024, - num_workers=15, + num_workers=1, worker_init_fn=worker_init_fn, pin_memory=True if torch.cuda.is_available() else False, - collate_fn=custom_collate, + collate_fn=custom_collate_with_txt, prefetch_factor=8, persistent_workers=True, shuffle=False, # 数据集内部已实现打乱 @@ -51,11 +51,12 @@ if __name__ == "__main__": # 测试数据集 try: logger.info("测试数据集") - total = 3000 - for i, sample in tqdm(enumerate(dataloader), total=total): + total = 20 + for i, sample in tqdm(enumerate(dataloader), total=20): if i >= total: break - #print(f"Sample {i+1}: {sample['txt'][0:10]}") + + print(f"Sample {i+1}: {sample['txt'][0:10]}") """ print(f"Sample {i+1}:") print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}") diff --git a/pyproject.toml b/pyproject.toml index 01dbaf9..b7980b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,8 @@ requires-python = ">=3.13" dependencies = [ "bokeh>=3.8.2", "datasets>=4.5.0", + "ipykernel>=7.2.0", + "ipython>=9.10.0", "loguru>=0.7.3", "modelscope>=1.34.0", "msgpack>=1.1.2", diff --git a/src/tmp_utils/gen_eval_dataset.py b/src/tmp_utils/gen_eval_dataset.py new file mode 100644 index 0000000..24a1852 --- /dev/null +++ b/src/tmp_utils/gen_eval_dataset.py @@ -0,0 +1,48 @@ +from tqdm import tqdm +from loguru import logger +import torch +from torch.utils.data import DataLoader +import pickle +from pathlib import Path + +from suinput.dataset import PinyinInputDataset, worker_init_fn, custom_collate_with_txt +from suinput.query import QueryEngine + + +# 使用示例 +if __name__ == "__main__": + # 初始化查询引擎 + query_engine = QueryEngine() + query_engine.load() + + # 创建数据集 + dataset = PinyinInputDataset( + data_dir="/home/songsenand/DataSet/data", + query_engine=query_engine, + tokenizer_name="iic/nlp_structbert_backbone_tiny_std", + max_len=88, + batch_query_size=300, + shuffle=True, + shuffle_buffer_size=4000, + ) + logger.info("数据集初始化") + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=1, + worker_init_fn=worker_init_fn, + pin_memory=True if torch.cuda.is_available() else False, + collate_fn=custom_collate_with_txt, + prefetch_factor=8, + persistent_workers=True, + shuffle=False, # 数据集内部已实现打乱 + ) + try: + total = 5 + for i, sample in tqdm(enumerate(dataloader), total=5): + if i >= total: + break + print(sample) + # pickle.dump(sample, open(f"{str(Path(__file__).parent.parent / 'trainer' / 'eval_dataset')}/sample_{i}.pkl", "wb")) + except StopIteration: + print("数据集为空") diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl new file mode 100644 index 0000000..db6d947 Binary files /dev/null and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl new file mode 100644 index 0000000..c93a719 Binary files /dev/null and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl new file mode 100644 index 0000000..1dd4bb6 Binary files /dev/null and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl new file mode 100644 index 0000000..eddc692 Binary files /dev/null and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl new file mode 100644 index 0000000..6a1d02e Binary files /dev/null and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index 6650a19..723954b 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -3,17 +3,19 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.amp as amp -from transformers import AutoModel +from modelscope import AutoModel +import pickle +from importlib.resources import files from tqdm import tqdm from .monitor import TrainingMonitor -# ---------------------------- 工具函数 ---------------------------- -def round_to_power_of_two(x): - """将数字向上取整为2的幂(此处固定返回1024)""" - return 1024 # 根据您的说明固定 d_model = 1024 +EVAL_DATALOADER = ( + pickle.load(file.open('rb')) + for file in (files(__package__) / "eval_dataset").glob("*.pkl") +) # ---------------------------- 残差块 ---------------------------- @@ -98,6 +100,7 @@ class MoEModel(nn.Module): self.embedding = bert.embeddings self.bert_config = bert.config self.hidden_size = self.bert_config.hidden_size # BERT 隐层维度 + self.device = None # 将在 to() 调用时设置 # 2. 4 层标准 Transformer Encoder(从 config 读取参数) encoder_layer = nn.TransformerEncoderLayer( @@ -136,6 +139,11 @@ class MoEModel(nn.Module): # 可选:为领域专家和共享专家设置不同权重衰减(通过优化器实现,此处不处理) + def to(self, device): + """重写 to 方法,记录设备""" + self.device = device + return super().to(device) + def forward(self, input_ids, attention_mask, pg): """ input_ids : [batch, seq_len] @@ -200,35 +208,234 @@ class MoEModel(nn.Module): return probs return logits - def predict(self, inputs, labels): - pass + def model_eval(self, eval_dataloader, criterion=None): + """ + 在验证集上评估模型,返回准确率和平均损失。 - def train( - self, - dataloader, - monitor: TrainingMonitor, - criterion = nn.CrossEntropyLoss(), - device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), - optimizer=None, - sample_frequency=1000, - ): - self.train() - if optimizer is None: - optimizer = optim.AdamW(self.parameters(), lr=6e-6) + 参数: + eval_dataloader: DataLoader,提供 'input_ids', 'attention_mask', 'pg', 'char_id' + criterion: 损失函数,默认为 CrossEntropyLoss() + 返回: + accuracy: float, 准确率 + avg_loss: float, 平均损失 + """ + if criterion is None: + criterion = nn.CrossEntropyLoss() - for i, sample in tqdm(enumerate(dataloader), total=1e7): - optimizer.zero_grad() - with amp.autocast(): - char_id = sample.pop("char_id").to(device) - input_ids = sample.pop("input_ids").to(device) - attention_mask = sample.pop("attention_mask").to(device) - pg = sample.pop("pg").to(device) + self.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for batch in eval_dataloader: + # 移动数据到模型设备 + input_ids = batch["input_ids"].to(self.device) + attention_mask = batch["attention_mask"].to(self.device) + pg = batch["pg"].to(self.device) + labels = batch["char_id"].to(self.device) + + # 前向传播 logits = self(input_ids, attention_mask, pg) - loss = criterion(logits, char_id) - loss.backward() - optimizer.step() + loss = criterion(logits, labels) + total_loss += loss.item() * labels.size(0) + # 计算准确率 + preds = logits.argmax(dim=-1) + correct += (preds == labels).sum().item() + total += labels.size(0) + avg_loss = total_loss / total if total > 0 else 0.0 + accuracy = correct / total if total > 0 else 0.0 + return accuracy, avg_loss + + def predict(self, sample, debug=False): + """ + 基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 + + 参数: + sample : dict + 必须包含字段: + - 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本) + - 'attention_mask': 同上 + - 'pg' : [batch] 或标量 + - 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供) + 调试时(debug=True)必须包含字段: + - 'txt' : 字符串列表(batch)或单个字符串 + - 'char' : 字符串列表(batch)或单个字符串 + - 'py' : 字符串列表(batch)或单个字符串 + debug : bool + 是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。 + + 返回: + preds : torch.Tensor + [batch] 预测类别标签(若输入为单样本且无 batch 维度,则返回标量) + """ + self.eval() + + # ------------------ 1. 提取并规范化输入 ------------------ + # 判断是否为单样本(input_ids 无 batch 维度) + input_ids = sample["input_ids"] + attention_mask = sample["attention_mask"] + pg = sample["pg"] + has_batch_dim = input_ids.dim() > 1 + + if not has_batch_dim: + input_ids = input_ids.unsqueeze(0) + attention_mask = attention_mask.unsqueeze(0) + if pg.dim() == 0: + pg = pg.unsqueeze(0).expand(input_ids.size(0)) + + # ------------------ 2. 移动设备 ------------------ + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + pg = pg.to(self.device) + + # ------------------ 3. 推理 ------------------ + with torch.no_grad(): + logits = self(input_ids, attention_mask, pg) + preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] + + # ------------------ 4. 调试打印(错误样本) ------------------ + if debug: + # 检查必需字段 + required_keys = ["char_id", "txt", "char", "py"] + missing = [k for k in required_keys if k not in sample] + if missing: + raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}") + + # 提取真实标签 + true_labels = sample["char_id"] + if true_labels.dim() == 0: + true_labels = true_labels.unsqueeze(0) + + # 移动真实标签到相同设备 + true_labels = true_labels.to(self.device) + + # 找出预测错误的索引 + incorrect_mask = preds != true_labels + incorrect_indices = torch.where(incorrect_mask)[0] + + if len(incorrect_indices) > 0: + print("\n=== 预测错误样本 ===") + # 获取调试字段(可能是列表或单个字符串) + txts = sample["txt"] + chars = sample["char"] + pys = sample["py"] + + # 统一转换为列表(如果输入是单个字符串) + if isinstance(txts, str): + txts = [txts] + chars = [chars] + pys = [pys] + + for idx in incorrect_indices.cpu().numpy(): + print(f"样本索引 {idx}:") + print(f" Text : {txts[idx]}") + print(f" Char : {chars[idx]}") + print(f" Pinyin: {pys[idx]}") + print( + f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}" + ) + print("===================\n") + + # ------------------ 5. 返回结果(保持与输入维度一致) ------------------ + if not has_batch_dim: + return preds.squeeze(0) # 返回标量 + return preds + + def fit( + self, + train_dataloader, + eval_dataloader=None, + monitor: TrainingMonitor = None, + criterion=nn.CrossEntropyLoss(), + optimizer=None, + scheduler=None, + num_epochs=1, + eval_frequency=1000, + grad_accum_steps=1, + clip_grad_norm=1.0, + mixed_precision=False, + ): + """ + 训练模型,支持混合精度、梯度累积、学习率调度、实时监控。 + + 参数: + train_dataloader: DataLoader,训练数据 + eval_dataloader: DataLoader,验证数据(可选) + monitor: TrainingMonitor 实例,用于实时绘图 + criterion: 损失函数 + optimizer: 优化器,默认 AdamW(lr=6e-6) + scheduler: 学习率调度器 + num_epochs: 训练轮数 + eval_frequency: 评估间隔(步数) + grad_accum_steps: 梯度累积步数 + clip_grad_norm: 梯度裁剪范数 + mixed_precision: 是否启用混合精度 + """ + # 确保模型在正确的设备上 + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self.device) + + # 切换到训练模式(调用父类方法) + super().train() + + # 默认优化器 + if optimizer is None: + optimizer = optim.AdamW(self.parameters(), lr=1e-4) + + # 混合精度缩放器 + scaler = amp.GradScaler(enabled=mixed_precision) + + global_step = 0 + optimizer.zero_grad() + + for epoch in range(num_epochs): + for batch_idx, batch in enumerate( + tqdm(train_dataloader, total=1e6) + ): + # 移动数据 + input_ids = batch['hint']["input_ids"].to(self.device) + attention_mask = batch['hint']["attention_mask"].to(self.device) + pg = batch["pg"].to(self.device) + labels = batch["char_id"].to(self.device) + + # 混合精度前向 + with amp.autocast(device_type=self.device.type,enabled=mixed_precision): + logits = self(input_ids, attention_mask, pg) + loss = criterion(logits, labels) + loss = loss / grad_accum_steps # 梯度累积归一化 + + # 反向传播(缩放) + scaler.scale(loss).backward() + + # 梯度累积:每 grad_accum_steps 步更新一次 + if (batch_idx + 1) % grad_accum_steps == 0: + scaler.unscale_(optimizer) # 用于梯度裁剪 + torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + global_step += 1 + + # 周期性评估 + if ( + eval_dataloader is not None + and global_step % eval_frequency == 0 + ): + acc, _ = self.model_eval(eval_dataloader, criterion) + if monitor is not None: + monitor.add_step( + global_step, + {"loss": loss.item() * grad_accum_steps, "acc": acc}, + ) + elif monitor is not None: + # 仅记录训练损失 + monitor.add_step( + global_step, {"loss": loss.item() * grad_accum_steps} + ) # ============================ 使用示例 ============================