docs: 移除模型扩容两阶段训练文档并更新相关用法说明
This commit is contained in:
parent
33f56f709b
commit
3175ace9c5
105
README.md
105
README.md
|
|
@ -790,112 +790,7 @@ train-model evaluate \
|
|||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 6.8 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 6.9 导出模型(开发中)
|
||||
|
||||
|
|
|
|||
109
docs/TRAINING.md
109
docs/TRAINING.md
|
|
@ -479,118 +479,11 @@ train-model evaluate \
|
|||
--batch-size 32
|
||||
```
|
||||
|
||||
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||||
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
|
||||
- 加载训练好的模型检查点
|
||||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 导出模型(开发中)
|
||||
|
||||
当前导出功能尚在开发中:
|
||||
|
|
|
|||
137
inference.py
137
inference.py
|
|
@ -22,6 +22,8 @@
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
|
@ -60,6 +62,98 @@ class InputMethodInference:
|
|||
|
||||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||||
|
||||
# 尝试启用readline以获得更好的行编辑功能
|
||||
try:
|
||||
import readline
|
||||
|
||||
# 设置readline使用UTF-8编码
|
||||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||
print("📝 readline已启用,支持更好的行编辑功能")
|
||||
except ImportError:
|
||||
print("📝 readline不可用,使用标准输入")
|
||||
|
||||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||
"""
|
||||
安全的输入函数,尝试正确处理UTF-8字符和退格键
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
用户输入的字符串
|
||||
"""
|
||||
try:
|
||||
# 显示提示和默认值
|
||||
if default:
|
||||
full_prompt = f"{prompt} [{default}]: "
|
||||
else:
|
||||
full_prompt = f"{prompt}: "
|
||||
|
||||
# 使用标准input
|
||||
result = input(full_prompt)
|
||||
|
||||
# 如果用户直接回车且存在默认值,则返回默认值
|
||||
if not result and default:
|
||||
return default
|
||||
|
||||
return result.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
# 用户按Ctrl+D或Ctrl+C
|
||||
print()
|
||||
return ""
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
print(f"\n⚠️ 输入错误: {e}")
|
||||
return ""
|
||||
|
||||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||
"""
|
||||
清理拼音输入字符串,处理退格键等特殊字符
|
||||
|
||||
拼音只允许: a-z, `, ', -
|
||||
中文字符和其他字符会被忽略
|
||||
|
||||
Args:
|
||||
pinyin: 原始拼音输入字符串
|
||||
|
||||
Returns:
|
||||
清理后的拼音字符串
|
||||
"""
|
||||
if not pinyin:
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for c in pinyin:
|
||||
# 检查是否为合法拼音字符 (a-z, `, ', -)
|
||||
# 注意: 中文字符的isalpha()也返回True,所以需要额外检查
|
||||
is_valid_pinyin_char = (
|
||||
("a" <= c <= "z")
|
||||
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
|
||||
or c in ["`", "'", "-"]
|
||||
)
|
||||
|
||||
if is_valid_pinyin_char:
|
||||
# 合法拼音字符,转换为小写
|
||||
result.append(c.lower())
|
||||
elif c == " ":
|
||||
# 空格忽略
|
||||
continue
|
||||
elif c == "\b" or c == "\x7f" or c == "\x08":
|
||||
# 退格键、删除键:删除前一个字符
|
||||
if result:
|
||||
result.pop()
|
||||
elif c == "\x1b":
|
||||
# ESC键:清空所有输入
|
||||
result.clear()
|
||||
else:
|
||||
# 其他字符(包括中文字符)忽略
|
||||
# 注意:这里不添加到result,所以退格键无法删除它们
|
||||
# 但用户可能在拼音输入中误输入中文字符,应该忽略
|
||||
pass
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def load_model(self):
|
||||
"""加载训练好的模型"""
|
||||
# 创建模型实例(不编译)
|
||||
|
|
@ -184,7 +278,7 @@ class InputMethodInference:
|
|||
"""
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
# 根据dataset.py,格式为: "part4|part1" 和 part3
|
||||
# 根据test.py和dataset.py,格式为: "part4|part1" 和 part3
|
||||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||||
# part1: text_before
|
||||
# part3: text_after
|
||||
|
|
@ -192,13 +286,14 @@ class InputMethodInference:
|
|||
# 处理上下文提示
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
# 构建输入文本
|
||||
# 构建输入文本 - 与test.py保持一致
|
||||
# test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize
|
||||
# 2. Tokenize - 与test.py保持一致
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
|
|
@ -209,8 +304,11 @@ class InputMethodInference:
|
|||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 3. 处理拼音输入
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin)
|
||||
# 3. 处理拼音输入 - 与test.py保持一致
|
||||
# 首先清理拼音字符串,处理退格键等特殊字符
|
||||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||
|
||||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
|
|
@ -396,7 +494,16 @@ class InputMethodInference:
|
|||
print("\n" + "=" * 60)
|
||||
print("输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
print("说明:")
|
||||
|
||||
# 检查终端编码
|
||||
encoding = sys.stdout.encoding or "unknown"
|
||||
print(f"终端编码: {encoding}")
|
||||
if encoding.lower() not in ["utf-8", "utf8"]:
|
||||
print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题")
|
||||
print(" 建议设置: export LANG=en_US.UTF-8")
|
||||
print(" 或设置: export LC_ALL=en_US.UTF-8")
|
||||
|
||||
print("\n说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
|
|
@ -411,7 +518,7 @@ class InputMethodInference:
|
|||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||||
print("格式: 用逗号分隔多个词汇,可为空")
|
||||
print("示例: 张三,李四,北京大学")
|
||||
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
|
||||
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
|
||||
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -434,7 +541,7 @@ class InputMethodInference:
|
|||
print("第2步: 光标前文本")
|
||||
print("说明: 光标前的连续文本内容")
|
||||
print("示例: 今天天气很好")
|
||||
text_before = input("请输入光标前文本: ").strip()
|
||||
text_before = self._safe_input("请输入光标前文本")
|
||||
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -446,7 +553,7 @@ class InputMethodInference:
|
|||
print("第3步: 光标后文本")
|
||||
print("说明: 光标后的连续文本内容")
|
||||
print("示例: 我们去公园玩")
|
||||
text_after = input("请输入光标后文本: ").strip()
|
||||
text_after = self._safe_input("请输入光标后文本")
|
||||
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -458,7 +565,7 @@ class InputMethodInference:
|
|||
print("第4步: 拼音输入")
|
||||
print("说明: 当前正在输入的拼音")
|
||||
print("示例: tian, shang, hao")
|
||||
pinyin = input("请输入拼音: ").strip()
|
||||
pinyin = self._safe_input("请输入拼音")
|
||||
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -471,7 +578,7 @@ class InputMethodInference:
|
|||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||||
print(" 今天,天气 (表示已确认两个词)")
|
||||
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
|
||||
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
|
||||
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -524,7 +631,9 @@ class InputMethodInference:
|
|||
|
||||
# 询问是否继续
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = input("是否继续推理?(y/n): ").strip().lower()
|
||||
continue_input = (
|
||||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
@ -539,7 +648,9 @@ class InputMethodInference:
|
|||
traceback.print_exc()
|
||||
|
||||
# 询问是否继续
|
||||
continue_input = input("\n是否继续?(y/n): ").strip().lower()
|
||||
continue_input = (
|
||||
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module):
|
|||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
|
||||
|
||||
# 5. 槽位注意力池化
|
||||
self.slot_attention = nn.Linear(dim, 1)
|
||||
|
|
|
|||
24
test.py
24
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
|||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||
|
||||
|
||||
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
|
||||
part2 = "piaoliang"
|
||||
part1 = "从中国回家后,他觉得世界上最好的城市就是"
|
||||
part2 = "shanghai"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
|
|
@ -56,9 +56,9 @@ if len_py < 24:
|
|||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
|
||||
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
|
||||
part3 = ""
|
||||
part4 = "可行|特别|伤害"
|
||||
part4 = ""
|
||||
|
||||
encoded = tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -83,8 +83,9 @@ sample = {
|
|||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
|
||||
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
|
|
@ -97,12 +98,13 @@ for k, v in sample.items():
|
|||
print(f"{k}: {v}")
|
||||
|
||||
start = time.time()
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
print(sort_res[0:5])
|
||||
with torch.no_grad():
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
print(sort_res[0:5])
|
||||
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
|
|
|||
Loading…
Reference in New Issue