diff --git a/README.md.bak b/README.md.bak deleted file mode 100644 index 1fb563e..0000000 --- a/README.md.bak +++ /dev/null @@ -1,916 +0,0 @@ -# 输入法预测模型架构设计 (Input Method Prediction Model) - -## 1. 概述 -本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。 - -## 2. 核心架构流程 -数据流遵循以下路径: -`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码` - -### 2.1 输入层设计 -模型接收三类输入,分别处理以保持语义清晰: -1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。 -2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。 -3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。 - -### 2.2 模块详解 - -#### A. Transformer 编码器 (Context Encoder) -负责提取当前语境的深层语义表示。 -* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。 -* **骨干网络**:使用标准的 Transformer Encoder。 - * **隐藏层维度**:512 [1] - * **Transformer 层数**:4 层(轻量级设计,从头训练) [1] - * **注意力头数**:4 头 [1] -* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。 - -#### B. 槽位记忆模块 (Slot Memory) -负责将非结构化的历史输入转化为结构化的记忆向量。 -* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。 -* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。 -* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。 - -#### C. 交叉注意力融合 (Cross-Attention Fusion) -这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。 -* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。 -* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。 -* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。 -* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。 - -#### D. 门控与专家混合 (Gating + MoE) -实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。 -* **专家数量**:20 个专家 [1]。 -* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。 -* **输出**:经过专家网络增强后的特征向量。 - -#### E. 分类头与解码 -* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。 -* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。 - -## 3. 关键超参数配置 - -为确保模型性能与效率的平衡,建议采用以下超参数 [1]: - -| 参数项 | 推荐值 | 说明 | -| :--- | :--- | :--- | -| **序列长度 (L)** | 128 | 上下文窗口大小 [1] | -| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] | -| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] | -| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] | -| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] | -| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] | -| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] | - -## 4. 训练策略 - -本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。 - -### 4.1 数据构造与标签 -* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。 - * **上下文**:光标前后的文本片段。 - * **拼音**:当前待输入字的拼音序列。 - * **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。 -* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。 - -### 4.2 损失函数与优化 -* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。 - * **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。 -* **优化器**:采用 **AdamW** 进行参数更新 [1]。 - -### 4.3 训练流程细节 -1. **前向传播**: - * 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。 - * 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。 - * 分类头输出当前步所有候选字的概率分布。 -2. **Teacher Forcing**: - * 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。 -3. **反向传播**: - * 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。 - -### 4.4 推理与训练的差异 -* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。 -* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。 - * **束宽**:默认为 5 [1]。 - * **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。 - * **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。 - -## 5. Jupyter Lab 训练示例 - -以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例: - -```python -# %% [markdown] -# # 输入法模型训练示例 -# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型 - -# %% [code] -# 1. 导入必要的库 -import sys -import os -from pathlib import Path -from datetime import datetime - -import torch -from torch.utils.data import DataLoader - -# 添加项目路径(适应不同的Jupyter Lab运行位置) -project_root = Path.cwd() -# 检查当前目录是否包含src目录,如果不包含则使用父目录 -if not (project_root / "src").exists(): - project_root = project_root.parent -sys.path.insert(0, str(project_root)) # 优先搜索项目目录 - -# 导入项目模块 -from src.model.model import InputMethodEngine -from src.model.dataset import PinyinInputDataset -from src.model.trainer import Trainer, worker_init_fn, collate_fn - -# %% [code] -# 2. 配置训练参数 -config = { - # 数据参数 - "train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径 - "eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径 - "output_dir": "./training_output", - - # 模型参数 - "vocab_size": 10019, - "pinyin_vocab_size": 30, - "dim": 512, - "num_slots": 8, - "n_layers": 4, - "n_heads": 4, - "num_experts": 20, - "max_seq_len": 128, - - # 训练参数 - "batch_size": 64, # 根据GPU内存调整 - "num_epochs": 10, - "learning_rate": 3e-4, - "min_learning_rate": 1e-9, - "weight_decay": 0.1, - "warmup_ratio": 0.1, - "label_smoothing": 0.15, - "grad_accum_steps": 2, # 梯度累积,模拟更大batch size - "clip_grad_norm": 1.0, - "eval_frequency": 500, # 每500步评估一次 - "save_frequency": 2000, # 每2000步保存检查点 - - # 高级选项 - "mixed_precision": True, - "use_tensorboard": True, - "seed": 42, - "max_iter_length": 1024 * 1024 * 128, # 最大迭代长度 -} - -# %% [code] -# 3. 设置随机种子和设备 -torch.manual_seed(config["seed"]) -if torch.cuda.is_available(): - torch.cuda.manual_seed_all(config["seed"]) - device = torch.device("cuda") - print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}") -else: - device = torch.device("cpu") - print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)") - -# %% [code] -# 4. 创建数据集和数据加载器 -print("📊 创建数据集和数据加载器...") - -# 训练数据集 -train_dataset = PinyinInputDataset( - data_path=config["train_data_path"], - max_workers=-1, # 自动选择worker数量 - max_iter_length=config["max_iter_length"], - max_seq_length=config["max_seq_len"], - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=5000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, -) - -# 训练数据加载器 -train_dataloader = DataLoader( - train_dataset, - batch_size=config["batch_size"], - num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker - pin_memory=torch.cuda.is_available(), - worker_init_fn=worker_init_fn, - collate_fn=collate_fn, - prefetch_factor=32, - persistent_workers=True, -) - -# 评估数据集 -eval_dataset = PinyinInputDataset( - data_path=config["eval_data_path"], - max_workers=-1, - max_iter_length=1024, # 评估集较小 - max_seq_length=config["max_seq_len"], - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=1000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, -) - -eval_dataloader = DataLoader( - eval_dataset, - batch_size=config["batch_size"], - num_workers=1, - pin_memory=torch.cuda.is_available(), - worker_init_fn=worker_init_fn, - collate_fn=collate_fn, - prefetch_factor=32, - persistent_workers=True, -) - -print(f"✅ 数据加载器创建完成") -print(f" 训练批次大小: {config['batch_size']}") -print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}") - -# %% [code] -# 5. 创建模型 -print("🧠 创建输入法模型...") - -model = InputMethodEngine( - vocab_size=config["vocab_size"], - pinyin_vocab_size=config["pinyin_vocab_size"], - dim=config["dim"], - num_slots=config["num_slots"], - n_layers=config["n_layers"], - n_heads=config["n_heads"], - num_experts=config["num_experts"], - max_seq_len=config["max_seq_len"], -) - -# 将模型移动到设备 -model.to(device) - -# 计算参数量 -total_params = sum(p.numel() for p in model.parameters()) -trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - -print(f"✅ 模型创建完成") -print(f" 总参数量: {total_params:,}") -print(f" 可训练参数量: {trainable_params:,}") -print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家") - -# %% [code] -# 6. 创建训练器 -print("⚙️ 创建训练器...") - -# 计算总训练步数 -total_steps = int(config["max_iter_length"] / config["batch_size"]) - -trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - total_steps=total_steps, - output_dir=config["output_dir"], - num_epochs=config["num_epochs"], - learning_rate=config["learning_rate"], - min_learning_rate=config["min_learning_rate"], - weight_decay=config["weight_decay"], - warmup_ratio=config["warmup_ratio"], - label_smoothing=config["label_smoothing"], - grad_accum_steps=config["grad_accum_steps"], - clip_grad_norm=config["clip_grad_norm"], - eval_frequency=config["eval_frequency"], - save_frequency=config["save_frequency"], - mixed_precision=config["mixed_precision"], - use_tensorboard=config["use_tensorboard"], -) - -print(f"✅ 训练器创建完成") -print(f" 总训练步数: {total_steps:,}") -print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}") -print(f" 输出目录: {config['output_dir']}") - -# %% [code] -# 7. 开始训练 -print("🚀 开始训练...") -print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - -try: - # 开始训练(可以从检查点恢复训练) - trainer.train(resume_from=None) # 设置检查点路径以恢复训练 - - print("✅ 训练完成!") - print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - print(f"模型和日志保存在: {config['output_dir']}") - -except KeyboardInterrupt: - print("⏹️ 训练被用户中断") - print("💾 保存当前检查点...") - trainer.save_checkpoint("interrupted") - print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt") - -except Exception as e: - print(f"❌ 训练过程中出现错误: {e}") - import traceback - traceback.print_exc() - -# %% [code] -# 8. 监控训练进度(如果使用TensorBoard) -if config["use_tensorboard"]: - print("📈 TensorBoard日志已记录在:") - print(f" {config['output_dir']}/tensorboard") - print("\n启动TensorBoard查看训练进度:") - print(" tensorboard --logdir ./training_output/tensorboard") - print("然后在浏览器中打开: http://localhost:6006") - -# %% [code] -# 9. 加载训练好的模型进行推理(示例) -def load_trained_model(checkpoint_path): - """加载训练好的模型进行检查点""" - print(f"📥 加载检查点: {checkpoint_path}") - - # 创建与训练时相同配置的模型 - loaded_model = InputMethodEngine( - vocab_size=config["vocab_size"], - pinyin_vocab_size=config["pinyin_vocab_size"], - dim=config["dim"], - num_slots=config["num_slots"], - n_layers=config["n_layers"], - n_heads=config["n_heads"], - num_experts=config["num_experts"], - max_seq_len=config["max_seq_len"], - ) - - # 加载检查点 - checkpoint = torch.load(checkpoint_path, map_location=device) - loaded_model.load_state_dict(checkpoint["model_state_dict"]) - loaded_model.to(device) - loaded_model.eval() - - print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}") - print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}") - - return loaded_model - -# 使用示例(取消注释以使用) -# trained_model = load_trained_model("./training_output/checkpoint_final.pt") -``` - -### 关键说明 - -1. **环境要求**: - - Python 3.12+ - - PyTorch 2.10+ - - 建议使用GPU进行训练 - - 安装项目依赖:`pip install -e .` - -2. **数据集格式**: - - 使用Hugging Face `datasets`格式 - - 必须包含`text`字段 - - 支持流式读取(streaming=True) - -3. **训练监控**: - - 控制台输出训练进度和指标 - - TensorBoard记录损失、准确率、学习率等 - - 定期保存模型检查点 - -4. **可调整参数**: - - `batch_size`: 根据GPU内存调整 - - `learning_rate`: 建议在1e-4到5e-4之间 - - `grad_accum_steps`: 模拟更大batch size - - `num_epochs`: 根据数据集大小调整 - -5. **故障排除**: - - GPU内存不足:减小`batch_size`或增加`grad_accum_steps` - - 训练不稳定:降低`learning_rate`或增加`warmup_ratio` - - 过拟合:增加`label_smoothing`或使用更大数据集 - -## 6. 使用指南 - -本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。 - -### 6.1 安装与准备 - -#### 使用 uv(推荐) -本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。 - -1. **安装 uv**(如果尚未安装): - ```bash - # Linux/macOS - curl -LsSf https://astral.sh/uv/install.sh | sh - - # 或使用 pipx - pipx install uv - - # Windows (PowerShell) - powershell -c "irm https://astral.sh/uv/install.ps1 | iex" - ``` - -2. **安装项目依赖**: - ```bash - uv pip install -e . - ``` - -#### 使用传统 pip -如果不使用 uv,也可以用标准的 pip 安装: - -```bash -# 创建并激活虚拟环境(推荐) -python -m venv .venv -source .venv/bin/activate # Linux/macOS -# .venv\Scripts\activate # Windows - -# 安装依赖 -pip install -e . -``` - -#### 验证安装 -安装完成后,可通过以下命令验证: -```bash -train-model --help -``` - -### 6.2 数据格式 - -训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。 - -#### 本地数据集示例 -```python -# dataset.py -from datasets import Dataset - -data = { - "text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"] -} -dataset = Dataset.from_dict(data) -dataset.save_to_disk("./local_dataset") -``` - -#### 远程数据集示例 -支持Hugging Face Hub或ModelScope上的数据集: -- `huggingface.co/datasets/username/dataset_name` -- `modelscope.cn/datasets/username/dataset_name` - -#### 数据格式要求 -- **必需字段**: `text`(字符串类型,包含中文文本) -- **流式读取**: 数据集必须支持 `streaming=True` 参数 -- **数据量**: 建议至少数百万条文本以获得良好效果 - -#### 数据预处理 -数据集会自动进行以下处理: -1. 文本分词和编码 -2. 拼音转换和编码 -3. 上下文窗口滑动生成训练样本 -4. 频率调整(削峰填谷)以平衡高频/低频字词 - -### 6.3 基本训练命令 - -使用 `train-model train` 命令开始训练: - -```bash -train-model train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --output-dir "./output" \ - --batch-size 128 \ - --num-epochs 10 \ - --learning-rate 1e-5 -``` - -#### 检查点恢复训练 - -要从检查点恢复训练(保持原有的训练状态): - -```bash -train-model train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --resume-from "./output/checkpoints/latest_checkpoint.pt" -``` - -#### 重置训练状态 - -如果只想加载模型权重,从头开始训练(学习率、epoch等都重新开始): - -```bash -train-model train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --resume-from "./output/checkpoints/best_model.pt" \ - --reset-training-state -``` - -这个功能在以下场景非常有用: -- 想要用预训练权重初始化模型,但用新的训练计划重新训练 -- 需要调整学习率策略或训练时长 -- 在现有模型基础上进行迁移学习 - -#### 学习率建议 -根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围: -- **标准范围**: 1e-4 ~ 5e-4 -- **配合Warmup策略**:在训练初期逐步提高学习率 -- **余弦退火**:使用最小学习率 1e-9 进行细调 - -### 6.4 参数详解 - -#### 数据参数 -- `--train-data-path`, `-t`: 训练数据集路径(必需) -- `--eval-data-path`, `-e`: 评估数据集路径(必需) -- `--output-dir`, `-o`: 输出目录(默认:`./output`) -- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728) - -#### 模型参数 -- `--vocab-size`: 词汇表大小(默认:10019) -- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30) -- `--dim`: 模型维度(默认:512) -- `--num-slots`: 历史槽位数量(默认:8) -- `--n-layers`: Transformer层数(默认:4) -- `--n-heads`: 注意力头数(默认:4) -- `--num-experts`: MoE专家数量(默认:20) -- `--max-seq-len`: 最大序列长度(默认:128) -- `--use-pinyin`: 是否使用拼音特征(默认:False) - -#### 训练参数 -- `--batch-size`, `-b`: 批次大小(默认:128) -- `--num-epochs`: 训练轮数(默认:10) -- `--learning-rate`, `-lr`: 学习率(默认:1e-5) -- `--min-learning-rate`: 最小学习率(默认:1e-9) -- `--weight-decay`: 权重衰减(默认:0.1) -- `--warmup-ratio`: 热身步数比例(默认:0.1) -- `--label-smoothing`: 标签平滑参数(默认:0.15) -- `--grad-accum-steps`: 梯度累积步数(默认:1) -- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0) -- `--eval-frequency`: 评估频率(默认:500步) -- `--save-frequency`: 保存频率(默认:10000步) - -#### 高级选项 -- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用) -- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用) -- `--resume-from`: 从检查点恢复训练(可选) -- `--reset-training-state`: 重置训练状态,只加载模型权重从头开始训练(默认:False) -- `--seed`: 随机种子(默认:42) - -### 6.5 监控训练进度 - -训练过程中会显示: -- 当前训练步数/总步数 -- 损失值和准确率 -- 学习率变化 -- 内存使用情况 - -启用TensorBoard后,可使用以下命令查看可视化结果: - -```bash -tensorboard --logdir ./output/tensorboard -``` - -### 6.6 基于JSON旁路记录法的移动端监控方案 - -为了提供移动端友好的训练监控体验,我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时,额外写入一份JSON状态文件,并通过Streamlit提供移动端友好的Web界面。 - -#### 方案特点 - -**📱 移动端体验** -- Streamlit自动生成响应式界面,完美适配手机屏幕 -- 图表支持双指缩放和滑动操作 -- 大字体显示核心指标,触控操作便捷 - -**🚀 低耦合架构** -- 训练和监控通过文件系统解耦 -- 监控服务重启不影响训练进程 -- 训练脚本只需几行代码修改即可支持 - -**🔒 安全稳定** -- 纯文本JSON文件,无文件锁冲突问题 -- 读写速度快,稳定性高 -- 不会影响原有的TensorBoard记录流程 - -**📊 实时监控** -- 默认每5秒自动刷新数据 -- 实时显示训练进度和指标趋势 -- 数据新鲜度状态指示(实时/较新/较旧/陈旧) - -#### 使用方法 - -**启动监控服务** -```bash -# 启动监控服务(默认端口8501) -monitor-training monitor - -# 指定状态文件路径和端口 -monitor-training monitor --status-file ./output/training_status.json --port 8080 - -# 不自动打开浏览器 -monitor-training monitor --no-browser - -# 指定自定义Streamlit脚本 -monitor-training monitor --streamlit-script ./custom_monitor.py -``` - -**查看训练状态** -```bash -# 查看最近10条训练记录 -monitor-training view - -# 查看最近50条记录(原始JSON格式) -monitor-training view --limit 50 --raw - -# 查看指定状态文件 -monitor-training view /path/to/status.json -``` - -**检查状态文件** -```bash -# 检查状态文件状态 -monitor-training check - -# 检查指定文件 -monitor-training check ./output/training_status.json -``` - -**启动HTTP静态文件服务** -```bash -# 启动HTTP静态文件服务(默认端口8080) -monitor-training serve - -# 指定状态文件路径和端口 -monitor-training serve --status-file ./output/training_status.json --port 8080 - -# 禁用CORS支持(默认启用) -monitor-training serve --no-cors - -# 指定主机地址 -monitor-training serve --host 0.0.0.0 --port 8080 -``` - -#### 监控界面功能 - -**📊 核心指标看板** -- 当前步数、轮次、训练损失、准确率 -- 评估损失和准确率 -- 当前学习率、最后更新时间 - -**📈 趋势图表** -- 损失曲线(训练损失 + 评估损失) -- 准确率曲线(训练准确率 + 评估准确率) -- 学习率变化图(对数坐标) - -**📋 数据详情** -- 完整的训练记录表格 -- 数据统计信息(总数据点、训练时长、总步数) -- 训练进度条 - -**⚙️ 配置选项** -- 状态文件路径(支持环境变量 `TRAINING_STATUS_FILE`) -- 自动刷新间隔(1-30秒可调) -- 显示数据点数量(10-1000条可调) - -#### 技术实现 - -**训练端改造** -- 在 `Trainer.__init__` 中添加 `status_file` 参数 -- 实现 `_write_training_status()` 方法,在每次评估时写入JSON文件 -- 支持从现有状态文件恢复,避免数据丢失 - -**监控端搭建** -- 使用Streamlit构建移动端友好的Web界面 -- 采用Plotly图表库,支持触控交互 -- 自动刷新机制,实时更新训练状态 - -**命令行工具** -- 提供 `monitor`、`view`、`check` 三个子命令 -- 自动检测Streamlit可用性 -- 支持环境变量传递 - -#### 访问方式 - -**本地访问** -```bash -# 启动监控服务后,通过浏览器访问 -http://localhost:8501 -``` - -**局域网访问** -```bash -# 启动服务时指定主机地址 -monitor-training monitor --host 0.0.0.0 --port 8080 - -# 手机浏览器访问(同一局域网) -http://192.168.1.100:8080 -``` - -**公网访问**(需端口转发) -```bash -# 确保服务器防火墙开放对应端口 -# 通过域名或公网IP访问 -http://your-server.com:8501 -``` - -**远程HTTP监控** -```bash -# GPU服务器启动HTTP服务 -monitor-training serve --port 8080 --host 0.0.0.0 - -# 本地运行Streamlit监控,从HTTP URL读取数据 -monitor-training monitor --host 127.0.0.1 --port 8501 - -# 在Streamlit界面输入远程URL: -http://:8080/training_status.json -``` - -#### 状态文件格式 - -状态文件 `training_status.json` 位于训练输出目录,格式如下: -```json -[ - { - "step": 100, - "epoch": 1, - "timestamp": "2024-01-01T12:00:00", - "train/loss": 2.345, - "train/accuracy": 0.456, - "eval/loss": 2.123, - "eval/accuracy": 0.512, - "train/learning_rate": 0.0001 - }, - ... -] -``` - -#### HTTP静态文件服务与远程监控 - -针对GPU服务器只支持HTTP协议(不支持WebSockets)的环境,我们提供了HTTP静态文件服务方案,实现远程训练监控。 - -**🔧 技术特点** -- 纯HTTP协议,无需WebSockets支持 -- 原子写入机制,避免读取不完整JSON数据 -- 自动重试和JSON验证,确保数据完整性 -- CORS支持,方便跨域访问 -- 轻量级设计,不影响训练性能 - -**🚀 工作原理** -1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件 -2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务 -3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面 -4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据 -5. Streamlit:通过HTTP轮询获取实时训练数据并展示 - -**🛡️ 数据安全** -- 原子写入:先写入临时文件,然后原子重命名,避免读取中断 -- JSON验证:HTTP服务端验证JSON格式后才返回数据 -- 临时文件处理:智能识别和读取`.tmp`临时文件 -- 重试机制:JSON解析失败时自动重试读取 - -**🌐 网络要求** -- GPU服务器:需要开放HTTP端口(默认8080) -- 本地机器:需要能访问GPU服务器的HTTP端口 -- 网络协议:纯HTTP,兼容防火墙和代理 - -#### 注意事项 -1. 首次监控时如果状态文件不存在,会自动创建空文件 -2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0` -3. 从检查点恢复训练时会自动加载已有的状态数据 -4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟 -5. HTTP服务支持原子写入,避免训练进程写入时读取不完整JSON -6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口 -7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问 - - -### 6.7 评估模型(开发中) - -当前评估功能尚在开发中: - -```bash -train-model evaluate \ - --checkpoint "./output/checkpoint_final.pt" \ - --data-path "path/to/eval/dataset" \ - --batch-size 32 -``` - -命令将显示"评估功能待实现"的提示信息。该功能计划用于: -- 加载训练好的模型检查点 -- 在评估数据集上计算准确率、困惑度等指标 -- 生成详细的性能报告 - -### 6.8 模型扩容两阶段训练 - -当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。 - -#### 训练策略 - -1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等) -2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练 - -#### 基础用法 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "model:InputMethodEngine" \ - --num-experts 40 \ - --frozen-lr 2e-3 \ - --full-lr 5e-5 \ - --frozen-patience 8 -``` - -#### 完整参数示例 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --output-dir "./expansion_output" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "custom_model:ExpandedModel" \ - --vocab-size 10019 \ - --dim 512 \ - --num-experts 40 \ - --frozen-patience 10 \ - --frozen-lr 1e-3 \ - --full-lr 1e-4 \ - --frozen-scheduler cosine \ - --full-scheduler cosine \ - --batch-size 128 \ - --num-epochs 20 \ - --compile -``` - -#### 参数详解 - -**模型扩容参数** -- `--base-model-path`: 预训练基础模型检查点路径(必需) -- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需) - - 支持任意路径的模块导入,模块文件需包含自定义的模型类 - - 自定义模型类必须是 `InputMethodEngine` 的子类 - - 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类 - -**两阶段训练参数** -- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10) -- `--frozen-lr`: 冻结阶段学习率(默认:1e-3) -- `--full-lr`: 全量微调阶段学习率(默认:1e-4) -- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) -- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) - -**其他参数** -- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等) -- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等 - -#### 使用场景 - -1. **增加专家数量**(20→40) - - 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等) - - 新增参数:新专家网络、gate层 - -2. **增加top_k值**(2→3) - - 冻结效果:100% 参数可冻结(仅逻辑变化) - - 新增参数:无 - -3. **修改专家内部结构**(如增加resblocks) - - 冻结效果:~50% 参数可冻结(linear_in/output可冻结) - - 新增参数:新增的resblocks层 - -4. **增加Transformer层数**(4→5) - - 冻结效果:~80% 参数可冻结(前4层可冻结) - - 新增参数:新增的第5层 - -#### 自定义模型类示例 - -```python -# my_model.py -from model.model import InputMethodEngine - -class MyExpandedModel(InputMethodEngine): - def __init__(self, num_experts=40, **kwargs): - # 调用父类构造函数,覆盖num_experts参数 - super().__init__(num_experts=num_experts, **kwargs) - # 可以在这里添加额外的层或修改现有层 - -# 使用命令 -# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ... -``` - -#### 注意事项 - -1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类 -2. **冻结条件**:只有权重形状完全匹配的层才会被冻结 -3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能 -4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency` -5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载 - -### 6.9 导出模型(开发中) - -当前导出功能尚在开发中: - -```bash -train-model export \ - --checkpoint "./output/checkpoint_final.pt" \ - --output "./exported_model.onnx" -``` - -命令将显示"导出功能待实现"的提示信息。该功能计划用于: -- 将PyTorch模型转换为ONNX格式 -- 支持在不同推理引擎上部署 -- 提供优化后的推理模型 - -## 7. 总结 -本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。 \ No newline at end of file diff --git a/eval.py b/eval.py index 0eb5a55..4429e83 100644 --- a/eval.py +++ b/eval.py @@ -274,8 +274,43 @@ class TextEvaluator: part1 = text[i - 48 : i] # part2: 拼音输入(随机长度1-8,高斯分布) + # 当强制指定历史槽位时,需要确保拼音长度足够 pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] - pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) + min_pinyin_len = (force_slot_count + 1) if force_slot_count is not None else 1 + min_pinyin_len = max(1, min(8, min_pinyin_len)) + + # 检查剩余可用字符数量 + remaining_chars = len(text) - i + max_valid_len = 0 + for j in range(i, min(i + 8, len(text))): + if self.query_engine and self.query_engine.is_chinese_char(text[j]): + max_valid_len += 1 + else: + break + + # 调整拼音长度:必须至少有 min_pinyin_len 个有效字符 + if max_valid_len < min_pinyin_len: + # 重新选择位置,确保有足够字符 + raise ValueError( + f"位置 {i} 后只有 {max_valid_len} 个有效汉字,不足以测试历史槽位 {force_slot_count}" + ) + + # 随机选择拼音长度(至少 min_pinyin_len) + if force_slot_count is not None: + # 强制模式:随机选择比 min_pinyin_len 大的长度 + valid_lengths = [ + l for l in range(min_pinyin_len, min(max_valid_len + 1, 9)) + ] + if not valid_lengths: + valid_lengths = [min_pinyin_len] + # 使用原来的概率分布,但只选择有效长度 + adjusted_probs = [pinyin_len_probs[l - 1] for l in valid_lengths] + adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs) + pinyin_len = np.random.choice(valid_lengths, p=adjusted_probs) + else: + pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) + pinyin_len = min(pinyin_len, max_valid_len) + py_end = min(i + pinyin_len, len(text)) # 获取拼音 @@ -317,7 +352,7 @@ class TextEvaluator: string_list.append(text[start_pos:end_pos]) part4 = "|".join(string_list) - # 标签:预测字符的ID + # 标签:预测字符的ID(整个拼音序列的所有字符) labels = [] try: labels = [ @@ -328,27 +363,38 @@ class TextEvaluator: ) ] except (AttributeError, TypeError) as e: - # 如果查询失败,使用默认ID print(f"⚠️ 获取标签失败: {e}") labels = [0] * pinyin_len_actual - # 历史槽位:模拟用户逐步确认过程 - # 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择 - # 注意:评估时不知道正确答案,只能基于模型预测来构建历史 + # 历史槽位:与训练数据生成逻辑一致 + # 当 force_slot_count=k 时,历史槽位为 labels[:k],预测目标为 labels[k] history_slot_ids = [] + predict_idx = 0 # 当前预测的标签索引 - # 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致) if force_slot_count is not None: force_slot_count = max(0, min(8, force_slot_count)) - # 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0 - # 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致) - for _ in range(force_slot_count): - history_slot_ids.append(random.randint(1, 1000)) + # 检查 labels 长度是否足够 + if len(labels) <= force_slot_count: + raise ValueError( + f"标签数量 {len(labels)} 不足以支持历史槽位 {force_slot_count}" + ) + predict_idx = force_slot_count + history_slot_ids = labels[:force_slot_count] else: - # 正常评估:历史槽位为空(模拟从零开始输入) - # 注意:实际使用时,历史槽位应该来自之前的用户选择 - # 但为了评估公平,我们从空历史开始 - pass + # 正常评估:随机选择历史长度,模拟训练数据分布 + history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] + max_history = min(len(labels) - 1, 8) # 至少留一个字符预测 + if max_history < 0: + max_history = 0 + valid_history_lens = list(range(0, max_history + 1)) + if valid_history_lens: + adjusted_probs = [history_weights[h] for h in valid_history_lens] + adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs) + history_len = np.random.choice(valid_history_lens, p=adjusted_probs) + else: + history_len = 0 + predict_idx = history_len + history_slot_ids = labels[:history_len] # 填充到8个槽位 if len(history_slot_ids) < 8: @@ -373,6 +419,12 @@ class TextEvaluator: # 拼音输入长度(字符数) pinyin_input_length = len(part2) + # 当前预测的字符和标签 + current_target_char = ( + text[i + predict_idx] if (i + predict_idx) < len(text) else "" + ) + current_target_label = labels[predict_idx] if predict_idx < len(labels) else 0 + # 构建样本 sample = { "input_ids": encoded["input_ids"], @@ -385,11 +437,13 @@ class TextEvaluator: "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids_tensor.unsqueeze(0), - "true_labels": labels, # 真实标签(多个字符) + "true_labels": labels, # 完整标签列表(所有字符) + "predict_idx": predict_idx, # 当前预测的标签索引 + "current_target_label": current_target_label, # 当前预测的标签 "position": i, - "target_char": text[i] if i < len(text) else "", - "valid_slot_count": valid_slot_count, # 有效槽位数量 - "pinyin_input_length": pinyin_input_length, # 拼音输入长度 + "target_char": current_target_char, # 当前预测的字符 + "valid_slot_count": valid_slot_count, + "pinyin_input_length": pinyin_input_length, } return sample @@ -483,56 +537,50 @@ class TextEvaluator: return analysis - def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: + def evaluate_sample_single_char( + self, sample: Dict, print_details: bool = True + ) -> Dict: """ - 评估单个样本。 + 单字符评估(与验证集 trainer.py 评估逻辑一致)。 + + 使用 argmax 选择预测字符,与训练时验证逻辑完全一致。 + 根据 predict_idx 选择正确的预测目标。 Returns: 评估结果字典 """ - # 执行推理 logits, probs = self.inference(sample) - - # 分析概率分布 analysis = self.analyze_probability_distribution(probs) - # 获取真实标签 true_labels = sample.get("true_labels", []) + predict_idx = sample.get("predict_idx", 0) + current_target_label = sample.get( + "current_target_label", true_labels[predict_idx] if true_labels else 0 + ) target_char = sample.get("target_char", "") - # 检查预测是否正确(只检查第一个字符) - correct = False - pred_top_idx = analysis["top_indices"][0] - if true_labels and len(true_labels) > 0: - correct = pred_top_idx == true_labels[0] + pred_idx = analysis["top_indices"][0] + correct = pred_idx == current_target_label - # 打印结果 if print_details: - # 获取槽位信息 valid_slot_count = sample.get("valid_slot_count", 0) pinyin_input_length = sample.get("pinyin_input_length", 0) - # 简化输出格式(在同一行继续输出) print(f" {target_char}", end="") - if true_labels: - print(f"(ID:{true_labels[0]})", end="") + print(f"(ID:{current_target_label})", end="") print( f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", end="", ) print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") - - # 关键概率信息 print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") if analysis["low_variance"]: print(f" ⚠️", end="") - # Top-2 预测 top_strs = [] for j in range(min(2, len(analysis["top_indices"]))): idx = analysis["top_indices"][j] prob = analysis["top_probs"][j] - # 处理特殊ID if idx == 0: char = "[空]" else: @@ -542,34 +590,43 @@ class TextEvaluator: else None ) char = char_info.char if char_info else f"[{idx}]" - is_correct = ( - (true_labels and idx == true_labels[0]) if true_labels else False - ) + is_correct = idx == current_target_label correct_mark = "✓" if is_correct else "" top_strs.append(f"{char}{correct_mark}({prob:.2f})") print(f" | Top-2: {' '.join(top_strs)}", end="") - print(f" | 结果:{'✓' if correct else '✗'}") - # 返回评估结果 - result = { + return { "sample": sample, "analysis": analysis, "correct": correct, "target_char": target_char, - "true_label": true_labels[0] if true_labels else None, - "predicted_label": pred_top_idx, + "true_label": current_target_label, + "predicted_label": pred_idx, + "predict_idx": predict_idx, } - return result + def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: + """ + 评估单个样本(兼容旧接口,调用单字符评估)。 + + Returns: + 评估结果字典 + """ + return self.evaluate_sample_single_char(sample, print_details) def evaluate_sample_with_sequential_confirmation( - self, sample: Dict, print_details: bool = True + self, sample: Dict, print_details: bool = True, use_oracle_history: bool = True ) -> Dict: """ 评估样本并模拟用户逐步确认过程(用于多字符预测)。 - 对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。 + Args: + sample: 样本字典 + print_details: 是否打印详细信息 + use_oracle_history: True=使用真实标签作为历史(与训练/验证一致),False=使用预测结果 + + 对于拼音长度>1的情况,模拟用户逐步选择过程。 Returns: 评估结果字典 @@ -583,7 +640,7 @@ class TextEvaluator: return self.evaluate_sample(sample, print_details) # 多字符预测:模拟逐步确认 - confirmed_history = [] # 用户已确认的字符ID + confirmed_history = [] # 已确认的字符ID all_predictions = [] # 每一步的预测结果 all_correct = [] # 每一步是否正确 @@ -591,8 +648,16 @@ class TextEvaluator: sample_copy = sample.copy() for step in range(pinyin_len): - # 更新历史槽位:使用已确认的字符 - history_slot_ids = confirmed_history[:] + # 更新历史槽位: + # - oracle模式:使用真实标签作为历史(与训练一致) + # - user模式:使用模型预测作为历史(模拟实际使用) + if use_oracle_history: + # 使用真实标签作为历史(与训练数据生成逻辑一致) + history_slot_ids = true_labels[:step] + else: + # 使用预测结果作为历史 + history_slot_ids = confirmed_history[:] + if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: @@ -626,9 +691,8 @@ class TextEvaluator: ) all_correct.append(correct) - # 模拟用户选择:将Top-1预测添加到已确认历史 - # 注意:这里假设用户总是选择Top-1结果 - if pred_top_idx > 0: # 只添加有效ID + # 更新已确认历史(用于下一轮) + if pred_top_idx > 0: confirmed_history.append(pred_top_idx) # 计算整体准确率:所有步骤都正确才算正确 @@ -643,6 +707,8 @@ class TextEvaluator: print(f" {target_char}", end="") print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="") print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="") + mode_str = "oracle" if use_oracle_history else "user" + print(f" | [{mode_str}]", end="") # 显示逐步预测结果 step_results = [] @@ -672,13 +738,16 @@ class TextEvaluator: return result - def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: + def evaluate_text( + self, text: str, num_samples: int = 10, use_sequential: bool = False + ) -> List[Dict]: """ 评估给定文本,生成多个样本进行评估。 Args: text: 输入文本 num_samples: 生成的样本数量 + use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致) Returns: 评估结果列表 @@ -699,10 +768,15 @@ class TextEvaluator: try: sample = self.create_sample_from_text(text) - # 使用逐步确认评估 - result = self.evaluate_sample_with_sequential_confirmation( - sample, print_details=True - ) + # 默认使用单字符评估(与验证集一致) + if use_sequential: + result = self.evaluate_sample_with_sequential_confirmation( + sample, print_details=True, use_oracle_history=True + ) + else: + result = self.evaluate_sample_single_char( + sample, print_details=True + ) results.append(result) except Exception as e: @@ -718,7 +792,7 @@ class TextEvaluator: return results def evaluate_by_slot_count( - self, text: str, samples_per_slot: int = 100 + self, text: str, samples_per_slot: int = 100, use_sequential: bool = False ) -> Dict[int, Dict]: """ 按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。 @@ -726,6 +800,7 @@ class TextEvaluator: Args: text: 输入文本 samples_per_slot: 每个槽位数量的测试样本数 + use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致) Returns: 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} @@ -735,7 +810,8 @@ class TextEvaluator: text = text[start : start + 300] print(f"\n{'=' * 60}") - print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本") + mode_str = "逐步确认" if use_sequential else "单字符(与验证集一致)" + print(f"按槽位数量评估 [{mode_str}] | 每组 {samples_per_slot} 个样本") print(f"{'=' * 60}") slot_results = {} @@ -750,12 +826,19 @@ class TextEvaluator: sample = self.create_sample_from_text( text, force_slot_count=slot_count ) - # 使用逐步确认评估(更符合实际使用场景) - result = self.evaluate_sample_with_sequential_confirmation( - sample, print_details=False - ) + # 默认使用单字符评估(与验证集一致) + if use_sequential: + result = self.evaluate_sample_with_sequential_confirmation( + sample, print_details=False, use_oracle_history=True + ) + else: + result = self.evaluate_sample_single_char( + sample, print_details=False + ) results.append(result) - # 注意:逐步确认评估没有low_variance分析 + # 统计低方差样本 + if "analysis" in result and result["analysis"].get("low_variance"): + low_var_count += 1 except Exception: failed += 1 continue @@ -764,14 +847,15 @@ class TextEvaluator: correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) - # 对于逐步确认评估,计算平均步数准确率 + # 计算平均步数准确率(仅用于逐步确认模式) step_accuracies = [] - for r in results: - if "step_correct" in r: - step_correct = r["step_correct"] - if step_correct: - step_acc = sum(step_correct) / len(step_correct) - step_accuracies.append(step_acc) + if use_sequential: + for r in results: + if "step_correct" in r: + step_correct = r["step_correct"] + if step_correct: + step_acc = sum(step_correct) / len(step_correct) + step_accuracies.append(step_acc) mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0 @@ -781,17 +865,25 @@ class TextEvaluator: "correct": correct_count, "total": len(results), "mean_step_accuracy": mean_step_accuracy, + "low_var_count": low_var_count, "failed": failed, } bar_len = 30 filled = int(bar_len * accuracy) bar = "█" * filled + "░" * (bar_len - filled) - print( - f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " - f"({correct_count:2d}/{len(results):2d}) " - f"| 平均步数准确率: {mean_step_accuracy:.1%}" - ) + if use_sequential: + print( + f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " + f"({correct_count:2d}/{len(results):2d}) " + f"| 平均步数准确率: {mean_step_accuracy:.1%}" + ) + else: + print( + f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " + f"({correct_count:2d}/{len(results):2d}) " + f"| 低方差: {low_var_count}" + ) else: slot_results[slot_count] = { "results": [], @@ -799,6 +891,7 @@ class TextEvaluator: "correct": 0, "total": 0, "mean_step_accuracy": 0, + "low_var_count": 0, "failed": failed, } print(f" 槽位 {slot_count}/8 | 全部样本生成失败") @@ -851,11 +944,11 @@ def main(): parser.add_argument( "--checkpoint", type=str, - default="/home/songsenand/下载/best_model.pt", - help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", + default="/home/songsenand/下载/best_model.ptrom", + help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.ptrom)", ) parser.add_argument( - "--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)" + "--num-samples", type=int, default=100, help="评估样本数量 (默认: 100)" ) parser.add_argument( "--device", @@ -864,6 +957,11 @@ def main(): choices=["auto", "cpu", "cuda"], help="推理设备 (默认: auto)", ) + parser.add_argument( + "--use-sequential", + action="store_true", + help="使用逐步确认评估模式(默认使用单字符评估,与验证集一致)", + ) args = parser.parse_args() @@ -890,6 +988,9 @@ def main(): print(f"🔧 使用设备: {device}") print(f"📦 模型checkpoint: {args.checkpoint}") print(f"🔬 评估样本数: {args.num_samples}") + print( + f"📊 评估模式: {'逐步确认' if args.use_sequential else '单字符(与验证集一致)'}" + ) # 初始化评估器 try: @@ -902,7 +1003,9 @@ def main(): sys.exit(1) # 执行评估 - evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples) + evaluator.evaluate_by_slot_count( + text, samples_per_slot=args.num_samples, use_sequential=args.use_sequential + ) if __name__ == "__main__": diff --git a/src/model/components.py b/src/model/components.py index f66c83e..1bb6173 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -28,6 +28,11 @@ class AttentionPooling(nn.Module): # ---------------------------- 拼音LSTM编码器 ---------------------------- class PinyinLSTMEncoder(nn.Module): + """ + 拼音序列编码器,返回每位置的拼音编码。 + 使用双向LSTM,每个位置都能看到前后文信息。 + """ + def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2): super().__init__() self.input_dim = input_dim @@ -35,7 +40,6 @@ class PinyinLSTMEncoder(nn.Module): self.num_layers = num_layers self.dropout = dropout - # Bidirectional LSTM self.lstm = nn.LSTM( input_size=input_dim, hidden_size=self.hidden_dim, @@ -45,7 +49,6 @@ class PinyinLSTMEncoder(nn.Module): dropout=dropout if num_layers > 1 else 0.0, ) - # Project concatenated hidden states to input_dim self.proj = nn.Linear(self.hidden_dim * 2, input_dim) self.layer_norm = nn.LayerNorm(input_dim) @@ -53,32 +56,25 @@ class PinyinLSTMEncoder(nn.Module): """ Args: x: [batch, seq_len, input_dim] pinyin embeddings - mask: [batch, seq_len] optional padding mask (0 for padding) + mask: [batch, seq_len] optional padding mask (True for valid, False for padding) Returns: - pooled: [batch, input_dim] global pinyin representation + output: [batch, seq_len, input_dim] 每位置的拼音编码 """ + total_len = x.size(1) + if mask is not None: - # lengths for pack_padded_sequence - lengths = mask.sum(dim=1).cpu() - # pack sequence + lengths = mask.sum(dim=1).cpu().clamp(min=1) packed = nn.utils.rnn.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) packed_out, (hidden, cell) = self.lstm(packed) - # hidden shape: [num_layers * 2, batch, hidden_dim] - # Take last layer's forward and backward hidden states - forward_hidden = hidden[-2, :, :] # last layer forward - backward_hidden = hidden[-1, :, :] # last layer backward - hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1) + output, _ = nn.utils.rnn.pad_packed_sequence( + packed_out, batch_first=True, total_length=total_len + ) else: - # No mask, assume all sequences same length output, (hidden, cell) = self.lstm(x) - # hidden shape: [num_layers * 2, batch, hidden_dim] - forward_hidden = hidden[-2, :, :] - backward_hidden = hidden[-1, :, :] - hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1) - projected = self.proj(hidden_concat) + projected = self.proj(output) return self.layer_norm(projected) @@ -170,46 +166,32 @@ class ContextEncoder(nn.Module): def forward(self, text_ids, pinyin_ids, mask=None): """ Args: - text_ids: [batch, seq_len] - pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理) - mask: [batch, seq_len] optional padding mask + text_ids: [batch, seq_len] 文本 token ids + pinyin_ids: [batch, pinyin_len] 拼音 token ids + mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding) Returns: - H: [batch, seq_len, 512] Context representation [1] + H: [batch, seq_len, dim] 文本上下文编码 + P: [batch, pinyin_len, dim] 拼音序列编码(每位置) """ - # 1. Embed text - text_emb = self.text_emb(text_ids) # [B, 128, dim] + text_emb = self.text_emb(text_ids) # [B, seq_len, dim] - # 2. Embed and pool pinyin to global feature - pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim] - # LSTM encoder with masking for padding - pinyin_mask = pinyin_ids != 0 - pinyin_global = self.pinyin_pooling( - pinyin_emb, mask=pinyin_mask - ) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position + seq_len = text_emb.size(1) + pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0) + x = text_emb + self.pos_emb(pos_ids) - # Broadcast pinyin to all text positions - pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim] - pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim] - - # 策略:拼音作为增强特征叠加到文本上,符合轻量级设计 - x = text_emb + pinyin_broadcast - - seq_len = x.size(1) - pos_ids = ( - torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids) - ) - x += self.pos_emb(pos_ids) - - # 2. Transformer Encoding - # src_key_padding_mask expects True for padding positions if mask is not None: - # Convert 0/1 mask to bool mask where True is padding src_mask = mask == 0 else: src_mask = None H = self.transformer(x, src_key_padding_mask=src_mask) - return self.ln(H) + H = self.ln(H) + + pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim] + pinyin_mask = pinyin_ids != 0 + P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim] + + return H, P # ------------------------------------------------------------------ @@ -256,76 +238,100 @@ class SlotMemory(nn.Module): # ------------------------------------------------------------------ # 3. 交叉注意力融合 (Cross-Attention Fusion) -# 对应 README: Query=Slots, Key/Value=Context H [1] +# 槽位同时查询文本上下文和拼音序列 # ------------------------------------------------------------------ class CrossAttentionFusion(nn.Module): + """ + 双路交叉注意力: + - 槽位查询文本上下文 (H) + - 槽位查询拼音序列 (P) + 通过注意力机制,start_emb 自动学会关注"当前应该预测的拼音位置" + """ + def __init__(self, dim=512, n_heads=4): super().__init__() self.dim = dim self.n_heads = n_heads self.head_dim = dim // n_heads - assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads" + assert self.head_dim * n_heads == dim - # Linear projections for Q, K, V self.q_proj = nn.Linear(dim, dim, bias=False) - self.k_proj = nn.Linear(dim, dim, bias=False) - self.v_proj = nn.Linear(dim, dim, bias=False) + self.k_text_proj = nn.Linear(dim, dim, bias=False) + self.v_text_proj = nn.Linear(dim, dim, bias=False) + self.k_pinyin_proj = nn.Linear(dim, dim, bias=False) + self.v_pinyin_proj = nn.Linear(dim, dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) self.ln = nn.LayerNorm(dim) - def forward(self, slots_S, context_H, slot_mask=None, context_mask=None): + def forward( + self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None + ): """ Args: - slots_S: [batch, num_slots_steps, dim] Query - context_H: [batch, ctx_len, dim] Key/Value - slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention) - context_mask: [batch, ctx_len] Optional padding mask + slots_S: [batch, num_slots, dim] 槽位编码 + context_H: [batch, ctx_len, dim] 文本上下文编码 + pinyin_P: [batch, pinyin_len, dim] 拼音序列编码 + context_mask: [batch, ctx_len] 文本 padding mask (True for padding) + pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding) Returns: - Fused: [batch, num_slots_steps, dim] + fused: [batch, num_slots, dim] """ batch_size, num_slots, _ = slots_S.shape - _, ctx_len, _ = context_H.shape + ctx_len = context_H.size(1) + pinyin_len = pinyin_P.size(1) - # Project queries, keys, values - Q = self.q_proj(slots_S) # [batch, num_slots, dim] - K = self.k_proj(context_H) # [batch, ctx_len, dim] - V = self.v_proj(context_H) # [batch, ctx_len, dim] + Q = self.q_proj(slots_S) + K_text = self.k_text_proj(context_H) + V_text = self.v_text_proj(context_H) + K_pinyin = self.k_pinyin_proj(pinyin_P) + V_pinyin = self.v_pinyin_proj(pinyin_P) - # Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim] Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2) - K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) - V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) + K_text = K_text.view( + batch_size, ctx_len, self.n_heads, self.head_dim + ).transpose(1, 2) + V_text = V_text.view( + batch_size, ctx_len, self.n_heads, self.head_dim + ).transpose(1, 2) + K_pinyin = K_pinyin.view( + batch_size, pinyin_len, self.n_heads, self.head_dim + ).transpose(1, 2) + V_pinyin = V_pinyin.view( + batch_size, pinyin_len, self.n_heads, self.head_dim + ).transpose(1, 2) - # Prepare attention mask if context_mask is provided - attn_mask = None + text_attn_mask = None if context_mask is not None: - # context_mask: [batch, ctx_len] where 0 means padding - # Convert to bool mask and reshape for broadcasting - bool_mask = context_mask == 0 # [batch, ctx_len] - bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len] - # Convert to float mask where True (padding) becomes -inf - attn_mask = bool_mask.float().masked_fill(bool_mask, -1e9) + text_attn_mask = ( + context_mask[:, None, None, :] + .float() + .masked_fill(context_mask[:, None, None, :], -1e9) + ) - # Scaled dot-product attention - attn_output = F.scaled_dot_product_attention( - Q, - K, - V, - attn_mask=attn_mask, - dropout_p=0.0, # no dropout + pinyin_attn_mask = None + if pinyin_mask is not None: + pinyin_attn_mask = ( + pinyin_mask[:, None, None, :] + .float() + .masked_fill(pinyin_mask[:, None, None, :], -1e9) + ) + + text_attn = F.scaled_dot_product_attention( + Q, K_text, V_text, attn_mask=text_attn_mask + ) + pinyin_attn = F.scaled_dot_product_attention( + Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask ) - # Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim] - attn_output = ( - attn_output.transpose(1, 2) + combined_attn = text_attn + pinyin_attn + + combined_attn = ( + combined_attn.transpose(1, 2) .contiguous() .view(batch_size, num_slots, self.dim) ) - - # Project back - fused = self.out_proj(attn_output) - - # Residual connection and layer norm + fused = self.out_proj(combined_attn) fused = self.ln(fused + slots_S) return fused diff --git a/src/model/model.py b/src/model/model.py index 67035c4..d20f371 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -113,31 +113,24 @@ class InputMethodEngine(nn.Module): """ batch_size = input_ids.size(0) - # 处理 history_slot_ids:确保为 [batch_size, num_slots] - # 使用 view 替代 if 判断,避免 torch.compile 图断开 history_slot_ids = history_slot_ids.view(-1, self.num_slots) - # 1. 上下文编码 -> H [batch, seq_len, dim] - # 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask - H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask) + H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask) - # 2. 槽位记忆编码 -> S [batch, num_slots, dim] - S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots] + S = self.slot_memory(history_slot_ids) - # 3. 交叉注意力融合 (使用 CrossAttentionFusion) - fused = self.cross_attn(S, H, context_mask=attention_mask) + context_mask = attention_mask == 0 + pinyin_mask = pinyin_ids == 0 + fused = self.cross_attn( + S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask + ) - # 4. MoE 处理 -> [batch, num_slots, dim] moe_out = self.moe(fused) - # 5. 槽位注意力池化 batch_size = input_ids.size(0) - # 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots] slot_scores = self.slot_attention(moe_out).squeeze(-1) - # 应用softmax获取注意力权重 - slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots] - # 加权求和得到池化表示 - pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim] + slot_weights = torch.softmax(slot_scores, dim=1) + pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) - logits = self.classifier(pooled) # [batch, vocab_size] + logits = self.classifier(pooled) return logits diff --git a/src/model/trainer.py b/src/model/trainer.py index 38c2f47..f63309d 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1244,14 +1244,14 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, + shuffle_buffer_size=2000000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=batch_size, - num_workers=1, # 评估使用较少的worker + num_workers=2, # 评估使用较少的worker pin_memory=torch.cuda.is_available(), max_iter_length=batch_size * 64, ) @@ -1356,615 +1356,5 @@ def export( console.print("[yellow]导出功能待实现[/yellow]") -@app.command() -def expand_and_train( - # 数据参数 - train_data_path: str = typer.Option( - ..., "--train-data-path", "-t", help="训练数据集路径" - ), - eval_data_path: str = typer.Option( - ..., "--eval-data-path", "-e", help="评估数据集路径" - ), - output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"), - # 模型参数 - base_model_path: str = typer.Option( - ..., "--base-model-path", help="预训练基础模型检查点路径" - ), - new_model_spec: str = typer.Option( - ..., - "--new-model-spec", - "-m", - help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类", - ), - # 数据大小 - max_iter_length: int = typer.Option( - 1024 * 1024 * 128, "--max_iter_length", help="数据集大小" - ), - # 训练参数 - batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), - num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), - learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"), - min_learning_rate: float = typer.Option( - 1e-9, "--min-learning-rate", help="最小学习率" - ), - weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"), - warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), - label_smoothing: float = typer.Option( - 0.1, "--label-smoothing", help="标签平滑参数" - ), - grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), - clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), - eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), - save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"), - # 其他参数 - mixed_precision: bool = typer.Option( - True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" - ), - num_workers: int = typer.Option( - 2, "--num-workers", help="数据加载worker数量(流式数据集建议为2)" - ), - use_tensorboard: bool = typer.Option( - True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" - ), - resume_from: Optional[str] = typer.Option( - None, "--resume-from", help="从检查点恢复训练" - ), - reset_training_state: bool = typer.Option( - False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练" - ), - auto_resume: bool = typer.Option( - True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练" - ), - seed: int = typer.Option(42, "--seed", help="随机种子"), - compile: bool = typer.Option( - False, - "--compile/--no-compile", - help="是否开启 torch.compile 优化(需 PyTorch 2.0+)", - ), -): - torch.multiprocessing.set_sharing_strategy("file_system") - - # 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能) - if torch.cuda.is_available(): - torch.set_float32_matmul_precision("high") - - # 设置随机种子 - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - # 硬编码模型参数 - vocab_size = 10019 - pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射 - dim = 512 - num_slots = 8 - n_layers = 4 - n_heads = 4 - num_experts = 10 - max_seq_len = 128 - use_pinyin = True # 始终使用拼音 - console = Console() - - # 打印配置信息 - console.print( - Panel.fit( - "[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan" - ) - ) - - config_table = Table(show_header=True, header_style="bold magenta") - config_table.add_column("Category", style="cyan") - config_table.add_column("Parameter", style="green") - config_table.add_column("Value", style="yellow") - - # 添加配置信息 - config_table.add_row("数据", "训练数据路径", train_data_path) - config_table.add_row("数据", "评估数据路径", eval_data_path) - config_table.add_row("数据", "输出目录", output_dir) - config_table.add_row("数据", "批次大小", str(batch_size)) - config_table.add_row("数据", "Worker数量", str(num_workers)) - - config_table.add_row("模型", "基础模型路径", base_model_path) - config_table.add_row("模型", "新模型规格", new_model_spec) - config_table.add_row("模型", "词汇表大小", str(vocab_size)) - config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size)) - config_table.add_row("模型", "模型维度", str(dim)) - config_table.add_row("模型", "槽位数量", str(num_slots)) - config_table.add_row("模型", "Transformer层数", str(n_layers)) - config_table.add_row("模型", "注意力头数", str(n_heads)) - config_table.add_row("模型", "MoE专家数", str(num_experts)) - config_table.add_row("模型", "使用拼音", str(use_pinyin)) - config_table.add_row("模型", "编译优化", str(compile)) - - config_table.add_row("训练", "训练轮数", str(num_epochs)) - config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") - config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}") - config_table.add_row("训练", "权重衰减", str(weight_decay)) - config_table.add_row("训练", "热身比例", str(warmup_ratio)) - config_table.add_row("训练", "标签平滑", str(label_smoothing)) - config_table.add_row("训练", "梯度累积", str(grad_accum_steps)) - config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm)) - config_table.add_row("训练", "混合精度", str(mixed_precision)) - - config_table.add_row("其他", "自动恢复", str(auto_resume)) - console.print(config_table) - - # 创建输出目录 - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - - # 保存配置 - config = { - "train_data_path": train_data_path, - "eval_data_path": eval_data_path, - "output_dir": output_dir, - "base_model_path": base_model_path, - "new_model_spec": new_model_spec, - "vocab_size": vocab_size, - "pinyin_vocab_size": pinyin_vocab_size, - "dim": dim, - "num_slots": num_slots, - "n_layers": n_layers, - "n_heads": n_heads, - "num_experts": num_experts, - "max_seq_len": max_seq_len, - "use_pinyin": use_pinyin, - "batch_size": batch_size, - "num_workers": num_workers, - "num_epochs": num_epochs, - "learning_rate": learning_rate, - "min_learning_rate": min_learning_rate, - "weight_decay": weight_decay, - "warmup_ratio": warmup_ratio, - "label_smoothing": label_smoothing, - "grad_accum_steps": grad_accum_steps, - "clip_grad_norm": clip_grad_norm, - "eval_frequency": eval_frequency, - "save_frequency": save_frequency, - "mixed_precision": mixed_precision, - "use_tensorboard": use_tensorboard, - "seed": seed, - "auto_resume": auto_resume, - "max_iter_length": max_iter_length, - "compile": compile, - } - - config_file = output_path / "expansion_training_config.json" - with open(config_file, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2, ensure_ascii=False) - - logger.info(f"Configuration saved to {config_file}") - - # 创建数据加载器 - console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") - - # 训练数据集 - train_dataset = PinyinInputDataset( - data_path=train_data_path, - max_workers=-1, # 自动选择worker数量 - max_iter_length=max_iter_length, - max_seq_length=max_seq_len, - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, - ) - - # 训练数据加载器 - train_dataloader = create_dataloader( - dataset=train_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - max_iter_length=max_iter_length, - ) - - # 评估数据集 - eval_dataset = PinyinInputDataset( - data_path=eval_data_path, - max_workers=-1, - max_iter_length=batch_size * 64, # 评估集较小 - max_seq_length=max_seq_len, - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, - ) - - eval_dataloader = create_dataloader( - dataset=eval_dataset, - batch_size=batch_size, - num_workers=1, # 评估使用较少的worker - pin_memory=torch.cuda.is_available(), - max_iter_length=batch_size * 64, - ) - - # 创建扩容模型 - console.print("[bold cyan]正在创建扩容模型...[/bold cyan]") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - model_kwargs = { - "vocab_size": vocab_size, - "pinyin_vocab_size": pinyin_vocab_size, - "dim": dim, - "num_slots": num_slots, - "n_layers": n_layers, - "n_heads": n_heads, - "num_experts": num_experts, - "max_seq_len": max_seq_len, - "compile": compile, - } - - model = load_expanded_model( - base_model_path=base_model_path, - new_model_spec=new_model_spec, - device=device, - **model_kwargs, - ) - - console.print( - f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" - ) - - # 统计冻结参数比例 - total_params = sum(p.numel() for p in model.parameters()) - frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) - console.print( - f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]" - ) - - # 创建训练器(使用普通 Trainer,只进行第一阶段冻结训练) - console.print("[bold cyan]正在创建训练器...[/bold cyan]") - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - total_steps=int(max_iter_length * num_epochs / batch_size), - output_dir=output_dir, - num_epochs=num_epochs, - learning_rate=learning_rate, - min_learning_rate=min_learning_rate, - weight_decay=weight_decay, - warmup_ratio=warmup_ratio, - label_smoothing=label_smoothing, - grad_accum_steps=grad_accum_steps, - clip_grad_norm=clip_grad_norm, - eval_frequency=eval_frequency, - save_frequency=save_frequency, - mixed_precision=mixed_precision, - use_tensorboard=use_tensorboard, - status_file="training_status.json", - ) - - console.print("[green]✓ 训练器创建完成[/green]") - - # 开始训练 - console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]") - console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - try: - trainer.train( - resume_from=resume_from, - reset_training_state=reset_training_state, - auto_resume=auto_resume, - ) - except KeyboardInterrupt: - console.print("[bold green]训练被终止[/bold green]") - trainer.save_checkpoint("interrupted_model.pt") - - # 保存扩容信息供第二阶段使用 - expansion_info = { - "stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"), - "model_spec": new_model_spec, - "model_kwargs": model_kwargs, - "train_data_path": train_data_path, - "eval_data_path": eval_data_path, - "output_dir": output_dir, - "batch_size": batch_size, - "max_iter_length": max_iter_length, - "max_seq_len": max_seq_len, - "num_workers": num_workers, - } - - expansion_info_file = output_path / "expansion_info.json" - with open(expansion_info_file, "w", encoding="utf-8") as f: - json.dump(expansion_info, f, indent=2, ensure_ascii=False) - - logger.info(f"Expansion info saved to {expansion_info_file}") - - console.print("[bold green]✓ 第一阶段训练完成![/bold green]") - console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - console.print(f"模型和日志保存在: {output_dir}") - console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]") - console.print( - "[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]" - ) - - -@app.command() -def expand_finetune( - expand_config: str = typer.Option( - ..., - "--expand-config", - "-c", - help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'", - ), - stage1_info: str = typer.Option( - ..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径" - ), - # 可选覆盖参数 - checkpoint: Optional[str] = typer.Option( - None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)" - ), - output_dir: Optional[str] = typer.Option( - None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)" - ), - train_data_path: Optional[str] = typer.Option( - None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)" - ), - eval_data_path: Optional[str] = typer.Option( - None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)" - ), - batch_size: Optional[int] = typer.Option( - None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)" - ), - num_epochs: Optional[int] = typer.Option( - None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)" - ), - learning_rate: Optional[float] = typer.Option( - None, "--learning-rate", "-lr", help="学习率" - ), - min_learning_rate: Optional[float] = typer.Option( - None, "--min-learning-rate", help="最小学习率" - ), - weight_decay: Optional[float] = typer.Option( - None, "--weight-decay", help="权重衰减" - ), - warmup_ratio: Optional[float] = typer.Option( - None, "--warmup-ratio", help="热身步数比例" - ), - label_smoothing: Optional[float] = typer.Option( - None, "--label-smoothing", help="标签平滑参数" - ), - grad_accum_steps: Optional[int] = typer.Option( - None, "--grad-accum-steps", help="梯度累积步数" - ), - clip_grad_norm: Optional[float] = typer.Option( - None, "--clip-grad-norm", help="梯度裁剪范数" - ), - eval_frequency: Optional[int] = typer.Option( - None, "--eval-frequency", help="评估频率" - ), - save_frequency: Optional[int] = typer.Option( - None, "--save-frequency", help="保存频率" - ), - max_iter_length: Optional[int] = typer.Option( - None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)" - ), - max_seq_len: Optional[int] = typer.Option( - None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)" - ), - num_workers: Optional[int] = typer.Option( - None, "--num-workers", help="数据加载worker数量" - ), - mixed_precision: bool = typer.Option( - True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" - ), - use_tensorboard: bool = typer.Option( - True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard" - ), - resume_from: Optional[str] = typer.Option( - None, "--resume-from", help="从检查点恢复训练" - ), - reset_training_state: bool = typer.Option( - False, "--reset-training-state", help="重置训练状态" - ), - auto_resume: bool = typer.Option( - True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练" - ), - seed: int = typer.Option(42, "--seed", help="随机种子"), - compile: Optional[bool] = typer.Option( - None, "--compile/--no-compile", help="是否开启 torch.compile 优化" - ), -): - """ - 模型扩容第二阶段训练:读取第一阶段的 expansion_info.json,加载扩容模型进行全量微调。 - 命令行参数优先级高于 JSON 文件中的配置。 - """ - torch.multiprocessing.set_sharing_strategy("file_system") - - if torch.cuda.is_available(): - torch.set_float32_matmul_precision("high") - - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - console = Console() - - # 加载第一阶段信息 - stage1_info_path = Path(stage1_info) - if not stage1_info_path.exists(): - console.print( - f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]" - ) - raise typer.Exit(1) - - with open(stage1_info_path, "r", encoding="utf-8") as f: - info = json.load(f) - - # 命令行参数优先级高于 JSON 文件 - final_checkpoint = checkpoint or info["stage1_checkpoint_path"] - final_output_dir = output_dir or info["output_dir"] - final_train_data_path = train_data_path or info["train_data_path"] - final_eval_data_path = eval_data_path or info["eval_data_path"] - final_batch_size = batch_size if batch_size is not None else info["batch_size"] - final_num_epochs = ( - num_epochs if num_epochs is not None else info.get("num_epochs", 10) - ) - final_max_iter_length = ( - max_iter_length if max_iter_length is not None else info["max_iter_length"] - ) - final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"] - final_num_workers = ( - num_workers if num_workers is not None else info.get("num_workers", 2) - ) - - # 训练参数(有默认值,不覆盖则使用默认) - final_learning_rate = learning_rate if learning_rate is not None else 1e-4 - final_min_learning_rate = ( - min_learning_rate if min_learning_rate is not None else 1e-9 - ) - final_weight_decay = weight_decay if weight_decay is not None else 0.1 - final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1 - final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15 - final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1 - final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0 - final_eval_frequency = eval_frequency if eval_frequency is not None else 500 - final_save_frequency = save_frequency if save_frequency is not None else 1000 - - # 模型参数从 JSON 获取 - model_kwargs = info["model_kwargs"] - if compile is not None: - model_kwargs["compile"] = compile - - console.print( - Panel.fit( - "[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan" - ) - ) - - config_table = Table(show_header=True, header_style="bold magenta") - config_table.add_column("Category", style="cyan") - config_table.add_column("Parameter", style="green") - config_table.add_column("Value", style="yellow") - - config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path)) - config_table.add_row("数据", "训练数据路径", final_train_data_path) - config_table.add_row("数据", "评估数据路径", final_eval_data_path) - config_table.add_row("数据", "输出目录", final_output_dir) - config_table.add_row("数据", "批次大小", str(final_batch_size)) - config_table.add_row("数据", "Worker数量", str(final_num_workers)) - - config_table.add_row("模型", "新模型规格", expand_config) - config_table.add_row("模型", "检查点路径", final_checkpoint) - for k, v in model_kwargs.items(): - config_table.add_row("模型", k, str(v)) - - config_table.add_row("训练", "训练轮数", str(final_num_epochs)) - config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}") - config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}") - config_table.add_row("训练", "权重衰减", str(final_weight_decay)) - config_table.add_row("训练", "热身比例", str(final_warmup_ratio)) - config_table.add_row("训练", "标签平滑", str(final_label_smoothing)) - config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps)) - config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm)) - config_table.add_row("训练", "混合精度", str(mixed_precision)) - - config_table.add_row("其他", "自动恢复", str(auto_resume)) - console.print(config_table) - - output_path = Path(final_output_dir) - output_path.mkdir(parents=True, exist_ok=True) - - console.print("[bold cyan]正在创建数据加载器...[/bold cyan]") - - train_dataset = PinyinInputDataset( - data_path=final_train_data_path, - max_workers=-1, - max_iter_length=final_max_iter_length, - max_seq_length=final_max_seq_len, - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, - ) - - train_dataloader = create_dataloader( - dataset=train_dataset, - batch_size=final_batch_size, - num_workers=final_num_workers, - pin_memory=torch.cuda.is_available(), - max_iter_length=final_max_iter_length, - ) - - eval_dataset = PinyinInputDataset( - data_path=final_eval_data_path, - max_workers=-1, - max_iter_length=final_batch_size * 64, - max_seq_length=final_max_seq_len, - text_field="text", - py_style_weight=(9, 2, 1), - shuffle_buffer_size=100000, - length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, - ) - - eval_dataloader = create_dataloader( - dataset=eval_dataset, - batch_size=final_batch_size, - num_workers=1, - pin_memory=torch.cuda.is_available(), - max_iter_length=final_batch_size * 64, - ) - - console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - model = load_expanded_model( - base_model_path=final_checkpoint, - new_model_spec=expand_config, - device=device, - **model_kwargs, - ) - - # 全量微调:解冻所有参数 - for param in model.parameters(): - param.requires_grad = True - - console.print( - f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]" - ) - console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]") - - console.print("[bold cyan]正在创建训练器...[/bold cyan]") - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size), - output_dir=final_output_dir, - num_epochs=final_num_epochs, - learning_rate=final_learning_rate, - min_learning_rate=final_min_learning_rate, - weight_decay=final_weight_decay, - warmup_ratio=final_warmup_ratio, - label_smoothing=final_label_smoothing, - grad_accum_steps=final_grad_accum_steps, - clip_grad_norm=final_clip_grad_norm, - eval_frequency=final_eval_frequency, - save_frequency=final_save_frequency, - mixed_precision=mixed_precision, - use_tensorboard=use_tensorboard, - status_file="training_status_finetune.json", - ) - - console.print("[green]✓ 训练器创建完成[/green]") - - console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]") - console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - try: - trainer.train( - resume_from=resume_from, - reset_training_state=reset_training_state, - auto_resume=auto_resume, - ) - except KeyboardInterrupt: - console.print("[bold green]训练被终止[/bold green]") - trainer.save_checkpoint("interrupted_model.pt") - - console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]") - console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - console.print(f"模型和日志保存在: {final_output_dir}") - - if __name__ == "__main__": app() diff --git a/test.py b/test.py index db5aa0d..e0d7880 100644 --- a/test.py +++ b/test.py @@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]: return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] -part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为" -part2 = "weiruo" +part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常" +part2 = "piaoliang" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: @@ -56,7 +56,7 @@ if len_py < 24: else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) -masked_labels = [649, 925, 0, 0, 0, 0, 0, 0] +masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0] part3 = "" part4 = "可行|特别|伤害" @@ -83,7 +83,7 @@ sample = { model = InputMethodEngine(pinyin_vocab_size=30, compile=False) -checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu") +checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) input_ids = sample["input_ids"] diff --git a/test.py.backup b/test.py.backup deleted file mode 100644 index 67d0f69..0000000 --- a/test.py.backup +++ /dev/null @@ -1,31 +0,0 @@ -import sys - -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm - -from model.dataset import PinyinInputDataset -from model.trainer import collate_fn, worker_init_fn - -max_iter_length = 128 * 128 -batch_size = 1024 - -if sys.platform == "win32": - dataset_path = "data" -else: - dataset_path = "/home/songsenand/Data/corpus/CCI-Data/" - -dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length) -dataloader = DataLoader( - dataset, - batch_size=batch_size, - num_workers=2, - pin_memory=torch.cuda.is_available(), - worker_init_fn=worker_init_fn, - collate_fn=collate_fn, - prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景 - persistent_workers=True, # 保持worker进程存活,避免重建开销 -) -dataloader = list([i for i in dataloader]) -for i, line in tqdm(enumerate(dataloader), total=max_iter_length / batch_size): - print((line["labels"] == 0).sum()) diff --git a/transfer_weights.py b/transfer_weights.py new file mode 100644 index 0000000..ebd4df8 --- /dev/null +++ b/transfer_weights.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +参数迁移脚本:将旧模型权重迁移到新架构 + +使用方法: + python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt + +功能: + 1. 加载旧模型 checkpoint + 2. 创建新模型架构 + 3. 直接迁移匹配的参数 + 4. 拆分迁移 k_proj/v_proj 到 text/pinyin 双分支 + 5. 保存迁移后的 checkpoint + 6. 打印详细迁移报告 +""" + +import argparse +import sys +from pathlib import Path + +import torch + +sys.path.append("src") +from src.model.model import InputMethodEngine + + +def transfer_weights( + old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device +): + """ + 迁移旧模型权重到新架构 + + Args: + old_checkpoint_path: 旧 checkpoint 路径 + new_model: 新模型实例 + device: 设备 + + Returns: + (directly_transferred, split_transferred, new_params, skipped) 迁移报告 + """ + old_checkpoint = torch.load(old_checkpoint_path, map_location=device) + + if "model_state_dict" in old_checkpoint: + old_state_dict = old_checkpoint["model_state_dict"] + else: + old_state_dict = old_checkpoint + + new_state_dict = new_model.state_dict() + + directly_transferred = [] + split_transferred = [] + new_params = [] + skipped = [] + + for key in new_state_dict.keys(): + if key in old_state_dict: + if new_state_dict[key].shape == old_state_dict[key].shape: + new_state_dict[key] = old_state_dict[key] + directly_transferred.append(key) + else: + skipped.append( + ( + key, + "shape mismatch", + f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}", + ) + ) + elif key in [ + "cross_attn.k_text_proj.weight", + "cross_attn.k_pinyin_proj.weight", + ]: + if "cross_attn.k_proj.weight" in old_state_dict: + new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone() + split_transferred.append((key, "cross_attn.k_proj.weight")) + elif key in [ + "cross_attn.v_text_proj.weight", + "cross_attn.v_pinyin_proj.weight", + ]: + if "cross_attn.v_proj.weight" in old_state_dict: + new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone() + split_transferred.append((key, "cross_attn.v_proj.weight")) + else: + new_params.append(key) + + new_model.load_state_dict(new_state_dict) + + return directly_transferred, split_transferred, new_params, skipped + + +def print_report( + directly_transferred, split_transferred, new_params, skipped, output_path +): + """打印迁移报告""" + total_params = len(directly_transferred) + len(split_transferred) + len(new_params) + coverage = ( + (len(directly_transferred) + len(split_transferred)) / total_params * 100 + if total_params > 0 + else 100 + ) + + print("\n" + "=" * 60) + print("📊 参数迁移报告") + print("=" * 60) + + print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):") + categories = {} + for key in directly_transferred: + category = key.split(".")[0] + if category not in categories: + categories[category] = [] + categories[category].append(key) + for cat, keys in sorted(categories.items()): + print(f" - {cat}.* ({len(keys)} 个参数)") + + if split_transferred: + print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):") + for new_key, old_key in split_transferred: + print(f" - {new_key} ← {old_key}") + + if new_params: + print(f"\n⚠️ 新增参数 ({len(new_params)} 层):") + for key in new_params[:10]: + print(f" - {key}") + if len(new_params) > 10: + print(f" ... 等共 {len(new_params)} 个") + + if skipped: + print(f"\n❌ 跳过 ({len(skipped)} 层):") + for key, reason, detail in skipped[:5]: + print(f" - {key}: {reason} ({detail})") + if len(skipped) > 5: + print(f" ... 等共 {len(skipped)} 个") + + print("\n" + "-" * 60) + print( + f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})" + ) + print("=" * 60) + print(f"\n💾 已保存到: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构") + parser.add_argument( + "--old-checkpoint", + "-c", + type=str, + required=True, + help="旧模型 checkpoint 路径", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="./migrated_model.pt", + help="迁移后的 checkpoint 输出路径", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="设备 (cuda/cpu)", + ) + + args = parser.parse_args() + device = torch.device(args.device) + + print(f"📦 加载旧模型: {args.old_checkpoint}") + print(f"🔧 使用设备: {device}") + + old_path = Path(args.old_checkpoint) + if not old_path.exists(): + print(f"❌ 文件不存在: {args.old_checkpoint}") + sys.exit(1) + + print("\n🏗️ 创建新模型架构...") + new_model = InputMethodEngine( + vocab_size=10019, + pinyin_vocab_size=30, + dim=512, + num_slots=8, + n_layers=4, + n_heads=4, + num_experts=10, + max_seq_len=128, + compile=False, + ) + new_model.to(device) + + print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}") + + print("\n🔄 开始迁移参数...") + directly_transferred, split_transferred, new_params, skipped = transfer_weights( + args.old_checkpoint, new_model, device + ) + + print("\n💾 保存迁移后的 checkpoint...") + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + checkpoint = { + "step": 0, + "epoch": 0, + "model_state_dict": new_model.state_dict(), + "best_eval_loss": float("inf"), + "config": { + "vocab_size": 10019, + "pinyin_vocab_size": 30, + "dim": 512, + "num_slots": 8, + "n_layers": 4, + "n_heads": 4, + "num_experts": 10, + "max_seq_len": 128, + }, + "migration_source": args.old_checkpoint, + } + + torch.save(checkpoint, output_path) + + print_report( + directly_transferred, split_transferred, new_params, skipped, output_path + ) + + print("\n✅ 迁移完成!使用方法:") + print( + f" python -m model.trainer train --resume-from {args.output} --reset-training-state" + ) + + +if __name__ == "__main__": + main()