docs: 移除模型扩容两阶段训练文档并更新相关用法说明

This commit is contained in:
songsenand 2026-04-13 14:09:14 +08:00
parent 33f56f709b
commit 3175ace9c5
5 changed files with 139 additions and 238 deletions

105
README.md
View File

@ -790,112 +790,7 @@ train-model evaluate \
- 在评估数据集上计算准确率、困惑度等指标 - 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告 - 生成详细的性能报告
### 6.8 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 6.9 导出模型(开发中) ### 6.9 导出模型(开发中)

View File

@ -479,118 +479,11 @@ train-model evaluate \
--batch-size 32 --batch-size 32
``` ```
命令将显示"评估功能待实现"的提示信息。该功能计划用于: 命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点 - 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标 - 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告 - 生成详细的性能报告
### 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 导出模型(开发中) ### 导出模型(开发中)
当前导出功能尚在开发中: 当前导出功能尚在开发中:

View File

@ -22,6 +22,8 @@
""" """
import argparse import argparse
import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -60,6 +62,98 @@ class InputMethodInference:
print(f"✅ 推理器初始化完成 (设备: {self.device})") print(f"✅ 推理器初始化完成 (设备: {self.device})")
# 尝试启用readline以获得更好的行编辑功能
try:
import readline
# 设置readline使用UTF-8编码
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
print("📝 readline已启用支持更好的行编辑功能")
except ImportError:
print("📝 readline不可用使用标准输入")
def _safe_input(self, prompt: str, default: str = "") -> str:
"""
安全的输入函数尝试正确处理UTF-8字符和退格键
Args:
prompt: 提示文本
default: 默认值
Returns:
用户输入的字符串
"""
try:
# 显示提示和默认值
if default:
full_prompt = f"{prompt} [{default}]: "
else:
full_prompt = f"{prompt}: "
# 使用标准input
result = input(full_prompt)
# 如果用户直接回车且存在默认值,则返回默认值
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
# 用户按Ctrl+D或Ctrl+C
print()
return ""
except Exception as e:
# 其他错误
print(f"\n⚠️ 输入错误: {e}")
return ""
def _clean_pinyin_input(self, pinyin: str) -> str:
"""
清理拼音输入字符串处理退格键等特殊字符
拼音只允许: a-z, `, ', -
中文字符和其他字符会被忽略
Args:
pinyin: 原始拼音输入字符串
Returns:
清理后的拼音字符串
"""
if not pinyin:
return ""
result = []
for c in pinyin:
# 检查是否为合法拼音字符 (a-z, `, ', -)
# 注意: 中文字符的isalpha()也返回True所以需要额外检查
is_valid_pinyin_char = (
("a" <= c <= "z")
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
or c in ["`", "'", "-"]
)
if is_valid_pinyin_char:
# 合法拼音字符,转换为小写
result.append(c.lower())
elif c == " ":
# 空格忽略
continue
elif c == "\b" or c == "\x7f" or c == "\x08":
# 退格键、删除键:删除前一个字符
if result:
result.pop()
elif c == "\x1b":
# ESC键清空所有输入
result.clear()
else:
# 其他字符(包括中文字符)忽略
# 注意这里不添加到result所以退格键无法删除它们
# 但用户可能在拼音输入中误输入中文字符,应该忽略
pass
return "".join(result)
def load_model(self): def load_model(self):
"""加载训练好的模型""" """加载训练好的模型"""
# 创建模型实例(不编译) # 创建模型实例(不编译)
@ -184,7 +278,7 @@ class InputMethodInference:
""" """
# 1. 构建tokenizer输入 # 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3 # 根据test.py和dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握) # part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before # part1: text_before
# part3: text_after # part3: text_after
@ -192,13 +286,14 @@ class InputMethodInference:
# 处理上下文提示 # 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else "" context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本 # 构建输入文本 - 与test.py保持一致
# test.py: f"{part4}|{part1}" 作为第一个参数part3作为第二个参数
if context_text: if context_text:
input_text = f"{context_text}|{text_before}" input_text = f"{context_text}|{text_before}"
else: else:
input_text = text_before input_text = text_before
# 2. Tokenize # 2. Tokenize - 与test.py保持一致
encoded = self.tokenizer( encoded = self.tokenizer(
input_text, input_text,
text_after, text_after,
@ -209,8 +304,11 @@ class InputMethodInference:
return_token_type_ids=True, return_token_type_ids=True,
) )
# 3. 处理拼音输入 # 3. 处理拼音输入 - 与test.py保持一致
pinyin_ids = text_to_pinyin_ids(pinyin) # 首先清理拼音字符串,处理退格键等特殊字符
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24: if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids))) pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else: else:
@ -396,7 +494,16 @@ class InputMethodInference:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式") print("输入法模型推理 - 交互模式")
print("=" * 60) print("=" * 60)
print("说明:")
# 检查终端编码
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
if encoding.lower() not in ["utf-8", "utf8"]:
print("⚠️ 警告: 终端编码不是UTF-8中文输入可能有问题")
print(" 建议设置: export LANG=en_US.UTF-8")
print(" 或设置: export LC_ALL=en_US.UTF-8")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本") print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本") print(" - 光标后文本: 光标后的连续文本")
@ -411,7 +518,7 @@ class InputMethodInference:
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)") print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空") print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学") print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip() context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]: if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -434,7 +541,7 @@ class InputMethodInference:
print("第2步: 光标前文本") print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容") print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好") print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip() text_before = self._safe_input("请输入光标前文本")
if text_before.lower() in ["quit", "exit", "q"]: if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -446,7 +553,7 @@ class InputMethodInference:
print("第3步: 光标后文本") print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容") print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩") print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip() text_after = self._safe_input("请输入光标后文本")
if text_after.lower() in ["quit", "exit", "q"]: if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -458,7 +565,7 @@ class InputMethodInference:
print("第4步: 拼音输入") print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音") print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao") print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip() pinyin = self._safe_input("请输入拼音")
if pinyin.lower() in ["quit", "exit", "q"]: if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -471,7 +578,7 @@ class InputMethodInference:
print("说明: 用户已确认的输入历史,用逗号分隔") print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''") print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)") print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip() slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]: if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式") print("退出交互模式")
@ -524,7 +631,9 @@ class InputMethodInference:
# 询问是否继续 # 询问是否继续
print("\n" + "-" * 40) print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower() continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]: if continue_input not in ["y", "yes", ""]:
print("退出交互模式") print("退出交互模式")
break break
@ -539,7 +648,9 @@ class InputMethodInference:
traceback.print_exc() traceback.print_exc()
# 询问是否继续 # 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower() continue_input = (
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]: if continue_input not in ["y", "yes", ""]:
print("退出交互模式") print("退出交互模式")
break break

View File

@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE) # 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8) self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
# 5. 槽位注意力池化 # 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1) self.slot_attention = nn.Linear(dim, 1)

22
test.py
View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常" part1 = "从中国回家后,他觉得世界上最好的城市就是"
part2 = "piaoliang" part2 = "shanghai"
pinyin_ids = text_to_pinyin_ids(part2) pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids) len_py = len(pinyin_ids)
if len_py < 24: if len_py < 24:
@ -56,9 +56,9 @@ if len_py < 24:
else: else:
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0] masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
part3 = "" part3 = ""
part4 = "可行|特别|伤害" part4 = ""
encoded = tokenizer( encoded = tokenizer(
f"{part4}|{part1}", f"{part4}|{part1}",
@ -83,8 +83,9 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False) model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu") checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
input_ids = sample["input_ids"] input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"] token_type_ids = sample["token_type_ids"]
@ -97,12 +98,13 @@ for k, v in sample.items():
print(f"{k}: {v}") print(f"{k}: {v}")
start = time.time() start = time.time()
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) with torch.no_grad():
print(f'计算时长: {(time.time() - start) * 1000:4f}ms') res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
sort_res = sorted( print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
) )
print(sort_res[0:5]) print(sort_res[0:5])
query_engine = QueryEngine() query_engine = QueryEngine()
query_engine.load() query_engine.load()