feat(train): 添加训练脚本并重构模型输入处理逻辑

This commit is contained in:
songsenand 2026-04-05 00:08:29 +08:00
parent 1af85a36bc
commit 69349a88a6
8 changed files with 3304 additions and 1865 deletions

View File

@ -25,6 +25,9 @@ dependencies = [
"typer>=0.21.1",
]
[project.scripts]
train-model = "model.trainer:app"
[tool.uv]
# 设置当前项目的默认索引源
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from modelscope import AutoModel
# ---------------------------- 注意力池化模块----------------------------
@ -25,9 +25,8 @@ class AttentionPooling(nn.Module):
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
return pooled
# ---------------------------- 残差块 ----------------------------
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
super().__init__()
@ -92,9 +91,12 @@ class ContextEncoder(nn.Module):
self.dim = dim
# Embeddings
self.text_emb = nn.Embedding(vocab_size, dim)
self.text_emb = AutoModel.from_pretrained(
"iic/nlp_structbert_backbone_lite_std"
).embeddings
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
self.pinyin_pooling = AttentionPooling(dim)
# Transformer Encoder (4 layers, 4 heads) [1]
encoder_layer = nn.TransformerEncoderLayer(
@ -118,9 +120,22 @@ class ContextEncoder(nn.Module):
Returns:
H: [batch, seq_len, 512] Context representation [1]
"""
# 1. Embedding Fusion: Text + Pinyin + Position
# 1. Embed text
text_emb = self.text_emb(text_ids) # [B, 128, dim]
# 2. Embed and pool pinyin to global feature
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
# 方式1Attention Pooling推荐
pinyin_global = self.pinyin_pooling(
pinyin_emb, mask=None
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
# Broadcast pinyin to all text positions
pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim]
pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim]
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
x = text_emb + pinyin_broadcast
seq_len = x.size(1)
pos_ids = (

View File

@ -1,25 +1,42 @@
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import Style, lazy_pinyin
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import IterableDataset
from .query import QueryEngine
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表
支持 a-z `
未知字符映射为 0 (PAD/UNK)
"""
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
class PinyinInputDataset(IterableDataset):
def __init__(
self,
data_path: str,
max_workes: int = -1,
max_workers: int = -1,
max_iter_length=1e6,
max_seq_length=128,
text_field: str = "text",
@ -35,7 +52,7 @@ class PinyinInputDataset(IterableDataset):
self.min_freq = 109
self.tokenizer = AutoTokenizer.from_pretrained(
Path(files(__package__) / "assets" / "tokenizer")
Path(str(files(__package__))) / "assets" / "tokenizer"
)
self.data_path = data_path
@ -43,7 +60,7 @@ class PinyinInputDataset(IterableDataset):
self.max_seq_length = max_seq_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workes
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
self.possible_lengths = list(length_weights.keys())
@ -98,7 +115,56 @@ class PinyinInputDataset(IterableDataset):
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
return lazy_pinyin(text, errors=lambda x: [c for c in x])
"""
流式处理单条文本转换为拼音列表
特性
1. 严格一一对应len(result) == len(text)
2. 高多音字准确率利用 pypinyin 内部的词语分词能力
3. 高性能预分配内存无多余对象创建
Args:
text: 输入字符串
Returns:
List[str]: 拼音或非汉字字符的列表
"""
if not text:
return []
text_len = len(text)
# 2. 预分配结果列表,初始化占位符。
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
result: List[str] = [""] * text_len
# 3. 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
# style=Style.NORMAL 获取不带声调的拼音
pinyin_list = lazy_pinyin(hanzi_segment)
# 5. 健壮性兜底:
# 正常情况下pypinyin 返回的拼音数应等于汉字数。
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
# 6. 直接通过索引填充到预分配的位置
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 7. 填充非汉字字符
# 遍历原文,如果 result 对应位置为空,则填入原字符
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(
@ -107,7 +173,7 @@ class PinyinInputDataset(IterableDataset):
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
return i, mask_pinyin
break
else:
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
@ -116,7 +182,7 @@ class PinyinInputDataset(IterableDataset):
if py == "":
py = pinyin_list[i][0]
mask_pinyin.append(py)
return len(text) - 1, mask_pinyin
return len(mask_pinyin), mask_pinyin
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
@ -165,7 +231,7 @@ class PinyinInputDataset(IterableDataset):
if current_iter_index >= worker_quota:
break
labels = [] # 添加起始符
labels = []
# 如果text[i]不在字符库中,则跳过
# 当i小于48时候则将part1取text[0:i]
# 当i大于48时候则将part1取text[i-48:i]
@ -186,7 +252,19 @@ class PinyinInputDataset(IterableDataset):
pinyin_len, part2 = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
)
part2 = "".join(part2)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
part2 = split_char.join(part2)
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
@ -219,7 +297,6 @@ class PinyinInputDataset(IterableDataset):
string_list.append(text[start_pos:end_pos])
# 用|连接所有字符串
part4 = "|".join(string_list)
labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(
@ -240,9 +317,9 @@ class PinyinInputDataset(IterableDataset):
samples = []
for i, label in enumerate(labels):
repeats = self.adjust_frequency(label)
l = labels[:i]
len_l = len(l)
l.extend([0] * (8 - len_l))
masked_labels = labels[:i]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))
samples.extend(
[
@ -252,11 +329,12 @@ class PinyinInputDataset(IterableDataset):
"attention_mask": encoded["attention_mask"],
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(
l, dtype=torch.long
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
]
* repeats

View File

@ -6,7 +6,6 @@ import torch.nn.functional as F
# 导入 components.py 中的组件
from .components import (
AttentionPooling, # 可选,暂不使用
ContextEncoder,
CrossAttentionFusion,
MoELayer,
@ -41,19 +40,17 @@ class InputMethodEngine(nn.Module):
n_heads: int = 4, # 注意力头数
num_experts: int = 20, # MoE 专家数量
max_seq_len: int = 128, # 最大上下文长度
use_pinyin: bool = False, # 是否使用拼音特征(若为 False拼音嵌入恒为零
):
super().__init__()
self.dim = dim
self.num_slots = num_slots
self.use_pinyin = use_pinyin
self.vocab_size = vocab_size
# 1. 上下文编码器 (ContextEncoder)
# 若 use_pinyin=False则传入 pinyin_vocab_size=1 并固定嵌入为零
self.context_encoder = ContextEncoder(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size if use_pinyin else 1,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
n_layers=n_layers,
n_heads=n_heads,
@ -79,16 +76,12 @@ class InputMethodEngine(nn.Module):
# 5. 分类头
self.classifier = nn.Linear(dim, vocab_size)
# 可选:如果不需要拼音,将拼音嵌入矩阵固定为零
if not use_pinyin and hasattr(self.context_encoder, "pinyin_emb"):
# 将拼音嵌入权重置零,确保对输出无影响
nn.init.zeros_(self.context_encoder.pinyin_emb.weight)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor,
pinyin_ids: torch.Tensor,
history_slot_ids: torch.Tensor,
) -> torch.Tensor:
"""
@ -100,28 +93,20 @@ class InputMethodEngine(nn.Module):
if history_slot_ids.dim() == 1:
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
# 1. 构造拼音输入(如果 use_pinyin=False则使用全零占位符
if self.use_pinyin:
# 注意:这里需要真实的拼音 ids但当前输入未提供故用零占位实际应用中应从外部获取
# 为简化演示,此处使用全零张量,并假设拼音词汇表中 0 为 padding。
pinyin_ids = torch.zeros_like(input_ids)
else:
pinyin_ids = torch.zeros_like(input_ids)
# 2. 上下文编码 -> H [batch, seq_len, dim]
# 1. 上下文编码 -> H [batch, seq_len, dim]
# 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
# 3. 槽位记忆编码 -> S [batch, num_slots, dim]
# 2. 槽位记忆编码 -> S [batch, num_slots, dim]
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
# 4. 交叉注意力融合 (使用 CrossAttentionFusion)
# 3. 交叉注意力融合 (使用 CrossAttentionFusion)
fused = self.cross_attn(S, H, context_mask=attention_mask)
# 5. MoE 处理 -> [batch, num_slots, dim]
# 4. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused)
# 6. 池化与分类:对槽位维度求平均(或使用 mask 池化)
# 5. 池化与分类:对槽位维度求平均(或使用 mask 池化)
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]

913
src/model/trainer.py Normal file
View File

@ -0,0 +1,913 @@
import json
import math
import os
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import typer
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from torch import autocast
from torch.amp.grad_scaler import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .dataset import PinyinInputDataset
# 导入模型和数据
from .model import InputMethodEngine
class Trainer:
"""
输入法模型训练器
实现训练InputMethodEngine模型支持
- 预热+余弦退火学习率调度
- TensorBoard日志记录
- AdamW优化器weight_decay=0.1
- 混合精度训练
- CrossEntropyLoss损失函数支持weight和label_smoothing
- Rich终端美化输出
"""
def __init__(
self,
model: InputMethodEngine,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
total_steps: int,
output_dir: str = "./output",
num_epochs: int = 10,
learning_rate: float = 1e-4,
min_learning_rate: float = 1e-6,
weight_decay: float = 0.1,
warmup_ratio: float = 0.1,
label_smoothing: float = 0.15,
loss_weight: Optional[torch.Tensor] = None,
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
eval_frequency: int = 500,
save_frequency: int = 1000,
mixed_precision: bool = True,
device: Optional[torch.device] = None,
use_tensorboard: bool = True,
):
"""
初始化训练器
Args:
model: 要训练的InputMethodEngine模型
train_dataloader: 训练数据加载器
eval_dataloader: 评估数据加载器可选
output_dir: 输出目录用于保存模型和日志
num_epochs: 训练轮数
total_steps: 总训练步数如果为None则根据epochs计算
learning_rate: 最大学习率预热后
min_learning_rate: 最小学习率余弦退火后的最低值
weight_decay: AdamW优化器的权重衰减
warmup_ratio: 热身步数占总步数的比例
label_smoothing: CrossEntropyLoss的标签平滑参数
loss_weight: CrossEntropyLoss的类别权重
grad_accum_steps: 梯度累积步数
clip_grad_norm: 梯度裁剪的最大范数
eval_frequency: 评估频率步数
save_frequency: 保存检查点频率步数
mixed_precision: 是否使用混合精度训练
device: 训练设备如果为None则自动选择
use_tensorboard: 是否使用TensorBoard记录
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = list([i for i in eval_dataloader])
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.min_learning_rate = min_learning_rate
self.weight_decay = weight_decay
self.warmup_ratio = warmup_ratio
self.label_smoothing = label_smoothing
self.loss_weight = loss_weight
self.grad_accum_steps = grad_accum_steps
self.clip_grad_norm = clip_grad_norm
self.eval_frequency = eval_frequency
self.save_frequency = save_frequency
self.mixed_precision = mixed_precision
self.use_tensorboard = use_tensorboard
# 设置设备
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
# 移动模型到设备
self.model.to(self.device)
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(exist_ok=True)
# 计算总步数
self.total_steps = total_steps
self.warmup_steps = int(self.total_steps * warmup_ratio)
# 初始化优化器
self.optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999),
eps=1e-8,
)
# 初始化损失函数
if loss_weight is not None:
self.criterion = nn.CrossEntropyLoss(
weight=loss_weight.to(self.device), label_smoothing=label_smoothing
)
else:
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
# 初始化混合精度训练器
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
self.scaler = GradScaler(device_type, enabled=mixed_precision)
# 初始化TensorBoard
if use_tensorboard:
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
else:
self.writer = None
# 初始化Rich控制台
self.console = Console()
# 训练状态
self.current_step = 0
self.current_epoch = 0
self.best_eval_loss = float("inf")
# 学习率调度函数
self.lr_scheduler = self._create_lr_scheduler()
logger.info(f"Trainer initialized with device: {self.device}")
logger.info(
f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}"
)
logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
def _create_lr_scheduler(self) -> Callable[[int], float]:
"""创建学习率调度函数(预热 + 余弦退火)"""
def lr_scheduler(step: int) -> float:
if step < self.warmup_steps:
# 线性预热
return self.learning_rate * (step / self.warmup_steps)
else:
# 余弦退火
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
decayed_lr = (
self.min_learning_rate
+ (self.learning_rate - self.min_learning_rate) * cosine_decay
)
return decayed_lr
return lr_scheduler
def _get_current_lr(self) -> float:
"""获取当前学习率"""
return self.lr_scheduler(self.current_step)
def _update_learning_rate(self):
"""更新优化器中的学习率"""
current_lr = self._get_current_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = current_lr
return current_lr
def train_step(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[float, Dict[str, Any]]:
"""
执行单个训练步骤
Args:
batch: 包含输入数据的批次
Returns:
loss: 损失值
metrics: 训练指标字典
"""
self.model.train()
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
pinyin_ids = batch["pinyin_ids"].to(self.device).squeeze(-1)
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
# 混合精度训练
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
loss = loss / self.grad_accum_steps
# 反向传播
self.scaler.scale(loss).backward()
metrics = {
"loss": loss.item() * self.grad_accum_steps,
"lr": self._get_current_lr(),
}
# 计算准确率
with torch.no_grad():
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total = labels.size(0)
metrics["accuracy"] = correct / total if total > 0 else 0.0
return loss.item() * self.grad_accum_steps, metrics
def evaluate(self) -> Dict[str, float]:
"""
在评估集上评估模型
Returns:
评估指标字典
"""
if self.eval_dataloader is None:
return {}
self.model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
with torch.no_grad():
for batch in self.eval_dataloader:
# 移动数据到设备
input_ids = batch["input_ids"].to(self.device)
token_type_ids = batch["token_type_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
history_slot_ids = batch["history_slot_ids"].to(self.device)
labels = batch["label"].to(self.device).squeeze(1)
# 前向传播
logits = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
history_slot_ids=history_slot_ids,
)
# 计算损失
loss = self.criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
# 计算准确率
preds = torch.argmax(logits, dim=-1)
correct = (preds == labels).sum().item()
total_correct += correct
total_samples += labels.size(0)
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
def save_checkpoint(self, filename: str, is_best: bool = False):
"""
保存检查点
Args:
filename: 检查点文件名
is_best: 是否是最佳模型
"""
checkpoint_path = self.checkpoint_dir / filename
checkpoint = {
"step": self.current_step,
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"best_eval_loss": self.best_eval_loss,
"config": {
"learning_rate": self.learning_rate,
"weight_decay": self.weight_decay,
"warmup_ratio": self.warmup_ratio,
"label_smoothing": self.label_smoothing,
"total_steps": self.total_steps,
},
}
torch.save(checkpoint, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
logger.info(f"Best model saved to {best_path}")
def load_checkpoint(self, checkpoint_path: Union[str, Path]):
"""
加载检查点
Args:
checkpoint_path: 检查点文件路径
"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.current_step = checkpoint["step"]
self.current_epoch = checkpoint["epoch"]
self.best_eval_loss = checkpoint["best_eval_loss"]
logger.info(f"Checkpoint loaded from {checkpoint_path}")
logger.info(
f"Resuming from step {self.current_step}, epoch {self.current_epoch}"
)
def _log_to_tensorboard(self, metrics: Dict[str, float], step: int):
"""将指标记录到TensorBoard"""
if self.writer is None:
return
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
def _create_progress_bar(self) -> Progress:
"""创建Rich进度条"""
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
def _print_training_info(self):
"""打印训练信息"""
info_table = Table(
title="Training Configuration",
show_header=True,
header_style="bold magenta",
)
info_table.add_column("Parameter", style="cyan")
info_table.add_column("Value", style="green")
info_table.add_row("Device", str(self.device))
info_table.add_row("Total Steps", str(self.total_steps))
info_table.add_row("Warmup Steps", str(self.warmup_steps))
info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}")
info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}")
info_table.add_row("Weight Decay", str(self.weight_decay))
info_table.add_row("Label Smoothing", str(self.label_smoothing))
info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps))
info_table.add_row("Mixed Precision", str(self.mixed_precision))
self.console.print(info_table)
def train(self, resume_from: Optional[str] = None):
"""
主训练循环
Args:
resume_from: 从哪个检查点恢复训练可选
"""
# 如果提供了检查点,则恢复训练
if resume_from is not None:
self.load_checkpoint(resume_from)
# 打印训练信息
self._print_training_info()
# 初始化训练状态
global_step = self.current_step
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
)
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
# 训练步骤
loss, metrics = self.train_step(batch)
# 累积指标
accumulated_loss += loss
accumulated_accuracy += metrics.get("accuracy", 0.0)
accumulation_counter += 1
# 梯度累积每grad_accum_steps步更新一次参数
if (global_step + 1) % self.grad_accum_steps == 0:
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_grad_norm
)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
progress.update(
epoch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
)
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
avg_loss = accumulated_loss / accumulation_counter
avg_accuracy = accumulated_accuracy / accumulation_counter
# 评估模型
eval_metrics = self.evaluate()
# 准备日志指标
log_metrics = {
"train/loss": avg_loss,
"train/accuracy": avg_accuracy,
"train/learning_rate": current_lr,
}
if eval_metrics:
log_metrics.update(
{
"eval/loss": eval_metrics["eval_loss"],
"eval/accuracy": eval_metrics["eval_accuracy"],
}
)
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
self.save_checkpoint(
f"step_{global_step}.pt", is_best=True
)
# 记录到TensorBoard
self._log_to_tensorboard(log_metrics, global_step)
# 打印日志
log_text = (
f"[Epoch {epoch + 1}/{self.num_epochs}] "
f"[Step {global_step}/{self.total_steps}] "
f"Train Loss: {avg_loss:.4f} | "
f"Train Acc: {avg_accuracy:.4f} | "
f"LR: {current_lr:.2e}"
)
if eval_metrics:
log_text += (
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
accumulated_accuracy = 0.0
accumulation_counter = 0
# 定期保存检查点
if (global_step + 1) % self.save_frequency == 0:
self.save_checkpoint(f"step_{global_step}.pt")
# 更新步数
global_step += 1
self.current_step = global_step
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 重置进度条
progress.reset(epoch_task)
# 每个epoch结束后保存检查点
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
# 检查是否达到总步数
if global_step >= self.total_steps:
break
# 训练完成
logger.info("Training completed!")
# 保存最终模型
self.save_checkpoint("final_model.pt")
# 关闭TensorBoard写入器
if self.writer is not None:
self.writer.close()
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
Args:
batch: 样本列表每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
# Typer CLI应用
app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False)
@app.command()
def train(
# 数据参数
train_data_path: str = typer.Option(
..., "--train-data-path", "-t", help="训练数据集路径"
),
eval_data_path: str = typer.Option(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
pinyin_vocab_size: int = typer.Option(
30, "--pinyin-vocab-size", help="拼音词汇表大小"
),
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
dim: int = typer.Option(512, "--dim", help="模型维度"),
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
):
"""
训练输入法模型
"""
torch.multiprocessing.set_sharing_strategy("file_system")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 打印配置信息
console.print(
Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan")
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
# 添加配置信息
config_table.add_row("数据", "训练数据路径", train_data_path)
config_table.add_row("数据", "评估数据路径", eval_data_path)
config_table.add_row("数据", "输出目录", output_dir)
config_table.add_row("数据", "批次大小", str(batch_size))
config_table.add_row("模型", "词汇表大小", str(vocab_size))
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
config_table.add_row("模型", "模型维度", str(dim))
config_table.add_row("模型", "槽位数量", str(num_slots))
config_table.add_row("模型", "Transformer层数", str(n_layers))
config_table.add_row("模型", "注意力头数", str(n_heads))
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio))
config_table.add_row("训练", "标签平滑", str(label_smoothing))
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存配置
config = {
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"use_pinyin": use_pinyin,
"batch_size": batch_size,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"label_smoothing": label_smoothing,
"grad_accum_steps": grad_accum_steps,
"clip_grad_norm": clip_grad_norm,
"eval_frequency": eval_frequency,
"save_frequency": save_frequency,
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"max_iter_length": max_iter_length,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
# 注意PinyinInputDataset是IterableDataset所以不能使用shuffle参数
# 多worker配置每个worker处理数据集的一个分片由dataset.__iter__中的shard处理
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=max(1, (os.cpu_count() or 1) - 1),
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch适合大内存场景
persistent_workers=True, # 保持worker进程存活避免重建开销
)
# 评估数据集(使用相同的设置,但可以调整参数)
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=1024, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=1000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=batch_size,
num_workers=1,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch
persistent_workers=True, # 保持worker进程存活
)
console.print("[green]✓ 数据加载器创建完成[/green]")
console.print(f" 训练批次大小: {batch_size}")
console.print(f" 评估批次大小: {batch_size}")
# 创建模型
console.print("[bold cyan]正在创建模型...[/bold cyan]")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
)
console.print(
f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
# 创建训练器
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length / batch_size),
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
)
console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
trainer.train(resume_from=resume_from)
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}")
@app.command()
def evaluate(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"),
batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"),
):
"""
评估训练好的模型
"""
console = Console()
console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]")
# 这里应该实现评估逻辑
# 1. 加载检查点
# 2. 创建数据加载器
# 3. 评估模型
console.print("[yellow]评估功能待实现[/yellow]")
@app.command()
def export(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
output_path: str = typer.Option(
"./exported_model.onnx", "--output", "-o", help="输出路径"
),
):
"""
导出模型为ONNX格式
"""
console = Console()
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
# 这里应该实现导出逻辑
# 1. 加载检查点
# 2. 导出为ONNX
console.print("[yellow]导出功能待实现[/yellow]")
if __name__ == "__main__":
app()

24
test.py
View File

@ -1,17 +1,29 @@
import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.trainer import collate_fn, worker_init_fn
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=20, max_workes=3)
for i, line in enumerate(dataset):
for k, v in line.items():
if isinstance(v, str):
continue
print(k, v.shape)
dataset = PinyinInputDataset(dataset_path, max_iter_length=128 * 128)
dataloader = DataLoader(
dataset,
batch_size=128,
num_workers=2,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch适合大内存场景
persistent_workers=True, # 保持worker进程存活避免重建开销
)
dataloader = list([i for i in dataloader])
print(len(dataloader[0]["labels"]))
for i, line in tqdm(enumerate(dataloader), total=128):
print(line["pinyin_ids"].squeeze(-1).shape)

451
test_trainer.py Normal file
View File

@ -0,0 +1,451 @@
import sys
from pathlib import Path
import torch
import torch.nn as nn
from rich.console import Console
from torch.utils.data import DataLoader, Dataset
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.model import InputMethodEngine
from src.model.trainer import Trainer
console = Console()
class MockDataset(Dataset):
"""模拟数据集用于测试"""
def __init__(self, num_samples=100, vocab_size=100, seq_len=128, num_slots=8):
self.num_samples = num_samples
self.vocab_size = vocab_size
self.seq_len = seq_len
self.num_slots = num_slots
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 生成模拟数据
return {
"input_ids": torch.randint(0, self.vocab_size, (self.seq_len,)),
"token_type_ids": torch.randint(0, 2, (self.seq_len,)),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"history_slot_ids": torch.randint(0, self.vocab_size, (self.num_slots,)),
"label": torch.randint(0, self.vocab_size, (1,)),
}
def test_dataset_creation():
"""测试数据集创建"""
console.print("[bold cyan]测试数据集创建...[/bold cyan]")
dataset = MockDataset(num_samples=10)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
batch = next(iter(dataloader))
console.print(f"批处理形状:")
console.print(f" input_ids: {batch['input_ids'].shape}")
console.print(f" token_type_ids: {batch['token_type_ids'].shape}")
console.print(f" attention_mask: {batch['attention_mask'].shape}")
console.print(f" history_slot_ids: {batch['history_slot_ids'].shape}")
console.print(f" label: {batch['label'].shape}")
assert batch["input_ids"].shape == (2, 128), "input_ids形状不正确"
assert batch["token_type_ids"].shape == (2, 128), "token_type_ids形状不正确"
assert batch["attention_mask"].shape == (2, 128), "attention_mask形状不正确"
assert batch["history_slot_ids"].shape == (2, 8), "history_slot_ids形状不正确"
assert batch["label"].shape == (2, 1), "label形状不正确"
console.print("[green]✓ 数据集测试通过[/green]\n")
return dataloader
def test_model_creation():
"""测试模型创建"""
console.print("[bold cyan]测试模型创建...[/bold cyan]")
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64, # 使用较小的维度以加速测试
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
# 测试前向传播
batch_size = 2
input_ids = torch.randint(0, 100, (batch_size, 128))
token_type_ids = torch.randint(0, 2, (batch_size, 128))
attention_mask = torch.ones(batch_size, 128, dtype=torch.long)
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
history_slot_ids=history_slot_ids,
)
console.print(f"模型输出形状: {logits.shape}")
console.print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
assert logits.shape == (batch_size, 100), "模型输出形状不正确"
console.print("[green]✓ 模型测试通过[/green]\n")
return model
def test_trainer_initialization():
"""测试训练器初始化"""
console.print("[bold cyan]测试训练器初始化...[/bold cyan]")
# 创建模型和数据集
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=50)
eval_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
# 初始化训练器
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
output_dir="./test_output",
num_epochs=1,
total_steps=10, # 限制总步数
learning_rate=1e-4,
min_learning_rate=1e-6,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.1,
grad_accum_steps=1,
clip_grad_norm=1.0,
eval_frequency=5,
save_frequency=10,
mixed_precision=False, # 测试时关闭混合精度
use_tensorboard=False, # 测试时关闭TensorBoard
)
console.print(f"训练器设备: {trainer.device}")
console.print(f"总步数: {trainer.total_steps}")
console.print(f"热身步数: {trainer.warmup_steps}")
console.print(f"优化器类型: {type(trainer.optimizer)}")
console.print(f"损失函数类型: {type(trainer.criterion)}")
assert trainer.device.type in ["cpu", "cuda"], "设备类型不正确"
assert trainer.total_steps == 10, "总步数不正确"
assert trainer.warmup_steps == 1, "热身步数不正确" # 10 * 0.1 = 1
assert isinstance(trainer.optimizer, torch.optim.AdamW), "优化器类型不正确"
assert isinstance(trainer.criterion, nn.CrossEntropyLoss), "损失函数类型不正确"
console.print("[green]✓ 训练器初始化测试通过[/green]\n")
return trainer
def test_training_step():
"""测试训练步骤"""
console.print("[bold cyan]测试训练步骤...[/bold cyan]")
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir="./test_output",
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 执行一个训练步骤
batch = next(iter(train_dataloader))
loss, metrics = trainer.train_step(batch)
console.print(f"训练步骤损失: {loss:.4f}")
console.print(f"训练步骤指标: {metrics}")
assert isinstance(loss, float), "损失值类型不正确"
assert loss >= 0, "损失值应为非负数"
assert "lr" in metrics, "指标中缺少学习率"
assert "accuracy" in metrics, "指标中缺少准确率"
assert 0 <= metrics["accuracy"] <= 1, "准确率应在0-1之间"
console.print("[green]✓ 训练步骤测试通过[/green]\n")
def test_evaluation():
"""测试评估功能"""
console.print("[bold cyan]测试评估功能...[/bold cyan]")
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
eval_dataset = MockDataset(num_samples=5)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
output_dir="./test_output",
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 执行评估
eval_metrics = trainer.evaluate()
console.print(f"评估指标: {eval_metrics}")
assert "eval_loss" in eval_metrics, "评估指标中缺少eval_loss"
assert "eval_accuracy" in eval_metrics, "评估指标中缺少eval_accuracy"
assert eval_metrics["eval_loss"] >= 0, "评估损失应为非负数"
assert 0 <= eval_metrics["eval_accuracy"] <= 1, "评估准确率应在0-1之间"
console.print("[green]✓ 评估功能测试通过[/green]\n")
def test_lr_scheduler():
"""测试学习率调度器"""
console.print("[bold cyan]测试学习率调度器...[/bold cyan]")
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir="./test_output",
num_epochs=1,
total_steps=100,
learning_rate=1e-3,
min_learning_rate=1e-5,
warmup_ratio=0.2, # 20%热身
mixed_precision=False,
use_tensorboard=False,
)
# 测试不同步数的学习率
test_steps = [0, 10, 20, 50, 99]
lr_values = []
for step in test_steps:
trainer.current_step = step
lr = trainer._get_current_lr()
lr_values.append(lr)
console.print(f"步数 {step}: 学习率 = {lr:.2e}")
# 验证学习率变化趋势
assert lr_values[0] == 0.0, "第0步学习率应为0"
assert lr_values[1] > lr_values[0], "热身阶段学习率应增加"
assert lr_values[4] < lr_values[2], "余弦退火阶段学习率应下降"
assert lr_values[4] >= 1e-5, "最终学习率不应低于最小值"
console.print("[green]✓ 学习率调度器测试通过[/green]\n")
def test_checkpoint_saving():
"""测试检查点保存"""
console.print("[bold cyan]测试检查点保存...[/bold cyan]")
import shutil
import tempfile
# 创建临时目录
temp_dir = tempfile.mkdtemp()
try:
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir=temp_dir,
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 保存检查点
checkpoint_path = Path(temp_dir) / "checkpoints" / "test_checkpoint.pt"
trainer.save_checkpoint("test_checkpoint.pt")
console.print(f"检查点保存路径: {checkpoint_path}")
assert checkpoint_path.exists(), "检查点文件未创建"
# 加载检查点
trainer2 = Trainer(
model=InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
),
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir=temp_dir,
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
trainer2.load_checkpoint(checkpoint_path)
console.print(f"加载后的步数: {trainer2.current_step}")
console.print(f"加载后的epoch: {trainer2.current_epoch}")
assert trainer2.current_step == trainer.current_step, "步数未正确恢复"
assert trainer2.current_epoch == trainer.current_epoch, "epoch未正确恢复"
console.print("[green]✓ 检查点保存测试通过[/green]\n")
finally:
# 清理临时目录
shutil.rmtree(temp_dir)
def main():
"""运行所有测试"""
console.print("[bold blue]开始测试Trainer类...[/bold blue]\n")
try:
# 测试1: 数据集创建
test_dataset_creation()
# 测试2: 模型创建
test_model_creation()
# 测试3: 训练器初始化
test_trainer_initialization()
# 测试4: 训练步骤
test_training_step()
# 测试5: 评估功能
test_evaluation()
# 测试6: 学习率调度器
test_lr_scheduler()
# 测试7: 检查点保存
test_checkpoint_saving()
console.print("[bold green]所有测试通过! ✅[/bold green]")
except Exception as e:
console.print(f"[bold red]测试失败: {e}[/bold red]")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
# 清理之前的测试输出
import shutil
test_output_dir = Path("./test_output")
if test_output_dir.exists():
shutil.rmtree(test_output_dir)
# 运行测试
exit_code = main()
# 清理测试输出
if test_output_dir.exists():
shutil.rmtree(test_output_dir)
sys.exit(exit_code)

3610
uv.lock

File diff suppressed because it is too large Load Diff