docs: 移除模型扩容两阶段训练文档并更新相关用法说明
This commit is contained in:
parent
33f56f709b
commit
3175ace9c5
105
README.md
105
README.md
|
|
@ -790,112 +790,7 @@ train-model evaluate \
|
||||||
- 在评估数据集上计算准确率、困惑度等指标
|
- 在评估数据集上计算准确率、困惑度等指标
|
||||||
- 生成详细的性能报告
|
- 生成详细的性能报告
|
||||||
|
|
||||||
### 6.8 模型扩容两阶段训练
|
|
||||||
|
|
||||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
|
||||||
|
|
||||||
#### 训练策略
|
|
||||||
|
|
||||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
|
||||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
|
||||||
|
|
||||||
#### 基础用法
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "model:InputMethodEngine" \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-lr 2e-3 \
|
|
||||||
--full-lr 5e-5 \
|
|
||||||
--frozen-patience 8
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 完整参数示例
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--output-dir "./expansion_output" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "custom_model:ExpandedModel" \
|
|
||||||
--vocab-size 10019 \
|
|
||||||
--dim 512 \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-patience 10 \
|
|
||||||
--frozen-lr 1e-3 \
|
|
||||||
--full-lr 1e-4 \
|
|
||||||
--frozen-scheduler cosine \
|
|
||||||
--full-scheduler cosine \
|
|
||||||
--batch-size 128 \
|
|
||||||
--num-epochs 20 \
|
|
||||||
--compile
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 参数详解
|
|
||||||
|
|
||||||
**模型扩容参数**
|
|
||||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
|
||||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
|
||||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
|
||||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
|
||||||
|
|
||||||
**两阶段训练参数**
|
|
||||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
|
||||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
|
||||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
|
||||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
|
|
||||||
**其他参数**
|
|
||||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
|
||||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
|
||||||
|
|
||||||
#### 使用场景
|
|
||||||
|
|
||||||
1. **增加专家数量**(20→40)
|
|
||||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
|
||||||
- 新增参数:新专家网络、gate层
|
|
||||||
|
|
||||||
2. **增加top_k值**(2→3)
|
|
||||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
|
||||||
- 新增参数:无
|
|
||||||
|
|
||||||
3. **修改专家内部结构**(如增加resblocks)
|
|
||||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
|
||||||
- 新增参数:新增的resblocks层
|
|
||||||
|
|
||||||
4. **增加Transformer层数**(4→5)
|
|
||||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
|
||||||
- 新增参数:新增的第5层
|
|
||||||
|
|
||||||
#### 自定义模型类示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
# my_model.py
|
|
||||||
from model.model import InputMethodEngine
|
|
||||||
|
|
||||||
class MyExpandedModel(InputMethodEngine):
|
|
||||||
def __init__(self, num_experts=40, **kwargs):
|
|
||||||
# 调用父类构造函数,覆盖num_experts参数
|
|
||||||
super().__init__(num_experts=num_experts, **kwargs)
|
|
||||||
# 可以在这里添加额外的层或修改现有层
|
|
||||||
|
|
||||||
# 使用命令
|
|
||||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 注意事项
|
|
||||||
|
|
||||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
|
||||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
|
||||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
|
||||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
|
||||||
|
|
||||||
### 6.9 导出模型(开发中)
|
### 6.9 导出模型(开发中)
|
||||||
|
|
||||||
|
|
|
||||||
109
docs/TRAINING.md
109
docs/TRAINING.md
|
|
@ -479,118 +479,11 @@ train-model evaluate \
|
||||||
--batch-size 32
|
--batch-size 32
|
||||||
```
|
```
|
||||||
|
|
||||||
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||||||
- 加载训练好的模型检查点
|
- 加载训练好的模型检查点
|
||||||
- 在评估数据集上计算准确率、困惑度等指标
|
- 在评估数据集上计算准确率、困惑度等指标
|
||||||
- 生成详细的性能报告
|
- 生成详细的性能报告
|
||||||
|
|
||||||
### 模型扩容两阶段训练
|
|
||||||
|
|
||||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
|
||||||
|
|
||||||
#### 训练策略
|
|
||||||
|
|
||||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
|
||||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
|
||||||
|
|
||||||
#### 基础用法
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "model:InputMethodEngine" \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-lr 2e-3 \
|
|
||||||
--full-lr 5e-5 \
|
|
||||||
--frozen-patience 8
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 完整参数示例
|
|
||||||
|
|
||||||
```bash
|
|
||||||
train-model expand-and-train \
|
|
||||||
--train-data-path "path/to/train/dataset" \
|
|
||||||
--eval-data-path "path/to/eval/dataset" \
|
|
||||||
--output-dir "./expansion_output" \
|
|
||||||
--base-model-path "./pretrained/model.pt" \
|
|
||||||
--new-model-spec "custom_model:ExpandedModel" \
|
|
||||||
--vocab-size 10019 \
|
|
||||||
--dim 512 \
|
|
||||||
--num-experts 40 \
|
|
||||||
--frozen-patience 10 \
|
|
||||||
--frozen-lr 1e-3 \
|
|
||||||
--full-lr 1e-4 \
|
|
||||||
--frozen-scheduler cosine \
|
|
||||||
--full-scheduler cosine \
|
|
||||||
--batch-size 128 \
|
|
||||||
--num-epochs 20 \
|
|
||||||
--compile
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 参数详解
|
|
||||||
|
|
||||||
**模型扩容参数**
|
|
||||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
|
||||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
|
||||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
|
||||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
|
||||||
|
|
||||||
**两阶段训练参数**
|
|
||||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
|
||||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
|
||||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
|
||||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
|
||||||
|
|
||||||
**其他参数**
|
|
||||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
|
||||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
|
||||||
|
|
||||||
#### 使用场景
|
|
||||||
|
|
||||||
1. **增加专家数量**(20→40)
|
|
||||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
|
||||||
- 新增参数:新专家网络、gate层
|
|
||||||
|
|
||||||
2. **增加top_k值**(2→3)
|
|
||||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
|
||||||
- 新增参数:无
|
|
||||||
|
|
||||||
3. **修改专家内部结构**(如增加resblocks)
|
|
||||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
|
||||||
- 新增参数:新增的resblocks层
|
|
||||||
|
|
||||||
4. **增加Transformer层数**(4→5)
|
|
||||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
|
||||||
- 新增参数:新增的第5层
|
|
||||||
|
|
||||||
#### 自定义模型类示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
# my_model.py
|
|
||||||
from model.model import InputMethodEngine
|
|
||||||
|
|
||||||
class MyExpandedModel(InputMethodEngine):
|
|
||||||
def __init__(self, num_experts=40, **kwargs):
|
|
||||||
# 调用父类构造函数,覆盖num_experts参数
|
|
||||||
super().__init__(num_experts=num_experts, **kwargs)
|
|
||||||
# 可以在这里添加额外的层或修改现有层
|
|
||||||
|
|
||||||
# 使用命令
|
|
||||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 注意事项
|
|
||||||
|
|
||||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
|
||||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
|
||||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
|
||||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
|
||||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
|
||||||
|
|
||||||
### 导出模型(开发中)
|
### 导出模型(开发中)
|
||||||
|
|
||||||
当前导出功能尚在开发中:
|
当前导出功能尚在开发中:
|
||||||
|
|
|
||||||
137
inference.py
137
inference.py
|
|
@ -22,6 +22,8 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
@ -60,6 +62,98 @@ class InputMethodInference:
|
||||||
|
|
||||||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||||||
|
|
||||||
|
# 尝试启用readline以获得更好的行编辑功能
|
||||||
|
try:
|
||||||
|
import readline
|
||||||
|
|
||||||
|
# 设置readline使用UTF-8编码
|
||||||
|
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||||
|
print("📝 readline已启用,支持更好的行编辑功能")
|
||||||
|
except ImportError:
|
||||||
|
print("📝 readline不可用,使用标准输入")
|
||||||
|
|
||||||
|
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||||
|
"""
|
||||||
|
安全的输入函数,尝试正确处理UTF-8字符和退格键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示文本
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户输入的字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 显示提示和默认值
|
||||||
|
if default:
|
||||||
|
full_prompt = f"{prompt} [{default}]: "
|
||||||
|
else:
|
||||||
|
full_prompt = f"{prompt}: "
|
||||||
|
|
||||||
|
# 使用标准input
|
||||||
|
result = input(full_prompt)
|
||||||
|
|
||||||
|
# 如果用户直接回车且存在默认值,则返回默认值
|
||||||
|
if not result and default:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
# 用户按Ctrl+D或Ctrl+C
|
||||||
|
print()
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
# 其他错误
|
||||||
|
print(f"\n⚠️ 输入错误: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||||
|
"""
|
||||||
|
清理拼音输入字符串,处理退格键等特殊字符
|
||||||
|
|
||||||
|
拼音只允许: a-z, `, ', -
|
||||||
|
中文字符和其他字符会被忽略
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pinyin: 原始拼音输入字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的拼音字符串
|
||||||
|
"""
|
||||||
|
if not pinyin:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for c in pinyin:
|
||||||
|
# 检查是否为合法拼音字符 (a-z, `, ', -)
|
||||||
|
# 注意: 中文字符的isalpha()也返回True,所以需要额外检查
|
||||||
|
is_valid_pinyin_char = (
|
||||||
|
("a" <= c <= "z")
|
||||||
|
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
|
||||||
|
or c in ["`", "'", "-"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_valid_pinyin_char:
|
||||||
|
# 合法拼音字符,转换为小写
|
||||||
|
result.append(c.lower())
|
||||||
|
elif c == " ":
|
||||||
|
# 空格忽略
|
||||||
|
continue
|
||||||
|
elif c == "\b" or c == "\x7f" or c == "\x08":
|
||||||
|
# 退格键、删除键:删除前一个字符
|
||||||
|
if result:
|
||||||
|
result.pop()
|
||||||
|
elif c == "\x1b":
|
||||||
|
# ESC键:清空所有输入
|
||||||
|
result.clear()
|
||||||
|
else:
|
||||||
|
# 其他字符(包括中文字符)忽略
|
||||||
|
# 注意:这里不添加到result,所以退格键无法删除它们
|
||||||
|
# 但用户可能在拼音输入中误输入中文字符,应该忽略
|
||||||
|
pass
|
||||||
|
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""加载训练好的模型"""
|
"""加载训练好的模型"""
|
||||||
# 创建模型实例(不编译)
|
# 创建模型实例(不编译)
|
||||||
|
|
@ -184,7 +278,7 @@ class InputMethodInference:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. 构建tokenizer输入
|
# 1. 构建tokenizer输入
|
||||||
# 根据dataset.py,格式为: "part4|part1" 和 part3
|
# 根据test.py和dataset.py,格式为: "part4|part1" 和 part3
|
||||||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||||||
# part1: text_before
|
# part1: text_before
|
||||||
# part3: text_after
|
# part3: text_after
|
||||||
|
|
@ -192,13 +286,14 @@ class InputMethodInference:
|
||||||
# 处理上下文提示
|
# 处理上下文提示
|
||||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||||
|
|
||||||
# 构建输入文本
|
# 构建输入文本 - 与test.py保持一致
|
||||||
|
# test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数
|
||||||
if context_text:
|
if context_text:
|
||||||
input_text = f"{context_text}|{text_before}"
|
input_text = f"{context_text}|{text_before}"
|
||||||
else:
|
else:
|
||||||
input_text = text_before
|
input_text = text_before
|
||||||
|
|
||||||
# 2. Tokenize
|
# 2. Tokenize - 与test.py保持一致
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenizer(
|
||||||
input_text,
|
input_text,
|
||||||
text_after,
|
text_after,
|
||||||
|
|
@ -209,8 +304,11 @@ class InputMethodInference:
|
||||||
return_token_type_ids=True,
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 处理拼音输入
|
# 3. 处理拼音输入 - 与test.py保持一致
|
||||||
pinyin_ids = text_to_pinyin_ids(pinyin)
|
# 首先清理拼音字符串,处理退格键等特殊字符
|
||||||
|
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||||
|
|
||||||
|
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||||
if len(pinyin_ids) < 24:
|
if len(pinyin_ids) < 24:
|
||||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||||
else:
|
else:
|
||||||
|
|
@ -396,7 +494,16 @@ class InputMethodInference:
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("输入法模型推理 - 交互模式")
|
print("输入法模型推理 - 交互模式")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("说明:")
|
|
||||||
|
# 检查终端编码
|
||||||
|
encoding = sys.stdout.encoding or "unknown"
|
||||||
|
print(f"终端编码: {encoding}")
|
||||||
|
if encoding.lower() not in ["utf-8", "utf8"]:
|
||||||
|
print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题")
|
||||||
|
print(" 建议设置: export LANG=en_US.UTF-8")
|
||||||
|
print(" 或设置: export LC_ALL=en_US.UTF-8")
|
||||||
|
|
||||||
|
print("\n说明:")
|
||||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||||
print(" - 光标前文本: 光标前的连续文本")
|
print(" - 光标前文本: 光标前的连续文本")
|
||||||
print(" - 光标后文本: 光标后的连续文本")
|
print(" - 光标后文本: 光标后的连续文本")
|
||||||
|
|
@ -411,7 +518,7 @@ class InputMethodInference:
|
||||||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||||||
print("格式: 用逗号分隔多个词汇,可为空")
|
print("格式: 用逗号分隔多个词汇,可为空")
|
||||||
print("示例: 张三,李四,北京大学")
|
print("示例: 张三,李四,北京大学")
|
||||||
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
|
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
|
||||||
|
|
||||||
if context_input.lower() in ["quit", "exit", "q"]:
|
if context_input.lower() in ["quit", "exit", "q"]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
|
|
@ -434,7 +541,7 @@ class InputMethodInference:
|
||||||
print("第2步: 光标前文本")
|
print("第2步: 光标前文本")
|
||||||
print("说明: 光标前的连续文本内容")
|
print("说明: 光标前的连续文本内容")
|
||||||
print("示例: 今天天气很好")
|
print("示例: 今天天气很好")
|
||||||
text_before = input("请输入光标前文本: ").strip()
|
text_before = self._safe_input("请输入光标前文本")
|
||||||
|
|
||||||
if text_before.lower() in ["quit", "exit", "q"]:
|
if text_before.lower() in ["quit", "exit", "q"]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
|
|
@ -446,7 +553,7 @@ class InputMethodInference:
|
||||||
print("第3步: 光标后文本")
|
print("第3步: 光标后文本")
|
||||||
print("说明: 光标后的连续文本内容")
|
print("说明: 光标后的连续文本内容")
|
||||||
print("示例: 我们去公园玩")
|
print("示例: 我们去公园玩")
|
||||||
text_after = input("请输入光标后文本: ").strip()
|
text_after = self._safe_input("请输入光标后文本")
|
||||||
|
|
||||||
if text_after.lower() in ["quit", "exit", "q"]:
|
if text_after.lower() in ["quit", "exit", "q"]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
|
|
@ -458,7 +565,7 @@ class InputMethodInference:
|
||||||
print("第4步: 拼音输入")
|
print("第4步: 拼音输入")
|
||||||
print("说明: 当前正在输入的拼音")
|
print("说明: 当前正在输入的拼音")
|
||||||
print("示例: tian, shang, hao")
|
print("示例: tian, shang, hao")
|
||||||
pinyin = input("请输入拼音: ").strip()
|
pinyin = self._safe_input("请输入拼音")
|
||||||
|
|
||||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
|
|
@ -471,7 +578,7 @@ class InputMethodInference:
|
||||||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||||||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||||||
print(" 今天,天气 (表示已确认两个词)")
|
print(" 今天,天气 (表示已确认两个词)")
|
||||||
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
|
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
|
||||||
|
|
||||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
|
|
@ -524,7 +631,9 @@ class InputMethodInference:
|
||||||
|
|
||||||
# 询问是否继续
|
# 询问是否继续
|
||||||
print("\n" + "-" * 40)
|
print("\n" + "-" * 40)
|
||||||
continue_input = input("是否继续推理?(y/n): ").strip().lower()
|
continue_input = (
|
||||||
|
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||||
|
)
|
||||||
if continue_input not in ["y", "yes", ""]:
|
if continue_input not in ["y", "yes", ""]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
break
|
break
|
||||||
|
|
@ -539,7 +648,9 @@ class InputMethodInference:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# 询问是否继续
|
# 询问是否继续
|
||||||
continue_input = input("\n是否继续?(y/n): ").strip().lower()
|
continue_input = (
|
||||||
|
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
|
||||||
|
)
|
||||||
if continue_input not in ["y", "yes", ""]:
|
if continue_input not in ["y", "yes", ""]:
|
||||||
print("退出交互模式")
|
print("退出交互模式")
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
|
||||||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||||
|
|
||||||
# 4. 混合专家层 (MoE)
|
# 4. 混合专家层 (MoE)
|
||||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
|
||||||
|
|
||||||
# 5. 槽位注意力池化
|
# 5. 槽位注意力池化
|
||||||
self.slot_attention = nn.Linear(dim, 1)
|
self.slot_attention = nn.Linear(dim, 1)
|
||||||
|
|
|
||||||
22
test.py
22
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
|
part1 = "从中国回家后,他觉得世界上最好的城市就是"
|
||||||
part2 = "piaoliang"
|
part2 = "shanghai"
|
||||||
pinyin_ids = text_to_pinyin_ids(part2)
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
len_py = len(pinyin_ids)
|
len_py = len(pinyin_ids)
|
||||||
if len_py < 24:
|
if len_py < 24:
|
||||||
|
|
@ -56,9 +56,9 @@ if len_py < 24:
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
|
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = ""
|
part3 = ""
|
||||||
part4 = "可行|特别|伤害"
|
part4 = ""
|
||||||
|
|
||||||
encoded = tokenizer(
|
encoded = tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
|
|
@ -83,8 +83,9 @@ sample = {
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
token_type_ids = sample["token_type_ids"]
|
token_type_ids = sample["token_type_ids"]
|
||||||
|
|
@ -97,12 +98,13 @@ for k, v in sample.items():
|
||||||
print(f"{k}: {v}")
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
with torch.no_grad():
|
||||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||||
sort_res = sorted(
|
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||||
|
sort_res = sorted(
|
||||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||||
)
|
)
|
||||||
print(sort_res[0:5])
|
print(sort_res[0:5])
|
||||||
|
|
||||||
query_engine = QueryEngine()
|
query_engine = QueryEngine()
|
||||||
query_engine.load()
|
query_engine.load()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue