452 lines
13 KiB
Python
452 lines
13 KiB
Python
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from rich.console import Console
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
# 添加src目录到路径
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from src.model.model import InputMethodEngine
|
|
from src.model.trainer import Trainer
|
|
|
|
console = Console()
|
|
|
|
|
|
class MockDataset(Dataset):
|
|
"""模拟数据集用于测试"""
|
|
|
|
def __init__(self, num_samples=100, vocab_size=100, seq_len=128, num_slots=8):
|
|
self.num_samples = num_samples
|
|
self.vocab_size = vocab_size
|
|
self.seq_len = seq_len
|
|
self.num_slots = num_slots
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
# 生成模拟数据
|
|
return {
|
|
"input_ids": torch.randint(0, self.vocab_size, (self.seq_len,)),
|
|
"token_type_ids": torch.randint(0, 2, (self.seq_len,)),
|
|
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
|
|
"history_slot_ids": torch.randint(0, self.vocab_size, (self.num_slots,)),
|
|
"label": torch.randint(0, self.vocab_size, (1,)),
|
|
}
|
|
|
|
|
|
def test_dataset_creation():
|
|
"""测试数据集创建"""
|
|
console.print("[bold cyan]测试数据集创建...[/bold cyan]")
|
|
|
|
dataset = MockDataset(num_samples=10)
|
|
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
|
|
|
batch = next(iter(dataloader))
|
|
console.print(f"批处理形状:")
|
|
console.print(f" input_ids: {batch['input_ids'].shape}")
|
|
console.print(f" token_type_ids: {batch['token_type_ids'].shape}")
|
|
console.print(f" attention_mask: {batch['attention_mask'].shape}")
|
|
console.print(f" history_slot_ids: {batch['history_slot_ids'].shape}")
|
|
console.print(f" label: {batch['label'].shape}")
|
|
|
|
assert batch["input_ids"].shape == (2, 128), "input_ids形状不正确"
|
|
assert batch["token_type_ids"].shape == (2, 128), "token_type_ids形状不正确"
|
|
assert batch["attention_mask"].shape == (2, 128), "attention_mask形状不正确"
|
|
assert batch["history_slot_ids"].shape == (2, 8), "history_slot_ids形状不正确"
|
|
assert batch["label"].shape == (2, 1), "label形状不正确"
|
|
|
|
console.print("[green]✓ 数据集测试通过[/green]\n")
|
|
return dataloader
|
|
|
|
|
|
def test_model_creation():
|
|
"""测试模型创建"""
|
|
console.print("[bold cyan]测试模型创建...[/bold cyan]")
|
|
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64, # 使用较小的维度以加速测试
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
# 测试前向传播
|
|
batch_size = 2
|
|
input_ids = torch.randint(0, 100, (batch_size, 128))
|
|
token_type_ids = torch.randint(0, 2, (batch_size, 128))
|
|
attention_mask = torch.ones(batch_size, 128, dtype=torch.long)
|
|
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
|
|
|
|
with torch.no_grad():
|
|
logits = model(
|
|
input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask,
|
|
history_slot_ids=history_slot_ids,
|
|
)
|
|
|
|
console.print(f"模型输出形状: {logits.shape}")
|
|
console.print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
assert logits.shape == (batch_size, 100), "模型输出形状不正确"
|
|
|
|
console.print("[green]✓ 模型测试通过[/green]\n")
|
|
return model
|
|
|
|
|
|
def test_trainer_initialization():
|
|
"""测试训练器初始化"""
|
|
console.print("[bold cyan]测试训练器初始化...[/bold cyan]")
|
|
|
|
# 创建模型和数据集
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
train_dataset = MockDataset(num_samples=50)
|
|
eval_dataset = MockDataset(num_samples=10)
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
|
|
|
|
# 初始化训练器
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=eval_dataloader,
|
|
output_dir="./test_output",
|
|
num_epochs=1,
|
|
total_steps=10, # 限制总步数
|
|
learning_rate=1e-4,
|
|
min_learning_rate=1e-6,
|
|
weight_decay=0.1,
|
|
warmup_ratio=0.1,
|
|
label_smoothing=0.1,
|
|
grad_accum_steps=1,
|
|
clip_grad_norm=1.0,
|
|
eval_frequency=5,
|
|
save_frequency=10,
|
|
mixed_precision=False, # 测试时关闭混合精度
|
|
use_tensorboard=False, # 测试时关闭TensorBoard
|
|
)
|
|
|
|
console.print(f"训练器设备: {trainer.device}")
|
|
console.print(f"总步数: {trainer.total_steps}")
|
|
console.print(f"热身步数: {trainer.warmup_steps}")
|
|
console.print(f"优化器类型: {type(trainer.optimizer)}")
|
|
console.print(f"损失函数类型: {type(trainer.criterion)}")
|
|
|
|
assert trainer.device.type in ["cpu", "cuda"], "设备类型不正确"
|
|
assert trainer.total_steps == 10, "总步数不正确"
|
|
assert trainer.warmup_steps == 1, "热身步数不正确" # 10 * 0.1 = 1
|
|
assert isinstance(trainer.optimizer, torch.optim.AdamW), "优化器类型不正确"
|
|
assert isinstance(trainer.criterion, nn.CrossEntropyLoss), "损失函数类型不正确"
|
|
|
|
console.print("[green]✓ 训练器初始化测试通过[/green]\n")
|
|
return trainer
|
|
|
|
|
|
def test_training_step():
|
|
"""测试训练步骤"""
|
|
console.print("[bold cyan]测试训练步骤...[/bold cyan]")
|
|
|
|
# 创建训练器
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
train_dataset = MockDataset(num_samples=10)
|
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=None,
|
|
output_dir="./test_output",
|
|
num_epochs=1,
|
|
total_steps=5,
|
|
learning_rate=1e-4,
|
|
mixed_precision=False,
|
|
use_tensorboard=False,
|
|
)
|
|
|
|
# 执行一个训练步骤
|
|
batch = next(iter(train_dataloader))
|
|
loss, metrics = trainer.train_step(batch)
|
|
|
|
console.print(f"训练步骤损失: {loss:.4f}")
|
|
console.print(f"训练步骤指标: {metrics}")
|
|
|
|
assert isinstance(loss, float), "损失值类型不正确"
|
|
assert loss >= 0, "损失值应为非负数"
|
|
assert "lr" in metrics, "指标中缺少学习率"
|
|
assert "accuracy" in metrics, "指标中缺少准确率"
|
|
assert 0 <= metrics["accuracy"] <= 1, "准确率应在0-1之间"
|
|
|
|
console.print("[green]✓ 训练步骤测试通过[/green]\n")
|
|
|
|
|
|
def test_evaluation():
|
|
"""测试评估功能"""
|
|
console.print("[bold cyan]测试评估功能...[/bold cyan]")
|
|
|
|
# 创建训练器
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
train_dataset = MockDataset(num_samples=10)
|
|
eval_dataset = MockDataset(num_samples=5)
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=eval_dataloader,
|
|
output_dir="./test_output",
|
|
num_epochs=1,
|
|
total_steps=5,
|
|
learning_rate=1e-4,
|
|
mixed_precision=False,
|
|
use_tensorboard=False,
|
|
)
|
|
|
|
# 执行评估
|
|
eval_metrics = trainer.evaluate()
|
|
|
|
console.print(f"评估指标: {eval_metrics}")
|
|
|
|
assert "eval_loss" in eval_metrics, "评估指标中缺少eval_loss"
|
|
assert "eval_accuracy" in eval_metrics, "评估指标中缺少eval_accuracy"
|
|
assert eval_metrics["eval_loss"] >= 0, "评估损失应为非负数"
|
|
assert 0 <= eval_metrics["eval_accuracy"] <= 1, "评估准确率应在0-1之间"
|
|
|
|
console.print("[green]✓ 评估功能测试通过[/green]\n")
|
|
|
|
|
|
def test_lr_scheduler():
|
|
"""测试学习率调度器"""
|
|
console.print("[bold cyan]测试学习率调度器...[/bold cyan]")
|
|
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
train_dataset = MockDataset(num_samples=10)
|
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=None,
|
|
output_dir="./test_output",
|
|
num_epochs=1,
|
|
total_steps=100,
|
|
learning_rate=1e-3,
|
|
min_learning_rate=1e-5,
|
|
warmup_ratio=0.2, # 20%热身
|
|
mixed_precision=False,
|
|
use_tensorboard=False,
|
|
)
|
|
|
|
# 测试不同步数的学习率
|
|
test_steps = [0, 10, 20, 50, 99]
|
|
lr_values = []
|
|
|
|
for step in test_steps:
|
|
trainer.current_step = step
|
|
lr = trainer._get_current_lr()
|
|
lr_values.append(lr)
|
|
console.print(f"步数 {step}: 学习率 = {lr:.2e}")
|
|
|
|
# 验证学习率变化趋势
|
|
assert lr_values[0] == 0.0, "第0步学习率应为0"
|
|
assert lr_values[1] > lr_values[0], "热身阶段学习率应增加"
|
|
assert lr_values[4] < lr_values[2], "余弦退火阶段学习率应下降"
|
|
assert lr_values[4] >= 1e-5, "最终学习率不应低于最小值"
|
|
|
|
console.print("[green]✓ 学习率调度器测试通过[/green]\n")
|
|
|
|
|
|
def test_checkpoint_saving():
|
|
"""测试检查点保存"""
|
|
console.print("[bold cyan]测试检查点保存...[/bold cyan]")
|
|
|
|
import shutil
|
|
import tempfile
|
|
|
|
# 创建临时目录
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
try:
|
|
# 创建训练器
|
|
model = InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
)
|
|
|
|
train_dataset = MockDataset(num_samples=10)
|
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=None,
|
|
output_dir=temp_dir,
|
|
num_epochs=1,
|
|
total_steps=5,
|
|
learning_rate=1e-4,
|
|
mixed_precision=False,
|
|
use_tensorboard=False,
|
|
)
|
|
|
|
# 保存检查点
|
|
checkpoint_path = Path(temp_dir) / "checkpoints" / "test_checkpoint.pt"
|
|
trainer.save_checkpoint("test_checkpoint.pt")
|
|
|
|
console.print(f"检查点保存路径: {checkpoint_path}")
|
|
assert checkpoint_path.exists(), "检查点文件未创建"
|
|
|
|
# 加载检查点
|
|
trainer2 = Trainer(
|
|
model=InputMethodEngine(
|
|
vocab_size=100,
|
|
pinyin_vocab_size=28,
|
|
dim=64,
|
|
num_slots=8,
|
|
n_layers=2,
|
|
n_heads=2,
|
|
num_experts=4,
|
|
max_seq_len=128,
|
|
use_pinyin=False,
|
|
),
|
|
train_dataloader=train_dataloader,
|
|
eval_dataloader=None,
|
|
output_dir=temp_dir,
|
|
num_epochs=1,
|
|
total_steps=5,
|
|
learning_rate=1e-4,
|
|
mixed_precision=False,
|
|
use_tensorboard=False,
|
|
)
|
|
|
|
trainer2.load_checkpoint(checkpoint_path)
|
|
|
|
console.print(f"加载后的步数: {trainer2.current_step}")
|
|
console.print(f"加载后的epoch: {trainer2.current_epoch}")
|
|
|
|
assert trainer2.current_step == trainer.current_step, "步数未正确恢复"
|
|
assert trainer2.current_epoch == trainer.current_epoch, "epoch未正确恢复"
|
|
|
|
console.print("[green]✓ 检查点保存测试通过[/green]\n")
|
|
|
|
finally:
|
|
# 清理临时目录
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
def main():
|
|
"""运行所有测试"""
|
|
console.print("[bold blue]开始测试Trainer类...[/bold blue]\n")
|
|
|
|
try:
|
|
# 测试1: 数据集创建
|
|
test_dataset_creation()
|
|
|
|
# 测试2: 模型创建
|
|
test_model_creation()
|
|
|
|
# 测试3: 训练器初始化
|
|
test_trainer_initialization()
|
|
|
|
# 测试4: 训练步骤
|
|
test_training_step()
|
|
|
|
# 测试5: 评估功能
|
|
test_evaluation()
|
|
|
|
# 测试6: 学习率调度器
|
|
test_lr_scheduler()
|
|
|
|
# 测试7: 检查点保存
|
|
test_checkpoint_saving()
|
|
|
|
console.print("[bold green]所有测试通过! ✅[/bold green]")
|
|
|
|
except Exception as e:
|
|
console.print(f"[bold red]测试失败: {e}[/bold red]")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 清理之前的测试输出
|
|
import shutil
|
|
|
|
test_output_dir = Path("./test_output")
|
|
if test_output_dir.exists():
|
|
shutil.rmtree(test_output_dir)
|
|
|
|
# 运行测试
|
|
exit_code = main()
|
|
|
|
# 清理测试输出
|
|
if test_output_dir.exists():
|
|
shutil.rmtree(test_output_dir)
|
|
|
|
sys.exit(exit_code)
|