SUimeModelTraner/test_trainer.py

452 lines
13 KiB
Python

import sys
from pathlib import Path
import torch
import torch.nn as nn
from rich.console import Console
from torch.utils.data import DataLoader, Dataset
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.model import InputMethodEngine
from src.model.trainer import Trainer
console = Console()
class MockDataset(Dataset):
"""模拟数据集用于测试"""
def __init__(self, num_samples=100, vocab_size=100, seq_len=128, num_slots=8):
self.num_samples = num_samples
self.vocab_size = vocab_size
self.seq_len = seq_len
self.num_slots = num_slots
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 生成模拟数据
return {
"input_ids": torch.randint(0, self.vocab_size, (self.seq_len,)),
"token_type_ids": torch.randint(0, 2, (self.seq_len,)),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"history_slot_ids": torch.randint(0, self.vocab_size, (self.num_slots,)),
"label": torch.randint(0, self.vocab_size, (1,)),
}
def test_dataset_creation():
"""测试数据集创建"""
console.print("[bold cyan]测试数据集创建...[/bold cyan]")
dataset = MockDataset(num_samples=10)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
batch = next(iter(dataloader))
console.print(f"批处理形状:")
console.print(f" input_ids: {batch['input_ids'].shape}")
console.print(f" token_type_ids: {batch['token_type_ids'].shape}")
console.print(f" attention_mask: {batch['attention_mask'].shape}")
console.print(f" history_slot_ids: {batch['history_slot_ids'].shape}")
console.print(f" label: {batch['label'].shape}")
assert batch["input_ids"].shape == (2, 128), "input_ids形状不正确"
assert batch["token_type_ids"].shape == (2, 128), "token_type_ids形状不正确"
assert batch["attention_mask"].shape == (2, 128), "attention_mask形状不正确"
assert batch["history_slot_ids"].shape == (2, 8), "history_slot_ids形状不正确"
assert batch["label"].shape == (2, 1), "label形状不正确"
console.print("[green]✓ 数据集测试通过[/green]\n")
return dataloader
def test_model_creation():
"""测试模型创建"""
console.print("[bold cyan]测试模型创建...[/bold cyan]")
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64, # 使用较小的维度以加速测试
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
# 测试前向传播
batch_size = 2
input_ids = torch.randint(0, 100, (batch_size, 128))
token_type_ids = torch.randint(0, 2, (batch_size, 128))
attention_mask = torch.ones(batch_size, 128, dtype=torch.long)
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
with torch.no_grad():
logits = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
history_slot_ids=history_slot_ids,
)
console.print(f"模型输出形状: {logits.shape}")
console.print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
assert logits.shape == (batch_size, 100), "模型输出形状不正确"
console.print("[green]✓ 模型测试通过[/green]\n")
return model
def test_trainer_initialization():
"""测试训练器初始化"""
console.print("[bold cyan]测试训练器初始化...[/bold cyan]")
# 创建模型和数据集
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=50)
eval_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
# 初始化训练器
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
output_dir="./test_output",
num_epochs=1,
total_steps=10, # 限制总步数
learning_rate=1e-4,
min_learning_rate=1e-6,
weight_decay=0.1,
warmup_ratio=0.1,
label_smoothing=0.1,
grad_accum_steps=1,
clip_grad_norm=1.0,
eval_frequency=5,
save_frequency=10,
mixed_precision=False, # 测试时关闭混合精度
use_tensorboard=False, # 测试时关闭TensorBoard
)
console.print(f"训练器设备: {trainer.device}")
console.print(f"总步数: {trainer.total_steps}")
console.print(f"热身步数: {trainer.warmup_steps}")
console.print(f"优化器类型: {type(trainer.optimizer)}")
console.print(f"损失函数类型: {type(trainer.criterion)}")
assert trainer.device.type in ["cpu", "cuda"], "设备类型不正确"
assert trainer.total_steps == 10, "总步数不正确"
assert trainer.warmup_steps == 1, "热身步数不正确" # 10 * 0.1 = 1
assert isinstance(trainer.optimizer, torch.optim.AdamW), "优化器类型不正确"
assert isinstance(trainer.criterion, nn.CrossEntropyLoss), "损失函数类型不正确"
console.print("[green]✓ 训练器初始化测试通过[/green]\n")
return trainer
def test_training_step():
"""测试训练步骤"""
console.print("[bold cyan]测试训练步骤...[/bold cyan]")
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir="./test_output",
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 执行一个训练步骤
batch = next(iter(train_dataloader))
loss, metrics = trainer.train_step(batch)
console.print(f"训练步骤损失: {loss:.4f}")
console.print(f"训练步骤指标: {metrics}")
assert isinstance(loss, float), "损失值类型不正确"
assert loss >= 0, "损失值应为非负数"
assert "lr" in metrics, "指标中缺少学习率"
assert "accuracy" in metrics, "指标中缺少准确率"
assert 0 <= metrics["accuracy"] <= 1, "准确率应在0-1之间"
console.print("[green]✓ 训练步骤测试通过[/green]\n")
def test_evaluation():
"""测试评估功能"""
console.print("[bold cyan]测试评估功能...[/bold cyan]")
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
eval_dataset = MockDataset(num_samples=5)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
output_dir="./test_output",
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 执行评估
eval_metrics = trainer.evaluate()
console.print(f"评估指标: {eval_metrics}")
assert "eval_loss" in eval_metrics, "评估指标中缺少eval_loss"
assert "eval_accuracy" in eval_metrics, "评估指标中缺少eval_accuracy"
assert eval_metrics["eval_loss"] >= 0, "评估损失应为非负数"
assert 0 <= eval_metrics["eval_accuracy"] <= 1, "评估准确率应在0-1之间"
console.print("[green]✓ 评估功能测试通过[/green]\n")
def test_lr_scheduler():
"""测试学习率调度器"""
console.print("[bold cyan]测试学习率调度器...[/bold cyan]")
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir="./test_output",
num_epochs=1,
total_steps=100,
learning_rate=1e-3,
min_learning_rate=1e-5,
warmup_ratio=0.2, # 20%热身
mixed_precision=False,
use_tensorboard=False,
)
# 测试不同步数的学习率
test_steps = [0, 10, 20, 50, 99]
lr_values = []
for step in test_steps:
trainer.current_step = step
lr = trainer._get_current_lr()
lr_values.append(lr)
console.print(f"步数 {step}: 学习率 = {lr:.2e}")
# 验证学习率变化趋势
assert lr_values[0] == 0.0, "第0步学习率应为0"
assert lr_values[1] > lr_values[0], "热身阶段学习率应增加"
assert lr_values[4] < lr_values[2], "余弦退火阶段学习率应下降"
assert lr_values[4] >= 1e-5, "最终学习率不应低于最小值"
console.print("[green]✓ 学习率调度器测试通过[/green]\n")
def test_checkpoint_saving():
"""测试检查点保存"""
console.print("[bold cyan]测试检查点保存...[/bold cyan]")
import shutil
import tempfile
# 创建临时目录
temp_dir = tempfile.mkdtemp()
try:
# 创建训练器
model = InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
)
train_dataset = MockDataset(num_samples=10)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir=temp_dir,
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
# 保存检查点
checkpoint_path = Path(temp_dir) / "checkpoints" / "test_checkpoint.pt"
trainer.save_checkpoint("test_checkpoint.pt")
console.print(f"检查点保存路径: {checkpoint_path}")
assert checkpoint_path.exists(), "检查点文件未创建"
# 加载检查点
trainer2 = Trainer(
model=InputMethodEngine(
vocab_size=100,
pinyin_vocab_size=28,
dim=64,
num_slots=8,
n_layers=2,
n_heads=2,
num_experts=4,
max_seq_len=128,
use_pinyin=False,
),
train_dataloader=train_dataloader,
eval_dataloader=None,
output_dir=temp_dir,
num_epochs=1,
total_steps=5,
learning_rate=1e-4,
mixed_precision=False,
use_tensorboard=False,
)
trainer2.load_checkpoint(checkpoint_path)
console.print(f"加载后的步数: {trainer2.current_step}")
console.print(f"加载后的epoch: {trainer2.current_epoch}")
assert trainer2.current_step == trainer.current_step, "步数未正确恢复"
assert trainer2.current_epoch == trainer.current_epoch, "epoch未正确恢复"
console.print("[green]✓ 检查点保存测试通过[/green]\n")
finally:
# 清理临时目录
shutil.rmtree(temp_dir)
def main():
"""运行所有测试"""
console.print("[bold blue]开始测试Trainer类...[/bold blue]\n")
try:
# 测试1: 数据集创建
test_dataset_creation()
# 测试2: 模型创建
test_model_creation()
# 测试3: 训练器初始化
test_trainer_initialization()
# 测试4: 训练步骤
test_training_step()
# 测试5: 评估功能
test_evaluation()
# 测试6: 学习率调度器
test_lr_scheduler()
# 测试7: 检查点保存
test_checkpoint_saving()
console.print("[bold green]所有测试通过! ✅[/bold green]")
except Exception as e:
console.print(f"[bold red]测试失败: {e}[/bold red]")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
# 清理之前的测试输出
import shutil
test_output_dir = Path("./test_output")
if test_output_dir.exists():
shutil.rmtree(test_output_dir)
# 运行测试
exit_code = main()
# 清理测试输出
if test_output_dir.exists():
shutil.rmtree(test_output_dir)
sys.exit(exit_code)