607 lines
19 KiB
Markdown
607 lines
19 KiB
Markdown
# 训练指南
|
||
|
||
## Jupyter Lab 训练示例
|
||
|
||
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
|
||
|
||
```python
|
||
# %% [markdown]
|
||
# # 输入法模型训练示例
|
||
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
|
||
|
||
# %% [code]
|
||
# 1. 导入必要的库
|
||
import sys
|
||
import os
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
|
||
# 添加项目路径(适应不同的Jupyter Lab运行位置)
|
||
project_root = Path.cwd()
|
||
# 检查当前目录是否包含src目录,如果不包含则使用父目录
|
||
if not (project_root / "src").exists():
|
||
project_root = project_root.parent
|
||
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
|
||
|
||
# 导入项目模块
|
||
from src.model.model import InputMethodEngine
|
||
from src.model.dataset import PinyinInputDataset
|
||
from src.model.trainer import Trainer, worker_init_fn, collate_fn
|
||
|
||
# %% [code]
|
||
# 2. 配置训练参数
|
||
config = {
|
||
# 数据参数
|
||
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
|
||
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
|
||
"output_dir": "./training_output",
|
||
|
||
# 模型参数
|
||
"vocab_size": 10019,
|
||
"pinyin_vocab_size": 30,
|
||
"dim": 512,
|
||
"num_slots": 8,
|
||
"n_layers": 4,
|
||
"n_heads": 4,
|
||
"num_experts": 20,
|
||
"max_seq_len": 128,
|
||
|
||
# 训练参数
|
||
"batch_size": 64, # 根据GPU内存调整
|
||
"num_epochs": 10,
|
||
"learning_rate": 3e-4,
|
||
"min_learning_rate": 1e-9,
|
||
"weight_decay": 0.1,
|
||
"warmup_ratio": 0.1,
|
||
"label_smoothing": 0.15,
|
||
"grad_accum_steps": 2, # 梯度累积,模拟更大batch size
|
||
"clip_grad_norm": 1.0,
|
||
"eval_frequency": 500, # 每500步评估一次
|
||
"save_frequency": 2000, # 每2000步保存检查点
|
||
|
||
# 高级选项
|
||
"mixed_precision": True,
|
||
"use_tensorboard": True,
|
||
"seed": 42,
|
||
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
|
||
}
|
||
|
||
# %% [code]
|
||
# 3. 设置随机种子和设备
|
||
torch.manual_seed(config["seed"])
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(config["seed"])
|
||
device = torch.device("cuda")
|
||
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
|
||
else:
|
||
device = torch.device("cpu")
|
||
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
|
||
|
||
# %% [code]
|
||
# 4. 创建数据集和数据加载器
|
||
print("📊 创建数据集和数据加载器...")
|
||
|
||
# 训练数据集
|
||
train_dataset = PinyinInputDataset(
|
||
data_path=config["train_data_path"],
|
||
max_workers=-1, # 自动选择worker数量
|
||
max_iter_length=config["max_iter_length"],
|
||
max_seq_length=config["max_seq_len"],
|
||
text_field="text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size=5000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
# 训练数据加载器
|
||
train_dataloader = DataLoader(
|
||
train_dataset,
|
||
batch_size=config["batch_size"],
|
||
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
|
||
pin_memory=torch.cuda.is_available(),
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=32,
|
||
persistent_workers=True,
|
||
)
|
||
|
||
# 评估数据集
|
||
eval_dataset = PinyinInputDataset(
|
||
data_path=config["eval_data_path"],
|
||
max_workers=-1,
|
||
max_iter_length=1024, # 评估集较小
|
||
max_seq_length=config["max_seq_len"],
|
||
text_field="text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size=1000,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
)
|
||
|
||
eval_dataloader = DataLoader(
|
||
eval_dataset,
|
||
batch_size=config["batch_size"],
|
||
num_workers=1,
|
||
pin_memory=torch.cuda.is_available(),
|
||
worker_init_fn=worker_init_fn,
|
||
collate_fn=collate_fn,
|
||
prefetch_factor=32,
|
||
persistent_workers=True,
|
||
)
|
||
|
||
print(f"✅ 数据加载器创建完成")
|
||
print(f" 训练批次大小: {config['batch_size']}")
|
||
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
|
||
|
||
# %% [code]
|
||
# 5. 创建模型
|
||
print("🧠 创建输入法模型...")
|
||
|
||
model = InputMethodEngine(
|
||
vocab_size=config["vocab_size"],
|
||
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||
dim=config["dim"],
|
||
num_slots=config["num_slots"],
|
||
n_layers=config["n_layers"],
|
||
n_heads=config["n_heads"],
|
||
num_experts=config["num_experts"],
|
||
max_seq_len=config["max_seq_len"],
|
||
)
|
||
|
||
# 将模型移动到设备
|
||
model.to(device)
|
||
|
||
# 计算参数量
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
||
print(f"✅ 模型创建完成")
|
||
print(f" 总参数量: {total_params:,}")
|
||
print(f" 可训练参数量: {trainable_params:,}")
|
||
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
|
||
|
||
# %% [code]
|
||
# 6. 创建训练器
|
||
print("⚙️ 创建训练器...")
|
||
|
||
# 计算总训练步数
|
||
total_steps = int(config["max_iter_length"] / config["batch_size"])
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
train_dataloader=train_dataloader,
|
||
eval_dataloader=eval_dataloader,
|
||
total_steps=total_steps,
|
||
output_dir=config["output_dir"],
|
||
num_epochs=config["num_epochs"],
|
||
learning_rate=config["learning_rate"],
|
||
min_learning_rate=config["min_learning_rate"],
|
||
weight_decay=config["weight_decay"],
|
||
warmup_ratio=config["warmup_ratio"],
|
||
label_smoothing=config["label_smoothing"],
|
||
grad_accum_steps=config["grad_accum_steps"],
|
||
clip_grad_norm=config["clip_grad_norm"],
|
||
eval_frequency=config["eval_frequency"],
|
||
save_frequency=config["save_frequency"],
|
||
mixed_precision=config["mixed_precision"],
|
||
use_tensorboard=config["use_tensorboard"],
|
||
)
|
||
|
||
print(f"✅ 训练器创建完成")
|
||
print(f" 总训练步数: {total_steps:,}")
|
||
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
|
||
print(f" 输出目录: {config['output_dir']}")
|
||
|
||
# %% [code]
|
||
# 7. 开始训练
|
||
print("🚀 开始训练...")
|
||
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
||
try:
|
||
# 开始训练(可以从检查点恢复训练)
|
||
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
|
||
|
||
print("✅ 训练完成!")
|
||
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"模型和日志保存在: {config['output_dir']}")
|
||
|
||
except KeyboardInterrupt:
|
||
print("⏹️ 训练被用户中断")
|
||
print("💾 保存当前检查点...")
|
||
trainer.save_checkpoint("interrupted")
|
||
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# %% [code]
|
||
# 8. 监控训练进度(如果使用TensorBoard)
|
||
if config["use_tensorboard"]:
|
||
print("📈 TensorBoard日志已记录在:")
|
||
print(f" {config['output_dir']}/tensorboard")
|
||
print("\n启动TensorBoard查看训练进度:")
|
||
print(" tensorboard --logdir ./training_output/tensorboard")
|
||
print("然后在浏览器中打开: http://localhost:6006")
|
||
|
||
# %% [code]
|
||
# 9. 加载训练好的模型进行推理(示例)
|
||
def load_trained_model(checkpoint_path):
|
||
"""加载训练好的模型进行检查点"""
|
||
print(f"📥 加载检查点: {checkpoint_path}")
|
||
|
||
# 创建与训练时相同配置的模型
|
||
loaded_model = InputMethodEngine(
|
||
vocab_size=config["vocab_size"],
|
||
pinyin_vocab_size=config["pinyin_vocab_size"],
|
||
dim=config["dim"],
|
||
num_slots=config["num_slots"],
|
||
n_layers=config["n_layers"],
|
||
n_heads=config["n_heads"],
|
||
num_experts=config["num_experts"],
|
||
max_seq_len=config["max_seq_len"],
|
||
)
|
||
|
||
# 加载检查点
|
||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||
loaded_model.load_state_dict(checkpoint["model_state_dict"])
|
||
loaded_model.to(device)
|
||
loaded_model.eval()
|
||
|
||
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
|
||
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
|
||
|
||
return loaded_model
|
||
|
||
# 使用示例(取消注释以使用)
|
||
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
|
||
```
|
||
|
||
### 关键说明
|
||
|
||
1. **环境要求**:
|
||
- Python 3.12+
|
||
- PyTorch 2.10+
|
||
- 建议使用GPU进行训练
|
||
- 安装项目依赖:`pip install -e .`
|
||
|
||
2. **数据集格式**:
|
||
- 使用Hugging Face `datasets`格式
|
||
- 必须包含`text`字段
|
||
- 支持流式读取(streaming=True)
|
||
|
||
3. **训练监控**:
|
||
- 控制台输出训练进度和指标
|
||
- TensorBoard记录损失、准确率、学习率等
|
||
- 定期保存模型检查点
|
||
|
||
4. **可调整参数**:
|
||
- `batch_size`: 根据GPU内存调整
|
||
- `learning_rate`: 建议在1e-4到5e-4之间
|
||
- `grad_accum_steps`: 模拟更大batch size
|
||
- `num_epochs`: 根据数据集大小调整
|
||
|
||
5. **故障排除**:
|
||
- GPU内存不足:减小`batch_size`或增加`grad_accum_steps`
|
||
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
|
||
- 过拟合:增加`label_smoothing`或使用更大数据集
|
||
|
||
## 使用指南
|
||
|
||
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
|
||
|
||
### 安装与准备
|
||
|
||
#### 使用 uv(推荐)
|
||
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。
|
||
|
||
1. **安装 uv**(如果尚未安装):
|
||
```bash
|
||
# Linux/macOS
|
||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||
|
||
# 或使用 pipx
|
||
pipx install uv
|
||
|
||
# Windows (PowerShell)
|
||
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||
```
|
||
|
||
2. **安装项目依赖**:
|
||
```bash
|
||
uv pip install -e .
|
||
```
|
||
|
||
#### 使用传统 pip
|
||
如果不使用 uv,也可以用标准的 pip 安装:
|
||
|
||
```bash
|
||
# 创建并激活虚拟环境(推荐)
|
||
python -m venv .venv
|
||
source .venv/bin/activate # Linux/macOS
|
||
# .venv\Scripts\activate # Windows
|
||
|
||
# 安装依赖
|
||
pip install -e .
|
||
```
|
||
|
||
#### 验证安装
|
||
安装完成后,可通过以下命令验证:
|
||
```bash
|
||
train-model --help
|
||
```
|
||
|
||
### 数据格式
|
||
|
||
训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。
|
||
|
||
#### 本地数据集示例
|
||
```python
|
||
# dataset.py
|
||
from datasets import Dataset
|
||
|
||
data = {
|
||
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
|
||
}
|
||
dataset = Dataset.from_dict(data)
|
||
dataset.save_to_disk("./local_dataset")
|
||
```
|
||
|
||
#### 远程数据集示例
|
||
支持Hugging Face Hub或ModelScope上的数据集:
|
||
- `huggingface.co/datasets/username/dataset_name`
|
||
- `modelscope.cn/datasets/username/dataset_name`
|
||
|
||
#### 数据格式要求
|
||
- **必需字段**: `text`(字符串类型,包含中文文本)
|
||
- **流式读取**: 数据集必须支持 `streaming=True` 参数
|
||
- **数据量**: 建议至少数百万条文本以获得良好效果
|
||
|
||
#### 数据预处理
|
||
数据集会自动进行以下处理:
|
||
1. 文本分词和编码
|
||
2. 拼音转换和编码
|
||
3. 上下文窗口滑动生成训练样本
|
||
4. 频率调整(削峰填谷)以平衡高频/低频字词
|
||
|
||
### 基本训练命令
|
||
|
||
使用 `train-model train` 命令开始训练:
|
||
|
||
```bash
|
||
train-model train \
|
||
--train-data-path "path/to/train/dataset" \
|
||
--eval-data-path "path/to/eval/dataset" \
|
||
--output-dir "./output" \
|
||
--batch-size 128 \
|
||
--num-epochs 10 \
|
||
--learning-rate 1e-5
|
||
```
|
||
|
||
#### 检查点恢复训练
|
||
|
||
要从检查点恢复训练(保持原有的训练状态):
|
||
|
||
```bash
|
||
train-model train \
|
||
--train-data-path "path/to/train/dataset" \
|
||
--eval-data-path "path/to/eval/dataset" \
|
||
--resume-from "./output/checkpoints/latest_checkpoint.pt"
|
||
```
|
||
|
||
#### 重置训练状态
|
||
|
||
如果只想加载模型权重,从头开始训练(学习率、epoch等都重新开始):
|
||
|
||
```bash
|
||
train-model train \
|
||
--train-data-path "path/to/train/dataset" \
|
||
--eval-data-path "path/to/eval/dataset" \
|
||
--resume-from "./output/checkpoints/best_model.pt" \
|
||
--reset-training-state
|
||
```
|
||
|
||
这个功能在以下场景非常有用:
|
||
- 想要用预训练权重初始化模型,但用新的训练计划重新训练
|
||
- 需要调整学习率策略或训练时长
|
||
- 在现有模型基础上进行迁移学习
|
||
|
||
#### 学习率建议
|
||
根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围:
|
||
- **标准范围**: 1e-4 ~ 5e-4
|
||
- **配合Warmup策略**:在训练初期逐步提高学习率
|
||
- **余弦退火**:使用最小学习率 1e-9 进行细调
|
||
|
||
### 参数详解
|
||
|
||
#### 数据参数
|
||
- `--train-data-path`, `-t`: 训练数据集路径(必需)
|
||
- `--eval-data-path`, ` -e`: 评估数据集路径(必需)
|
||
- `--output-dir`, `-o`: 输出目录(默认:`./output`)
|
||
- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728)
|
||
|
||
#### 模型参数
|
||
- `--vocab-size`: 词汇表大小(默认:10019)
|
||
- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30)
|
||
- `--dim`: 模型维度(默认:512)
|
||
- `--num-slots`: 历史槽位数量(默认:8)
|
||
- `--n-layers`: Transformer层数(默认:4)
|
||
- `--n-heads`: 注意力头数(默认:4)
|
||
- `--num-experts`: MoE专家数量(默认:20)
|
||
- `--max-seq-len`: 最大序列长度(默认:128)
|
||
- `--use-pinyin`: 是否使用拼音特征(默认:False)
|
||
|
||
#### 训练参数
|
||
- `--batch-size`, `-b`: 批次大小(默认:128)
|
||
- `--num-epochs`: 训练轮数(默认:10)
|
||
- `--learning-rate`, `-lr`: 学习率(默认:1e-5)
|
||
- `--min-learning-rate`: 最小学习率(默认:1e-9)
|
||
- `--weight-decay`: 权重衰减(默认:0.1)
|
||
- `--warmup-ratio`: 热身步数比例(默认:0.1)
|
||
- `--label-smoothing`: 标签平滑参数(默认:0.15)
|
||
- `--grad-accum-steps`: 梯度累积步数(默认:1)
|
||
- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0)
|
||
- `--eval-frequency`: 评估频率(默认:500步)
|
||
- `--save-frequency`: 保存频率(默认:10000步)
|
||
|
||
#### 高级选项
|
||
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
|
||
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用)
|
||
- `--resume-from`: 从检查点恢复训练(可选)
|
||
- `--reset-training-state`: 重置训练状态,只加载模型权重从头开始训练(默认:False)
|
||
- `--seed`: 随机种子(默认:42)
|
||
|
||
### 监控训练进度
|
||
|
||
训练过程中会显示:
|
||
- 当前训练步数/总步数
|
||
- 损失值和准确率
|
||
- 学习率变化
|
||
- 内存使用情况
|
||
|
||
启用TensorBoard后,可使用以下命令查看可视化结果:
|
||
|
||
```bash
|
||
tensorboard --logdir ./output/tensorboard
|
||
```
|
||
|
||
### 评估模型(开发中)
|
||
|
||
当前评估功能尚在开发中:
|
||
|
||
```bash
|
||
train-model evaluate \
|
||
--checkpoint "./output/checkpoint_final.pt" \
|
||
--data-path "path/to/eval/dataset" \
|
||
--batch-size 32
|
||
```
|
||
|
||
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||
- 加载训练好的模型检查点
|
||
- 在评估数据集上计算准确率、困惑度等指标
|
||
- 生成详细的性能报告
|
||
|
||
### 模型扩容两阶段训练
|
||
|
||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||
|
||
#### 训练策略
|
||
|
||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||
|
||
#### 基础用法
|
||
|
||
```bash
|
||
train-model expand-and-train \
|
||
--train-data-path "path/to/train/dataset" \
|
||
--eval-data-path "path/to/eval/dataset" \
|
||
--base-model-path "./pretrained/model.pt" \
|
||
--new-model-spec "model:InputMethodEngine" \
|
||
--num-experts 40 \
|
||
--frozen-lr 2e-3 \
|
||
--full-lr 5e-5 \
|
||
--frozen-patience 8
|
||
```
|
||
|
||
#### 完整参数示例
|
||
|
||
```bash
|
||
train-model expand-and-train \
|
||
--train-data-path "path/to/train/dataset" \
|
||
--eval-data-path "path/to/eval/dataset" \
|
||
--output-dir "./expansion_output" \
|
||
--base-model-path "./pretrained/model.pt" \
|
||
--new-model-spec "custom_model:ExpandedModel" \
|
||
--vocab-size 10019 \
|
||
--dim 512 \
|
||
--num-experts 40 \
|
||
--frozen-patience 10 \
|
||
--frozen-lr 1e-3 \
|
||
--full-lr 1e-4 \
|
||
--frozen-scheduler cosine \
|
||
--full-scheduler cosine \
|
||
--batch-size 128 \
|
||
--num-epochs 20 \
|
||
--compile
|
||
```
|
||
|
||
#### 参数详解
|
||
|
||
**模型扩容参数**
|
||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||
|
||
**两阶段训练参数**
|
||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||
|
||
**其他参数**
|
||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||
|
||
#### 使用场景
|
||
|
||
1. **增加专家数量**(20→40)
|
||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||
- 新增参数:新专家网络、gate层
|
||
|
||
2. **增加top_k值**(2→3)
|
||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||
- 新增参数:无
|
||
|
||
3. **修改专家内部结构**(如增加resblocks)
|
||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||
- 新增参数:新增的resblocks层
|
||
|
||
4. **增加Transformer层数**(4→5)
|
||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||
- 新增参数:新增的第5层
|
||
|
||
#### 自定义模型类示例
|
||
|
||
```python
|
||
# my_model.py
|
||
from model.model import InputMethodEngine
|
||
|
||
class MyExpandedModel(InputMethodEngine):
|
||
def __init__(self, num_experts=40, **kwargs):
|
||
# 调用父类构造函数,覆盖num_experts参数
|
||
super().__init__(num_experts=num_experts, **kwargs)
|
||
# 可以在这里添加额外的层或修改现有层
|
||
|
||
# 使用命令
|
||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||
```
|
||
|
||
#### 注意事项
|
||
|
||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||
|
||
### 导出模型(开发中)
|
||
|
||
当前导出功能尚在开发中:
|
||
|
||
```bash
|
||
train-model export \
|
||
--checkpoint "./output/checkpoint_final.pt" \
|
||
--output "./exported_model.onnx"
|
||
```
|
||
|
||
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
|
||
- 将PyTorch模型转换为ONNX格式
|
||
- 支持在不同推理引擎上部署
|
||
- 提供优化后的推理模型 |