docs: 删除 README.md.bak 文件

This commit is contained in:
songsenand 2026-04-12 12:13:36 +08:00
parent bb78e0afa0
commit 33f56f709b
8 changed files with 531 additions and 1754 deletions

View File

@ -1,916 +0,0 @@
# 输入法预测模型架构设计 (Input Method Prediction Model)
## 1. 概述
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
## 2. 核心架构流程
数据流遵循以下路径:
`输入编码``Transformer 上下文编码``槽位记忆嵌入``交叉注意力融合``门控+专家混合 (MoE)``分类预测``束搜索解码`
### 2.1 输入层设计
模型接收三类输入,分别处理以保持语义清晰:
1. **当前文本上下文**包含光标前文本Prefix和光标后文本Suffix
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
### 2.2 模块详解
#### A. Transformer 编码器 (Context Encoder)
负责提取当前语境的深层语义表示。
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
* **骨干网络**:使用标准的 Transformer Encoder。
* **隐藏层维度**512 [1]
* **Transformer 层数**4 层(轻量级设计,从头训练) [1]
* **注意力头数**4 头 [1]
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
#### B. 槽位记忆模块 (Slot Memory)
负责将非结构化的历史输入转化为结构化的记忆向量。
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`
#### C. 交叉注意力融合 (Cross-Attention Fusion)
这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
* **Key/Value (K/V)**Transformer 编码器输出的上下文表示 $H$ [1]。
* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`
#### D. 门控与专家混合 (Gating + MoE)
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
* **专家数量**20 个专家 [1]。
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
* **输出**:经过专家网络增强后的特征向量。
#### E. 分类头与解码
* **分类预测**MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
## 3. 关键超参数配置
为确保模型性能与效率的平衡,建议采用以下超参数 [1]
| 参数项 | 推荐值 | 说明 |
| :--- | :--- | :--- |
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
## 4. 训练策略
本模型采用标准的**序列到序列Seq2Seq监督学习**范式,直接对目标槽位序列进行逐步预测。
### 4.1 数据构造与标签
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
* **上下文**:光标前后的文本片段。
* **拼音**:当前待输入字的拼音序列。
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
### 4.3 训练流程细节
1. **前向传播**
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
* 分类头输出当前步所有候选字的概率分布。
2. **Teacher Forcing**
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。
3. **反向传播**
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
### 4.4 推理与训练的差异
* **训练时**:使用 Ground Truth真实标签作为槽位输入确保模型学习到最优的条件概率分布。
* **推理时**:由于无法获取真实标签,模型采用**束搜索Beam Search** [1]。
* **束宽**:默认为 5 [1]。
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
## 5. Jupyter Lab 训练示例
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
```python
# %% [markdown]
# # 输入法模型训练示例
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
# %% [code]
# 1. 导入必要的库
import sys
import os
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import DataLoader
# 添加项目路径适应不同的Jupyter Lab运行位置
project_root = Path.cwd()
# 检查当前目录是否包含src目录如果不包含则使用父目录
if not (project_root / "src").exists():
project_root = project_root.parent
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
# 导入项目模块
from src.model.model import InputMethodEngine
from src.model.dataset import PinyinInputDataset
from src.model.trainer import Trainer, worker_init_fn, collate_fn
# %% [code]
# 2. 配置训练参数
config = {
# 数据参数
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
"output_dir": "./training_output",
# 模型参数
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 20,
"max_seq_len": 128,
# 训练参数
"batch_size": 64, # 根据GPU内存调整
"num_epochs": 10,
"learning_rate": 3e-4,
"min_learning_rate": 1e-9,
"weight_decay": 0.1,
"warmup_ratio": 0.1,
"label_smoothing": 0.15,
"grad_accum_steps": 2, # 梯度累积模拟更大batch size
"clip_grad_norm": 1.0,
"eval_frequency": 500, # 每500步评估一次
"save_frequency": 2000, # 每2000步保存检查点
# 高级选项
"mixed_precision": True,
"use_tensorboard": True,
"seed": 42,
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
}
# %% [code]
# 3. 设置随机种子和设备
torch.manual_seed(config["seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config["seed"])
device = torch.device("cuda")
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
# %% [code]
# 4. 创建数据集和数据加载器
print("📊 创建数据集和数据加载器...")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=config["train_data_path"],
max_workers=-1, # 自动选择worker数量
max_iter_length=config["max_iter_length"],
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
# 评估数据集
eval_dataset = PinyinInputDataset(
data_path=config["eval_data_path"],
max_workers=-1,
max_iter_length=1024, # 评估集较小
max_seq_length=config["max_seq_len"],
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=1000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=config["batch_size"],
num_workers=1,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=32,
persistent_workers=True,
)
print(f"✅ 数据加载器创建完成")
print(f" 训练批次大小: {config['batch_size']}")
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
# %% [code]
# 5. 创建模型
print("🧠 创建输入法模型...")
model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 将模型移动到设备
model.to(device)
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✅ 模型创建完成")
print(f" 总参数量: {total_params:,}")
print(f" 可训练参数量: {trainable_params:,}")
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
# %% [code]
# 6. 创建训练器
print("⚙️ 创建训练器...")
# 计算总训练步数
total_steps = int(config["max_iter_length"] / config["batch_size"])
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=config["output_dir"],
num_epochs=config["num_epochs"],
learning_rate=config["learning_rate"],
min_learning_rate=config["min_learning_rate"],
weight_decay=config["weight_decay"],
warmup_ratio=config["warmup_ratio"],
label_smoothing=config["label_smoothing"],
grad_accum_steps=config["grad_accum_steps"],
clip_grad_norm=config["clip_grad_norm"],
eval_frequency=config["eval_frequency"],
save_frequency=config["save_frequency"],
mixed_precision=config["mixed_precision"],
use_tensorboard=config["use_tensorboard"],
)
print(f"✅ 训练器创建完成")
print(f" 总训练步数: {total_steps:,}")
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
print(f" 输出目录: {config['output_dir']}")
# %% [code]
# 7. 开始训练
print("🚀 开始训练...")
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
# 开始训练(可以从检查点恢复训练)
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
print("✅ 训练完成!")
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"模型和日志保存在: {config['output_dir']}")
except KeyboardInterrupt:
print("⏹️ 训练被用户中断")
print("💾 保存当前检查点...")
trainer.save_checkpoint("interrupted")
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
except Exception as e:
print(f"❌ 训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
# %% [code]
# 8. 监控训练进度如果使用TensorBoard
if config["use_tensorboard"]:
print("📈 TensorBoard日志已记录在:")
print(f" {config['output_dir']}/tensorboard")
print("\n启动TensorBoard查看训练进度:")
print(" tensorboard --logdir ./training_output/tensorboard")
print("然后在浏览器中打开: http://localhost:6006")
# %% [code]
# 9. 加载训练好的模型进行推理(示例)
def load_trained_model(checkpoint_path):
"""加载训练好的模型进行检查点"""
print(f"📥 加载检查点: {checkpoint_path}")
# 创建与训练时相同配置的模型
loaded_model = InputMethodEngine(
vocab_size=config["vocab_size"],
pinyin_vocab_size=config["pinyin_vocab_size"],
dim=config["dim"],
num_slots=config["num_slots"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
num_experts=config["num_experts"],
max_seq_len=config["max_seq_len"],
)
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location=device)
loaded_model.load_state_dict(checkpoint["model_state_dict"])
loaded_model.to(device)
loaded_model.eval()
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
return loaded_model
# 使用示例(取消注释以使用)
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
```
### 关键说明
1. **环境要求**
- Python 3.12+
- PyTorch 2.10+
- 建议使用GPU进行训练
- 安装项目依赖:`pip install -e .`
2. **数据集格式**
- 使用Hugging Face `datasets`格式
- 必须包含`text`字段
- 支持流式读取streaming=True
3. **训练监控**
- 控制台输出训练进度和指标
- TensorBoard记录损失、准确率、学习率等
- 定期保存模型检查点
4. **可调整参数**
- `batch_size`: 根据GPU内存调整
- `learning_rate`: 建议在1e-4到5e-4之间
- `grad_accum_steps`: 模拟更大batch size
- `num_epochs`: 根据数据集大小调整
5. **故障排除**
- GPU内存不足减小`batch_size`或增加`grad_accum_steps`
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
- 过拟合:增加`label_smoothing`或使用更大数据集
## 6. 使用指南
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
### 6.1 安装与准备
#### 使用 uv推荐
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器它比传统的 pip 更快且更可靠。
1. **安装 uv**(如果尚未安装):
```bash
# Linux/macOS
curl -LsSf https://astral.sh/uv/install.sh | sh
# 或使用 pipx
pipx install uv
# Windows (PowerShell)
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
```
2. **安装项目依赖**
```bash
uv pip install -e .
```
#### 使用传统 pip
如果不使用 uv也可以用标准的 pip 安装:
```bash
# 创建并激活虚拟环境(推荐)
python -m venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# 安装依赖
pip install -e .
```
#### 验证安装
安装完成后,可通过以下命令验证:
```bash
train-model --help
```
### 6.2 数据格式
训练数据应为Hugging Face数据集格式支持本地文件或远程数据集仓库。数据集需包含 `text` 字段并支持流式读取streaming=True
#### 本地数据集示例
```python
# dataset.py
from datasets import Dataset
data = {
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
}
dataset = Dataset.from_dict(data)
dataset.save_to_disk("./local_dataset")
```
#### 远程数据集示例
支持Hugging Face Hub或ModelScope上的数据集
- `huggingface.co/datasets/username/dataset_name`
- `modelscope.cn/datasets/username/dataset_name`
#### 数据格式要求
- **必需字段**: `text`(字符串类型,包含中文文本)
- **流式读取**: 数据集必须支持 `streaming=True` 参数
- **数据量**: 建议至少数百万条文本以获得良好效果
#### 数据预处理
数据集会自动进行以下处理:
1. 文本分词和编码
2. 拼音转换和编码
3. 上下文窗口滑动生成训练样本
4. 频率调整(削峰填谷)以平衡高频/低频字词
### 6.3 基本训练命令
使用 `train-model train` 命令开始训练:
```bash
train-model train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./output" \
--batch-size 128 \
--num-epochs 10 \
--learning-rate 1e-5
```
#### 检查点恢复训练
要从检查点恢复训练(保持原有的训练状态):
```bash
train-model train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--resume-from "./output/checkpoints/latest_checkpoint.pt"
```
#### 重置训练状态
如果只想加载模型权重从头开始训练学习率、epoch等都重新开始
```bash
train-model train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--resume-from "./output/checkpoints/best_model.pt" \
--reset-training-state
```
这个功能在以下场景非常有用:
- 想要用预训练权重初始化模型,但用新的训练计划重新训练
- 需要调整学习率策略或训练时长
- 在现有模型基础上进行迁移学习
#### 学习率建议
根据模型架构和超参数配置4层Transformer512维度推荐使用以下学习率范围
- **标准范围**: 1e-4 ~ 5e-4
- **配合Warmup策略**:在训练初期逐步提高学习率
- **余弦退火**:使用最小学习率 1e-9 进行细调
### 6.4 参数详解
#### 数据参数
- `--train-data-path`, `-t`: 训练数据集路径(必需)
- `--eval-data-path`, `-e`: 评估数据集路径(必需)
- `--output-dir`, `-o`: 输出目录(默认:`./output`
- `--max_iter_length`: 最大迭代长度控制每次训练迭代处理的数据量默认134217728
#### 模型参数
- `--vocab-size`: 词汇表大小默认10019
- `--pinyin-vocab-size`: 拼音词汇表大小默认30
- `--dim`: 模型维度默认512
- `--num-slots`: 历史槽位数量默认8
- `--n-layers`: Transformer层数默认4
- `--n-heads`: 注意力头数默认4
- `--num-experts`: MoE专家数量默认20
- `--max-seq-len`: 最大序列长度默认128
- `--use-pinyin`: 是否使用拼音特征默认False
#### 训练参数
- `--batch-size`, `-b`: 批次大小默认128
- `--num-epochs`: 训练轮数默认10
- `--learning-rate`, `-lr`: 学习率默认1e-5
- `--min-learning-rate`: 最小学习率默认1e-9
- `--weight-decay`: 权重衰减默认0.1
- `--warmup-ratio`: 热身步数比例默认0.1
- `--label-smoothing`: 标签平滑参数默认0.15
- `--grad-accum-steps`: 梯度累积步数默认1
- `--clip-grad-norm`: 梯度裁剪范数默认1.0
- `--eval-frequency`: 评估频率默认500步
- `--save-frequency`: 保存频率默认10000步
#### 高级选项
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard默认启用
- `--resume-from`: 从检查点恢复训练(可选)
- `--reset-training-state`: 重置训练状态只加载模型权重从头开始训练默认False
- `--seed`: 随机种子默认42
### 6.5 监控训练进度
训练过程中会显示:
- 当前训练步数/总步数
- 损失值和准确率
- 学习率变化
- 内存使用情况
启用TensorBoard后可使用以下命令查看可视化结果
```bash
tensorboard --logdir ./output/tensorboard
```
### 6.6 基于JSON旁路记录法的移动端监控方案
为了提供移动端友好的训练监控体验我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时额外写入一份JSON状态文件并通过Streamlit提供移动端友好的Web界面。
#### 方案特点
**📱 移动端体验**
- Streamlit自动生成响应式界面完美适配手机屏幕
- 图表支持双指缩放和滑动操作
- 大字体显示核心指标,触控操作便捷
**🚀 低耦合架构**
- 训练和监控通过文件系统解耦
- 监控服务重启不影响训练进程
- 训练脚本只需几行代码修改即可支持
**🔒 安全稳定**
- 纯文本JSON文件无文件锁冲突问题
- 读写速度快,稳定性高
- 不会影响原有的TensorBoard记录流程
**📊 实时监控**
- 默认每5秒自动刷新数据
- 实时显示训练进度和指标趋势
- 数据新鲜度状态指示(实时/较新/较旧/陈旧)
#### 使用方法
**启动监控服务**
```bash
# 启动监控服务默认端口8501
monitor-training monitor
# 指定状态文件路径和端口
monitor-training monitor --status-file ./output/training_status.json --port 8080
# 不自动打开浏览器
monitor-training monitor --no-browser
# 指定自定义Streamlit脚本
monitor-training monitor --streamlit-script ./custom_monitor.py
```
**查看训练状态**
```bash
# 查看最近10条训练记录
monitor-training view
# 查看最近50条记录原始JSON格式
monitor-training view --limit 50 --raw
# 查看指定状态文件
monitor-training view /path/to/status.json
```
**检查状态文件**
```bash
# 检查状态文件状态
monitor-training check
# 检查指定文件
monitor-training check ./output/training_status.json
```
**启动HTTP静态文件服务**
```bash
# 启动HTTP静态文件服务默认端口8080
monitor-training serve
# 指定状态文件路径和端口
monitor-training serve --status-file ./output/training_status.json --port 8080
# 禁用CORS支持默认启用
monitor-training serve --no-cors
# 指定主机地址
monitor-training serve --host 0.0.0.0 --port 8080
```
#### 监控界面功能
**📊 核心指标看板**
- 当前步数、轮次、训练损失、准确率
- 评估损失和准确率
- 当前学习率、最后更新时间
**📈 趋势图表**
- 损失曲线(训练损失 + 评估损失)
- 准确率曲线(训练准确率 + 评估准确率)
- 学习率变化图(对数坐标)
**📋 数据详情**
- 完整的训练记录表格
- 数据统计信息(总数据点、训练时长、总步数)
- 训练进度条
**⚙️ 配置选项**
- 状态文件路径(支持环境变量 `TRAINING_STATUS_FILE`
- 自动刷新间隔1-30秒可调
- 显示数据点数量10-1000条可调
#### 技术实现
**训练端改造**
- 在 `Trainer.__init__` 中添加 `status_file` 参数
- 实现 `_write_training_status()` 方法在每次评估时写入JSON文件
- 支持从现有状态文件恢复,避免数据丢失
**监控端搭建**
- 使用Streamlit构建移动端友好的Web界面
- 采用Plotly图表库支持触控交互
- 自动刷新机制,实时更新训练状态
**命令行工具**
- 提供 `monitor`、`view`、`check` 三个子命令
- 自动检测Streamlit可用性
- 支持环境变量传递
#### 访问方式
**本地访问**
```bash
# 启动监控服务后,通过浏览器访问
http://localhost:8501
```
**局域网访问**
```bash
# 启动服务时指定主机地址
monitor-training monitor --host 0.0.0.0 --port 8080
# 手机浏览器访问(同一局域网)
http://192.168.1.100:8080
```
**公网访问**(需端口转发)
```bash
# 确保服务器防火墙开放对应端口
# 通过域名或公网IP访问
http://your-server.com:8501
```
**远程HTTP监控**
```bash
# GPU服务器启动HTTP服务
monitor-training serve --port 8080 --host 0.0.0.0
# 本地运行Streamlit监控从HTTP URL读取数据
monitor-training monitor --host 127.0.0.1 --port 8501
# 在Streamlit界面输入远程URL
http://<gpu服务器IP>:8080/training_status.json
```
#### 状态文件格式
状态文件 `training_status.json` 位于训练输出目录,格式如下:
```json
[
{
"step": 100,
"epoch": 1,
"timestamp": "2024-01-01T12:00:00",
"train/loss": 2.345,
"train/accuracy": 0.456,
"eval/loss": 2.123,
"eval/accuracy": 0.512,
"train/learning_rate": 0.0001
},
...
]
```
#### HTTP静态文件服务与远程监控
针对GPU服务器只支持HTTP协议不支持WebSockets的环境我们提供了HTTP静态文件服务方案实现远程训练监控。
**🔧 技术特点**
- 纯HTTP协议无需WebSockets支持
- 原子写入机制避免读取不完整JSON数据
- 自动重试和JSON验证确保数据完整性
- CORS支持方便跨域访问
- 轻量级设计,不影响训练性能
**🚀 工作原理**
1. GPU服务器训练进程通过原子写入机制更新`training_status.json`文件
2. GPU服务器运行`monitor-training serve`提供HTTP静态文件服务
3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面
4. 本地机器在Streamlit界面输入HTTP URL访问远程数据
5. Streamlit通过HTTP轮询获取实时训练数据并展示
**🛡️ 数据安全**
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
- JSON验证HTTP服务端验证JSON格式后才返回数据
- 临时文件处理:智能识别和读取`.tmp`临时文件
- 重试机制JSON解析失败时自动重试读取
**🌐 网络要求**
- GPU服务器需要开放HTTP端口默认8080
- 本地机器需要能访问GPU服务器的HTTP端口
- 网络协议纯HTTP兼容防火墙和代理
#### 注意事项
1. 首次监控时如果状态文件不存在,会自动创建空文件
2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0`
3. 从检查点恢复训练时会自动加载已有的状态数据
4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟
5. HTTP服务支持原子写入避免训练进程写入时读取不完整JSON
6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口
7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问
### 6.7 评估模型(开发中)
当前评估功能尚在开发中:
```bash
train-model evaluate \
--checkpoint "./output/checkpoint_final.pt" \
--data-path "path/to/eval/dataset" \
--batch-size 32
```
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 6.8 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 6.9 导出模型(开发中)
当前导出功能尚在开发中:
```bash
train-model export \
--checkpoint "./output/checkpoint_final.pt" \
--output "./exported_model.onnx"
```
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
- 将PyTorch模型转换为ONNX格式
- 支持在不同推理引擎上部署
- 提供优化后的推理模型
## 7. 总结
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1]在保证模型轻量4层 Transformer的同时有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构该设计在工程实现上更优雅在推理效率上更高效是轻量级输入法模型的局部最优解。

233
eval.py
View File

@ -274,8 +274,43 @@ class TextEvaluator:
part1 = text[i - 48 : i] part1 = text[i - 48 : i]
# part2: 拼音输入随机长度1-8高斯分布 # part2: 拼音输入随机长度1-8高斯分布
# 当强制指定历史槽位时,需要确保拼音长度足够
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
min_pinyin_len = (force_slot_count + 1) if force_slot_count is not None else 1
min_pinyin_len = max(1, min(8, min_pinyin_len))
# 检查剩余可用字符数量
remaining_chars = len(text) - i
max_valid_len = 0
for j in range(i, min(i + 8, len(text))):
if self.query_engine and self.query_engine.is_chinese_char(text[j]):
max_valid_len += 1
else:
break
# 调整拼音长度:必须至少有 min_pinyin_len 个有效字符
if max_valid_len < min_pinyin_len:
# 重新选择位置,确保有足够字符
raise ValueError(
f"位置 {i} 后只有 {max_valid_len} 个有效汉字,不足以测试历史槽位 {force_slot_count}"
)
# 随机选择拼音长度(至少 min_pinyin_len
if force_slot_count is not None:
# 强制模式:随机选择比 min_pinyin_len 大的长度
valid_lengths = [
l for l in range(min_pinyin_len, min(max_valid_len + 1, 9))
]
if not valid_lengths:
valid_lengths = [min_pinyin_len]
# 使用原来的概率分布,但只选择有效长度
adjusted_probs = [pinyin_len_probs[l - 1] for l in valid_lengths]
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
pinyin_len = np.random.choice(valid_lengths, p=adjusted_probs)
else:
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
pinyin_len = min(pinyin_len, max_valid_len)
py_end = min(i + pinyin_len, len(text)) py_end = min(i + pinyin_len, len(text))
# 获取拼音 # 获取拼音
@ -317,7 +352,7 @@ class TextEvaluator:
string_list.append(text[start_pos:end_pos]) string_list.append(text[start_pos:end_pos])
part4 = "|".join(string_list) part4 = "|".join(string_list)
# 标签预测字符的ID # 标签预测字符的ID(整个拼音序列的所有字符)
labels = [] labels = []
try: try:
labels = [ labels = [
@ -328,27 +363,38 @@ class TextEvaluator:
) )
] ]
except (AttributeError, TypeError) as e: except (AttributeError, TypeError) as e:
# 如果查询失败使用默认ID
print(f"⚠️ 获取标签失败: {e}") print(f"⚠️ 获取标签失败: {e}")
labels = [0] * pinyin_len_actual labels = [0] * pinyin_len_actual
# 历史槽位:模拟用户逐步确认过程 # 历史槽位:与训练数据生成逻辑一致
# 对于多字符预测pinyin_len_actual > 1需要模拟用户逐步选择 # 当 force_slot_count=k 时,历史槽位为 labels[:k],预测目标为 labels[k]
# 注意:评估时不知道正确答案,只能基于模型预测来构建历史
history_slot_ids = [] history_slot_ids = []
predict_idx = 0 # 当前预测的标签索引
# 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致)
if force_slot_count is not None: if force_slot_count is not None:
force_slot_count = max(0, min(8, force_slot_count)) force_slot_count = max(0, min(8, force_slot_count))
# 创建模拟的历史槽位前force_slot_count个为有效槽位其余为0 # 检查 labels 长度是否足够
# 这里使用简单的模拟有效槽位用1-1000的随机ID与训练时随机填充逻辑一致 if len(labels) <= force_slot_count:
for _ in range(force_slot_count): raise ValueError(
history_slot_ids.append(random.randint(1, 1000)) f"标签数量 {len(labels)} 不足以支持历史槽位 {force_slot_count}"
)
predict_idx = force_slot_count
history_slot_ids = labels[:force_slot_count]
else: else:
# 正常评估:历史槽位为空(模拟从零开始输入) # 正常评估:随机选择历史长度,模拟训练数据分布
# 注意:实际使用时,历史槽位应该来自之前的用户选择 history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
# 但为了评估公平,我们从空历史开始 max_history = min(len(labels) - 1, 8) # 至少留一个字符预测
pass if max_history < 0:
max_history = 0
valid_history_lens = list(range(0, max_history + 1))
if valid_history_lens:
adjusted_probs = [history_weights[h] for h in valid_history_lens]
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
history_len = np.random.choice(valid_history_lens, p=adjusted_probs)
else:
history_len = 0
predict_idx = history_len
history_slot_ids = labels[:history_len]
# 填充到8个槽位 # 填充到8个槽位
if len(history_slot_ids) < 8: if len(history_slot_ids) < 8:
@ -373,6 +419,12 @@ class TextEvaluator:
# 拼音输入长度(字符数) # 拼音输入长度(字符数)
pinyin_input_length = len(part2) pinyin_input_length = len(part2)
# 当前预测的字符和标签
current_target_char = (
text[i + predict_idx] if (i + predict_idx) < len(text) else ""
)
current_target_label = labels[predict_idx] if predict_idx < len(labels) else 0
# 构建样本 # 构建样本
sample = { sample = {
"input_ids": encoded["input_ids"], "input_ids": encoded["input_ids"],
@ -385,11 +437,13 @@ class TextEvaluator:
"suffix": part3, "suffix": part3,
"pinyin": part2, "pinyin": part2,
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0), "pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
"true_labels": labels, # 真实标签(多个字符) "true_labels": labels, # 完整标签列表(所有字符)
"predict_idx": predict_idx, # 当前预测的标签索引
"current_target_label": current_target_label, # 当前预测的标签
"position": i, "position": i,
"target_char": text[i] if i < len(text) else "", "target_char": current_target_char, # 当前预测的字符
"valid_slot_count": valid_slot_count, # 有效槽位数量 "valid_slot_count": valid_slot_count,
"pinyin_input_length": pinyin_input_length, # 拼音输入长度 "pinyin_input_length": pinyin_input_length,
} }
return sample return sample
@ -483,56 +537,50 @@ class TextEvaluator:
return analysis return analysis
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: def evaluate_sample_single_char(
self, sample: Dict, print_details: bool = True
) -> Dict:
""" """
评估单个样本 单字符评估与验证集 trainer.py 评估逻辑一致
使用 argmax 选择预测字符与训练时验证逻辑完全一致
根据 predict_idx 选择正确的预测目标
Returns: Returns:
评估结果字典 评估结果字典
""" """
# 执行推理
logits, probs = self.inference(sample) logits, probs = self.inference(sample)
# 分析概率分布
analysis = self.analyze_probability_distribution(probs) analysis = self.analyze_probability_distribution(probs)
# 获取真实标签
true_labels = sample.get("true_labels", []) true_labels = sample.get("true_labels", [])
predict_idx = sample.get("predict_idx", 0)
current_target_label = sample.get(
"current_target_label", true_labels[predict_idx] if true_labels else 0
)
target_char = sample.get("target_char", "") target_char = sample.get("target_char", "")
# 检查预测是否正确(只检查第一个字符) pred_idx = analysis["top_indices"][0]
correct = False correct = pred_idx == current_target_label
pred_top_idx = analysis["top_indices"][0]
if true_labels and len(true_labels) > 0:
correct = pred_top_idx == true_labels[0]
# 打印结果
if print_details: if print_details:
# 获取槽位信息
valid_slot_count = sample.get("valid_slot_count", 0) valid_slot_count = sample.get("valid_slot_count", 0)
pinyin_input_length = sample.get("pinyin_input_length", 0) pinyin_input_length = sample.get("pinyin_input_length", 0)
# 简化输出格式(在同一行继续输出)
print(f" {target_char}", end="") print(f" {target_char}", end="")
if true_labels: print(f"(ID:{current_target_label})", end="")
print(f"(ID:{true_labels[0]})", end="")
print( print(
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
end="", end="",
) )
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
# 关键概率信息
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
if analysis["low_variance"]: if analysis["low_variance"]:
print(f" ⚠️", end="") print(f" ⚠️", end="")
# Top-2 预测
top_strs = [] top_strs = []
for j in range(min(2, len(analysis["top_indices"]))): for j in range(min(2, len(analysis["top_indices"]))):
idx = analysis["top_indices"][j] idx = analysis["top_indices"][j]
prob = analysis["top_probs"][j] prob = analysis["top_probs"][j]
# 处理特殊ID
if idx == 0: if idx == 0:
char = "[空]" char = "[空]"
else: else:
@ -542,34 +590,43 @@ class TextEvaluator:
else None else None
) )
char = char_info.char if char_info else f"[{idx}]" char = char_info.char if char_info else f"[{idx}]"
is_correct = ( is_correct = idx == current_target_label
(true_labels and idx == true_labels[0]) if true_labels else False
)
correct_mark = "" if is_correct else "" correct_mark = "" if is_correct else ""
top_strs.append(f"{char}{correct_mark}({prob:.2f})") top_strs.append(f"{char}{correct_mark}({prob:.2f})")
print(f" | Top-2: {' '.join(top_strs)}", end="") print(f" | Top-2: {' '.join(top_strs)}", end="")
print(f" | 结果:{'' if correct else ''}") print(f" | 结果:{'' if correct else ''}")
# 返回评估结果 return {
result = {
"sample": sample, "sample": sample,
"analysis": analysis, "analysis": analysis,
"correct": correct, "correct": correct,
"target_char": target_char, "target_char": target_char,
"true_label": true_labels[0] if true_labels else None, "true_label": current_target_label,
"predicted_label": pred_top_idx, "predicted_label": pred_idx,
"predict_idx": predict_idx,
} }
return result def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
"""
评估单个样本兼容旧接口调用单字符评估
Returns:
评估结果字典
"""
return self.evaluate_sample_single_char(sample, print_details)
def evaluate_sample_with_sequential_confirmation( def evaluate_sample_with_sequential_confirmation(
self, sample: Dict, print_details: bool = True self, sample: Dict, print_details: bool = True, use_oracle_history: bool = True
) -> Dict: ) -> Dict:
""" """
评估样本并模拟用户逐步确认过程用于多字符预测 评估样本并模拟用户逐步确认过程用于多字符预测
对于拼音长度>1的情况模拟用户逐步选择Top-1预测作为已确认字符 Args:
sample: 样本字典
print_details: 是否打印详细信息
use_oracle_history: True=使用真实标签作为历史与训练/验证一致False=使用预测结果
对于拼音长度>1的情况模拟用户逐步选择过程
Returns: Returns:
评估结果字典 评估结果字典
@ -583,7 +640,7 @@ class TextEvaluator:
return self.evaluate_sample(sample, print_details) return self.evaluate_sample(sample, print_details)
# 多字符预测:模拟逐步确认 # 多字符预测:模拟逐步确认
confirmed_history = [] # 用户已确认的字符ID confirmed_history = [] # 已确认的字符ID
all_predictions = [] # 每一步的预测结果 all_predictions = [] # 每一步的预测结果
all_correct = [] # 每一步是否正确 all_correct = [] # 每一步是否正确
@ -591,8 +648,16 @@ class TextEvaluator:
sample_copy = sample.copy() sample_copy = sample.copy()
for step in range(pinyin_len): for step in range(pinyin_len):
# 更新历史槽位:使用已确认的字符 # 更新历史槽位:
# - oracle模式使用真实标签作为历史与训练一致
# - user模式使用模型预测作为历史模拟实际使用
if use_oracle_history:
# 使用真实标签作为历史(与训练数据生成逻辑一致)
history_slot_ids = true_labels[:step]
else:
# 使用预测结果作为历史
history_slot_ids = confirmed_history[:] history_slot_ids = confirmed_history[:]
if len(history_slot_ids) < 8: if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids))) history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else: else:
@ -626,9 +691,8 @@ class TextEvaluator:
) )
all_correct.append(correct) all_correct.append(correct)
# 模拟用户选择将Top-1预测添加到已确认历史 # 更新已确认历史(用于下一轮)
# 注意这里假设用户总是选择Top-1结果 if pred_top_idx > 0:
if pred_top_idx > 0: # 只添加有效ID
confirmed_history.append(pred_top_idx) confirmed_history.append(pred_top_idx)
# 计算整体准确率:所有步骤都正确才算正确 # 计算整体准确率:所有步骤都正确才算正确
@ -643,6 +707,8 @@ class TextEvaluator:
print(f" {target_char}", end="") print(f" {target_char}", end="")
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="") print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="") print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
mode_str = "oracle" if use_oracle_history else "user"
print(f" | [{mode_str}]", end="")
# 显示逐步预测结果 # 显示逐步预测结果
step_results = [] step_results = []
@ -672,13 +738,16 @@ class TextEvaluator:
return result return result
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: def evaluate_text(
self, text: str, num_samples: int = 10, use_sequential: bool = False
) -> List[Dict]:
""" """
评估给定文本生成多个样本进行评估 评估给定文本生成多个样本进行评估
Args: Args:
text: 输入文本 text: 输入文本
num_samples: 生成的样本数量 num_samples: 生成的样本数量
use_sequential: True=使用逐步确认评估False=使用单字符评估与验证集一致
Returns: Returns:
评估结果列表 评估结果列表
@ -699,8 +768,13 @@ class TextEvaluator:
try: try:
sample = self.create_sample_from_text(text) sample = self.create_sample_from_text(text)
# 使用逐步确认评估 # 默认使用单字符评估(与验证集一致)
if use_sequential:
result = self.evaluate_sample_with_sequential_confirmation( result = self.evaluate_sample_with_sequential_confirmation(
sample, print_details=True, use_oracle_history=True
)
else:
result = self.evaluate_sample_single_char(
sample, print_details=True sample, print_details=True
) )
results.append(result) results.append(result)
@ -718,7 +792,7 @@ class TextEvaluator:
return results return results
def evaluate_by_slot_count( def evaluate_by_slot_count(
self, text: str, samples_per_slot: int = 100 self, text: str, samples_per_slot: int = 100, use_sequential: bool = False
) -> Dict[int, Dict]: ) -> Dict[int, Dict]:
""" """
按槽位数量分别评估每个槽位数量测试 samples_per_slot 个样本 按槽位数量分别评估每个槽位数量测试 samples_per_slot 个样本
@ -726,6 +800,7 @@ class TextEvaluator:
Args: Args:
text: 输入文本 text: 输入文本
samples_per_slot: 每个槽位数量的测试样本数 samples_per_slot: 每个槽位数量的测试样本数
use_sequential: True=使用逐步确认评估False=使用单字符评估与验证集一致
Returns: Returns:
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}} 按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
@ -735,7 +810,8 @@ class TextEvaluator:
text = text[start : start + 300] text = text[start : start + 300]
print(f"\n{'=' * 60}") print(f"\n{'=' * 60}")
print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本") mode_str = "逐步确认" if use_sequential else "单字符(与验证集一致)"
print(f"按槽位数量评估 [{mode_str}] | 每组 {samples_per_slot} 个样本")
print(f"{'=' * 60}") print(f"{'=' * 60}")
slot_results = {} slot_results = {}
@ -750,12 +826,19 @@ class TextEvaluator:
sample = self.create_sample_from_text( sample = self.create_sample_from_text(
text, force_slot_count=slot_count text, force_slot_count=slot_count
) )
# 使用逐步确认评估(更符合实际使用场景) # 默认使用单字符评估(与验证集一致)
if use_sequential:
result = self.evaluate_sample_with_sequential_confirmation( result = self.evaluate_sample_with_sequential_confirmation(
sample, print_details=False, use_oracle_history=True
)
else:
result = self.evaluate_sample_single_char(
sample, print_details=False sample, print_details=False
) )
results.append(result) results.append(result)
# 注意逐步确认评估没有low_variance分析 # 统计低方差样本
if "analysis" in result and result["analysis"].get("low_variance"):
low_var_count += 1
except Exception: except Exception:
failed += 1 failed += 1
continue continue
@ -764,8 +847,9 @@ class TextEvaluator:
correct_count = sum(1 for r in results if r["correct"]) correct_count = sum(1 for r in results if r["correct"])
accuracy = correct_count / len(results) accuracy = correct_count / len(results)
# 对于逐步确认评估,计算平均步数准确率 # 计算平均步数准确率(仅用于逐步确认模式)
step_accuracies = [] step_accuracies = []
if use_sequential:
for r in results: for r in results:
if "step_correct" in r: if "step_correct" in r:
step_correct = r["step_correct"] step_correct = r["step_correct"]
@ -781,17 +865,25 @@ class TextEvaluator:
"correct": correct_count, "correct": correct_count,
"total": len(results), "total": len(results),
"mean_step_accuracy": mean_step_accuracy, "mean_step_accuracy": mean_step_accuracy,
"low_var_count": low_var_count,
"failed": failed, "failed": failed,
} }
bar_len = 30 bar_len = 30
filled = int(bar_len * accuracy) filled = int(bar_len * accuracy)
bar = "" * filled + "" * (bar_len - filled) bar = "" * filled + "" * (bar_len - filled)
if use_sequential:
print( print(
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
f"({correct_count:2d}/{len(results):2d}) " f"({correct_count:2d}/{len(results):2d}) "
f"| 平均步数准确率: {mean_step_accuracy:.1%}" f"| 平均步数准确率: {mean_step_accuracy:.1%}"
) )
else:
print(
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
f"({correct_count:2d}/{len(results):2d}) "
f"| 低方差: {low_var_count}"
)
else: else:
slot_results[slot_count] = { slot_results[slot_count] = {
"results": [], "results": [],
@ -799,6 +891,7 @@ class TextEvaluator:
"correct": 0, "correct": 0,
"total": 0, "total": 0,
"mean_step_accuracy": 0, "mean_step_accuracy": 0,
"low_var_count": 0,
"failed": failed, "failed": failed,
} }
print(f" 槽位 {slot_count}/8 | 全部样本生成失败") print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
@ -851,11 +944,11 @@ def main():
parser.add_argument( parser.add_argument(
"--checkpoint", "--checkpoint",
type=str, type=str,
default="/home/songsenand/下载/best_model.pt", default="/home/songsenand/下载/best_model.ptrom",
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.ptrom)",
) )
parser.add_argument( parser.add_argument(
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)" "--num-samples", type=int, default=100, help="评估样本数量 (默认: 100)"
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
@ -864,6 +957,11 @@ def main():
choices=["auto", "cpu", "cuda"], choices=["auto", "cpu", "cuda"],
help="推理设备 (默认: auto)", help="推理设备 (默认: auto)",
) )
parser.add_argument(
"--use-sequential",
action="store_true",
help="使用逐步确认评估模式(默认使用单字符评估,与验证集一致)",
)
args = parser.parse_args() args = parser.parse_args()
@ -890,6 +988,9 @@ def main():
print(f"🔧 使用设备: {device}") print(f"🔧 使用设备: {device}")
print(f"📦 模型checkpoint: {args.checkpoint}") print(f"📦 模型checkpoint: {args.checkpoint}")
print(f"🔬 评估样本数: {args.num_samples}") print(f"🔬 评估样本数: {args.num_samples}")
print(
f"📊 评估模式: {'逐步确认' if args.use_sequential else '单字符(与验证集一致)'}"
)
# 初始化评估器 # 初始化评估器
try: try:
@ -902,7 +1003,9 @@ def main():
sys.exit(1) sys.exit(1)
# 执行评估 # 执行评估
evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples) evaluator.evaluate_by_slot_count(
text, samples_per_slot=args.num_samples, use_sequential=args.use_sequential
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -28,6 +28,11 @@ class AttentionPooling(nn.Module):
# ---------------------------- 拼音LSTM编码器 ---------------------------- # ---------------------------- 拼音LSTM编码器 ----------------------------
class PinyinLSTMEncoder(nn.Module): class PinyinLSTMEncoder(nn.Module):
"""
拼音序列编码器返回每位置的拼音编码
使用双向LSTM每个位置都能看到前后文信息
"""
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2): def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
super().__init__() super().__init__()
self.input_dim = input_dim self.input_dim = input_dim
@ -35,7 +40,6 @@ class PinyinLSTMEncoder(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.dropout = dropout self.dropout = dropout
# Bidirectional LSTM
self.lstm = nn.LSTM( self.lstm = nn.LSTM(
input_size=input_dim, input_size=input_dim,
hidden_size=self.hidden_dim, hidden_size=self.hidden_dim,
@ -45,7 +49,6 @@ class PinyinLSTMEncoder(nn.Module):
dropout=dropout if num_layers > 1 else 0.0, dropout=dropout if num_layers > 1 else 0.0,
) )
# Project concatenated hidden states to input_dim
self.proj = nn.Linear(self.hidden_dim * 2, input_dim) self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
self.layer_norm = nn.LayerNorm(input_dim) self.layer_norm = nn.LayerNorm(input_dim)
@ -53,32 +56,25 @@ class PinyinLSTMEncoder(nn.Module):
""" """
Args: Args:
x: [batch, seq_len, input_dim] pinyin embeddings x: [batch, seq_len, input_dim] pinyin embeddings
mask: [batch, seq_len] optional padding mask (0 for padding) mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
Returns: Returns:
pooled: [batch, input_dim] global pinyin representation output: [batch, seq_len, input_dim] 每位置的拼音编码
""" """
total_len = x.size(1)
if mask is not None: if mask is not None:
# lengths for pack_padded_sequence lengths = mask.sum(dim=1).cpu().clamp(min=1)
lengths = mask.sum(dim=1).cpu()
# pack sequence
packed = nn.utils.rnn.pack_padded_sequence( packed = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False x, lengths, batch_first=True, enforce_sorted=False
) )
packed_out, (hidden, cell) = self.lstm(packed) packed_out, (hidden, cell) = self.lstm(packed)
# hidden shape: [num_layers * 2, batch, hidden_dim] output, _ = nn.utils.rnn.pad_packed_sequence(
# Take last layer's forward and backward hidden states packed_out, batch_first=True, total_length=total_len
forward_hidden = hidden[-2, :, :] # last layer forward )
backward_hidden = hidden[-1, :, :] # last layer backward
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
else: else:
# No mask, assume all sequences same length
output, (hidden, cell) = self.lstm(x) output, (hidden, cell) = self.lstm(x)
# hidden shape: [num_layers * 2, batch, hidden_dim]
forward_hidden = hidden[-2, :, :]
backward_hidden = hidden[-1, :, :]
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
projected = self.proj(hidden_concat) projected = self.proj(output)
return self.layer_norm(projected) return self.layer_norm(projected)
@ -170,46 +166,32 @@ class ContextEncoder(nn.Module):
def forward(self, text_ids, pinyin_ids, mask=None): def forward(self, text_ids, pinyin_ids, mask=None):
""" """
Args: Args:
text_ids: [batch, seq_len] text_ids: [batch, seq_len] 文本 token ids
pinyin_ids: [batch, seq_len] (假设已对齐若不对齐需预处理) pinyin_ids: [batch, pinyin_len] 拼音 token ids
mask: [batch, seq_len] optional padding mask mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding)
Returns: Returns:
H: [batch, seq_len, 512] Context representation [1] H: [batch, seq_len, dim] 文本上下文编码
P: [batch, pinyin_len, dim] 拼音序列编码每位置
""" """
# 1. Embed text text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
text_emb = self.text_emb(text_ids) # [B, 128, dim]
# 2. Embed and pool pinyin to global feature seq_len = text_emb.size(1)
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim] pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0)
# LSTM encoder with masking for padding x = text_emb + self.pos_emb(pos_ids)
pinyin_mask = pinyin_ids != 0
pinyin_global = self.pinyin_pooling(
pinyin_emb, mask=pinyin_mask
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
# Broadcast pinyin to all text positions
pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim]
pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim]
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
x = text_emb + pinyin_broadcast
seq_len = x.size(1)
pos_ids = (
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
)
x += self.pos_emb(pos_ids)
# 2. Transformer Encoding
# src_key_padding_mask expects True for padding positions
if mask is not None: if mask is not None:
# Convert 0/1 mask to bool mask where True is padding
src_mask = mask == 0 src_mask = mask == 0
else: else:
src_mask = None src_mask = None
H = self.transformer(x, src_key_padding_mask=src_mask) H = self.transformer(x, src_key_padding_mask=src_mask)
return self.ln(H) H = self.ln(H)
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
pinyin_mask = pinyin_ids != 0
P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim]
return H, P
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -256,76 +238,100 @@ class SlotMemory(nn.Module):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 3. 交叉注意力融合 (Cross-Attention Fusion) # 3. 交叉注意力融合 (Cross-Attention Fusion)
# 对应 README: Query=Slots, Key/Value=Context H [1] # 槽位同时查询文本上下文和拼音序列
# ------------------------------------------------------------------ # ------------------------------------------------------------------
class CrossAttentionFusion(nn.Module): class CrossAttentionFusion(nn.Module):
"""
双路交叉注意力
- 槽位查询文本上下文 (H)
- 槽位查询拼音序列 (P)
通过注意力机制start_emb 自动学会关注"当前应该预测的拼音位置"
"""
def __init__(self, dim=512, n_heads=4): def __init__(self, dim=512, n_heads=4):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads" assert self.head_dim * n_heads == dim
# Linear projections for Q, K, V
self.q_proj = nn.Linear(dim, dim, bias=False) self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False) self.k_text_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False) self.v_text_proj = nn.Linear(dim, dim, bias=False)
self.k_pinyin_proj = nn.Linear(dim, dim, bias=False)
self.v_pinyin_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False)
self.ln = nn.LayerNorm(dim) self.ln = nn.LayerNorm(dim)
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None): def forward(
self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None
):
""" """
Args: Args:
slots_S: [batch, num_slots_steps, dim] Query slots_S: [batch, num_slots, dim] 槽位编码
context_H: [batch, ctx_len, dim] Key/Value context_H: [batch, ctx_len, dim] 文本上下文编码
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention) pinyin_P: [batch, pinyin_len, dim] 拼音序列编码
context_mask: [batch, ctx_len] Optional padding mask context_mask: [batch, ctx_len] 文本 padding mask (True for padding)
pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding)
Returns: Returns:
Fused: [batch, num_slots_steps, dim] fused: [batch, num_slots, dim]
""" """
batch_size, num_slots, _ = slots_S.shape batch_size, num_slots, _ = slots_S.shape
_, ctx_len, _ = context_H.shape ctx_len = context_H.size(1)
pinyin_len = pinyin_P.size(1)
# Project queries, keys, values Q = self.q_proj(slots_S)
Q = self.q_proj(slots_S) # [batch, num_slots, dim] K_text = self.k_text_proj(context_H)
K = self.k_proj(context_H) # [batch, ctx_len, dim] V_text = self.v_text_proj(context_H)
V = self.v_proj(context_H) # [batch, ctx_len, dim] K_pinyin = self.k_pinyin_proj(pinyin_P)
V_pinyin = self.v_pinyin_proj(pinyin_P)
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2) Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) K_text = K_text.view(
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2) batch_size, ctx_len, self.n_heads, self.head_dim
).transpose(1, 2)
V_text = V_text.view(
batch_size, ctx_len, self.n_heads, self.head_dim
).transpose(1, 2)
K_pinyin = K_pinyin.view(
batch_size, pinyin_len, self.n_heads, self.head_dim
).transpose(1, 2)
V_pinyin = V_pinyin.view(
batch_size, pinyin_len, self.n_heads, self.head_dim
).transpose(1, 2)
# Prepare attention mask if context_mask is provided text_attn_mask = None
attn_mask = None
if context_mask is not None: if context_mask is not None:
# context_mask: [batch, ctx_len] where 0 means padding text_attn_mask = (
# Convert to bool mask and reshape for broadcasting context_mask[:, None, None, :]
bool_mask = context_mask == 0 # [batch, ctx_len] .float()
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len] .masked_fill(context_mask[:, None, None, :], -1e9)
# Convert to float mask where True (padding) becomes -inf
attn_mask = bool_mask.float().masked_fill(bool_mask, -1e9)
# Scaled dot-product attention
attn_output = F.scaled_dot_product_attention(
Q,
K,
V,
attn_mask=attn_mask,
dropout_p=0.0, # no dropout
) )
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim] pinyin_attn_mask = None
attn_output = ( if pinyin_mask is not None:
attn_output.transpose(1, 2) pinyin_attn_mask = (
pinyin_mask[:, None, None, :]
.float()
.masked_fill(pinyin_mask[:, None, None, :], -1e9)
)
text_attn = F.scaled_dot_product_attention(
Q, K_text, V_text, attn_mask=text_attn_mask
)
pinyin_attn = F.scaled_dot_product_attention(
Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask
)
combined_attn = text_attn + pinyin_attn
combined_attn = (
combined_attn.transpose(1, 2)
.contiguous() .contiguous()
.view(batch_size, num_slots, self.dim) .view(batch_size, num_slots, self.dim)
) )
fused = self.out_proj(combined_attn)
# Project back
fused = self.out_proj(attn_output)
# Residual connection and layer norm
fused = self.ln(fused + slots_S) fused = self.ln(fused + slots_S)
return fused return fused

View File

@ -113,31 +113,24 @@ class InputMethodEngine(nn.Module):
""" """
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
# 处理 history_slot_ids确保为 [batch_size, num_slots]
# 使用 view 替代 if 判断,避免 torch.compile 图断开
history_slot_ids = history_slot_ids.view(-1, self.num_slots) history_slot_ids = history_slot_ids.view(-1, self.num_slots)
# 1. 上下文编码 -> H [batch, seq_len, dim] H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
# 注意ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
# 2. 槽位记忆编码 -> S [batch, num_slots, dim] S = self.slot_memory(history_slot_ids)
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
# 3. 交叉注意力融合 (使用 CrossAttentionFusion) context_mask = attention_mask == 0
fused = self.cross_attn(S, H, context_mask=attention_mask) pinyin_mask = pinyin_ids == 0
fused = self.cross_attn(
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
)
# 4. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused) moe_out = self.moe(fused)
# 5. 槽位注意力池化
batch_size = input_ids.size(0) batch_size = input_ids.size(0)
# 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots]
slot_scores = self.slot_attention(moe_out).squeeze(-1) slot_scores = self.slot_attention(moe_out).squeeze(-1)
# 应用softmax获取注意力权重 slot_weights = torch.softmax(slot_scores, dim=1)
slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots] pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
# 加权求和得到池化表示
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim]
logits = self.classifier(pooled) # [batch, vocab_size] logits = self.classifier(pooled)
return logits return logits

View File

@ -1244,14 +1244,14 @@ def train(
max_seq_length=max_seq_len, max_seq_length=max_seq_len,
text_field="text", text_field="text",
py_style_weight=(9, 2, 1), py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000, shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
) )
eval_dataloader = create_dataloader( eval_dataloader = create_dataloader(
dataset=eval_dataset, dataset=eval_dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=1, # 评估使用较少的worker num_workers=2, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(), pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64, max_iter_length=batch_size * 64,
) )
@ -1356,615 +1356,5 @@ def export(
console.print("[yellow]导出功能待实现[/yellow]") console.print("[yellow]导出功能待实现[/yellow]")
@app.command()
def expand_and_train(
# 数据参数
train_data_path: str = typer.Option(
..., "--train-data-path", "-t", help="训练数据集路径"
),
eval_data_path: str = typer.Option(
..., "--eval-data-path", "-e", help="评估数据集路径"
),
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
# 模型参数
base_model_path: str = typer.Option(
..., "--base-model-path", help="预训练基础模型检查点路径"
),
new_model_spec: str = typer.Option(
...,
"--new-model-spec",
"-m",
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
),
# 数据大小
max_iter_length: int = typer.Option(
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
),
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.1, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
# 其他参数
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
num_workers: int = typer.Option(
2, "--num-workers", help="数据加载worker数量流式数据集建议为2"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
),
auto_resume: bool = typer.Option(
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: bool = typer.Option(
False,
"--compile/--no-compile",
help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
),
):
torch.multiprocessing.set_sharing_strategy("file_system")
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
# 设置随机种子
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 硬编码模型参数
vocab_size = 10019
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 10
max_seq_len = 128
use_pinyin = True # 始终使用拼音
console = Console()
# 打印配置信息
console.print(
Panel.fit(
"[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan"
)
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
# 添加配置信息
config_table.add_row("数据", "训练数据路径", train_data_path)
config_table.add_row("数据", "评估数据路径", eval_data_path)
config_table.add_row("数据", "输出目录", output_dir)
config_table.add_row("数据", "批次大小", str(batch_size))
config_table.add_row("数据", "Worker数量", str(num_workers))
config_table.add_row("模型", "基础模型路径", base_model_path)
config_table.add_row("模型", "新模型规格", new_model_spec)
config_table.add_row("模型", "词汇表大小", str(vocab_size))
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
config_table.add_row("模型", "模型维度", str(dim))
config_table.add_row("模型", "槽位数量", str(num_slots))
config_table.add_row("模型", "Transformer层数", str(n_layers))
config_table.add_row("模型", "注意力头数", str(n_heads))
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(weight_decay))
config_table.add_row("训练", "热身比例", str(warmup_ratio))
config_table.add_row("训练", "标签平滑", str(label_smoothing))
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存配置
config = {
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"base_model_path": base_model_path,
"new_model_spec": new_model_spec,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"use_pinyin": use_pinyin,
"batch_size": batch_size,
"num_workers": num_workers,
"num_epochs": num_epochs,
"learning_rate": learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_ratio": warmup_ratio,
"label_smoothing": label_smoothing,
"grad_accum_steps": grad_accum_steps,
"clip_grad_norm": clip_grad_norm,
"eval_frequency": eval_frequency,
"save_frequency": save_frequency,
"mixed_precision": mixed_precision,
"use_tensorboard": use_tensorboard,
"seed": seed,
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
}
config_file = output_path / "expansion_training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
# 评估数据集
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=1, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
# 创建扩容模型
console.print("[bold cyan]正在创建扩容模型...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_kwargs = {
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"compile": compile,
}
model = load_expanded_model(
base_model_path=base_model_path,
new_model_spec=new_model_spec,
device=device,
**model_kwargs,
)
console.print(
f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
# 统计冻结参数比例
total_params = sum(p.numel() for p in model.parameters())
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
console.print(
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]"
)
# 创建训练器(使用普通 Trainer只进行第一阶段冻结训练
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size),
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
warmup_ratio=warmup_ratio,
label_smoothing=label_smoothing,
grad_accum_steps=grad_accum_steps,
clip_grad_norm=clip_grad_norm,
eval_frequency=eval_frequency,
save_frequency=save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
# 开始训练
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from,
reset_training_state=reset_training_state,
auto_resume=auto_resume,
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
# 保存扩容信息供第二阶段使用
expansion_info = {
"stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"),
"model_spec": new_model_spec,
"model_kwargs": model_kwargs,
"train_data_path": train_data_path,
"eval_data_path": eval_data_path,
"output_dir": output_dir,
"batch_size": batch_size,
"max_iter_length": max_iter_length,
"max_seq_len": max_seq_len,
"num_workers": num_workers,
}
expansion_info_file = output_path / "expansion_info.json"
with open(expansion_info_file, "w", encoding="utf-8") as f:
json.dump(expansion_info, f, indent=2, ensure_ascii=False)
logger.info(f"Expansion info saved to {expansion_info_file}")
console.print("[bold green]✓ 第一阶段训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {output_dir}")
console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]")
console.print(
"[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]"
)
@app.command()
def expand_finetune(
expand_config: str = typer.Option(
...,
"--expand-config",
"-c",
help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'",
),
stage1_info: str = typer.Option(
..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径"
),
# 可选覆盖参数
checkpoint: Optional[str] = typer.Option(
None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)"
),
output_dir: Optional[str] = typer.Option(
None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)"
),
train_data_path: Optional[str] = typer.Option(
None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)"
),
eval_data_path: Optional[str] = typer.Option(
None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)"
),
batch_size: Optional[int] = typer.Option(
None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)"
),
num_epochs: Optional[int] = typer.Option(
None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)"
),
learning_rate: Optional[float] = typer.Option(
None, "--learning-rate", "-lr", help="学习率"
),
min_learning_rate: Optional[float] = typer.Option(
None, "--min-learning-rate", help="最小学习率"
),
weight_decay: Optional[float] = typer.Option(
None, "--weight-decay", help="权重衰减"
),
warmup_ratio: Optional[float] = typer.Option(
None, "--warmup-ratio", help="热身步数比例"
),
label_smoothing: Optional[float] = typer.Option(
None, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: Optional[int] = typer.Option(
None, "--grad-accum-steps", help="梯度累积步数"
),
clip_grad_norm: Optional[float] = typer.Option(
None, "--clip-grad-norm", help="梯度裁剪范数"
),
eval_frequency: Optional[int] = typer.Option(
None, "--eval-frequency", help="评估频率"
),
save_frequency: Optional[int] = typer.Option(
None, "--save-frequency", help="保存频率"
),
max_iter_length: Optional[int] = typer.Option(
None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)"
),
max_seq_len: Optional[int] = typer.Option(
None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)"
),
num_workers: Optional[int] = typer.Option(
None, "--num-workers", help="数据加载worker数量"
),
mixed_precision: bool = typer.Option(
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
),
use_tensorboard: bool = typer.Option(
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
),
resume_from: Optional[str] = typer.Option(
None, "--resume-from", help="从检查点恢复训练"
),
reset_training_state: bool = typer.Option(
False, "--reset-training-state", help="重置训练状态"
),
auto_resume: bool = typer.Option(
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
),
seed: int = typer.Option(42, "--seed", help="随机种子"),
compile: Optional[bool] = typer.Option(
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
),
):
"""
模型扩容第二阶段训练读取第一阶段的 expansion_info.json加载扩容模型进行全量微调
命令行参数优先级高于 JSON 文件中的配置
"""
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
console = Console()
# 加载第一阶段信息
stage1_info_path = Path(stage1_info)
if not stage1_info_path.exists():
console.print(
f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]"
)
raise typer.Exit(1)
with open(stage1_info_path, "r", encoding="utf-8") as f:
info = json.load(f)
# 命令行参数优先级高于 JSON 文件
final_checkpoint = checkpoint or info["stage1_checkpoint_path"]
final_output_dir = output_dir or info["output_dir"]
final_train_data_path = train_data_path or info["train_data_path"]
final_eval_data_path = eval_data_path or info["eval_data_path"]
final_batch_size = batch_size if batch_size is not None else info["batch_size"]
final_num_epochs = (
num_epochs if num_epochs is not None else info.get("num_epochs", 10)
)
final_max_iter_length = (
max_iter_length if max_iter_length is not None else info["max_iter_length"]
)
final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"]
final_num_workers = (
num_workers if num_workers is not None else info.get("num_workers", 2)
)
# 训练参数(有默认值,不覆盖则使用默认)
final_learning_rate = learning_rate if learning_rate is not None else 1e-4
final_min_learning_rate = (
min_learning_rate if min_learning_rate is not None else 1e-9
)
final_weight_decay = weight_decay if weight_decay is not None else 0.1
final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1
final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15
final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1
final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0
final_eval_frequency = eval_frequency if eval_frequency is not None else 500
final_save_frequency = save_frequency if save_frequency is not None else 1000
# 模型参数从 JSON 获取
model_kwargs = info["model_kwargs"]
if compile is not None:
model_kwargs["compile"] = compile
console.print(
Panel.fit(
"[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan"
)
)
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path))
config_table.add_row("数据", "训练数据路径", final_train_data_path)
config_table.add_row("数据", "评估数据路径", final_eval_data_path)
config_table.add_row("数据", "输出目录", final_output_dir)
config_table.add_row("数据", "批次大小", str(final_batch_size))
config_table.add_row("数据", "Worker数量", str(final_num_workers))
config_table.add_row("模型", "新模型规格", expand_config)
config_table.add_row("模型", "检查点路径", final_checkpoint)
for k, v in model_kwargs.items():
config_table.add_row("模型", k, str(v))
config_table.add_row("训练", "训练轮数", str(final_num_epochs))
config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(final_weight_decay))
config_table.add_row("训练", "热身比例", str(final_warmup_ratio))
config_table.add_row("训练", "标签平滑", str(final_label_smoothing))
config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
output_path = Path(final_output_dir)
output_path.mkdir(parents=True, exist_ok=True)
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=final_train_data_path,
max_workers=-1,
max_iter_length=final_max_iter_length,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=final_batch_size,
num_workers=final_num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_max_iter_length,
)
eval_dataset = PinyinInputDataset(
data_path=final_eval_data_path,
max_workers=-1,
max_iter_length=final_batch_size * 64,
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=final_batch_size,
num_workers=1,
pin_memory=torch.cuda.is_available(),
max_iter_length=final_batch_size * 64,
)
console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_expanded_model(
base_model_path=final_checkpoint,
new_model_spec=expand_config,
device=device,
**model_kwargs,
)
# 全量微调:解冻所有参数
for param in model.parameters():
param.requires_grad = True
console.print(
f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
)
console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]")
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size),
output_dir=final_output_dir,
num_epochs=final_num_epochs,
learning_rate=final_learning_rate,
min_learning_rate=final_min_learning_rate,
weight_decay=final_weight_decay,
warmup_ratio=final_warmup_ratio,
label_smoothing=final_label_smoothing,
grad_accum_steps=final_grad_accum_steps,
clip_grad_norm=final_clip_grad_norm,
eval_frequency=final_eval_frequency,
save_frequency=final_save_frequency,
mixed_precision=mixed_precision,
use_tensorboard=use_tensorboard,
status_file="training_status_finetune.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=resume_from,
reset_training_state=reset_training_state,
auto_resume=auto_resume,
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {final_output_dir}")
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为" part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
part2 = "weiruo" part2 = "piaoliang"
pinyin_ids = text_to_pinyin_ids(part2) pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids) len_py = len(pinyin_ids)
if len_py < 24: if len_py < 24:
@ -56,7 +56,7 @@ if len_py < 24:
else: else:
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [649, 925, 0, 0, 0, 0, 0, 0] masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
part3 = "" part3 = ""
part4 = "可行|特别|伤害" part4 = "可行|特别|伤害"
@ -83,7 +83,7 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False) model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu") checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
input_ids = sample["input_ids"] input_ids = sample["input_ids"]

View File

@ -1,31 +0,0 @@
import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.dataset import PinyinInputDataset
from model.trainer import collate_fn, worker_init_fn
max_iter_length = 128 * 128
batch_size = 1024
if sys.platform == "win32":
dataset_path = "data"
else:
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=64, # 每个worker预取64个batch适合大内存场景
persistent_workers=True, # 保持worker进程存活避免重建开销
)
dataloader = list([i for i in dataloader])
for i, line in tqdm(enumerate(dataloader), total=max_iter_length / batch_size):
print((line["labels"] == 0).sum())

232
transfer_weights.py Normal file
View File

@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
参数迁移脚本将旧模型权重迁移到新架构
使用方法
python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt
功能
1. 加载旧模型 checkpoint
2. 创建新模型架构
3. 直接迁移匹配的参数
4. 拆分迁移 k_proj/v_proj text/pinyin 双分支
5. 保存迁移后的 checkpoint
6. 打印详细迁移报告
"""
import argparse
import sys
from pathlib import Path
import torch
sys.path.append("src")
from src.model.model import InputMethodEngine
def transfer_weights(
old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device
):
"""
迁移旧模型权重到新架构
Args:
old_checkpoint_path: checkpoint 路径
new_model: 新模型实例
device: 设备
Returns:
(directly_transferred, split_transferred, new_params, skipped) 迁移报告
"""
old_checkpoint = torch.load(old_checkpoint_path, map_location=device)
if "model_state_dict" in old_checkpoint:
old_state_dict = old_checkpoint["model_state_dict"]
else:
old_state_dict = old_checkpoint
new_state_dict = new_model.state_dict()
directly_transferred = []
split_transferred = []
new_params = []
skipped = []
for key in new_state_dict.keys():
if key in old_state_dict:
if new_state_dict[key].shape == old_state_dict[key].shape:
new_state_dict[key] = old_state_dict[key]
directly_transferred.append(key)
else:
skipped.append(
(
key,
"shape mismatch",
f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}",
)
)
elif key in [
"cross_attn.k_text_proj.weight",
"cross_attn.k_pinyin_proj.weight",
]:
if "cross_attn.k_proj.weight" in old_state_dict:
new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone()
split_transferred.append((key, "cross_attn.k_proj.weight"))
elif key in [
"cross_attn.v_text_proj.weight",
"cross_attn.v_pinyin_proj.weight",
]:
if "cross_attn.v_proj.weight" in old_state_dict:
new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone()
split_transferred.append((key, "cross_attn.v_proj.weight"))
else:
new_params.append(key)
new_model.load_state_dict(new_state_dict)
return directly_transferred, split_transferred, new_params, skipped
def print_report(
directly_transferred, split_transferred, new_params, skipped, output_path
):
"""打印迁移报告"""
total_params = len(directly_transferred) + len(split_transferred) + len(new_params)
coverage = (
(len(directly_transferred) + len(split_transferred)) / total_params * 100
if total_params > 0
else 100
)
print("\n" + "=" * 60)
print("📊 参数迁移报告")
print("=" * 60)
print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):")
categories = {}
for key in directly_transferred:
category = key.split(".")[0]
if category not in categories:
categories[category] = []
categories[category].append(key)
for cat, keys in sorted(categories.items()):
print(f" - {cat}.* ({len(keys)} 个参数)")
if split_transferred:
print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):")
for new_key, old_key in split_transferred:
print(f" - {new_key}{old_key}")
if new_params:
print(f"\n⚠️ 新增参数 ({len(new_params)} 层):")
for key in new_params[:10]:
print(f" - {key}")
if len(new_params) > 10:
print(f" ... 等共 {len(new_params)}")
if skipped:
print(f"\n❌ 跳过 ({len(skipped)} 层):")
for key, reason, detail in skipped[:5]:
print(f" - {key}: {reason} ({detail})")
if len(skipped) > 5:
print(f" ... 等共 {len(skipped)}")
print("\n" + "-" * 60)
print(
f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})"
)
print("=" * 60)
print(f"\n💾 已保存到: {output_path}")
def main():
parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构")
parser.add_argument(
"--old-checkpoint",
"-c",
type=str,
required=True,
help="旧模型 checkpoint 路径",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="./migrated_model.pt",
help="迁移后的 checkpoint 输出路径",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="设备 (cuda/cpu)",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"📦 加载旧模型: {args.old_checkpoint}")
print(f"🔧 使用设备: {device}")
old_path = Path(args.old_checkpoint)
if not old_path.exists():
print(f"❌ 文件不存在: {args.old_checkpoint}")
sys.exit(1)
print("\n🏗️ 创建新模型架构...")
new_model = InputMethodEngine(
vocab_size=10019,
pinyin_vocab_size=30,
dim=512,
num_slots=8,
n_layers=4,
n_heads=4,
num_experts=10,
max_seq_len=128,
compile=False,
)
new_model.to(device)
print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}")
print("\n🔄 开始迁移参数...")
directly_transferred, split_transferred, new_params, skipped = transfer_weights(
args.old_checkpoint, new_model, device
)
print("\n💾 保存迁移后的 checkpoint...")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
"step": 0,
"epoch": 0,
"model_state_dict": new_model.state_dict(),
"best_eval_loss": float("inf"),
"config": {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
},
"migration_source": args.old_checkpoint,
}
torch.save(checkpoint, output_path)
print_report(
directly_transferred, split_transferred, new_params, skipped, output_path
)
print("\n✅ 迁移完成!使用方法:")
print(
f" python -m model.trainer train --resume-from {args.output} --reset-training-state"
)
if __name__ == "__main__":
main()