import sys from pathlib import Path import torch import torch.nn as nn from rich.console import Console from torch.utils.data import DataLoader, Dataset # 添加src目录到路径 sys.path.insert(0, str(Path(__file__).parent)) from src.model.model import InputMethodEngine from src.model.trainer import Trainer console = Console() class MockDataset(Dataset): """模拟数据集用于测试""" def __init__(self, num_samples=100, vocab_size=100, seq_len=128, num_slots=8): self.num_samples = num_samples self.vocab_size = vocab_size self.seq_len = seq_len self.num_slots = num_slots def __len__(self): return self.num_samples def __getitem__(self, idx): # 生成模拟数据 return { "input_ids": torch.randint(0, self.vocab_size, (self.seq_len,)), "token_type_ids": torch.randint(0, 2, (self.seq_len,)), "attention_mask": torch.ones(self.seq_len, dtype=torch.long), "history_slot_ids": torch.randint(0, self.vocab_size, (self.num_slots,)), "label": torch.randint(0, self.vocab_size, (1,)), } def test_dataset_creation(): """测试数据集创建""" console.print("[bold cyan]测试数据集创建...[/bold cyan]") dataset = MockDataset(num_samples=10) dataloader = DataLoader(dataset, batch_size=2, shuffle=False) batch = next(iter(dataloader)) console.print(f"批处理形状:") console.print(f" input_ids: {batch['input_ids'].shape}") console.print(f" token_type_ids: {batch['token_type_ids'].shape}") console.print(f" attention_mask: {batch['attention_mask'].shape}") console.print(f" history_slot_ids: {batch['history_slot_ids'].shape}") console.print(f" label: {batch['label'].shape}") assert batch["input_ids"].shape == (2, 128), "input_ids形状不正确" assert batch["token_type_ids"].shape == (2, 128), "token_type_ids形状不正确" assert batch["attention_mask"].shape == (2, 128), "attention_mask形状不正确" assert batch["history_slot_ids"].shape == (2, 8), "history_slot_ids形状不正确" assert batch["label"].shape == (2, 1), "label形状不正确" console.print("[green]✓ 数据集测试通过[/green]\n") return dataloader def test_model_creation(): """测试模型创建""" console.print("[bold cyan]测试模型创建...[/bold cyan]") model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, # 使用较小的维度以加速测试 num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) # 测试前向传播 batch_size = 2 input_ids = torch.randint(0, 100, (batch_size, 128)) token_type_ids = torch.randint(0, 2, (batch_size, 128)) attention_mask = torch.ones(batch_size, 128, dtype=torch.long) history_slot_ids = torch.randint(0, 100, (batch_size, 8)) with torch.no_grad(): logits = model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, history_slot_ids=history_slot_ids, ) console.print(f"模型输出形状: {logits.shape}") console.print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") assert logits.shape == (batch_size, 100), "模型输出形状不正确" console.print("[green]✓ 模型测试通过[/green]\n") return model def test_trainer_initialization(): """测试训练器初始化""" console.print("[bold cyan]测试训练器初始化...[/bold cyan]") # 创建模型和数据集 model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) train_dataset = MockDataset(num_samples=50) eval_dataset = MockDataset(num_samples=10) train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False) eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False) # 初始化训练器 trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, output_dir="./test_output", num_epochs=1, total_steps=10, # 限制总步数 learning_rate=1e-4, min_learning_rate=1e-6, weight_decay=0.1, warmup_ratio=0.1, label_smoothing=0.1, grad_accum_steps=1, clip_grad_norm=1.0, eval_frequency=5, save_frequency=10, mixed_precision=False, # 测试时关闭混合精度 use_tensorboard=False, # 测试时关闭TensorBoard ) console.print(f"训练器设备: {trainer.device}") console.print(f"总步数: {trainer.total_steps}") console.print(f"热身步数: {trainer.warmup_steps}") console.print(f"优化器类型: {type(trainer.optimizer)}") console.print(f"损失函数类型: {type(trainer.criterion)}") assert trainer.device.type in ["cpu", "cuda"], "设备类型不正确" assert trainer.total_steps == 10, "总步数不正确" assert trainer.warmup_steps == 1, "热身步数不正确" # 10 * 0.1 = 1 assert isinstance(trainer.optimizer, torch.optim.AdamW), "优化器类型不正确" assert isinstance(trainer.criterion, nn.CrossEntropyLoss), "损失函数类型不正确" console.print("[green]✓ 训练器初始化测试通过[/green]\n") return trainer def test_training_step(): """测试训练步骤""" console.print("[bold cyan]测试训练步骤...[/bold cyan]") # 创建训练器 model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) train_dataset = MockDataset(num_samples=10) train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=None, output_dir="./test_output", num_epochs=1, total_steps=5, learning_rate=1e-4, mixed_precision=False, use_tensorboard=False, ) # 执行一个训练步骤 batch = next(iter(train_dataloader)) loss, metrics = trainer.train_step(batch) console.print(f"训练步骤损失: {loss:.4f}") console.print(f"训练步骤指标: {metrics}") assert isinstance(loss, float), "损失值类型不正确" assert loss >= 0, "损失值应为非负数" assert "lr" in metrics, "指标中缺少学习率" assert "accuracy" in metrics, "指标中缺少准确率" assert 0 <= metrics["accuracy"] <= 1, "准确率应在0-1之间" console.print("[green]✓ 训练步骤测试通过[/green]\n") def test_evaluation(): """测试评估功能""" console.print("[bold cyan]测试评估功能...[/bold cyan]") # 创建训练器 model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) train_dataset = MockDataset(num_samples=10) eval_dataset = MockDataset(num_samples=5) train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False) eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, output_dir="./test_output", num_epochs=1, total_steps=5, learning_rate=1e-4, mixed_precision=False, use_tensorboard=False, ) # 执行评估 eval_metrics = trainer.evaluate() console.print(f"评估指标: {eval_metrics}") assert "eval_loss" in eval_metrics, "评估指标中缺少eval_loss" assert "eval_accuracy" in eval_metrics, "评估指标中缺少eval_accuracy" assert eval_metrics["eval_loss"] >= 0, "评估损失应为非负数" assert 0 <= eval_metrics["eval_accuracy"] <= 1, "评估准确率应在0-1之间" console.print("[green]✓ 评估功能测试通过[/green]\n") def test_lr_scheduler(): """测试学习率调度器""" console.print("[bold cyan]测试学习率调度器...[/bold cyan]") model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) train_dataset = MockDataset(num_samples=10) train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=None, output_dir="./test_output", num_epochs=1, total_steps=100, learning_rate=1e-3, min_learning_rate=1e-5, warmup_ratio=0.2, # 20%热身 mixed_precision=False, use_tensorboard=False, ) # 测试不同步数的学习率 test_steps = [0, 10, 20, 50, 99] lr_values = [] for step in test_steps: trainer.current_step = step lr = trainer._get_current_lr() lr_values.append(lr) console.print(f"步数 {step}: 学习率 = {lr:.2e}") # 验证学习率变化趋势 assert lr_values[0] == 0.0, "第0步学习率应为0" assert lr_values[1] > lr_values[0], "热身阶段学习率应增加" assert lr_values[4] < lr_values[2], "余弦退火阶段学习率应下降" assert lr_values[4] >= 1e-5, "最终学习率不应低于最小值" console.print("[green]✓ 学习率调度器测试通过[/green]\n") def test_checkpoint_saving(): """测试检查点保存""" console.print("[bold cyan]测试检查点保存...[/bold cyan]") import shutil import tempfile # 创建临时目录 temp_dir = tempfile.mkdtemp() try: # 创建训练器 model = InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ) train_dataset = MockDataset(num_samples=10) train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=None, output_dir=temp_dir, num_epochs=1, total_steps=5, learning_rate=1e-4, mixed_precision=False, use_tensorboard=False, ) # 保存检查点 checkpoint_path = Path(temp_dir) / "checkpoints" / "test_checkpoint.pt" trainer.save_checkpoint("test_checkpoint.pt") console.print(f"检查点保存路径: {checkpoint_path}") assert checkpoint_path.exists(), "检查点文件未创建" # 加载检查点 trainer2 = Trainer( model=InputMethodEngine( vocab_size=100, pinyin_vocab_size=28, dim=64, num_slots=8, n_layers=2, n_heads=2, num_experts=4, max_seq_len=128, use_pinyin=False, ), train_dataloader=train_dataloader, eval_dataloader=None, output_dir=temp_dir, num_epochs=1, total_steps=5, learning_rate=1e-4, mixed_precision=False, use_tensorboard=False, ) trainer2.load_checkpoint(checkpoint_path) console.print(f"加载后的步数: {trainer2.current_step}") console.print(f"加载后的epoch: {trainer2.current_epoch}") assert trainer2.current_step == trainer.current_step, "步数未正确恢复" assert trainer2.current_epoch == trainer.current_epoch, "epoch未正确恢复" console.print("[green]✓ 检查点保存测试通过[/green]\n") finally: # 清理临时目录 shutil.rmtree(temp_dir) def main(): """运行所有测试""" console.print("[bold blue]开始测试Trainer类...[/bold blue]\n") try: # 测试1: 数据集创建 test_dataset_creation() # 测试2: 模型创建 test_model_creation() # 测试3: 训练器初始化 test_trainer_initialization() # 测试4: 训练步骤 test_training_step() # 测试5: 评估功能 test_evaluation() # 测试6: 学习率调度器 test_lr_scheduler() # 测试7: 检查点保存 test_checkpoint_saving() console.print("[bold green]所有测试通过! ✅[/bold green]") except Exception as e: console.print(f"[bold red]测试失败: {e}[/bold red]") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": # 清理之前的测试输出 import shutil test_output_dir = Path("./test_output") if test_output_dir.exists(): shutil.rmtree(test_output_dir) # 运行测试 exit_code = main() # 清理测试输出 if test_output_dir.exists(): shutil.rmtree(test_output_dir) sys.exit(exit_code)