diff --git a/README.md b/README.md index 41010e0..4bcaee6 100644 --- a/README.md +++ b/README.md @@ -790,112 +790,7 @@ train-model evaluate \ - 在评估数据集上计算准确率、困惑度等指标 - 生成详细的性能报告 -### 6.8 模型扩容两阶段训练 -当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。 - -#### 训练策略 - -1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等) -2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练 - -#### 基础用法 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "model:InputMethodEngine" \ - --num-experts 40 \ - --frozen-lr 2e-3 \ - --full-lr 5e-5 \ - --frozen-patience 8 -``` - -#### 完整参数示例 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --output-dir "./expansion_output" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "custom_model:ExpandedModel" \ - --vocab-size 10019 \ - --dim 512 \ - --num-experts 40 \ - --frozen-patience 10 \ - --frozen-lr 1e-3 \ - --full-lr 1e-4 \ - --frozen-scheduler cosine \ - --full-scheduler cosine \ - --batch-size 128 \ - --num-epochs 20 \ - --compile -``` - -#### 参数详解 - -**模型扩容参数** -- `--base-model-path`: 预训练基础模型检查点路径(必需) -- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需) - - 支持任意路径的模块导入,模块文件需包含自定义的模型类 - - 自定义模型类必须是 `InputMethodEngine` 的子类 - - 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类 - -**两阶段训练参数** -- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10) -- `--frozen-lr`: 冻结阶段学习率(默认:1e-3) -- `--full-lr`: 全量微调阶段学习率(默认:1e-4) -- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) -- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) - -**其他参数** -- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等) -- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等 - -#### 使用场景 - -1. **增加专家数量**(20→40) - - 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等) - - 新增参数:新专家网络、gate层 - -2. **增加top_k值**(2→3) - - 冻结效果:100% 参数可冻结(仅逻辑变化) - - 新增参数:无 - -3. **修改专家内部结构**(如增加resblocks) - - 冻结效果:~50% 参数可冻结(linear_in/output可冻结) - - 新增参数:新增的resblocks层 - -4. **增加Transformer层数**(4→5) - - 冻结效果:~80% 参数可冻结(前4层可冻结) - - 新增参数:新增的第5层 - -#### 自定义模型类示例 - -```python -# my_model.py -from model.model import InputMethodEngine - -class MyExpandedModel(InputMethodEngine): - def __init__(self, num_experts=40, **kwargs): - # 调用父类构造函数,覆盖num_experts参数 - super().__init__(num_experts=num_experts, **kwargs) - # 可以在这里添加额外的层或修改现有层 - -# 使用命令 -# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ... -``` - -#### 注意事项 - -1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类 -2. **冻结条件**:只有权重形状完全匹配的层才会被冻结 -3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能 -4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency` -5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载 ### 6.9 导出模型(开发中) diff --git a/docs/TRAINING.md b/docs/TRAINING.md index 0d48dd4..e2ecb73 100644 --- a/docs/TRAINING.md +++ b/docs/TRAINING.md @@ -479,118 +479,11 @@ train-model evaluate \ --batch-size 32 ``` -命令将显示"评估功能待实现"的提示信息。该功能计划用于: + 命令将显示"评估功能待实现"的提示信息。该功能计划用于: - 加载训练好的模型检查点 - 在评估数据集上计算准确率、困惑度等指标 - 生成详细的性能报告 -### 模型扩容两阶段训练 - -当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。 - -#### 训练策略 - -1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等) -2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练 - -#### 基础用法 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "model:InputMethodEngine" \ - --num-experts 40 \ - --frozen-lr 2e-3 \ - --full-lr 5e-5 \ - --frozen-patience 8 -``` - -#### 完整参数示例 - -```bash -train-model expand-and-train \ - --train-data-path "path/to/train/dataset" \ - --eval-data-path "path/to/eval/dataset" \ - --output-dir "./expansion_output" \ - --base-model-path "./pretrained/model.pt" \ - --new-model-spec "custom_model:ExpandedModel" \ - --vocab-size 10019 \ - --dim 512 \ - --num-experts 40 \ - --frozen-patience 10 \ - --frozen-lr 1e-3 \ - --full-lr 1e-4 \ - --frozen-scheduler cosine \ - --full-scheduler cosine \ - --batch-size 128 \ - --num-epochs 20 \ - --compile -``` - -#### 参数详解 - -**模型扩容参数** -- `--base-model-path`: 预训练基础模型检查点路径(必需) -- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需) - - 支持任意路径的模块导入,模块文件需包含自定义的模型类 - - 自定义模型类必须是 `InputMethodEngine` 的子类 - - 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类 - -**两阶段训练参数** -- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10) -- `--frozen-lr`: 冻结阶段学习率(默认:1e-3) -- `--full-lr`: 全量微调阶段学习率(默认:1e-4) -- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) -- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`) - -**其他参数** -- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等) -- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等 - -#### 使用场景 - -1. **增加专家数量**(20→40) - - 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等) - - 新增参数:新专家网络、gate层 - -2. **增加top_k值**(2→3) - - 冻结效果:100% 参数可冻结(仅逻辑变化) - - 新增参数:无 - -3. **修改专家内部结构**(如增加resblocks) - - 冻结效果:~50% 参数可冻结(linear_in/output可冻结) - - 新增参数:新增的resblocks层 - -4. **增加Transformer层数**(4→5) - - 冻结效果:~80% 参数可冻结(前4层可冻结) - - 新增参数:新增的第5层 - -#### 自定义模型类示例 - -```python -# my_model.py -from model.model import InputMethodEngine - -class MyExpandedModel(InputMethodEngine): - def __init__(self, num_experts=40, **kwargs): - # 调用父类构造函数,覆盖num_experts参数 - super().__init__(num_experts=num_experts, **kwargs) - # 可以在这里添加额外的层或修改现有层 - -# 使用命令 -# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ... -``` - -#### 注意事项 - -1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类 -2. **冻结条件**:只有权重形状完全匹配的层才会被冻结 -3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能 -4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency` -5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载 - ### 导出模型(开发中) 当前导出功能尚在开发中: diff --git a/inference.py b/inference.py index f97ebb2..e9559c1 100644 --- a/inference.py +++ b/inference.py @@ -22,6 +22,8 @@ """ import argparse +import os +import sys import time from pathlib import Path from typing import List, Optional, Tuple @@ -60,6 +62,98 @@ class InputMethodInference: print(f"✅ 推理器初始化完成 (设备: {self.device})") + # 尝试启用readline以获得更好的行编辑功能 + try: + import readline + + # 设置readline使用UTF-8编码 + readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?") + print("📝 readline已启用,支持更好的行编辑功能") + except ImportError: + print("📝 readline不可用,使用标准输入") + + def _safe_input(self, prompt: str, default: str = "") -> str: + """ + 安全的输入函数,尝试正确处理UTF-8字符和退格键 + + Args: + prompt: 提示文本 + default: 默认值 + + Returns: + 用户输入的字符串 + """ + try: + # 显示提示和默认值 + if default: + full_prompt = f"{prompt} [{default}]: " + else: + full_prompt = f"{prompt}: " + + # 使用标准input + result = input(full_prompt) + + # 如果用户直接回车且存在默认值,则返回默认值 + if not result and default: + return default + + return result.strip() + except (EOFError, KeyboardInterrupt): + # 用户按Ctrl+D或Ctrl+C + print() + return "" + except Exception as e: + # 其他错误 + print(f"\n⚠️ 输入错误: {e}") + return "" + + def _clean_pinyin_input(self, pinyin: str) -> str: + """ + 清理拼音输入字符串,处理退格键等特殊字符 + + 拼音只允许: a-z, `, ', - + 中文字符和其他字符会被忽略 + + Args: + pinyin: 原始拼音输入字符串 + + Returns: + 清理后的拼音字符串 + """ + if not pinyin: + return "" + + result = [] + for c in pinyin: + # 检查是否为合法拼音字符 (a-z, `, ', -) + # 注意: 中文字符的isalpha()也返回True,所以需要额外检查 + is_valid_pinyin_char = ( + ("a" <= c <= "z") + or ("A" <= c <= "Z") # 允许大写字母,转换为小写 + or c in ["`", "'", "-"] + ) + + if is_valid_pinyin_char: + # 合法拼音字符,转换为小写 + result.append(c.lower()) + elif c == " ": + # 空格忽略 + continue + elif c == "\b" or c == "\x7f" or c == "\x08": + # 退格键、删除键:删除前一个字符 + if result: + result.pop() + elif c == "\x1b": + # ESC键:清空所有输入 + result.clear() + else: + # 其他字符(包括中文字符)忽略 + # 注意:这里不添加到result,所以退格键无法删除它们 + # 但用户可能在拼音输入中误输入中文字符,应该忽略 + pass + + return "".join(result) + def load_model(self): """加载训练好的模型""" # 创建模型实例(不编译) @@ -184,7 +278,7 @@ class InputMethodInference: """ # 1. 构建tokenizer输入 - # 根据dataset.py,格式为: "part4|part1" 和 part3 + # 根据test.py和dataset.py,格式为: "part4|part1" 和 part3 # part4: 上下文提示(专有词汇、姓名等,模型不掌握) # part1: text_before # part3: text_after @@ -192,13 +286,14 @@ class InputMethodInference: # 处理上下文提示 context_text = "|".join(context_prompts) if context_prompts else "" - # 构建输入文本 + # 构建输入文本 - 与test.py保持一致 + # test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数 if context_text: input_text = f"{context_text}|{text_before}" else: input_text = text_before - # 2. Tokenize + # 2. Tokenize - 与test.py保持一致 encoded = self.tokenizer( input_text, text_after, @@ -209,8 +304,11 @@ class InputMethodInference: return_token_type_ids=True, ) - # 3. 处理拼音输入 - pinyin_ids = text_to_pinyin_ids(pinyin) + # 3. 处理拼音输入 - 与test.py保持一致 + # 首先清理拼音字符串,处理退格键等特殊字符 + cleaned_pinyin = self._clean_pinyin_input(pinyin) + + pinyin_ids = text_to_pinyin_ids(cleaned_pinyin) if len(pinyin_ids) < 24: pinyin_ids.extend([0] * (24 - len(pinyin_ids))) else: @@ -396,7 +494,16 @@ class InputMethodInference: print("\n" + "=" * 60) print("输入法模型推理 - 交互模式") print("=" * 60) - print("说明:") + + # 检查终端编码 + encoding = sys.stdout.encoding or "unknown" + print(f"终端编码: {encoding}") + if encoding.lower() not in ["utf-8", "utf8"]: + print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题") + print(" 建议设置: export LANG=en_US.UTF-8") + print(" 或设置: export LC_ALL=en_US.UTF-8") + + print("\n说明:") print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)") print(" - 光标前文本: 光标前的连续文本") print(" - 光标后文本: 光标后的连续文本") @@ -411,7 +518,7 @@ class InputMethodInference: print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)") print("格式: 用逗号分隔多个词汇,可为空") print("示例: 张三,李四,北京大学") - context_input = input("请输入上下文提示(直接回车跳过): ").strip() + context_input = self._safe_input("请输入上下文提示(直接回车跳过)") if context_input.lower() in ["quit", "exit", "q"]: print("退出交互模式") @@ -434,7 +541,7 @@ class InputMethodInference: print("第2步: 光标前文本") print("说明: 光标前的连续文本内容") print("示例: 今天天气很好") - text_before = input("请输入光标前文本: ").strip() + text_before = self._safe_input("请输入光标前文本") if text_before.lower() in ["quit", "exit", "q"]: print("退出交互模式") @@ -446,7 +553,7 @@ class InputMethodInference: print("第3步: 光标后文本") print("说明: 光标后的连续文本内容") print("示例: 我们去公园玩") - text_after = input("请输入光标后文本: ").strip() + text_after = self._safe_input("请输入光标后文本") if text_after.lower() in ["quit", "exit", "q"]: print("退出交互模式") @@ -458,7 +565,7 @@ class InputMethodInference: print("第4步: 拼音输入") print("说明: 当前正在输入的拼音") print("示例: tian, shang, hao") - pinyin = input("请输入拼音: ").strip() + pinyin = self._safe_input("请输入拼音") if pinyin.lower() in ["quit", "exit", "q"]: print("退出交互模式") @@ -471,7 +578,7 @@ class InputMethodInference: print("说明: 用户已确认的输入历史,用逗号分隔") print("示例: 上 (表示输入'shanghai'已确认'上')") print(" 今天,天气 (表示已确认两个词)") - slot_input = input("请输入槽位历史(直接回车表示无): ").strip() + slot_input = self._safe_input("请输入槽位历史(直接回车表示无)") if slot_input.lower() in ["quit", "exit", "q"]: print("退出交互模式") @@ -524,7 +631,9 @@ class InputMethodInference: # 询问是否继续 print("\n" + "-" * 40) - continue_input = input("是否继续推理?(y/n): ").strip().lower() + continue_input = ( + self._safe_input("是否继续推理?(y/n)", "y").strip().lower() + ) if continue_input not in ["y", "yes", ""]: print("退出交互模式") break @@ -539,7 +648,9 @@ class InputMethodInference: traceback.print_exc() # 询问是否继续 - continue_input = input("\n是否继续?(y/n): ").strip().lower() + continue_input = ( + self._safe_input("\n是否继续?(y/n)", "y").strip().lower() + ) if continue_input not in ["y", "yes", ""]: print("退出交互模式") break diff --git a/src/model/model.py b/src/model/model.py index d20f371..6591170 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -72,7 +72,7 @@ class InputMethodEngine(nn.Module): self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) # 4. 混合专家层 (MoE) - self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8) + self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12) # 5. 槽位注意力池化 self.slot_attention = nn.Linear(dim, 1) diff --git a/test.py b/test.py index e0d7880..8ff8453 100644 --- a/test.py +++ b/test.py @@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]: return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] -part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常" -part2 = "piaoliang" +part1 = "从中国回家后,他觉得世界上最好的城市就是" +part2 = "shanghai" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: @@ -56,9 +56,9 @@ if len_py < 24: else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) -masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0] +masked_labels = [22, 0, 0, 0, 0, 0, 0, 0] part3 = "" -part4 = "可行|特别|伤害" +part4 = "" encoded = tokenizer( f"{part4}|{part1}", @@ -83,8 +83,9 @@ sample = { model = InputMethodEngine(pinyin_vocab_size=30, compile=False) -checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu") +checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) +model.eval() input_ids = sample["input_ids"] token_type_ids = sample["token_type_ids"] @@ -97,12 +98,13 @@ for k, v in sample.items(): print(f"{k}: {v}") start = time.time() -res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) -print(f'计算时长: {(time.time() - start) * 1000:4f}ms') -sort_res = sorted( - [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True -) -print(sort_res[0:5]) +with torch.no_grad(): + res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) + print(f'计算时长: {(time.time() - start) * 1000:4f}ms') + sort_res = sorted( + [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True + ) + print(sort_res[0:5]) query_engine = QueryEngine() query_engine.load()