docs: 删除 README.md.bak 文件
This commit is contained in:
parent
bb78e0afa0
commit
33f56f709b
916
README.md.bak
916
README.md.bak
|
|
@ -1,916 +0,0 @@
|
||||||
# 输入法预测模型架构设计 (Input Method Prediction Model)
|
|
||||||
|
|
||||||
## 1. 概述
|
|
||||||
本项目旨在构建一个轻量级、高精度的中文输入法预测模型。核心设计理念是通过**结构化槽位记忆**与**交叉注意力机制**,将当前语境(光标前后文本+拼音)与历史输入习惯深度融合。为了在有限的计算资源下保持高表达能力,模型引入了**混合专家网络 (MoE)** 模块。
|
|
||||||
|
|
||||||
## 2. 核心架构流程
|
|
||||||
数据流遵循以下路径:
|
|
||||||
`输入编码` → `Transformer 上下文编码` → `槽位记忆嵌入` → `交叉注意力融合` → `门控+专家混合 (MoE)` → `分类预测` → `束搜索解码`
|
|
||||||
|
|
||||||
### 2.1 输入层设计
|
|
||||||
模型接收三类输入,分别处理以保持语义清晰:
|
|
||||||
1. **当前文本上下文**:包含光标前文本(Prefix)和光标后文本(Suffix)。
|
|
||||||
2. **拼音序列**:与当前文本对应的拼音信息,作为增强特征融入文本编码。
|
|
||||||
3. **历史槽位序列**:最近 N 个历史输入词汇,作为结构化记忆输入。
|
|
||||||
|
|
||||||
### 2.2 模块详解
|
|
||||||
|
|
||||||
#### A. Transformer 编码器 (Context Encoder)
|
|
||||||
负责提取当前语境的深层语义表示。
|
|
||||||
* **输入处理**:将 Prefix、Suffix 及拼音通过 Embedding 层映射。拼音采用**特征叠加**或**独立 Token** 方式融入,避免双流架构的复杂性。
|
|
||||||
* **骨干网络**:使用标准的 Transformer Encoder。
|
|
||||||
* **隐藏层维度**:512 [1]
|
|
||||||
* **Transformer 层数**:4 层(轻量级设计,从头训练) [1]
|
|
||||||
* **注意力头数**:4 头 [1]
|
|
||||||
* **输出**:上下文表示 $H$,形状为 `[batch, L, 512]` [1]。
|
|
||||||
|
|
||||||
#### B. 槽位记忆模块 (Slot Memory)
|
|
||||||
负责将非结构化的历史输入转化为结构化的记忆向量。
|
|
||||||
* **嵌入方式**:历史词汇通过独立的 `Slot Embedding` 查找表映射。
|
|
||||||
* **位置编码**:添加可学习的 `Positional Embedding` 以保留历史输入的时间顺序信息。
|
|
||||||
* **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。
|
|
||||||
|
|
||||||
#### C. 交叉注意力融合 (Cross-Attention Fusion)
|
|
||||||
这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。
|
|
||||||
* **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。
|
|
||||||
* **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。
|
|
||||||
* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。
|
|
||||||
* **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。
|
|
||||||
|
|
||||||
#### D. 门控与专家混合 (Gating + MoE)
|
|
||||||
实际测试表明,移除 MoE 会导致模型性能显著下降,因此该模块对于捕捉复杂分布至关重要。
|
|
||||||
* **专家数量**:20 个专家 [1]。
|
|
||||||
* **门控机制**:根据输入特征动态选择激活部分专家,实现稀疏激活,在增加模型容量的同时控制计算成本。
|
|
||||||
* **输出**:经过专家网络增强后的特征向量。
|
|
||||||
|
|
||||||
#### E. 分类头与解码
|
|
||||||
* **分类预测**:MoE 输出的特征向量通过全连接层映射到词表空间,输出下一个字/词的概率分布。
|
|
||||||
* **解码策略**:推理阶段使用**束搜索 (Beam Search)**,束宽设为 5 [1]。
|
|
||||||
|
|
||||||
## 3. 关键超参数配置
|
|
||||||
|
|
||||||
为确保模型性能与效率的平衡,建议采用以下超参数 [1]:
|
|
||||||
|
|
||||||
| 参数项 | 推荐值 | 说明 |
|
|
||||||
| :--- | :--- | :--- |
|
|
||||||
| **序列长度 (L)** | 128 | 上下文窗口大小 [1] |
|
|
||||||
| **隐藏层维度** | 512 | Embedding 及 Transformer 内部维度 [1] |
|
|
||||||
| **Transformer 层数** | 4 | 轻量级骨干,降低延迟 [1] |
|
|
||||||
| **注意力头数** | 4 | 适配 512 维度的高效配置 [1] |
|
|
||||||
| **专家数量** | 20 | MoE 层中的专家总数,对性能至关重要 [1] |
|
|
||||||
| **束宽 (Beam Width)** | 5 | 推理时平衡速度与准确率 [1] |
|
|
||||||
| **学习率** | 1e-4 ~ 5e-4 | 建议配合 Warmup 策略 [1] |
|
|
||||||
|
|
||||||
## 4. 训练策略
|
|
||||||
|
|
||||||
本模型采用标准的**序列到序列(Seq2Seq)监督学习**范式,直接对目标槽位序列进行逐步预测。
|
|
||||||
|
|
||||||
### 4.1 数据构造与标签
|
|
||||||
* **输入三元组**:训练数据由 `(上下文, 拼音, 目标槽位序列)` 构成 [1]。
|
|
||||||
* **上下文**:光标前后的文本片段。
|
|
||||||
* **拼音**:当前待输入字的拼音序列。
|
|
||||||
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
|
||||||
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
|
||||||
|
|
||||||
### 4.2 损失函数与优化
|
|
||||||
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
|
||||||
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
|
||||||
* **优化器**:采用 **AdamW** 进行参数更新 [1]。
|
|
||||||
|
|
||||||
### 4.3 训练流程细节
|
|
||||||
1. **前向传播**:
|
|
||||||
* 模型接收上下文和拼音,通过 Transformer 编码得到语境表示。
|
|
||||||
* 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。
|
|
||||||
* 分类头输出当前步所有候选字的概率分布。
|
|
||||||
2. **Teacher Forcing**:
|
|
||||||
* 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。
|
|
||||||
3. **反向传播**:
|
|
||||||
* 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。
|
|
||||||
|
|
||||||
### 4.4 推理与训练的差异
|
|
||||||
* **训练时**:使用 Ground Truth(真实标签)作为槽位输入,确保模型学习到最优的条件概率分布。
|
|
||||||
* **推理时**:由于无法获取真实标签,模型采用**束搜索(Beam Search)** [1]。
|
|
||||||
* **束宽**:默认为 5 [1]。
|
|
||||||
* **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。
|
|
||||||
* **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。
|
|
||||||
|
|
||||||
## 5. Jupyter Lab 训练示例
|
|
||||||
|
|
||||||
以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# %% [markdown]
|
|
||||||
# # 输入法模型训练示例
|
|
||||||
# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 1. 导入必要的库
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# 添加项目路径(适应不同的Jupyter Lab运行位置)
|
|
||||||
project_root = Path.cwd()
|
|
||||||
# 检查当前目录是否包含src目录,如果不包含则使用父目录
|
|
||||||
if not (project_root / "src").exists():
|
|
||||||
project_root = project_root.parent
|
|
||||||
sys.path.insert(0, str(project_root)) # 优先搜索项目目录
|
|
||||||
|
|
||||||
# 导入项目模块
|
|
||||||
from src.model.model import InputMethodEngine
|
|
||||||
from src.model.dataset import PinyinInputDataset
|
|
||||||
from src.model.trainer import Trainer, worker_init_fn, collate_fn
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 2. 配置训练参数
|
|
||||||
config = {
|
|
||||||
# 数据参数
|
|
||||||
"train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径
|
|
||||||
"eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径
|
|
||||||
"output_dir": "./training_output",
|
|
||||||
|
|
||||||
# 模型参数
|
|
||||||
"vocab_size": 10019,
|
|
||||||
"pinyin_vocab_size": 30,
|
|
||||||
"dim": 512,
|
|
||||||
"num_slots": 8,
|
|
||||||
"n_layers": 4,
|
|
||||||
"n_heads": 4,
|
|
||||||
"num_experts": 20,
|
|
||||||
"max_seq_len": 128,
|
|
||||||
|
|
||||||
# 训练参数
|
|
||||||
"batch_size": 64, # 根据GPU内存调整
|
|
||||||
"num_epochs": 10,
|
|
||||||
"learning_rate": 3e-4,
|
|
||||||
"min_learning_rate": 1e-9,
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
"warmup_ratio": 0.1,
|
|
||||||
"label_smoothing": 0.15,
|
|
||||||
"grad_accum_steps": 2, # 梯度累积,模拟更大batch size
|
|
||||||
"clip_grad_norm": 1.0,
|
|
||||||
"eval_frequency": 500, # 每500步评估一次
|
|
||||||
"save_frequency": 2000, # 每2000步保存检查点
|
|
||||||
|
|
||||||
# 高级选项
|
|
||||||
"mixed_precision": True,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"seed": 42,
|
|
||||||
"max_iter_length": 1024 * 1024 * 128, # 最大迭代长度
|
|
||||||
}
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 3. 设置随机种子和设备
|
|
||||||
torch.manual_seed(config["seed"])
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(config["seed"])
|
|
||||||
device = torch.device("cuda")
|
|
||||||
print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)")
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 4. 创建数据集和数据加载器
|
|
||||||
print("📊 创建数据集和数据加载器...")
|
|
||||||
|
|
||||||
# 训练数据集
|
|
||||||
train_dataset = PinyinInputDataset(
|
|
||||||
data_path=config["train_data_path"],
|
|
||||||
max_workers=-1, # 自动选择worker数量
|
|
||||||
max_iter_length=config["max_iter_length"],
|
|
||||||
max_seq_length=config["max_seq_len"],
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=5000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 训练数据加载器
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
prefetch_factor=32,
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估数据集
|
|
||||||
eval_dataset = PinyinInputDataset(
|
|
||||||
data_path=config["eval_data_path"],
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=1024, # 评估集较小
|
|
||||||
max_seq_length=config["max_seq_len"],
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=1000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataloader = DataLoader(
|
|
||||||
eval_dataset,
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
num_workers=1,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
prefetch_factor=32,
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ 数据加载器创建完成")
|
|
||||||
print(f" 训练批次大小: {config['batch_size']}")
|
|
||||||
print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}")
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 5. 创建模型
|
|
||||||
print("🧠 创建输入法模型...")
|
|
||||||
|
|
||||||
model = InputMethodEngine(
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
pinyin_vocab_size=config["pinyin_vocab_size"],
|
|
||||||
dim=config["dim"],
|
|
||||||
num_slots=config["num_slots"],
|
|
||||||
n_layers=config["n_layers"],
|
|
||||||
n_heads=config["n_heads"],
|
|
||||||
num_experts=config["num_experts"],
|
|
||||||
max_seq_len=config["max_seq_len"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将模型移动到设备
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
# 计算参数量
|
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
|
||||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
print(f"✅ 模型创建完成")
|
|
||||||
print(f" 总参数量: {total_params:,}")
|
|
||||||
print(f" 可训练参数量: {trainable_params:,}")
|
|
||||||
print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家")
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 6. 创建训练器
|
|
||||||
print("⚙️ 创建训练器...")
|
|
||||||
|
|
||||||
# 计算总训练步数
|
|
||||||
total_steps = int(config["max_iter_length"] / config["batch_size"])
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
total_steps=total_steps,
|
|
||||||
output_dir=config["output_dir"],
|
|
||||||
num_epochs=config["num_epochs"],
|
|
||||||
learning_rate=config["learning_rate"],
|
|
||||||
min_learning_rate=config["min_learning_rate"],
|
|
||||||
weight_decay=config["weight_decay"],
|
|
||||||
warmup_ratio=config["warmup_ratio"],
|
|
||||||
label_smoothing=config["label_smoothing"],
|
|
||||||
grad_accum_steps=config["grad_accum_steps"],
|
|
||||||
clip_grad_norm=config["clip_grad_norm"],
|
|
||||||
eval_frequency=config["eval_frequency"],
|
|
||||||
save_frequency=config["save_frequency"],
|
|
||||||
mixed_precision=config["mixed_precision"],
|
|
||||||
use_tensorboard=config["use_tensorboard"],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ 训练器创建完成")
|
|
||||||
print(f" 总训练步数: {total_steps:,}")
|
|
||||||
print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}")
|
|
||||||
print(f" 输出目录: {config['output_dir']}")
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 7. 开始训练
|
|
||||||
print("🚀 开始训练...")
|
|
||||||
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 开始训练(可以从检查点恢复训练)
|
|
||||||
trainer.train(resume_from=None) # 设置检查点路径以恢复训练
|
|
||||||
|
|
||||||
print("✅ 训练完成!")
|
|
||||||
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
print(f"模型和日志保存在: {config['output_dir']}")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("⏹️ 训练被用户中断")
|
|
||||||
print("💾 保存当前检查点...")
|
|
||||||
trainer.save_checkpoint("interrupted")
|
|
||||||
print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 训练过程中出现错误: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 8. 监控训练进度(如果使用TensorBoard)
|
|
||||||
if config["use_tensorboard"]:
|
|
||||||
print("📈 TensorBoard日志已记录在:")
|
|
||||||
print(f" {config['output_dir']}/tensorboard")
|
|
||||||
print("\n启动TensorBoard查看训练进度:")
|
|
||||||
print(" tensorboard --logdir ./training_output/tensorboard")
|
|
||||||
print("然后在浏览器中打开: http://localhost:6006")
|
|
||||||
|
|
||||||
# %% [code]
|
|
||||||
# 9. 加载训练好的模型进行推理(示例)
|
|
||||||
def load_trained_model(checkpoint_path):
|
|
||||||
"""加载训练好的模型进行检查点"""
|
|
||||||
print(f"📥 加载检查点: {checkpoint_path}")
|
|
||||||
|
|
||||||
# 创建与训练时相同配置的模型
|
|
||||||
loaded_model = InputMethodEngine(
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
pinyin_vocab_size=config["pinyin_vocab_size"],
|
|
||||||
dim=config["dim"],
|
|
||||||
num_slots=config["num_slots"],
|
|
||||||
n_layers=config["n_layers"],
|
|
||||||
n_heads=config["n_heads"],
|
|
||||||
num_experts=config["num_experts"],
|
|
||||||
max_seq_len=config["max_seq_len"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载检查点
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
||||||
loaded_model.load_state_dict(checkpoint["model_state_dict"])
|
|
||||||
loaded_model.to(device)
|
|
||||||
loaded_model.eval()
|
|
||||||
|
|
||||||
print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}")
|
|
||||||
print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}")
|
|
||||||
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
# 使用示例(取消注释以使用)
|
|
||||||
# trained_model = load_trained_model("./training_output/checkpoint_final.pt")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 关键说明
|
|
||||||
|
|
||||||
1. **环境要求**:
|
|
||||||
- Python 3.12+
|
|
||||||
- PyTorch 2.10+
|
|
||||||
- 建议使用GPU进行训练
|
|
||||||
- 安装项目依赖:`pip install -e .`
|
|
||||||
|
|
||||||
2. **数据集格式**:
|
|
||||||
- 使用Hugging Face `datasets`格式
|
|
||||||
- 必须包含`text`字段
|
|
||||||
- 支持流式读取(streaming=True)
|
|
||||||
|
|
||||||
3. **训练监控**:
|
|
||||||
- 控制台输出训练进度和指标
|
|
||||||
- TensorBoard记录损失、准确率、学习率等
|
|
||||||
- 定期保存模型检查点
|
|
||||||
|
|
||||||
4. **可调整参数**:
|
|
||||||
- `batch_size`: 根据GPU内存调整
|
|
||||||
- `learning_rate`: 建议在1e-4到5e-4之间
|
|
||||||
- `grad_accum_steps`: 模拟更大batch size
|
|
||||||
- `num_epochs`: 根据数据集大小调整
|
|
||||||
|
|
||||||
5. **故障排除**:
|
|
||||||
- GPU内存不足:减小`batch_size`或增加`grad_accum_steps`
|
|
||||||
- 训练不稳定:降低`learning_rate`或增加`warmup_ratio`
|
|
||||||
- 过拟合:增加`label_smoothing`或使用更大数据集
|
|
||||||
|
|
||||||
## 6. 使用指南
|
|
||||||
|
|
||||||
本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。
|
|
||||||
|
|
||||||
### 6.1 安装与准备
|
|
||||||
|
|
||||||
#### 使用 uv(推荐)
|
|
||||||
本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。
|
|
||||||
|
|
||||||
1. **安装 uv**(如果尚未安装):
|
|
||||||
```bash
|
|
||||||
# Linux/macOS
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
|
|
||||||
# 或使用 pipx
|
|
||||||
pipx install uv
|
|
||||||
|
|
||||||
# Windows (PowerShell)
|
|
||||||
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **安装项目依赖**:
|
|
||||||
```bash
|
|
||||||
uv pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 使用传统 pip
|
|
||||||
如果不使用 uv,也可以用标准的 pip 安装:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 创建并激活虚拟环境(推荐)
|
|
||||||
python -m venv .venv
|
|
||||||
source .venv/bin/activate # Linux/macOS
|
|
||||||
# .venv\Scripts\activate # Windows
|
|
||||||
|
|
||||||
# 安装依赖
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 验证安装
|
|
||||||
安装完成后,可通过以下命令验证:
|
|
||||||
```bash
|
|
||||||
train-model --help
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 数据格式
|
|
||||||
|
|
||||||
训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。
|
|
||||||
|
|
||||||
#### 本地数据集示例
|
|
||||||
```python
|
|
||||||
# dataset.py
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"]
|
|
||||||
}
|
|
||||||
dataset = Dataset.from_dict(data)
|
|
||||||
dataset.save_to_disk("./local_dataset")
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 远程数据集示例
|
|
||||||
支持Hugging Face Hub或ModelScope上的数据集:
|
|
||||||
- `huggingface.co/datasets/username/dataset_name`
|
|
||||||
- `modelscope.cn/datasets/username/dataset_name`
|
|
||||||
|
|
||||||
#### 数据格式要求
|
|
||||||
- **必需字段**: `text`(字符串类型,包含中文文本)
|
|
||||||
- **流式读取**: 数据集必须支持 `streaming=True` 参数
|
|
||||||
- **数据量**: 建议至少数百万条文本以获得良好效果
|
|
||||||
|
|
||||||
#### 数据预处理
|
|
||||||
数据集会自动进行以下处理:
|
|
||||||
1. 文本分词和编码
|
|
||||||
2. 拼音转换和编码
|
|
||||||
3. 上下文窗口滑动生成训练样本
|
|
||||||
4. 频率调整(削峰填谷)以平衡高频/低频字词
|
|
||||||
|
|
||||||
### 6.3 基本训练命令
|
|
||||||
|
|
||||||
使用 `train-model train` 命令开始训练:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--output-dir "./output" \
|
|
||||||
--batch-size 128 \
|
|
||||||
--num-epochs 10 \
|
|
||||||
--learning-rate 1e-5
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 检查点恢复训练
|
|
||||||
|
|
||||||
要从检查点恢复训练(保持原有的训练状态):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--resume-from "./output/checkpoints/latest_checkpoint.pt"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 重置训练状态
|
|
||||||
|
|
||||||
如果只想加载模型权重,从头开始训练(学习率、epoch等都重新开始):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--resume-from "./output/checkpoints/best_model.pt" \
|
|
||||||
--reset-training-state
|
|
||||||
```
|
|
||||||
|
|
||||||
这个功能在以下场景非常有用:
|
|
||||||
- 想要用预训练权重初始化模型,但用新的训练计划重新训练
|
|
||||||
- 需要调整学习率策略或训练时长
|
|
||||||
- 在现有模型基础上进行迁移学习
|
|
||||||
|
|
||||||
#### 学习率建议
|
|
||||||
根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围:
|
|
||||||
- **标准范围**: 1e-4 ~ 5e-4
|
|
||||||
- **配合Warmup策略**:在训练初期逐步提高学习率
|
|
||||||
- **余弦退火**:使用最小学习率 1e-9 进行细调
|
|
||||||
|
|
||||||
### 6.4 参数详解
|
|
||||||
|
|
||||||
#### 数据参数
|
|
||||||
- `--train-data-path`, `-t`: 训练数据集路径(必需)
|
|
||||||
- `--eval-data-path`, `-e`: 评估数据集路径(必需)
|
|
||||||
- `--output-dir`, `-o`: 输出目录(默认:`./output`)
|
|
||||||
- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728)
|
|
||||||
|
|
||||||
#### 模型参数
|
|
||||||
- `--vocab-size`: 词汇表大小(默认:10019)
|
|
||||||
- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30)
|
|
||||||
- `--dim`: 模型维度(默认:512)
|
|
||||||
- `--num-slots`: 历史槽位数量(默认:8)
|
|
||||||
- `--n-layers`: Transformer层数(默认:4)
|
|
||||||
- `--n-heads`: 注意力头数(默认:4)
|
|
||||||
- `--num-experts`: MoE专家数量(默认:20)
|
|
||||||
- `--max-seq-len`: 最大序列长度(默认:128)
|
|
||||||
- `--use-pinyin`: 是否使用拼音特征(默认:False)
|
|
||||||
|
|
||||||
#### 训练参数
|
|
||||||
- `--batch-size`, `-b`: 批次大小(默认:128)
|
|
||||||
- `--num-epochs`: 训练轮数(默认:10)
|
|
||||||
- `--learning-rate`, `-lr`: 学习率(默认:1e-5)
|
|
||||||
- `--min-learning-rate`: 最小学习率(默认:1e-9)
|
|
||||||
- `--weight-decay`: 权重衰减(默认:0.1)
|
|
||||||
- `--warmup-ratio`: 热身步数比例(默认:0.1)
|
|
||||||
- `--label-smoothing`: 标签平滑参数(默认:0.15)
|
|
||||||
- `--grad-accum-steps`: 梯度累积步数(默认:1)
|
|
||||||
- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0)
|
|
||||||
- `--eval-frequency`: 评估频率(默认:500步)
|
|
||||||
- `--save-frequency`: 保存频率(默认:10000步)
|
|
||||||
|
|
||||||
#### 高级选项
|
|
||||||
- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用)
|
|
||||||
- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用)
|
|
||||||
- `--resume-from`: 从检查点恢复训练(可选)
|
|
||||||
- `--reset-training-state`: 重置训练状态,只加载模型权重从头开始训练(默认:False)
|
|
||||||
- `--seed`: 随机种子(默认:42)
|
|
||||||
|
|
||||||
### 6.5 监控训练进度
|
|
||||||
|
|
||||||
训练过程中会显示:
|
|
||||||
- 当前训练步数/总步数
|
|
||||||
- 损失值和准确率
|
|
||||||
- 学习率变化
|
|
||||||
- 内存使用情况
|
|
||||||
|
|
||||||
启用TensorBoard后,可使用以下命令查看可视化结果:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
tensorboard --logdir ./output/tensorboard
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.6 基于JSON旁路记录法的移动端监控方案
|
|
||||||
|
|
||||||
为了提供移动端友好的训练监控体验,我们实现了基于JSON旁路记录法的监控方案。该方案在保持TensorBoard记录的同时,额外写入一份JSON状态文件,并通过Streamlit提供移动端友好的Web界面。
|
|
||||||
|
|
||||||
#### 方案特点
|
|
||||||
|
|
||||||
**📱 移动端体验**
|
|
||||||
- Streamlit自动生成响应式界面,完美适配手机屏幕
|
|
||||||
- 图表支持双指缩放和滑动操作
|
|
||||||
- 大字体显示核心指标,触控操作便捷
|
|
||||||
|
|
||||||
**🚀 低耦合架构**
|
|
||||||
- 训练和监控通过文件系统解耦
|
|
||||||
- 监控服务重启不影响训练进程
|
|
||||||
- 训练脚本只需几行代码修改即可支持
|
|
||||||
|
|
||||||
**🔒 安全稳定**
|
|
||||||
- 纯文本JSON文件,无文件锁冲突问题
|
|
||||||
- 读写速度快,稳定性高
|
|
||||||
- 不会影响原有的TensorBoard记录流程
|
|
||||||
|
|
||||||
**📊 实时监控**
|
|
||||||
- 默认每5秒自动刷新数据
|
|
||||||
- 实时显示训练进度和指标趋势
|
|
||||||
- 数据新鲜度状态指示(实时/较新/较旧/陈旧)
|
|
||||||
|
|
||||||
#### 使用方法
|
|
||||||
|
|
||||||
**启动监控服务**
|
|
||||||
```bash
|
|
||||||
# 启动监控服务(默认端口8501)
|
|
||||||
monitor-training monitor
|
|
||||||
|
|
||||||
# 指定状态文件路径和端口
|
|
||||||
monitor-training monitor --status-file ./output/training_status.json --port 8080
|
|
||||||
|
|
||||||
# 不自动打开浏览器
|
|
||||||
monitor-training monitor --no-browser
|
|
||||||
|
|
||||||
# 指定自定义Streamlit脚本
|
|
||||||
monitor-training monitor --streamlit-script ./custom_monitor.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**查看训练状态**
|
|
||||||
```bash
|
|
||||||
# 查看最近10条训练记录
|
|
||||||
monitor-training view
|
|
||||||
|
|
||||||
# 查看最近50条记录(原始JSON格式)
|
|
||||||
monitor-training view --limit 50 --raw
|
|
||||||
|
|
||||||
# 查看指定状态文件
|
|
||||||
monitor-training view /path/to/status.json
|
|
||||||
```
|
|
||||||
|
|
||||||
**检查状态文件**
|
|
||||||
```bash
|
|
||||||
# 检查状态文件状态
|
|
||||||
monitor-training check
|
|
||||||
|
|
||||||
# 检查指定文件
|
|
||||||
monitor-training check ./output/training_status.json
|
|
||||||
```
|
|
||||||
|
|
||||||
**启动HTTP静态文件服务**
|
|
||||||
```bash
|
|
||||||
# 启动HTTP静态文件服务(默认端口8080)
|
|
||||||
monitor-training serve
|
|
||||||
|
|
||||||
# 指定状态文件路径和端口
|
|
||||||
monitor-training serve --status-file ./output/training_status.json --port 8080
|
|
||||||
|
|
||||||
# 禁用CORS支持(默认启用)
|
|
||||||
monitor-training serve --no-cors
|
|
||||||
|
|
||||||
# 指定主机地址
|
|
||||||
monitor-training serve --host 0.0.0.0 --port 8080
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 监控界面功能
|
|
||||||
|
|
||||||
**📊 核心指标看板**
|
|
||||||
- 当前步数、轮次、训练损失、准确率
|
|
||||||
- 评估损失和准确率
|
|
||||||
- 当前学习率、最后更新时间
|
|
||||||
|
|
||||||
**📈 趋势图表**
|
|
||||||
- 损失曲线(训练损失 + 评估损失)
|
|
||||||
- 准确率曲线(训练准确率 + 评估准确率)
|
|
||||||
- 学习率变化图(对数坐标)
|
|
||||||
|
|
||||||
**📋 数据详情**
|
|
||||||
- 完整的训练记录表格
|
|
||||||
- 数据统计信息(总数据点、训练时长、总步数)
|
|
||||||
- 训练进度条
|
|
||||||
|
|
||||||
**⚙️ 配置选项**
|
|
||||||
- 状态文件路径(支持环境变量 `TRAINING_STATUS_FILE`)
|
|
||||||
- 自动刷新间隔(1-30秒可调)
|
|
||||||
- 显示数据点数量(10-1000条可调)
|
|
||||||
|
|
||||||
#### 技术实现
|
|
||||||
|
|
||||||
**训练端改造**
|
|
||||||
- 在 `Trainer.__init__` 中添加 `status_file` 参数
|
|
||||||
- 实现 `_write_training_status()` 方法,在每次评估时写入JSON文件
|
|
||||||
- 支持从现有状态文件恢复,避免数据丢失
|
|
||||||
|
|
||||||
**监控端搭建**
|
|
||||||
- 使用Streamlit构建移动端友好的Web界面
|
|
||||||
- 采用Plotly图表库,支持触控交互
|
|
||||||
- 自动刷新机制,实时更新训练状态
|
|
||||||
|
|
||||||
**命令行工具**
|
|
||||||
- 提供 `monitor`、`view`、`check` 三个子命令
|
|
||||||
- 自动检测Streamlit可用性
|
|
||||||
- 支持环境变量传递
|
|
||||||
|
|
||||||
#### 访问方式
|
|
||||||
|
|
||||||
**本地访问**
|
|
||||||
```bash
|
|
||||||
# 启动监控服务后,通过浏览器访问
|
|
||||||
http://localhost:8501
|
|
||||||
```
|
|
||||||
|
|
||||||
**局域网访问**
|
|
||||||
```bash
|
|
||||||
# 启动服务时指定主机地址
|
|
||||||
monitor-training monitor --host 0.0.0.0 --port 8080
|
|
||||||
|
|
||||||
# 手机浏览器访问(同一局域网)
|
|
||||||
http://192.168.1.100:8080
|
|
||||||
```
|
|
||||||
|
|
||||||
**公网访问**(需端口转发)
|
|
||||||
```bash
|
|
||||||
# 确保服务器防火墙开放对应端口
|
|
||||||
# 通过域名或公网IP访问
|
|
||||||
http://your-server.com:8501
|
|
||||||
```
|
|
||||||
|
|
||||||
**远程HTTP监控**
|
|
||||||
```bash
|
|
||||||
# GPU服务器启动HTTP服务
|
|
||||||
monitor-training serve --port 8080 --host 0.0.0.0
|
|
||||||
|
|
||||||
# 本地运行Streamlit监控,从HTTP URL读取数据
|
|
||||||
monitor-training monitor --host 127.0.0.1 --port 8501
|
|
||||||
|
|
||||||
# 在Streamlit界面输入远程URL:
|
|
||||||
http://<gpu服务器IP>:8080/training_status.json
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 状态文件格式
|
|
||||||
|
|
||||||
状态文件 `training_status.json` 位于训练输出目录,格式如下:
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"step": 100,
|
|
||||||
"epoch": 1,
|
|
||||||
"timestamp": "2024-01-01T12:00:00",
|
|
||||||
"train/loss": 2.345,
|
|
||||||
"train/accuracy": 0.456,
|
|
||||||
"eval/loss": 2.123,
|
|
||||||
"eval/accuracy": 0.512,
|
|
||||||
"train/learning_rate": 0.0001
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### HTTP静态文件服务与远程监控
|
|
||||||
|
|
||||||
针对GPU服务器只支持HTTP协议(不支持WebSockets)的环境,我们提供了HTTP静态文件服务方案,实现远程训练监控。
|
|
||||||
|
|
||||||
**🔧 技术特点**
|
|
||||||
- 纯HTTP协议,无需WebSockets支持
|
|
||||||
- 原子写入机制,避免读取不完整JSON数据
|
|
||||||
- 自动重试和JSON验证,确保数据完整性
|
|
||||||
- CORS支持,方便跨域访问
|
|
||||||
- 轻量级设计,不影响训练性能
|
|
||||||
|
|
||||||
**🚀 工作原理**
|
|
||||||
1. GPU服务器:训练进程通过原子写入机制更新`training_status.json`文件
|
|
||||||
2. GPU服务器:运行`monitor-training serve`提供HTTP静态文件服务
|
|
||||||
3. 本地机器:运行`monitor-training monitor`启动Streamlit监控界面
|
|
||||||
4. 本地机器:在Streamlit界面输入HTTP URL访问远程数据
|
|
||||||
5. Streamlit:通过HTTP轮询获取实时训练数据并展示
|
|
||||||
|
|
||||||
**🛡️ 数据安全**
|
|
||||||
- 原子写入:先写入临时文件,然后原子重命名,避免读取中断
|
|
||||||
- JSON验证:HTTP服务端验证JSON格式后才返回数据
|
|
||||||
- 临时文件处理:智能识别和读取`.tmp`临时文件
|
|
||||||
- 重试机制:JSON解析失败时自动重试读取
|
|
||||||
|
|
||||||
**🌐 网络要求**
|
|
||||||
- GPU服务器:需要开放HTTP端口(默认8080)
|
|
||||||
- 本地机器:需要能访问GPU服务器的HTTP端口
|
|
||||||
- 网络协议:纯HTTP,兼容防火墙和代理
|
|
||||||
|
|
||||||
#### 注意事项
|
|
||||||
1. 首次监控时如果状态文件不存在,会自动创建空文件
|
|
||||||
2. 需要安装 `plotly` 依赖用于图表绘制:`pip install plotly>=5.0.0`
|
|
||||||
3. 从检查点恢复训练时会自动加载已有的状态数据
|
|
||||||
4. 建议将监控服务与训练服务部署在同一服务器,避免网络延迟
|
|
||||||
5. HTTP服务支持原子写入,避免训练进程写入时读取不完整JSON
|
|
||||||
6. 远程监控需要确保GPU服务器防火墙开放对应HTTP端口
|
|
||||||
7. 建议使用`--host 0.0.0.0`参数使HTTP服务可被远程访问
|
|
||||||
|
|
||||||
|
|
||||||
### 6.7 评估模型(开发中)
|
|
||||||
|
|
||||||
当前评估功能尚在开发中:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model evaluate \
|
|
||||||
--checkpoint "./output/checkpoint_final.pt" \
|
|
||||||
--data-path "path/to/eval/dataset" \
|
|
||||||
--batch-size 32
|
|
||||||
```
|
|
||||||
|
|
||||||
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
|
||||||
- 加载训练好的模型检查点
|
|
||||||
- 在评估数据集上计算准确率、困惑度等指标
|
|
||||||
- 生成详细的性能报告
|
|
||||||
|
|
||||||
### 6.8 模型扩容两阶段训练
|
|
||||||
|
|
||||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
|
||||||
|
|
||||||
#### 训练策略
|
|
||||||
|
|
||||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
|
||||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
|
||||||
|
|
||||||
#### 基础用法
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "model:InputMethodEngine" \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-lr 2e-3 \
|
|
||||||
--full-lr 5e-5 \
|
|
||||||
--frozen-patience 8
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 完整参数示例
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--output-dir "./expansion_output" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "custom_model:ExpandedModel" \
|
|
||||||
--vocab-size 10019 \
|
|
||||||
--dim 512 \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-patience 10 \
|
|
||||||
--frozen-lr 1e-3 \
|
|
||||||
--full-lr 1e-4 \
|
|
||||||
--frozen-scheduler cosine \
|
|
||||||
--full-scheduler cosine \
|
|
||||||
--batch-size 128 \
|
|
||||||
--num-epochs 20 \
|
|
||||||
--compile
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 参数详解
|
|
||||||
|
|
||||||
**模型扩容参数**
|
|
||||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
|
||||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
|
||||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
|
||||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
|
||||||
|
|
||||||
**两阶段训练参数**
|
|
||||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
|
||||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
|
||||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
|
||||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
|
|
||||||
**其他参数**
|
|
||||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
|
||||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
|
||||||
|
|
||||||
#### 使用场景
|
|
||||||
|
|
||||||
1. **增加专家数量**(20→40)
|
|
||||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
|
||||||
- 新增参数:新专家网络、gate层
|
|
||||||
|
|
||||||
2. **增加top_k值**(2→3)
|
|
||||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
|
||||||
- 新增参数:无
|
|
||||||
|
|
||||||
3. **修改专家内部结构**(如增加resblocks)
|
|
||||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
|
||||||
- 新增参数:新增的resblocks层
|
|
||||||
|
|
||||||
4. **增加Transformer层数**(4→5)
|
|
||||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
|
||||||
- 新增参数:新增的第5层
|
|
||||||
|
|
||||||
#### 自定义模型类示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
# my_model.py
|
|
||||||
from model.model import InputMethodEngine
|
|
||||||
|
|
||||||
class MyExpandedModel(InputMethodEngine):
|
|
||||||
def __init__(self, num_experts=40, **kwargs):
|
|
||||||
# 调用父类构造函数,覆盖num_experts参数
|
|
||||||
super().__init__(num_experts=num_experts, **kwargs)
|
|
||||||
# 可以在这里添加额外的层或修改现有层
|
|
||||||
|
|
||||||
# 使用命令
|
|
||||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 注意事项
|
|
||||||
|
|
||||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
|
||||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
|
||||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
|
||||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
|
||||||
|
|
||||||
### 6.9 导出模型(开发中)
|
|
||||||
|
|
||||||
当前导出功能尚在开发中:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model export \
|
|
||||||
--checkpoint "./output/checkpoint_final.pt" \
|
|
||||||
--output "./exported_model.onnx"
|
|
||||||
```
|
|
||||||
|
|
||||||
命令将显示"导出功能待实现"的提示信息。该功能计划用于:
|
|
||||||
- 将PyTorch模型转换为ONNX格式
|
|
||||||
- 支持在不同推理引擎上部署
|
|
||||||
- 提供优化后的推理模型
|
|
||||||
|
|
||||||
## 7. 总结
|
|
||||||
本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。
|
|
||||||
233
eval.py
233
eval.py
|
|
@ -274,8 +274,43 @@ class TextEvaluator:
|
||||||
part1 = text[i - 48 : i]
|
part1 = text[i - 48 : i]
|
||||||
|
|
||||||
# part2: 拼音输入(随机长度1-8,高斯分布)
|
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||||||
|
# 当强制指定历史槽位时,需要确保拼音长度足够
|
||||||
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||||
|
min_pinyin_len = (force_slot_count + 1) if force_slot_count is not None else 1
|
||||||
|
min_pinyin_len = max(1, min(8, min_pinyin_len))
|
||||||
|
|
||||||
|
# 检查剩余可用字符数量
|
||||||
|
remaining_chars = len(text) - i
|
||||||
|
max_valid_len = 0
|
||||||
|
for j in range(i, min(i + 8, len(text))):
|
||||||
|
if self.query_engine and self.query_engine.is_chinese_char(text[j]):
|
||||||
|
max_valid_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 调整拼音长度:必须至少有 min_pinyin_len 个有效字符
|
||||||
|
if max_valid_len < min_pinyin_len:
|
||||||
|
# 重新选择位置,确保有足够字符
|
||||||
|
raise ValueError(
|
||||||
|
f"位置 {i} 后只有 {max_valid_len} 个有效汉字,不足以测试历史槽位 {force_slot_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 随机选择拼音长度(至少 min_pinyin_len)
|
||||||
|
if force_slot_count is not None:
|
||||||
|
# 强制模式:随机选择比 min_pinyin_len 大的长度
|
||||||
|
valid_lengths = [
|
||||||
|
l for l in range(min_pinyin_len, min(max_valid_len + 1, 9))
|
||||||
|
]
|
||||||
|
if not valid_lengths:
|
||||||
|
valid_lengths = [min_pinyin_len]
|
||||||
|
# 使用原来的概率分布,但只选择有效长度
|
||||||
|
adjusted_probs = [pinyin_len_probs[l - 1] for l in valid_lengths]
|
||||||
|
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
|
||||||
|
pinyin_len = np.random.choice(valid_lengths, p=adjusted_probs)
|
||||||
|
else:
|
||||||
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||||||
|
pinyin_len = min(pinyin_len, max_valid_len)
|
||||||
|
|
||||||
py_end = min(i + pinyin_len, len(text))
|
py_end = min(i + pinyin_len, len(text))
|
||||||
|
|
||||||
# 获取拼音
|
# 获取拼音
|
||||||
|
|
@ -317,7 +352,7 @@ class TextEvaluator:
|
||||||
string_list.append(text[start_pos:end_pos])
|
string_list.append(text[start_pos:end_pos])
|
||||||
part4 = "|".join(string_list)
|
part4 = "|".join(string_list)
|
||||||
|
|
||||||
# 标签:预测字符的ID
|
# 标签:预测字符的ID(整个拼音序列的所有字符)
|
||||||
labels = []
|
labels = []
|
||||||
try:
|
try:
|
||||||
labels = [
|
labels = [
|
||||||
|
|
@ -328,27 +363,38 @@ class TextEvaluator:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
except (AttributeError, TypeError) as e:
|
except (AttributeError, TypeError) as e:
|
||||||
# 如果查询失败,使用默认ID
|
|
||||||
print(f"⚠️ 获取标签失败: {e}")
|
print(f"⚠️ 获取标签失败: {e}")
|
||||||
labels = [0] * pinyin_len_actual
|
labels = [0] * pinyin_len_actual
|
||||||
|
|
||||||
# 历史槽位:模拟用户逐步确认过程
|
# 历史槽位:与训练数据生成逻辑一致
|
||||||
# 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择
|
# 当 force_slot_count=k 时,历史槽位为 labels[:k],预测目标为 labels[k]
|
||||||
# 注意:评估时不知道正确答案,只能基于模型预测来构建历史
|
|
||||||
history_slot_ids = []
|
history_slot_ids = []
|
||||||
|
predict_idx = 0 # 当前预测的标签索引
|
||||||
|
|
||||||
# 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致)
|
|
||||||
if force_slot_count is not None:
|
if force_slot_count is not None:
|
||||||
force_slot_count = max(0, min(8, force_slot_count))
|
force_slot_count = max(0, min(8, force_slot_count))
|
||||||
# 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0
|
# 检查 labels 长度是否足够
|
||||||
# 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致)
|
if len(labels) <= force_slot_count:
|
||||||
for _ in range(force_slot_count):
|
raise ValueError(
|
||||||
history_slot_ids.append(random.randint(1, 1000))
|
f"标签数量 {len(labels)} 不足以支持历史槽位 {force_slot_count}"
|
||||||
|
)
|
||||||
|
predict_idx = force_slot_count
|
||||||
|
history_slot_ids = labels[:force_slot_count]
|
||||||
else:
|
else:
|
||||||
# 正常评估:历史槽位为空(模拟从零开始输入)
|
# 正常评估:随机选择历史长度,模拟训练数据分布
|
||||||
# 注意:实际使用时,历史槽位应该来自之前的用户选择
|
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||||
# 但为了评估公平,我们从空历史开始
|
max_history = min(len(labels) - 1, 8) # 至少留一个字符预测
|
||||||
pass
|
if max_history < 0:
|
||||||
|
max_history = 0
|
||||||
|
valid_history_lens = list(range(0, max_history + 1))
|
||||||
|
if valid_history_lens:
|
||||||
|
adjusted_probs = [history_weights[h] for h in valid_history_lens]
|
||||||
|
adjusted_probs = np.array(adjusted_probs) / sum(adjusted_probs)
|
||||||
|
history_len = np.random.choice(valid_history_lens, p=adjusted_probs)
|
||||||
|
else:
|
||||||
|
history_len = 0
|
||||||
|
predict_idx = history_len
|
||||||
|
history_slot_ids = labels[:history_len]
|
||||||
|
|
||||||
# 填充到8个槽位
|
# 填充到8个槽位
|
||||||
if len(history_slot_ids) < 8:
|
if len(history_slot_ids) < 8:
|
||||||
|
|
@ -373,6 +419,12 @@ class TextEvaluator:
|
||||||
# 拼音输入长度(字符数)
|
# 拼音输入长度(字符数)
|
||||||
pinyin_input_length = len(part2)
|
pinyin_input_length = len(part2)
|
||||||
|
|
||||||
|
# 当前预测的字符和标签
|
||||||
|
current_target_char = (
|
||||||
|
text[i + predict_idx] if (i + predict_idx) < len(text) else ""
|
||||||
|
)
|
||||||
|
current_target_label = labels[predict_idx] if predict_idx < len(labels) else 0
|
||||||
|
|
||||||
# 构建样本
|
# 构建样本
|
||||||
sample = {
|
sample = {
|
||||||
"input_ids": encoded["input_ids"],
|
"input_ids": encoded["input_ids"],
|
||||||
|
|
@ -385,11 +437,13 @@ class TextEvaluator:
|
||||||
"suffix": part3,
|
"suffix": part3,
|
||||||
"pinyin": part2,
|
"pinyin": part2,
|
||||||
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
|
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
|
||||||
"true_labels": labels, # 真实标签(多个字符)
|
"true_labels": labels, # 完整标签列表(所有字符)
|
||||||
|
"predict_idx": predict_idx, # 当前预测的标签索引
|
||||||
|
"current_target_label": current_target_label, # 当前预测的标签
|
||||||
"position": i,
|
"position": i,
|
||||||
"target_char": text[i] if i < len(text) else "",
|
"target_char": current_target_char, # 当前预测的字符
|
||||||
"valid_slot_count": valid_slot_count, # 有效槽位数量
|
"valid_slot_count": valid_slot_count,
|
||||||
"pinyin_input_length": pinyin_input_length, # 拼音输入长度
|
"pinyin_input_length": pinyin_input_length,
|
||||||
}
|
}
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
@ -483,56 +537,50 @@ class TextEvaluator:
|
||||||
|
|
||||||
return analysis
|
return analysis
|
||||||
|
|
||||||
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
|
def evaluate_sample_single_char(
|
||||||
|
self, sample: Dict, print_details: bool = True
|
||||||
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
评估单个样本。
|
单字符评估(与验证集 trainer.py 评估逻辑一致)。
|
||||||
|
|
||||||
|
使用 argmax 选择预测字符,与训练时验证逻辑完全一致。
|
||||||
|
根据 predict_idx 选择正确的预测目标。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
评估结果字典
|
评估结果字典
|
||||||
"""
|
"""
|
||||||
# 执行推理
|
|
||||||
logits, probs = self.inference(sample)
|
logits, probs = self.inference(sample)
|
||||||
|
|
||||||
# 分析概率分布
|
|
||||||
analysis = self.analyze_probability_distribution(probs)
|
analysis = self.analyze_probability_distribution(probs)
|
||||||
|
|
||||||
# 获取真实标签
|
|
||||||
true_labels = sample.get("true_labels", [])
|
true_labels = sample.get("true_labels", [])
|
||||||
|
predict_idx = sample.get("predict_idx", 0)
|
||||||
|
current_target_label = sample.get(
|
||||||
|
"current_target_label", true_labels[predict_idx] if true_labels else 0
|
||||||
|
)
|
||||||
target_char = sample.get("target_char", "")
|
target_char = sample.get("target_char", "")
|
||||||
|
|
||||||
# 检查预测是否正确(只检查第一个字符)
|
pred_idx = analysis["top_indices"][0]
|
||||||
correct = False
|
correct = pred_idx == current_target_label
|
||||||
pred_top_idx = analysis["top_indices"][0]
|
|
||||||
if true_labels and len(true_labels) > 0:
|
|
||||||
correct = pred_top_idx == true_labels[0]
|
|
||||||
|
|
||||||
# 打印结果
|
|
||||||
if print_details:
|
if print_details:
|
||||||
# 获取槽位信息
|
|
||||||
valid_slot_count = sample.get("valid_slot_count", 0)
|
valid_slot_count = sample.get("valid_slot_count", 0)
|
||||||
pinyin_input_length = sample.get("pinyin_input_length", 0)
|
pinyin_input_length = sample.get("pinyin_input_length", 0)
|
||||||
|
|
||||||
# 简化输出格式(在同一行继续输出)
|
|
||||||
print(f" {target_char}", end="")
|
print(f" {target_char}", end="")
|
||||||
if true_labels:
|
print(f"(ID:{current_target_label})", end="")
|
||||||
print(f"(ID:{true_labels[0]})", end="")
|
|
||||||
print(
|
print(
|
||||||
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
|
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
|
||||||
end="",
|
end="",
|
||||||
)
|
)
|
||||||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
|
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
|
||||||
|
|
||||||
# 关键概率信息
|
|
||||||
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
|
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
|
||||||
if analysis["low_variance"]:
|
if analysis["low_variance"]:
|
||||||
print(f" ⚠️", end="")
|
print(f" ⚠️", end="")
|
||||||
|
|
||||||
# Top-2 预测
|
|
||||||
top_strs = []
|
top_strs = []
|
||||||
for j in range(min(2, len(analysis["top_indices"]))):
|
for j in range(min(2, len(analysis["top_indices"]))):
|
||||||
idx = analysis["top_indices"][j]
|
idx = analysis["top_indices"][j]
|
||||||
prob = analysis["top_probs"][j]
|
prob = analysis["top_probs"][j]
|
||||||
# 处理特殊ID
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
char = "[空]"
|
char = "[空]"
|
||||||
else:
|
else:
|
||||||
|
|
@ -542,34 +590,43 @@ class TextEvaluator:
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
char = char_info.char if char_info else f"[{idx}]"
|
char = char_info.char if char_info else f"[{idx}]"
|
||||||
is_correct = (
|
is_correct = idx == current_target_label
|
||||||
(true_labels and idx == true_labels[0]) if true_labels else False
|
|
||||||
)
|
|
||||||
correct_mark = "✓" if is_correct else ""
|
correct_mark = "✓" if is_correct else ""
|
||||||
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
|
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
|
||||||
print(f" | Top-2: {' '.join(top_strs)}", end="")
|
print(f" | Top-2: {' '.join(top_strs)}", end="")
|
||||||
|
|
||||||
print(f" | 结果:{'✓' if correct else '✗'}")
|
print(f" | 结果:{'✓' if correct else '✗'}")
|
||||||
|
|
||||||
# 返回评估结果
|
return {
|
||||||
result = {
|
|
||||||
"sample": sample,
|
"sample": sample,
|
||||||
"analysis": analysis,
|
"analysis": analysis,
|
||||||
"correct": correct,
|
"correct": correct,
|
||||||
"target_char": target_char,
|
"target_char": target_char,
|
||||||
"true_label": true_labels[0] if true_labels else None,
|
"true_label": current_target_label,
|
||||||
"predicted_label": pred_top_idx,
|
"predicted_label": pred_idx,
|
||||||
|
"predict_idx": predict_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
|
||||||
|
"""
|
||||||
|
评估单个样本(兼容旧接口,调用单字符评估)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估结果字典
|
||||||
|
"""
|
||||||
|
return self.evaluate_sample_single_char(sample, print_details)
|
||||||
|
|
||||||
def evaluate_sample_with_sequential_confirmation(
|
def evaluate_sample_with_sequential_confirmation(
|
||||||
self, sample: Dict, print_details: bool = True
|
self, sample: Dict, print_details: bool = True, use_oracle_history: bool = True
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
评估样本并模拟用户逐步确认过程(用于多字符预测)。
|
评估样本并模拟用户逐步确认过程(用于多字符预测)。
|
||||||
|
|
||||||
对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。
|
Args:
|
||||||
|
sample: 样本字典
|
||||||
|
print_details: 是否打印详细信息
|
||||||
|
use_oracle_history: True=使用真实标签作为历史(与训练/验证一致),False=使用预测结果
|
||||||
|
|
||||||
|
对于拼音长度>1的情况,模拟用户逐步选择过程。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
评估结果字典
|
评估结果字典
|
||||||
|
|
@ -583,7 +640,7 @@ class TextEvaluator:
|
||||||
return self.evaluate_sample(sample, print_details)
|
return self.evaluate_sample(sample, print_details)
|
||||||
|
|
||||||
# 多字符预测:模拟逐步确认
|
# 多字符预测:模拟逐步确认
|
||||||
confirmed_history = [] # 用户已确认的字符ID
|
confirmed_history = [] # 已确认的字符ID
|
||||||
all_predictions = [] # 每一步的预测结果
|
all_predictions = [] # 每一步的预测结果
|
||||||
all_correct = [] # 每一步是否正确
|
all_correct = [] # 每一步是否正确
|
||||||
|
|
||||||
|
|
@ -591,8 +648,16 @@ class TextEvaluator:
|
||||||
sample_copy = sample.copy()
|
sample_copy = sample.copy()
|
||||||
|
|
||||||
for step in range(pinyin_len):
|
for step in range(pinyin_len):
|
||||||
# 更新历史槽位:使用已确认的字符
|
# 更新历史槽位:
|
||||||
|
# - oracle模式:使用真实标签作为历史(与训练一致)
|
||||||
|
# - user模式:使用模型预测作为历史(模拟实际使用)
|
||||||
|
if use_oracle_history:
|
||||||
|
# 使用真实标签作为历史(与训练数据生成逻辑一致)
|
||||||
|
history_slot_ids = true_labels[:step]
|
||||||
|
else:
|
||||||
|
# 使用预测结果作为历史
|
||||||
history_slot_ids = confirmed_history[:]
|
history_slot_ids = confirmed_history[:]
|
||||||
|
|
||||||
if len(history_slot_ids) < 8:
|
if len(history_slot_ids) < 8:
|
||||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||||
else:
|
else:
|
||||||
|
|
@ -626,9 +691,8 @@ class TextEvaluator:
|
||||||
)
|
)
|
||||||
all_correct.append(correct)
|
all_correct.append(correct)
|
||||||
|
|
||||||
# 模拟用户选择:将Top-1预测添加到已确认历史
|
# 更新已确认历史(用于下一轮)
|
||||||
# 注意:这里假设用户总是选择Top-1结果
|
if pred_top_idx > 0:
|
||||||
if pred_top_idx > 0: # 只添加有效ID
|
|
||||||
confirmed_history.append(pred_top_idx)
|
confirmed_history.append(pred_top_idx)
|
||||||
|
|
||||||
# 计算整体准确率:所有步骤都正确才算正确
|
# 计算整体准确率:所有步骤都正确才算正确
|
||||||
|
|
@ -643,6 +707,8 @@ class TextEvaluator:
|
||||||
print(f" {target_char}", end="")
|
print(f" {target_char}", end="")
|
||||||
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
|
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
|
||||||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
|
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
|
||||||
|
mode_str = "oracle" if use_oracle_history else "user"
|
||||||
|
print(f" | [{mode_str}]", end="")
|
||||||
|
|
||||||
# 显示逐步预测结果
|
# 显示逐步预测结果
|
||||||
step_results = []
|
step_results = []
|
||||||
|
|
@ -672,13 +738,16 @@ class TextEvaluator:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
def evaluate_text(
|
||||||
|
self, text: str, num_samples: int = 10, use_sequential: bool = False
|
||||||
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
评估给定文本,生成多个样本进行评估。
|
评估给定文本,生成多个样本进行评估。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
num_samples: 生成的样本数量
|
num_samples: 生成的样本数量
|
||||||
|
use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
评估结果列表
|
评估结果列表
|
||||||
|
|
@ -699,8 +768,13 @@ class TextEvaluator:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sample = self.create_sample_from_text(text)
|
sample = self.create_sample_from_text(text)
|
||||||
# 使用逐步确认评估
|
# 默认使用单字符评估(与验证集一致)
|
||||||
|
if use_sequential:
|
||||||
result = self.evaluate_sample_with_sequential_confirmation(
|
result = self.evaluate_sample_with_sequential_confirmation(
|
||||||
|
sample, print_details=True, use_oracle_history=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.evaluate_sample_single_char(
|
||||||
sample, print_details=True
|
sample, print_details=True
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
@ -718,7 +792,7 @@ class TextEvaluator:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def evaluate_by_slot_count(
|
def evaluate_by_slot_count(
|
||||||
self, text: str, samples_per_slot: int = 100
|
self, text: str, samples_per_slot: int = 100, use_sequential: bool = False
|
||||||
) -> Dict[int, Dict]:
|
) -> Dict[int, Dict]:
|
||||||
"""
|
"""
|
||||||
按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。
|
按槽位数量分别评估,每个槽位数量测试 samples_per_slot 个样本。
|
||||||
|
|
@ -726,6 +800,7 @@ class TextEvaluator:
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
samples_per_slot: 每个槽位数量的测试样本数
|
samples_per_slot: 每个槽位数量的测试样本数
|
||||||
|
use_sequential: True=使用逐步确认评估,False=使用单字符评估(与验证集一致)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
|
按槽位数量分组的评估结果 {slot_count: {results, accuracy, ...}}
|
||||||
|
|
@ -735,7 +810,8 @@ class TextEvaluator:
|
||||||
text = text[start : start + 300]
|
text = text[start : start + 300]
|
||||||
|
|
||||||
print(f"\n{'=' * 60}")
|
print(f"\n{'=' * 60}")
|
||||||
print(f"按槽位数量评估 | 每组 {samples_per_slot} 个样本")
|
mode_str = "逐步确认" if use_sequential else "单字符(与验证集一致)"
|
||||||
|
print(f"按槽位数量评估 [{mode_str}] | 每组 {samples_per_slot} 个样本")
|
||||||
print(f"{'=' * 60}")
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
slot_results = {}
|
slot_results = {}
|
||||||
|
|
@ -750,12 +826,19 @@ class TextEvaluator:
|
||||||
sample = self.create_sample_from_text(
|
sample = self.create_sample_from_text(
|
||||||
text, force_slot_count=slot_count
|
text, force_slot_count=slot_count
|
||||||
)
|
)
|
||||||
# 使用逐步确认评估(更符合实际使用场景)
|
# 默认使用单字符评估(与验证集一致)
|
||||||
|
if use_sequential:
|
||||||
result = self.evaluate_sample_with_sequential_confirmation(
|
result = self.evaluate_sample_with_sequential_confirmation(
|
||||||
|
sample, print_details=False, use_oracle_history=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.evaluate_sample_single_char(
|
||||||
sample, print_details=False
|
sample, print_details=False
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
# 注意:逐步确认评估没有low_variance分析
|
# 统计低方差样本
|
||||||
|
if "analysis" in result and result["analysis"].get("low_variance"):
|
||||||
|
low_var_count += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
failed += 1
|
failed += 1
|
||||||
continue
|
continue
|
||||||
|
|
@ -764,8 +847,9 @@ class TextEvaluator:
|
||||||
correct_count = sum(1 for r in results if r["correct"])
|
correct_count = sum(1 for r in results if r["correct"])
|
||||||
accuracy = correct_count / len(results)
|
accuracy = correct_count / len(results)
|
||||||
|
|
||||||
# 对于逐步确认评估,计算平均步数准确率
|
# 计算平均步数准确率(仅用于逐步确认模式)
|
||||||
step_accuracies = []
|
step_accuracies = []
|
||||||
|
if use_sequential:
|
||||||
for r in results:
|
for r in results:
|
||||||
if "step_correct" in r:
|
if "step_correct" in r:
|
||||||
step_correct = r["step_correct"]
|
step_correct = r["step_correct"]
|
||||||
|
|
@ -781,17 +865,25 @@ class TextEvaluator:
|
||||||
"correct": correct_count,
|
"correct": correct_count,
|
||||||
"total": len(results),
|
"total": len(results),
|
||||||
"mean_step_accuracy": mean_step_accuracy,
|
"mean_step_accuracy": mean_step_accuracy,
|
||||||
|
"low_var_count": low_var_count,
|
||||||
"failed": failed,
|
"failed": failed,
|
||||||
}
|
}
|
||||||
|
|
||||||
bar_len = 30
|
bar_len = 30
|
||||||
filled = int(bar_len * accuracy)
|
filled = int(bar_len * accuracy)
|
||||||
bar = "█" * filled + "░" * (bar_len - filled)
|
bar = "█" * filled + "░" * (bar_len - filled)
|
||||||
|
if use_sequential:
|
||||||
print(
|
print(
|
||||||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||||
f"({correct_count:2d}/{len(results):2d}) "
|
f"({correct_count:2d}/{len(results):2d}) "
|
||||||
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
|
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||||
|
f"({correct_count:2d}/{len(results):2d}) "
|
||||||
|
f"| 低方差: {low_var_count}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
slot_results[slot_count] = {
|
slot_results[slot_count] = {
|
||||||
"results": [],
|
"results": [],
|
||||||
|
|
@ -799,6 +891,7 @@ class TextEvaluator:
|
||||||
"correct": 0,
|
"correct": 0,
|
||||||
"total": 0,
|
"total": 0,
|
||||||
"mean_step_accuracy": 0,
|
"mean_step_accuracy": 0,
|
||||||
|
"low_var_count": 0,
|
||||||
"failed": failed,
|
"failed": failed,
|
||||||
}
|
}
|
||||||
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||||||
|
|
@ -851,11 +944,11 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
default="/home/songsenand/下载/best_model.pt",
|
default="/home/songsenand/下载/best_model.ptrom",
|
||||||
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.ptrom)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 20)"
|
"--num-samples", type=int, default=100, help="评估样本数量 (默认: 100)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
|
|
@ -864,6 +957,11 @@ def main():
|
||||||
choices=["auto", "cpu", "cuda"],
|
choices=["auto", "cpu", "cuda"],
|
||||||
help="推理设备 (默认: auto)",
|
help="推理设备 (默认: auto)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-sequential",
|
||||||
|
action="store_true",
|
||||||
|
help="使用逐步确认评估模式(默认使用单字符评估,与验证集一致)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -890,6 +988,9 @@ def main():
|
||||||
print(f"🔧 使用设备: {device}")
|
print(f"🔧 使用设备: {device}")
|
||||||
print(f"📦 模型checkpoint: {args.checkpoint}")
|
print(f"📦 模型checkpoint: {args.checkpoint}")
|
||||||
print(f"🔬 评估样本数: {args.num_samples}")
|
print(f"🔬 评估样本数: {args.num_samples}")
|
||||||
|
print(
|
||||||
|
f"📊 评估模式: {'逐步确认' if args.use_sequential else '单字符(与验证集一致)'}"
|
||||||
|
)
|
||||||
|
|
||||||
# 初始化评估器
|
# 初始化评估器
|
||||||
try:
|
try:
|
||||||
|
|
@ -902,7 +1003,9 @@ def main():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 执行评估
|
# 执行评估
|
||||||
evaluator.evaluate_by_slot_count(text, samples_per_slot=args.num_samples)
|
evaluator.evaluate_by_slot_count(
|
||||||
|
text, samples_per_slot=args.num_samples, use_sequential=args.use_sequential
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,11 @@ class AttentionPooling(nn.Module):
|
||||||
|
|
||||||
# ---------------------------- 拼音LSTM编码器 ----------------------------
|
# ---------------------------- 拼音LSTM编码器 ----------------------------
|
||||||
class PinyinLSTMEncoder(nn.Module):
|
class PinyinLSTMEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
拼音序列编码器,返回每位置的拼音编码。
|
||||||
|
使用双向LSTM,每个位置都能看到前后文信息。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
|
|
@ -35,7 +40,6 @@ class PinyinLSTMEncoder(nn.Module):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
# Bidirectional LSTM
|
|
||||||
self.lstm = nn.LSTM(
|
self.lstm = nn.LSTM(
|
||||||
input_size=input_dim,
|
input_size=input_dim,
|
||||||
hidden_size=self.hidden_dim,
|
hidden_size=self.hidden_dim,
|
||||||
|
|
@ -45,7 +49,6 @@ class PinyinLSTMEncoder(nn.Module):
|
||||||
dropout=dropout if num_layers > 1 else 0.0,
|
dropout=dropout if num_layers > 1 else 0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Project concatenated hidden states to input_dim
|
|
||||||
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||||||
self.layer_norm = nn.LayerNorm(input_dim)
|
self.layer_norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
|
@ -53,32 +56,25 @@ class PinyinLSTMEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: [batch, seq_len, input_dim] pinyin embeddings
|
x: [batch, seq_len, input_dim] pinyin embeddings
|
||||||
mask: [batch, seq_len] optional padding mask (0 for padding)
|
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
|
||||||
Returns:
|
Returns:
|
||||||
pooled: [batch, input_dim] global pinyin representation
|
output: [batch, seq_len, input_dim] 每位置的拼音编码
|
||||||
"""
|
"""
|
||||||
|
total_len = x.size(1)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# lengths for pack_padded_sequence
|
lengths = mask.sum(dim=1).cpu().clamp(min=1)
|
||||||
lengths = mask.sum(dim=1).cpu()
|
|
||||||
# pack sequence
|
|
||||||
packed = nn.utils.rnn.pack_padded_sequence(
|
packed = nn.utils.rnn.pack_padded_sequence(
|
||||||
x, lengths, batch_first=True, enforce_sorted=False
|
x, lengths, batch_first=True, enforce_sorted=False
|
||||||
)
|
)
|
||||||
packed_out, (hidden, cell) = self.lstm(packed)
|
packed_out, (hidden, cell) = self.lstm(packed)
|
||||||
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
output, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
# Take last layer's forward and backward hidden states
|
packed_out, batch_first=True, total_length=total_len
|
||||||
forward_hidden = hidden[-2, :, :] # last layer forward
|
)
|
||||||
backward_hidden = hidden[-1, :, :] # last layer backward
|
|
||||||
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
|
||||||
else:
|
else:
|
||||||
# No mask, assume all sequences same length
|
|
||||||
output, (hidden, cell) = self.lstm(x)
|
output, (hidden, cell) = self.lstm(x)
|
||||||
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
|
||||||
forward_hidden = hidden[-2, :, :]
|
|
||||||
backward_hidden = hidden[-1, :, :]
|
|
||||||
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
|
||||||
|
|
||||||
projected = self.proj(hidden_concat)
|
projected = self.proj(output)
|
||||||
return self.layer_norm(projected)
|
return self.layer_norm(projected)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -170,46 +166,32 @@ class ContextEncoder(nn.Module):
|
||||||
def forward(self, text_ids, pinyin_ids, mask=None):
|
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
text_ids: [batch, seq_len]
|
text_ids: [batch, seq_len] 文本 token ids
|
||||||
pinyin_ids: [batch, seq_len] (假设已对齐,若不对齐需预处理)
|
pinyin_ids: [batch, pinyin_len] 拼音 token ids
|
||||||
mask: [batch, seq_len] optional padding mask
|
mask: [batch, seq_len] optional padding mask (1 for valid, 0 for padding)
|
||||||
Returns:
|
Returns:
|
||||||
H: [batch, seq_len, 512] Context representation [1]
|
H: [batch, seq_len, dim] 文本上下文编码
|
||||||
|
P: [batch, pinyin_len, dim] 拼音序列编码(每位置)
|
||||||
"""
|
"""
|
||||||
# 1. Embed text
|
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
|
||||||
text_emb = self.text_emb(text_ids) # [B, 128, dim]
|
|
||||||
|
|
||||||
# 2. Embed and pool pinyin to global feature
|
seq_len = text_emb.size(1)
|
||||||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
pos_ids = torch.arange(seq_len, device=text_ids.device).unsqueeze(0)
|
||||||
# LSTM encoder with masking for padding
|
x = text_emb + self.pos_emb(pos_ids)
|
||||||
pinyin_mask = pinyin_ids != 0
|
|
||||||
pinyin_global = self.pinyin_pooling(
|
|
||||||
pinyin_emb, mask=pinyin_mask
|
|
||||||
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
|
||||||
|
|
||||||
# Broadcast pinyin to all text positions
|
|
||||||
pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim]
|
|
||||||
pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim]
|
|
||||||
|
|
||||||
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
|
||||||
x = text_emb + pinyin_broadcast
|
|
||||||
|
|
||||||
seq_len = x.size(1)
|
|
||||||
pos_ids = (
|
|
||||||
torch.arange(seq_len, device=x.device).unsqueeze(0).expand_as(text_ids)
|
|
||||||
)
|
|
||||||
x += self.pos_emb(pos_ids)
|
|
||||||
|
|
||||||
# 2. Transformer Encoding
|
|
||||||
# src_key_padding_mask expects True for padding positions
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# Convert 0/1 mask to bool mask where True is padding
|
|
||||||
src_mask = mask == 0
|
src_mask = mask == 0
|
||||||
else:
|
else:
|
||||||
src_mask = None
|
src_mask = None
|
||||||
|
|
||||||
H = self.transformer(x, src_key_padding_mask=src_mask)
|
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||||
return self.ln(H)
|
H = self.ln(H)
|
||||||
|
|
||||||
|
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||||||
|
pinyin_mask = pinyin_ids != 0
|
||||||
|
P = self.pinyin_pooling(pinyin_emb, mask=pinyin_mask) # [B, pinyin_len, dim]
|
||||||
|
|
||||||
|
return H, P
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -256,76 +238,100 @@ class SlotMemory(nn.Module):
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
# 3. 交叉注意力融合 (Cross-Attention Fusion)
|
||||||
# 对应 README: Query=Slots, Key/Value=Context H [1]
|
# 槽位同时查询文本上下文和拼音序列
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
class CrossAttentionFusion(nn.Module):
|
class CrossAttentionFusion(nn.Module):
|
||||||
|
"""
|
||||||
|
双路交叉注意力:
|
||||||
|
- 槽位查询文本上下文 (H)
|
||||||
|
- 槽位查询拼音序列 (P)
|
||||||
|
通过注意力机制,start_emb 自动学会关注"当前应该预测的拼音位置"
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dim=512, n_heads=4):
|
def __init__(self, dim=512, n_heads=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
|
assert self.head_dim * n_heads == dim
|
||||||
|
|
||||||
# Linear projections for Q, K, V
|
|
||||||
self.q_proj = nn.Linear(dim, dim, bias=False)
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||||||
self.k_proj = nn.Linear(dim, dim, bias=False)
|
self.k_text_proj = nn.Linear(dim, dim, bias=False)
|
||||||
self.v_proj = nn.Linear(dim, dim, bias=False)
|
self.v_text_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.k_pinyin_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.v_pinyin_proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
self.out_proj = nn.Linear(dim, dim, bias=False)
|
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||||||
self.ln = nn.LayerNorm(dim)
|
self.ln = nn.LayerNorm(dim)
|
||||||
|
|
||||||
def forward(self, slots_S, context_H, slot_mask=None, context_mask=None):
|
def forward(
|
||||||
|
self, slots_S, context_H, pinyin_P, context_mask=None, pinyin_mask=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
slots_S: [batch, num_slots_steps, dim] Query
|
slots_S: [batch, num_slots, dim] 槽位编码
|
||||||
context_H: [batch, ctx_len, dim] Key/Value
|
context_H: [batch, ctx_len, dim] 文本上下文编码
|
||||||
slot_mask: [batch, num_slots_steps] Optional (not used in scaled_dot_product_attention)
|
pinyin_P: [batch, pinyin_len, dim] 拼音序列编码
|
||||||
context_mask: [batch, ctx_len] Optional padding mask
|
context_mask: [batch, ctx_len] 文本 padding mask (True for padding)
|
||||||
|
pinyin_mask: [batch, pinyin_len] 拼音 padding mask (True for padding)
|
||||||
Returns:
|
Returns:
|
||||||
Fused: [batch, num_slots_steps, dim]
|
fused: [batch, num_slots, dim]
|
||||||
"""
|
"""
|
||||||
batch_size, num_slots, _ = slots_S.shape
|
batch_size, num_slots, _ = slots_S.shape
|
||||||
_, ctx_len, _ = context_H.shape
|
ctx_len = context_H.size(1)
|
||||||
|
pinyin_len = pinyin_P.size(1)
|
||||||
|
|
||||||
# Project queries, keys, values
|
Q = self.q_proj(slots_S)
|
||||||
Q = self.q_proj(slots_S) # [batch, num_slots, dim]
|
K_text = self.k_text_proj(context_H)
|
||||||
K = self.k_proj(context_H) # [batch, ctx_len, dim]
|
V_text = self.v_text_proj(context_H)
|
||||||
V = self.v_proj(context_H) # [batch, ctx_len, dim]
|
K_pinyin = self.k_pinyin_proj(pinyin_P)
|
||||||
|
V_pinyin = self.v_pinyin_proj(pinyin_P)
|
||||||
|
|
||||||
# Reshape for multi-head attention: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
|
|
||||||
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
Q = Q.view(batch_size, num_slots, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
K = K.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
K_text = K_text.view(
|
||||||
V = V.view(batch_size, ctx_len, self.n_heads, self.head_dim).transpose(1, 2)
|
batch_size, ctx_len, self.n_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
V_text = V_text.view(
|
||||||
|
batch_size, ctx_len, self.n_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
K_pinyin = K_pinyin.view(
|
||||||
|
batch_size, pinyin_len, self.n_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
V_pinyin = V_pinyin.view(
|
||||||
|
batch_size, pinyin_len, self.n_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
# Prepare attention mask if context_mask is provided
|
text_attn_mask = None
|
||||||
attn_mask = None
|
|
||||||
if context_mask is not None:
|
if context_mask is not None:
|
||||||
# context_mask: [batch, ctx_len] where 0 means padding
|
text_attn_mask = (
|
||||||
# Convert to bool mask and reshape for broadcasting
|
context_mask[:, None, None, :]
|
||||||
bool_mask = context_mask == 0 # [batch, ctx_len]
|
.float()
|
||||||
bool_mask = bool_mask[:, None, None, :] # [batch, 1, 1, ctx_len]
|
.masked_fill(context_mask[:, None, None, :], -1e9)
|
||||||
# Convert to float mask where True (padding) becomes -inf
|
|
||||||
attn_mask = bool_mask.float().masked_fill(bool_mask, -1e9)
|
|
||||||
|
|
||||||
# Scaled dot-product attention
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
dropout_p=0.0, # no dropout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape back: [batch, n_heads, num_slots, head_dim] -> [batch, num_slots, dim]
|
pinyin_attn_mask = None
|
||||||
attn_output = (
|
if pinyin_mask is not None:
|
||||||
attn_output.transpose(1, 2)
|
pinyin_attn_mask = (
|
||||||
|
pinyin_mask[:, None, None, :]
|
||||||
|
.float()
|
||||||
|
.masked_fill(pinyin_mask[:, None, None, :], -1e9)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_attn = F.scaled_dot_product_attention(
|
||||||
|
Q, K_text, V_text, attn_mask=text_attn_mask
|
||||||
|
)
|
||||||
|
pinyin_attn = F.scaled_dot_product_attention(
|
||||||
|
Q, K_pinyin, V_pinyin, attn_mask=pinyin_attn_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_attn = text_attn + pinyin_attn
|
||||||
|
|
||||||
|
combined_attn = (
|
||||||
|
combined_attn.transpose(1, 2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
.view(batch_size, num_slots, self.dim)
|
.view(batch_size, num_slots, self.dim)
|
||||||
)
|
)
|
||||||
|
fused = self.out_proj(combined_attn)
|
||||||
# Project back
|
|
||||||
fused = self.out_proj(attn_output)
|
|
||||||
|
|
||||||
# Residual connection and layer norm
|
|
||||||
fused = self.ln(fused + slots_S)
|
fused = self.ln(fused + slots_S)
|
||||||
|
|
||||||
return fused
|
return fused
|
||||||
|
|
|
||||||
|
|
@ -113,31 +113,24 @@ class InputMethodEngine(nn.Module):
|
||||||
"""
|
"""
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
|
|
||||||
# 处理 history_slot_ids:确保为 [batch_size, num_slots]
|
|
||||||
# 使用 view 替代 if 判断,避免 torch.compile 图断开
|
|
||||||
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
||||||
|
|
||||||
# 1. 上下文编码 -> H [batch, seq_len, dim]
|
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||||
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
|
||||||
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
|
||||||
|
|
||||||
# 2. 槽位记忆编码 -> S [batch, num_slots, dim]
|
S = self.slot_memory(history_slot_ids)
|
||||||
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
|
|
||||||
|
|
||||||
# 3. 交叉注意力融合 (使用 CrossAttentionFusion)
|
context_mask = attention_mask == 0
|
||||||
fused = self.cross_attn(S, H, context_mask=attention_mask)
|
pinyin_mask = pinyin_ids == 0
|
||||||
|
fused = self.cross_attn(
|
||||||
|
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
|
||||||
|
)
|
||||||
|
|
||||||
# 4. MoE 处理 -> [batch, num_slots, dim]
|
|
||||||
moe_out = self.moe(fused)
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
# 5. 槽位注意力池化
|
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
# 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots]
|
|
||||||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||||
# 应用softmax获取注意力权重
|
slot_weights = torch.softmax(slot_scores, dim=1)
|
||||||
slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots]
|
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
||||||
# 加权求和得到池化表示
|
|
||||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim]
|
|
||||||
|
|
||||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
logits = self.classifier(pooled)
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
|
|
@ -1244,14 +1244,14 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=100000,
|
shuffle_buffer_size=2000000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dataloader = create_dataloader(
|
eval_dataloader = create_dataloader(
|
||||||
dataset=eval_dataset,
|
dataset=eval_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=1, # 评估使用较少的worker
|
num_workers=2, # 评估使用较少的worker
|
||||||
pin_memory=torch.cuda.is_available(),
|
pin_memory=torch.cuda.is_available(),
|
||||||
max_iter_length=batch_size * 64,
|
max_iter_length=batch_size * 64,
|
||||||
)
|
)
|
||||||
|
|
@ -1356,615 +1356,5 @@ def export(
|
||||||
console.print("[yellow]导出功能待实现[/yellow]")
|
console.print("[yellow]导出功能待实现[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def expand_and_train(
|
|
||||||
# 数据参数
|
|
||||||
train_data_path: str = typer.Option(
|
|
||||||
..., "--train-data-path", "-t", help="训练数据集路径"
|
|
||||||
),
|
|
||||||
eval_data_path: str = typer.Option(
|
|
||||||
..., "--eval-data-path", "-e", help="评估数据集路径"
|
|
||||||
),
|
|
||||||
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
|
|
||||||
# 模型参数
|
|
||||||
base_model_path: str = typer.Option(
|
|
||||||
..., "--base-model-path", help="预训练基础模型检查点路径"
|
|
||||||
),
|
|
||||||
new_model_spec: str = typer.Option(
|
|
||||||
...,
|
|
||||||
"--new-model-spec",
|
|
||||||
"-m",
|
|
||||||
help="新模型规格,格式:模块名:类名,如 'model:InputMethodEngine'。支持任意路径,自定义模型类必须是 InputMethodEngine 的子类",
|
|
||||||
),
|
|
||||||
# 数据大小
|
|
||||||
max_iter_length: int = typer.Option(
|
|
||||||
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
|
||||||
),
|
|
||||||
# 训练参数
|
|
||||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
|
||||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
|
||||||
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
|
|
||||||
min_learning_rate: float = typer.Option(
|
|
||||||
1e-9, "--min-learning-rate", help="最小学习率"
|
|
||||||
),
|
|
||||||
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
|
|
||||||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
|
||||||
label_smoothing: float = typer.Option(
|
|
||||||
0.1, "--label-smoothing", help="标签平滑参数"
|
|
||||||
),
|
|
||||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
|
||||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
|
||||||
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
|
|
||||||
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
|
|
||||||
# 其他参数
|
|
||||||
mixed_precision: bool = typer.Option(
|
|
||||||
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
|
||||||
),
|
|
||||||
num_workers: int = typer.Option(
|
|
||||||
2, "--num-workers", help="数据加载worker数量(流式数据集建议为2)"
|
|
||||||
),
|
|
||||||
use_tensorboard: bool = typer.Option(
|
|
||||||
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
|
|
||||||
),
|
|
||||||
resume_from: Optional[str] = typer.Option(
|
|
||||||
None, "--resume-from", help="从检查点恢复训练"
|
|
||||||
),
|
|
||||||
reset_training_state: bool = typer.Option(
|
|
||||||
False, "--reset-training-state", help="重置训练状态,只加载模型权重从头开始训练"
|
|
||||||
),
|
|
||||||
auto_resume: bool = typer.Option(
|
|
||||||
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
|
|
||||||
),
|
|
||||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
|
||||||
compile: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--compile/--no-compile",
|
|
||||||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
|
||||||
|
|
||||||
# 启用 TensorFloat32 加速矩阵乘法 (解决 UserWarning 并提升性能)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.set_float32_matmul_precision("high")
|
|
||||||
|
|
||||||
# 设置随机种子
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
# 硬编码模型参数
|
|
||||||
vocab_size = 10019
|
|
||||||
pinyin_vocab_size = 30 # 根据 dataset.CHAR_TO_ID 映射
|
|
||||||
dim = 512
|
|
||||||
num_slots = 8
|
|
||||||
n_layers = 4
|
|
||||||
n_heads = 4
|
|
||||||
num_experts = 10
|
|
||||||
max_seq_len = 128
|
|
||||||
use_pinyin = True # 始终使用拼音
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
# 打印配置信息
|
|
||||||
console.print(
|
|
||||||
Panel.fit(
|
|
||||||
"[bold cyan]模型扩容第一阶段训练配置[/bold cyan]", border_style="cyan"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
config_table = Table(show_header=True, header_style="bold magenta")
|
|
||||||
config_table.add_column("Category", style="cyan")
|
|
||||||
config_table.add_column("Parameter", style="green")
|
|
||||||
config_table.add_column("Value", style="yellow")
|
|
||||||
|
|
||||||
# 添加配置信息
|
|
||||||
config_table.add_row("数据", "训练数据路径", train_data_path)
|
|
||||||
config_table.add_row("数据", "评估数据路径", eval_data_path)
|
|
||||||
config_table.add_row("数据", "输出目录", output_dir)
|
|
||||||
config_table.add_row("数据", "批次大小", str(batch_size))
|
|
||||||
config_table.add_row("数据", "Worker数量", str(num_workers))
|
|
||||||
|
|
||||||
config_table.add_row("模型", "基础模型路径", base_model_path)
|
|
||||||
config_table.add_row("模型", "新模型规格", new_model_spec)
|
|
||||||
config_table.add_row("模型", "词汇表大小", str(vocab_size))
|
|
||||||
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
|
|
||||||
config_table.add_row("模型", "模型维度", str(dim))
|
|
||||||
config_table.add_row("模型", "槽位数量", str(num_slots))
|
|
||||||
config_table.add_row("模型", "Transformer层数", str(n_layers))
|
|
||||||
config_table.add_row("模型", "注意力头数", str(n_heads))
|
|
||||||
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
|
||||||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
|
||||||
config_table.add_row("模型", "编译优化", str(compile))
|
|
||||||
|
|
||||||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
|
||||||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
|
||||||
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
|
|
||||||
config_table.add_row("训练", "权重衰减", str(weight_decay))
|
|
||||||
config_table.add_row("训练", "热身比例", str(warmup_ratio))
|
|
||||||
config_table.add_row("训练", "标签平滑", str(label_smoothing))
|
|
||||||
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
|
|
||||||
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
|
||||||
|
|
||||||
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
|
||||||
console.print(config_table)
|
|
||||||
|
|
||||||
# 创建输出目录
|
|
||||||
output_path = Path(output_dir)
|
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
config = {
|
|
||||||
"train_data_path": train_data_path,
|
|
||||||
"eval_data_path": eval_data_path,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"base_model_path": base_model_path,
|
|
||||||
"new_model_spec": new_model_spec,
|
|
||||||
"vocab_size": vocab_size,
|
|
||||||
"pinyin_vocab_size": pinyin_vocab_size,
|
|
||||||
"dim": dim,
|
|
||||||
"num_slots": num_slots,
|
|
||||||
"n_layers": n_layers,
|
|
||||||
"n_heads": n_heads,
|
|
||||||
"num_experts": num_experts,
|
|
||||||
"max_seq_len": max_seq_len,
|
|
||||||
"use_pinyin": use_pinyin,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"num_workers": num_workers,
|
|
||||||
"num_epochs": num_epochs,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"min_learning_rate": min_learning_rate,
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"warmup_ratio": warmup_ratio,
|
|
||||||
"label_smoothing": label_smoothing,
|
|
||||||
"grad_accum_steps": grad_accum_steps,
|
|
||||||
"clip_grad_norm": clip_grad_norm,
|
|
||||||
"eval_frequency": eval_frequency,
|
|
||||||
"save_frequency": save_frequency,
|
|
||||||
"mixed_precision": mixed_precision,
|
|
||||||
"use_tensorboard": use_tensorboard,
|
|
||||||
"seed": seed,
|
|
||||||
"auto_resume": auto_resume,
|
|
||||||
"max_iter_length": max_iter_length,
|
|
||||||
"compile": compile,
|
|
||||||
}
|
|
||||||
|
|
||||||
config_file = output_path / "expansion_training_config.json"
|
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
logger.info(f"Configuration saved to {config_file}")
|
|
||||||
|
|
||||||
# 创建数据加载器
|
|
||||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
|
||||||
|
|
||||||
# 训练数据集
|
|
||||||
train_dataset = PinyinInputDataset(
|
|
||||||
data_path=train_data_path,
|
|
||||||
max_workers=-1, # 自动选择worker数量
|
|
||||||
max_iter_length=max_iter_length,
|
|
||||||
max_seq_length=max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=100000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 训练数据加载器
|
|
||||||
train_dataloader = create_dataloader(
|
|
||||||
dataset=train_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=max_iter_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估数据集
|
|
||||||
eval_dataset = PinyinInputDataset(
|
|
||||||
data_path=eval_data_path,
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=batch_size * 64, # 评估集较小
|
|
||||||
max_seq_length=max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=100000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataloader = create_dataloader(
|
|
||||||
dataset=eval_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=1, # 评估使用较少的worker
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=batch_size * 64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建扩容模型
|
|
||||||
console.print("[bold cyan]正在创建扩容模型...[/bold cyan]")
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
model_kwargs = {
|
|
||||||
"vocab_size": vocab_size,
|
|
||||||
"pinyin_vocab_size": pinyin_vocab_size,
|
|
||||||
"dim": dim,
|
|
||||||
"num_slots": num_slots,
|
|
||||||
"n_layers": n_layers,
|
|
||||||
"n_heads": n_heads,
|
|
||||||
"num_experts": num_experts,
|
|
||||||
"max_seq_len": max_seq_len,
|
|
||||||
"compile": compile,
|
|
||||||
}
|
|
||||||
|
|
||||||
model = load_expanded_model(
|
|
||||||
base_model_path=base_model_path,
|
|
||||||
new_model_spec=new_model_spec,
|
|
||||||
device=device,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[green]✓ 扩容模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 统计冻结参数比例
|
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
|
||||||
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
||||||
console.print(
|
|
||||||
f"[green]✓ 冻结参数: {frozen_params:,}/{total_params:,} ({frozen_params / total_params * 100:.1f}%)[/green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建训练器(使用普通 Trainer,只进行第一阶段冻结训练)
|
|
||||||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
total_steps=int(max_iter_length * num_epochs / batch_size),
|
|
||||||
output_dir=output_dir,
|
|
||||||
num_epochs=num_epochs,
|
|
||||||
learning_rate=learning_rate,
|
|
||||||
min_learning_rate=min_learning_rate,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
warmup_ratio=warmup_ratio,
|
|
||||||
label_smoothing=label_smoothing,
|
|
||||||
grad_accum_steps=grad_accum_steps,
|
|
||||||
clip_grad_norm=clip_grad_norm,
|
|
||||||
eval_frequency=eval_frequency,
|
|
||||||
save_frequency=save_frequency,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
use_tensorboard=use_tensorboard,
|
|
||||||
status_file="training_status.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[green]✓ 训练器创建完成[/green]")
|
|
||||||
|
|
||||||
# 开始训练
|
|
||||||
console.print("\n[bold cyan]开始扩容模型第一阶段训练...[/bold cyan]")
|
|
||||||
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
try:
|
|
||||||
trainer.train(
|
|
||||||
resume_from=resume_from,
|
|
||||||
reset_training_state=reset_training_state,
|
|
||||||
auto_resume=auto_resume,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
console.print("[bold green]训练被终止[/bold green]")
|
|
||||||
trainer.save_checkpoint("interrupted_model.pt")
|
|
||||||
|
|
||||||
# 保存扩容信息供第二阶段使用
|
|
||||||
expansion_info = {
|
|
||||||
"stage1_checkpoint_path": str(output_path / "checkpoints" / "best_model.pt"),
|
|
||||||
"model_spec": new_model_spec,
|
|
||||||
"model_kwargs": model_kwargs,
|
|
||||||
"train_data_path": train_data_path,
|
|
||||||
"eval_data_path": eval_data_path,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"max_iter_length": max_iter_length,
|
|
||||||
"max_seq_len": max_seq_len,
|
|
||||||
"num_workers": num_workers,
|
|
||||||
}
|
|
||||||
|
|
||||||
expansion_info_file = output_path / "expansion_info.json"
|
|
||||||
with open(expansion_info_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(expansion_info, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
logger.info(f"Expansion info saved to {expansion_info_file}")
|
|
||||||
|
|
||||||
console.print("[bold green]✓ 第一阶段训练完成![/bold green]")
|
|
||||||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
console.print(f"模型和日志保存在: {output_dir}")
|
|
||||||
console.print(f"[bold cyan]扩容信息已保存到: {expansion_info_file}[/bold cyan]")
|
|
||||||
console.print(
|
|
||||||
"[yellow]请手动检查模型后,使用 expand-finetune 命令进行第二阶段全量微调[/yellow]"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def expand_finetune(
|
|
||||||
expand_config: str = typer.Option(
|
|
||||||
...,
|
|
||||||
"--expand-config",
|
|
||||||
"-c",
|
|
||||||
help="新模型类规格,格式:模块名:类名,如 'big_expert:BigExpert'",
|
|
||||||
),
|
|
||||||
stage1_info: str = typer.Option(
|
|
||||||
..., "--stage1-info", "-i", help="第一阶段保存的 expansion_info.json 路径"
|
|
||||||
),
|
|
||||||
# 可选覆盖参数
|
|
||||||
checkpoint: Optional[str] = typer.Option(
|
|
||||||
None, "--checkpoint", help="第一阶段模型检查点路径(覆盖 JSON 文件中的路径)"
|
|
||||||
),
|
|
||||||
output_dir: Optional[str] = typer.Option(
|
|
||||||
None, "--output-dir", "-o", help="输出目录(覆盖 JSON 文件中的目录)"
|
|
||||||
),
|
|
||||||
train_data_path: Optional[str] = typer.Option(
|
|
||||||
None, "--train-data-path", "-t", help="训练数据路径(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
eval_data_path: Optional[str] = typer.Option(
|
|
||||||
None, "--eval-data-path", "-e", help="评估数据路径(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
batch_size: Optional[int] = typer.Option(
|
|
||||||
None, "--batch-size", "-b", help="批次大小(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
num_epochs: Optional[int] = typer.Option(
|
|
||||||
None, "--num-epochs", help="训练轮数(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
learning_rate: Optional[float] = typer.Option(
|
|
||||||
None, "--learning-rate", "-lr", help="学习率"
|
|
||||||
),
|
|
||||||
min_learning_rate: Optional[float] = typer.Option(
|
|
||||||
None, "--min-learning-rate", help="最小学习率"
|
|
||||||
),
|
|
||||||
weight_decay: Optional[float] = typer.Option(
|
|
||||||
None, "--weight-decay", help="权重衰减"
|
|
||||||
),
|
|
||||||
warmup_ratio: Optional[float] = typer.Option(
|
|
||||||
None, "--warmup-ratio", help="热身步数比例"
|
|
||||||
),
|
|
||||||
label_smoothing: Optional[float] = typer.Option(
|
|
||||||
None, "--label-smoothing", help="标签平滑参数"
|
|
||||||
),
|
|
||||||
grad_accum_steps: Optional[int] = typer.Option(
|
|
||||||
None, "--grad-accum-steps", help="梯度累积步数"
|
|
||||||
),
|
|
||||||
clip_grad_norm: Optional[float] = typer.Option(
|
|
||||||
None, "--clip-grad-norm", help="梯度裁剪范数"
|
|
||||||
),
|
|
||||||
eval_frequency: Optional[int] = typer.Option(
|
|
||||||
None, "--eval-frequency", help="评估频率"
|
|
||||||
),
|
|
||||||
save_frequency: Optional[int] = typer.Option(
|
|
||||||
None, "--save-frequency", help="保存频率"
|
|
||||||
),
|
|
||||||
max_iter_length: Optional[int] = typer.Option(
|
|
||||||
None, "--max-iter-length", help="数据集大小(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
max_seq_len: Optional[int] = typer.Option(
|
|
||||||
None, "--max-seq-len", help="最大序列长度(覆盖 JSON 文件)"
|
|
||||||
),
|
|
||||||
num_workers: Optional[int] = typer.Option(
|
|
||||||
None, "--num-workers", help="数据加载worker数量"
|
|
||||||
),
|
|
||||||
mixed_precision: bool = typer.Option(
|
|
||||||
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
|
||||||
),
|
|
||||||
use_tensorboard: bool = typer.Option(
|
|
||||||
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
|
|
||||||
),
|
|
||||||
resume_from: Optional[str] = typer.Option(
|
|
||||||
None, "--resume-from", help="从检查点恢复训练"
|
|
||||||
),
|
|
||||||
reset_training_state: bool = typer.Option(
|
|
||||||
False, "--reset-training-state", help="重置训练状态"
|
|
||||||
),
|
|
||||||
auto_resume: bool = typer.Option(
|
|
||||||
True, "--auto-resume/--no-auto-resume", help="是否自动恢复训练"
|
|
||||||
),
|
|
||||||
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
|
||||||
compile: Optional[bool] = typer.Option(
|
|
||||||
None, "--compile/--no-compile", help="是否开启 torch.compile 优化"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
模型扩容第二阶段训练:读取第一阶段的 expansion_info.json,加载扩容模型进行全量微调。
|
|
||||||
命令行参数优先级高于 JSON 文件中的配置。
|
|
||||||
"""
|
|
||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.set_float32_matmul_precision("high")
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
# 加载第一阶段信息
|
|
||||||
stage1_info_path = Path(stage1_info)
|
|
||||||
if not stage1_info_path.exists():
|
|
||||||
console.print(
|
|
||||||
f"[bold red]错误: 找不到第一阶段信息文件 {stage1_info}[/bold red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
with open(stage1_info_path, "r", encoding="utf-8") as f:
|
|
||||||
info = json.load(f)
|
|
||||||
|
|
||||||
# 命令行参数优先级高于 JSON 文件
|
|
||||||
final_checkpoint = checkpoint or info["stage1_checkpoint_path"]
|
|
||||||
final_output_dir = output_dir or info["output_dir"]
|
|
||||||
final_train_data_path = train_data_path or info["train_data_path"]
|
|
||||||
final_eval_data_path = eval_data_path or info["eval_data_path"]
|
|
||||||
final_batch_size = batch_size if batch_size is not None else info["batch_size"]
|
|
||||||
final_num_epochs = (
|
|
||||||
num_epochs if num_epochs is not None else info.get("num_epochs", 10)
|
|
||||||
)
|
|
||||||
final_max_iter_length = (
|
|
||||||
max_iter_length if max_iter_length is not None else info["max_iter_length"]
|
|
||||||
)
|
|
||||||
final_max_seq_len = max_seq_len if max_seq_len is not None else info["max_seq_len"]
|
|
||||||
final_num_workers = (
|
|
||||||
num_workers if num_workers is not None else info.get("num_workers", 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 训练参数(有默认值,不覆盖则使用默认)
|
|
||||||
final_learning_rate = learning_rate if learning_rate is not None else 1e-4
|
|
||||||
final_min_learning_rate = (
|
|
||||||
min_learning_rate if min_learning_rate is not None else 1e-9
|
|
||||||
)
|
|
||||||
final_weight_decay = weight_decay if weight_decay is not None else 0.1
|
|
||||||
final_warmup_ratio = warmup_ratio if warmup_ratio is not None else 0.1
|
|
||||||
final_label_smoothing = label_smoothing if label_smoothing is not None else 0.15
|
|
||||||
final_grad_accum_steps = grad_accum_steps if grad_accum_steps is not None else 1
|
|
||||||
final_clip_grad_norm = clip_grad_norm if clip_grad_norm is not None else 1.0
|
|
||||||
final_eval_frequency = eval_frequency if eval_frequency is not None else 500
|
|
||||||
final_save_frequency = save_frequency if save_frequency is not None else 1000
|
|
||||||
|
|
||||||
# 模型参数从 JSON 获取
|
|
||||||
model_kwargs = info["model_kwargs"]
|
|
||||||
if compile is not None:
|
|
||||||
model_kwargs["compile"] = compile
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
Panel.fit(
|
|
||||||
"[bold cyan]模型扩容第二阶段训练配置[/bold cyan]", border_style="cyan"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
config_table = Table(show_header=True, header_style="bold magenta")
|
|
||||||
config_table.add_column("Category", style="cyan")
|
|
||||||
config_table.add_column("Parameter", style="green")
|
|
||||||
config_table.add_column("Value", style="yellow")
|
|
||||||
|
|
||||||
config_table.add_row("数据", "第一阶段信息文件", str(stage1_info_path))
|
|
||||||
config_table.add_row("数据", "训练数据路径", final_train_data_path)
|
|
||||||
config_table.add_row("数据", "评估数据路径", final_eval_data_path)
|
|
||||||
config_table.add_row("数据", "输出目录", final_output_dir)
|
|
||||||
config_table.add_row("数据", "批次大小", str(final_batch_size))
|
|
||||||
config_table.add_row("数据", "Worker数量", str(final_num_workers))
|
|
||||||
|
|
||||||
config_table.add_row("模型", "新模型规格", expand_config)
|
|
||||||
config_table.add_row("模型", "检查点路径", final_checkpoint)
|
|
||||||
for k, v in model_kwargs.items():
|
|
||||||
config_table.add_row("模型", k, str(v))
|
|
||||||
|
|
||||||
config_table.add_row("训练", "训练轮数", str(final_num_epochs))
|
|
||||||
config_table.add_row("训练", "学习率", f"{final_learning_rate:.2e}")
|
|
||||||
config_table.add_row("训练", "最小学习率", f"{final_min_learning_rate:.2e}")
|
|
||||||
config_table.add_row("训练", "权重衰减", str(final_weight_decay))
|
|
||||||
config_table.add_row("训练", "热身比例", str(final_warmup_ratio))
|
|
||||||
config_table.add_row("训练", "标签平滑", str(final_label_smoothing))
|
|
||||||
config_table.add_row("训练", "梯度累积", str(final_grad_accum_steps))
|
|
||||||
config_table.add_row("训练", "梯度裁剪", str(final_clip_grad_norm))
|
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
|
||||||
|
|
||||||
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
|
||||||
console.print(config_table)
|
|
||||||
|
|
||||||
output_path = Path(final_output_dir)
|
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
|
||||||
|
|
||||||
train_dataset = PinyinInputDataset(
|
|
||||||
data_path=final_train_data_path,
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=final_max_iter_length,
|
|
||||||
max_seq_length=final_max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=100000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataloader = create_dataloader(
|
|
||||||
dataset=train_dataset,
|
|
||||||
batch_size=final_batch_size,
|
|
||||||
num_workers=final_num_workers,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=final_max_iter_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataset = PinyinInputDataset(
|
|
||||||
data_path=final_eval_data_path,
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=final_batch_size * 64,
|
|
||||||
max_seq_length=final_max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=100000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataloader = create_dataloader(
|
|
||||||
dataset=eval_dataset,
|
|
||||||
batch_size=final_batch_size,
|
|
||||||
num_workers=1,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=final_batch_size * 64,
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[bold cyan]正在加载扩容模型(全量微调模式)...[/bold cyan]")
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
model = load_expanded_model(
|
|
||||||
base_model_path=final_checkpoint,
|
|
||||||
new_model_spec=expand_config,
|
|
||||||
device=device,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 全量微调:解冻所有参数
|
|
||||||
for param in model.parameters():
|
|
||||||
param.requires_grad = True
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[green]✓ 模型加载完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
|
||||||
)
|
|
||||||
console.print("[green]✓ 所有参数已解冻,进入全量微调模式[/green]")
|
|
||||||
|
|
||||||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
total_steps=int(final_max_iter_length * final_num_epochs / final_batch_size),
|
|
||||||
output_dir=final_output_dir,
|
|
||||||
num_epochs=final_num_epochs,
|
|
||||||
learning_rate=final_learning_rate,
|
|
||||||
min_learning_rate=final_min_learning_rate,
|
|
||||||
weight_decay=final_weight_decay,
|
|
||||||
warmup_ratio=final_warmup_ratio,
|
|
||||||
label_smoothing=final_label_smoothing,
|
|
||||||
grad_accum_steps=final_grad_accum_steps,
|
|
||||||
clip_grad_norm=final_clip_grad_norm,
|
|
||||||
eval_frequency=final_eval_frequency,
|
|
||||||
save_frequency=final_save_frequency,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
use_tensorboard=use_tensorboard,
|
|
||||||
status_file="training_status_finetune.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[green]✓ 训练器创建完成[/green]")
|
|
||||||
|
|
||||||
console.print("\n[bold cyan]开始第二阶段全量微调...[/bold cyan]")
|
|
||||||
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
try:
|
|
||||||
trainer.train(
|
|
||||||
resume_from=resume_from,
|
|
||||||
reset_training_state=reset_training_state,
|
|
||||||
auto_resume=auto_resume,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
console.print("[bold green]训练被终止[/bold green]")
|
|
||||||
trainer.save_checkpoint("interrupted_model.pt")
|
|
||||||
|
|
||||||
console.print("[bold green]✓ 第二阶段全量微调完成![/bold green]")
|
|
||||||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
console.print(f"模型和日志保存在: {final_output_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
||||||
|
|
|
||||||
8
test.py
8
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
part1 = "招财猫背部或底部的太阳能板会持续将环境光(无论是阳光还是室内灯光)转化为"
|
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
|
||||||
part2 = "weiruo"
|
part2 = "piaoliang"
|
||||||
pinyin_ids = text_to_pinyin_ids(part2)
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
len_py = len(pinyin_ids)
|
len_py = len(pinyin_ids)
|
||||||
if len_py < 24:
|
if len_py < 24:
|
||||||
|
|
@ -56,7 +56,7 @@ if len_py < 24:
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [649, 925, 0, 0, 0, 0, 0, 0]
|
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = ""
|
part3 = ""
|
||||||
part4 = "可行|特别|伤害"
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ sample = {
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/20260411acc37final-model.pt", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
|
||||||
from model.trainer import collate_fn, worker_init_fn
|
|
||||||
|
|
||||||
max_iter_length = 128 * 128
|
|
||||||
batch_size = 1024
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
dataset_path = "data"
|
|
||||||
else:
|
|
||||||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
|
||||||
|
|
||||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=max_iter_length)
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=2,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景
|
|
||||||
persistent_workers=True, # 保持worker进程存活,避免重建开销
|
|
||||||
)
|
|
||||||
dataloader = list([i for i in dataloader])
|
|
||||||
for i, line in tqdm(enumerate(dataloader), total=max_iter_length / batch_size):
|
|
||||||
print((line["labels"] == 0).sum())
|
|
||||||
|
|
@ -0,0 +1,232 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
参数迁移脚本:将旧模型权重迁移到新架构
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python transfer_weights.py --old-checkpoint /path/to/best_model.pt --output ./migrated_model.pt
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 加载旧模型 checkpoint
|
||||||
|
2. 创建新模型架构
|
||||||
|
3. 直接迁移匹配的参数
|
||||||
|
4. 拆分迁移 k_proj/v_proj 到 text/pinyin 双分支
|
||||||
|
5. 保存迁移后的 checkpoint
|
||||||
|
6. 打印详细迁移报告
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append("src")
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_weights(
|
||||||
|
old_checkpoint_path: str, new_model: InputMethodEngine, device: torch.device
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
迁移旧模型权重到新架构
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_checkpoint_path: 旧 checkpoint 路径
|
||||||
|
new_model: 新模型实例
|
||||||
|
device: 设备
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(directly_transferred, split_transferred, new_params, skipped) 迁移报告
|
||||||
|
"""
|
||||||
|
old_checkpoint = torch.load(old_checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
if "model_state_dict" in old_checkpoint:
|
||||||
|
old_state_dict = old_checkpoint["model_state_dict"]
|
||||||
|
else:
|
||||||
|
old_state_dict = old_checkpoint
|
||||||
|
|
||||||
|
new_state_dict = new_model.state_dict()
|
||||||
|
|
||||||
|
directly_transferred = []
|
||||||
|
split_transferred = []
|
||||||
|
new_params = []
|
||||||
|
skipped = []
|
||||||
|
|
||||||
|
for key in new_state_dict.keys():
|
||||||
|
if key in old_state_dict:
|
||||||
|
if new_state_dict[key].shape == old_state_dict[key].shape:
|
||||||
|
new_state_dict[key] = old_state_dict[key]
|
||||||
|
directly_transferred.append(key)
|
||||||
|
else:
|
||||||
|
skipped.append(
|
||||||
|
(
|
||||||
|
key,
|
||||||
|
"shape mismatch",
|
||||||
|
f"old={old_state_dict[key].shape}, new={new_state_dict[key].shape}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif key in [
|
||||||
|
"cross_attn.k_text_proj.weight",
|
||||||
|
"cross_attn.k_pinyin_proj.weight",
|
||||||
|
]:
|
||||||
|
if "cross_attn.k_proj.weight" in old_state_dict:
|
||||||
|
new_state_dict[key] = old_state_dict["cross_attn.k_proj.weight"].clone()
|
||||||
|
split_transferred.append((key, "cross_attn.k_proj.weight"))
|
||||||
|
elif key in [
|
||||||
|
"cross_attn.v_text_proj.weight",
|
||||||
|
"cross_attn.v_pinyin_proj.weight",
|
||||||
|
]:
|
||||||
|
if "cross_attn.v_proj.weight" in old_state_dict:
|
||||||
|
new_state_dict[key] = old_state_dict["cross_attn.v_proj.weight"].clone()
|
||||||
|
split_transferred.append((key, "cross_attn.v_proj.weight"))
|
||||||
|
else:
|
||||||
|
new_params.append(key)
|
||||||
|
|
||||||
|
new_model.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
return directly_transferred, split_transferred, new_params, skipped
|
||||||
|
|
||||||
|
|
||||||
|
def print_report(
|
||||||
|
directly_transferred, split_transferred, new_params, skipped, output_path
|
||||||
|
):
|
||||||
|
"""打印迁移报告"""
|
||||||
|
total_params = len(directly_transferred) + len(split_transferred) + len(new_params)
|
||||||
|
coverage = (
|
||||||
|
(len(directly_transferred) + len(split_transferred)) / total_params * 100
|
||||||
|
if total_params > 0
|
||||||
|
else 100
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("📊 参数迁移报告")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"\n✅ 直接迁移 ({len(directly_transferred)} 层):")
|
||||||
|
categories = {}
|
||||||
|
for key in directly_transferred:
|
||||||
|
category = key.split(".")[0]
|
||||||
|
if category not in categories:
|
||||||
|
categories[category] = []
|
||||||
|
categories[category].append(key)
|
||||||
|
for cat, keys in sorted(categories.items()):
|
||||||
|
print(f" - {cat}.* ({len(keys)} 个参数)")
|
||||||
|
|
||||||
|
if split_transferred:
|
||||||
|
print(f"\n✅ 拆分迁移 ({len(split_transferred)} 层):")
|
||||||
|
for new_key, old_key in split_transferred:
|
||||||
|
print(f" - {new_key} ← {old_key}")
|
||||||
|
|
||||||
|
if new_params:
|
||||||
|
print(f"\n⚠️ 新增参数 ({len(new_params)} 层):")
|
||||||
|
for key in new_params[:10]:
|
||||||
|
print(f" - {key}")
|
||||||
|
if len(new_params) > 10:
|
||||||
|
print(f" ... 等共 {len(new_params)} 个")
|
||||||
|
|
||||||
|
if skipped:
|
||||||
|
print(f"\n❌ 跳过 ({len(skipped)} 层):")
|
||||||
|
for key, reason, detail in skipped[:5]:
|
||||||
|
print(f" - {key}: {reason} ({detail})")
|
||||||
|
if len(skipped) > 5:
|
||||||
|
print(f" ... 等共 {len(skipped)} 个")
|
||||||
|
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print(
|
||||||
|
f"迁移覆盖率: {coverage:.1f}% ({len(directly_transferred) + len(split_transferred)}/{total_params})"
|
||||||
|
)
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\n💾 已保存到: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="迁移旧模型权重到新架构")
|
||||||
|
parser.add_argument(
|
||||||
|
"--old-checkpoint",
|
||||||
|
"-c",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="旧模型 checkpoint 路径",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
type=str,
|
||||||
|
default="./migrated_model.pt",
|
||||||
|
help="迁移后的 checkpoint 输出路径",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="设备 (cuda/cpu)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
print(f"📦 加载旧模型: {args.old_checkpoint}")
|
||||||
|
print(f"🔧 使用设备: {device}")
|
||||||
|
|
||||||
|
old_path = Path(args.old_checkpoint)
|
||||||
|
if not old_path.exists():
|
||||||
|
print(f"❌ 文件不存在: {args.old_checkpoint}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("\n🏗️ 创建新模型架构...")
|
||||||
|
new_model = InputMethodEngine(
|
||||||
|
vocab_size=10019,
|
||||||
|
pinyin_vocab_size=30,
|
||||||
|
dim=512,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=4,
|
||||||
|
n_heads=4,
|
||||||
|
num_experts=10,
|
||||||
|
max_seq_len=128,
|
||||||
|
compile=False,
|
||||||
|
)
|
||||||
|
new_model.to(device)
|
||||||
|
|
||||||
|
print(f" 新模型参数量: {sum(p.numel() for p in new_model.parameters()):,}")
|
||||||
|
|
||||||
|
print("\n🔄 开始迁移参数...")
|
||||||
|
directly_transferred, split_transferred, new_params, skipped = transfer_weights(
|
||||||
|
args.old_checkpoint, new_model, device
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n💾 保存迁移后的 checkpoint...")
|
||||||
|
output_path = Path(args.output)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"step": 0,
|
||||||
|
"epoch": 0,
|
||||||
|
"model_state_dict": new_model.state_dict(),
|
||||||
|
"best_eval_loss": float("inf"),
|
||||||
|
"config": {
|
||||||
|
"vocab_size": 10019,
|
||||||
|
"pinyin_vocab_size": 30,
|
||||||
|
"dim": 512,
|
||||||
|
"num_slots": 8,
|
||||||
|
"n_layers": 4,
|
||||||
|
"n_heads": 4,
|
||||||
|
"num_experts": 10,
|
||||||
|
"max_seq_len": 128,
|
||||||
|
},
|
||||||
|
"migration_source": args.old_checkpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint, output_path)
|
||||||
|
|
||||||
|
print_report(
|
||||||
|
directly_transferred, split_transferred, new_params, skipped, output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ 迁移完成!使用方法:")
|
||||||
|
print(
|
||||||
|
f" python -m model.trainer train --resume-from {args.output} --reset-training-state"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue