docs: 移除模型扩容两阶段训练文档并更新相关用法说明

This commit is contained in:
songsenand 2026-04-13 14:09:14 +08:00
parent 33f56f709b
commit 3175ace9c5
5 changed files with 139 additions and 238 deletions

105
README.md
View File

@ -790,112 +790,7 @@ train-model evaluate \
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 6.8 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 6.9 导出模型(开发中)

View File

@ -484,113 +484,6 @@ train-model evaluate \
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 导出模型(开发中)
当前导出功能尚在开发中:

View File

@ -22,6 +22,8 @@
"""
import argparse
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
@ -60,6 +62,98 @@ class InputMethodInference:
print(f"✅ 推理器初始化完成 (设备: {self.device})")
# 尝试启用readline以获得更好的行编辑功能
try:
import readline
# 设置readline使用UTF-8编码
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
print("📝 readline已启用支持更好的行编辑功能")
except ImportError:
print("📝 readline不可用使用标准输入")
def _safe_input(self, prompt: str, default: str = "") -> str:
"""
安全的输入函数尝试正确处理UTF-8字符和退格键
Args:
prompt: 提示文本
default: 默认值
Returns:
用户输入的字符串
"""
try:
# 显示提示和默认值
if default:
full_prompt = f"{prompt} [{default}]: "
else:
full_prompt = f"{prompt}: "
# 使用标准input
result = input(full_prompt)
# 如果用户直接回车且存在默认值,则返回默认值
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
# 用户按Ctrl+D或Ctrl+C
print()
return ""
except Exception as e:
# 其他错误
print(f"\n⚠️ 输入错误: {e}")
return ""
def _clean_pinyin_input(self, pinyin: str) -> str:
"""
清理拼音输入字符串处理退格键等特殊字符
拼音只允许: a-z, `, ', -
中文字符和其他字符会被忽略
Args:
pinyin: 原始拼音输入字符串
Returns:
清理后的拼音字符串
"""
if not pinyin:
return ""
result = []
for c in pinyin:
# 检查是否为合法拼音字符 (a-z, `, ', -)
# 注意: 中文字符的isalpha()也返回True所以需要额外检查
is_valid_pinyin_char = (
("a" <= c <= "z")
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
or c in ["`", "'", "-"]
)
if is_valid_pinyin_char:
# 合法拼音字符,转换为小写
result.append(c.lower())
elif c == " ":
# 空格忽略
continue
elif c == "\b" or c == "\x7f" or c == "\x08":
# 退格键、删除键:删除前一个字符
if result:
result.pop()
elif c == "\x1b":
# ESC键清空所有输入
result.clear()
else:
# 其他字符(包括中文字符)忽略
# 注意这里不添加到result所以退格键无法删除它们
# 但用户可能在拼音输入中误输入中文字符,应该忽略
pass
return "".join(result)
def load_model(self):
"""加载训练好的模型"""
# 创建模型实例(不编译)
@ -184,7 +278,7 @@ class InputMethodInference:
"""
# 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3
# 根据test.py和dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before
# part3: text_after
@ -192,13 +286,14 @@ class InputMethodInference:
# 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本
# 构建输入文本 - 与test.py保持一致
# test.py: f"{part4}|{part1}" 作为第一个参数part3作为第二个参数
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
# 2. Tokenize - 与test.py保持一致
encoded = self.tokenizer(
input_text,
text_after,
@ -209,8 +304,11 @@ class InputMethodInference:
return_token_type_ids=True,
)
# 3. 处理拼音输入
pinyin_ids = text_to_pinyin_ids(pinyin)
# 3. 处理拼音输入 - 与test.py保持一致
# 首先清理拼音字符串,处理退格键等特殊字符
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
@ -396,7 +494,16 @@ class InputMethodInference:
print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式")
print("=" * 60)
print("说明:")
# 检查终端编码
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
if encoding.lower() not in ["utf-8", "utf8"]:
print("⚠️ 警告: 终端编码不是UTF-8中文输入可能有问题")
print(" 建议设置: export LANG=en_US.UTF-8")
print(" 或设置: export LC_ALL=en_US.UTF-8")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
@ -411,7 +518,7 @@ class InputMethodInference:
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -434,7 +541,7 @@ class InputMethodInference:
print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip()
text_before = self._safe_input("请输入光标前文本")
if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -446,7 +553,7 @@ class InputMethodInference:
print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip()
text_after = self._safe_input("请输入光标后文本")
if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -458,7 +565,7 @@ class InputMethodInference:
print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip()
pinyin = self._safe_input("请输入拼音")
if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -471,7 +578,7 @@ class InputMethodInference:
print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -524,7 +631,9 @@ class InputMethodInference:
# 询问是否继续
print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower()
continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
@ -539,7 +648,9 @@ class InputMethodInference:
traceback.print_exc()
# 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower()
continue_input = (
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break

View File

@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
# 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1)

12
test.py
View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
part2 = "piaoliang"
part1 = "从中国回家后,他觉得世界上最好的城市就是"
part2 = "shanghai"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
@ -56,9 +56,9 @@ if len_py < 24:
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = "可行|特别|伤害"
part4 = ""
encoded = tokenizer(
f"{part4}|{part1}",
@ -83,8 +83,9 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
@ -97,6 +98,7 @@ for k, v in sample.items():
print(f"{k}: {v}")
start = time.time()
with torch.no_grad():
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
sort_res = sorted(