Compare commits

..

15 Commits

Author SHA1 Message Date
songsenand 88955bcfdd refactor(model): 优化slot权重计算逻辑以提升稳定性 2026-05-23 13:42:44 +08:00
songsenand 53f244de2f docs: 详细描述history_slot_ids的设计策略与使用场景 2026-05-23 13:25:32 +08:00
songsenand 71ef54e3d4 fix(trainer): 使用固定最大序列长度的collate函数以避免内存问题 2026-05-15 14:47:31 +08:00
songsenand 722912f296 feat(data-preprocess): 预处理数据预打乱以提升训练效率 2026-05-15 13:49:41 +08:00
songsenand 0862b5b8fc fix(PreProcessedDataset): 修复数据类型转换,避免内存复制 2026-05-11 22:05:05 +08:00
songsenand 27beb7f0b1 refactor(trainer): 优化进度条逻辑与训练循环结构 2026-05-11 00:14:17 +08:00
songsenand d0f1534086 fix(dataset): 修复分片数据集时未正确计算样本数的问题 2026-05-10 23:34:08 +08:00
songsenand 483e4d4f98 fix(model): 移除 torch.compile 的注释和未使用配置 2026-05-10 10:38:14 +08:00
songsenand 432132a108 feat(MoELayer): 添加 moe_mode 支持稀疏和图内计算策略 2026-05-10 10:26:44 +08:00
songsenand e8eab1f260 refactor(generate_pinyin): 优化拼音生成逻辑,利用 pypinyin 分词能力处理多音字 2026-05-09 13:36:48 +08:00
songsenand 8b41bcdc6f feat(dataset): 引入幂律平滑方案优化频率调整逻辑 2026-04-30 08:10:34 +08:00
songsenand 4ded2d656f feat(analyze_frequency): 添加拼音字符频率分析脚本 2026-04-22 22:05:15 +08:00
songsenand 1b7da9ddd4 feat: 添加束搜索演示脚本及ONNX模型文件忽略规则 2026-04-20 11:49:25 +08:00
songsenand 710cfe7fc2 fix(dataset, trainer): 调整数据集和训练参数以提高模型效果 2026-04-16 22:35:59 +08:00
songsenand 3175ace9c5 docs: 移除模型扩容两阶段训练文档并更新相关用法说明 2026-04-13 14:09:14 +08:00
50 changed files with 50140 additions and 687 deletions

6
.gitignore vendored
View File

@ -177,3 +177,9 @@ cython_debug/
uv.lock
data/*
**/*.onnx
**/*.data
**/*.npz
**/*.pt

View File

@ -0,0 +1,229 @@
# 预处理数据预打乱方案
## 目标
用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。
## 改动清单
### 1. `src/model/subsample.py` — 默认开启输出打乱
**函数签名变更:**
```python
def pass2_subsample(
...,
shuffle: bool = True, # 新增
seed: int = 42, # 新增
) -> Tuple[int, int]:
```
**改动点:**
- `rng = np.random.RandomState()``rng = np.random.RandomState(seed)`
- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱
- 在两次 `np.savez_compressed` 写入前line ~251 和 line ~273插入
```python
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
```
- main() 中新增参数:
```python
parser.add_argument("--no-shuffle", action="store_false", dest="shuffle",
help="禁用输出分片内部打乱")
parser.add_argument("--seed", type=int, default=42,
help="随机种子(用于选择+打乱)")
```
- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed`
**metadata.json 新增字段:**
```json
"pre_shuffled": true,
"seed": 42
```
---
### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据
处理流程:
```
Phase 1: npz → 逐字段打乱 → .npymmap 友好)
输入: 19个不平衡 .npz 分片100M样本~10GB 压缩)
输出: 114个临时 .npy 文件6字段 × 19分片~144GB 未压缩)
内存峰值: ~10GB单字段 int16 最大 ~5GB + permuted copy ~5GB
耗时: ~15-20分钟解压+写入)
Phase 2: .npy → 平衡分配 → .npz
输入: Phase 1 的 .npy 文件mmap 模式)
输出: 100个平衡 .npz 分片每片100万样本~100MB/片压缩后)
每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save
内存峰值: ~3GB1个输出缓冲 + mmap pages
耗时: ~10-15分钟mmap读取+压缩写入)
总计: ~30-40分钟峰值内存 ~10GB
```
**磁盘需求:**
- 临时 .npy 文件Phase 1→2 中间产物):~144GB
- 最终输出 .npz~10GB
- 临时文件在 Phase 2 完成后自动删除
**用法:**
```bash
python -m src.model.shuffle_npz \
--input-dir /home/songsenand/DataSet/SubPro \
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
--shard-size 1000000 \
--seed 42
```
**关键实现:**
Phase 1 — 逐字段加载+打乱:
```python
for src_idx in range(num_shards):
data = np.load(shard_path) # lazy NpzFile
n = shard_sizes[src_idx]
perm = rng.permutation(n)
for field in FIELDS:
arr = data[field].copy() # ~5GB peak (input_ids)
shuffled = arr[perm] # ~5GB temp
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
del arr, shuffled
gc.collect()
data.close()
```
Phase 2 — 平衡分配:
```python
# 打开所有源 mmap读模式零内存
src_mmaps = [
{f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS}
for i in range(num_shards)
]
for out_j in range(num_output_shards):
buffers = {f: [] for f in FIELDS}
for src_i in range(num_shards):
s = shard_sizes[src_i]
start = (out_j * s) // num_output_shards
end = ((out_j + 1) * s) // num_output_shards
if start >= end:
continue
for f in FIELDS:
chunk = src_mmaps[src_i][f][start:end].copy()
buffers[f].append(chunk)
output = {f: np.concatenate(buffers[f]) for f in FIELDS}
# 额外打乱(跨源混合)
perm = rng.permutation(len(output[FIELDS[0]]))
for f in FIELDS:
output[f] = output[f][perm]
np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output)
del output, buffers, perm
gc.collect()
```
**输出 metadata.json**
```json
{
"num_samples": 99998406,
"max_seq_length": 128,
"dtype": "int16",
"fields": [...],
"shard_size": 1000000,
"num_shards": 100,
"shard_sizes": [1000000, ..., 998406],
"pre_shuffled": true,
"seed": 42
}
```
**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`eval 数据量小,不需要打乱)
---
### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle
**改动点train 函数line ~1258-1272**
```python
if is_train_preprocessed:
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
# pre_shuffled 数据不需要 DataLoader 的 RandomSampler
shuffle_train = not train_dataset.metadata.get("pre_shuffled", False)
total_steps = (len(train_dataset) // batch_size) * num_epochs
# 支持 max_iter_length 限制总步数
if max_iter_length > 0:
max_steps_per_epoch = max_iter_length // batch_size
total_steps = min(total_steps, max_steps_per_epoch * num_epochs)
train_num_workers = min(num_workers, 1)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train, # 预打乱数据不 shuffle
)
```
**eval DataLoaderline ~1295-1303**
```python
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=0, # eval 数据小,单进程足够
pin_memory=torch.cuda.is_available(),
shuffle=False, # eval 不需要打乱
)
```
**`create_dataloader` 函数line ~1076-1114** 无需改动,`shuffle` 参数已透传。
---
### 4. `src/model/preprocessed_dataset.py`
现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。
`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。
---
## 执行顺序
```bash
# Step 1: 打乱并平衡已有的 100M 数据集
python -m src.model.shuffle_npz \
--input-dir /home/songsenand/DataSet/SubPro \
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
--shard-size 1000000 \
--seed 42
# Step 2: 用新数据训练(数据已打乱,顺序读取即可)
uv run train-model train \
--train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \
--eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \
-b 16 \
-o ~/tmp \
--eval-frequency 20 \
--save-frequency 40
```
## 预期效果
| 改动前 | 改动后 |
|---|---|
| 每 batch 跨 13 个 shard | 顺序读1 个 shard 在缓存中 |
| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory |
| CPU 利用率 10% | 正常(训练计算是瓶颈) |
| 内存 40GB+ | <20GB shard 1M 样本 1.4GB |

135
README.md
View File

@ -72,6 +72,36 @@
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
* **标签处理**在每一个槽位步Step模型需要预测该步对应的真实文字 ID [1]。
#### 4.1.1 历史槽位history_slot_ids的设计策略
`history_slot_ids` 的语义是**当前拼音组内已经确认的字符**,用于模拟真实的输入法输入场景。
**设计原则:**
1. **拼音组隔离**每个拼音组jieba 分词的一个词或合并后的短词组)对应一次独立的拼音输入会话。拼音组开始时 `history_slot_ids` 为空(全 0组内逐字累积。
```
输入: 特种兵 → 拼音组 "tezhongbing"
Step 1: pinyin="tezhongbing", history=[] → 预测 特
Step 2: pinyin="tezhongbing", history=[特] → 预测 种
Step 3: pinyin="tezhongbing", history=[特, 种] → 预测 兵
输入: 十分倾佩 → 拼音组 "shifen" + 拼音组 "qing" + 拼音组 "pei"
拼音组 "qing": pinyin="qing", history=[] → 预测 倾
拼音组 "pei": pinyin="pei", history=[] → 预测 佩
```
2. **上下文承载前缀字**:当拼音组切换时,前一个拼音组已确认的字符(如"特")已写入文本,通过 BERT 编码的 `input_ids`(而非 `history_slot_ids`)传递给模型。这模拟了真实输入法环境——用户打完"te"确认"特"后,文本中已有"特",再输入"zhongbing"时,模型从上下文(`input_ids`)中能看到"特"。
3. **破词续接**:当一个词被按一定概率(~10%)拆分时(`should_break=True`Phase 1 处理前缀(如"特"Phase 2 处理续接(如"种兵"。Phase 2 的 `cont_processed_history = []` 是**正确的设计**——前缀已确认为文本的一部分续接是一个新的拼音会话history 从空开始。
```
"特种兵" 断在"特"后:
Phase 1: pinyin="te", history=[] → 预测 特 → 确认为文本上下文
Phase 2: pinyin="zhongbing", history=[] → 预测 种
history=[种] → 预测 兵
```
Phase 2 中"特"已通过 `part1`(光标前文本)进入 BERT 上下文中,不在 `history_slot_ids` 中。
4. **短词合并**相邻单字词≤2 字符)有 50% 概率合并为一个拼音组。合并后组内字符共享 history 累积未合并时各自为独立拼音组history 为空)。两种方式都对应真实输入法场景——用户可能一次性输入多个字,也可能逐字输入。
### 4.2 损失函数与优化
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
* **掩码机制**仅计算非填充位置Non-padding positions的损失忽略无效的时间步 [1]。
@ -790,112 +820,7 @@ train-model evaluate \
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 6.8 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 6.9 导出模型(开发中)

436
beam_search_demo.py Normal file
View File

@ -0,0 +1,436 @@
#!/usr/bin/env python3
"""
束搜索算法演示
展示如何使用导出的两个ONNX模型进行束搜索推理
模拟输入法场景给定上下文拼音和已确认字符生成候选汉字序列
"""
import argparse
import os
import sys
from pathlib import Path
from typing import List, Tuple, Dict, Any
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
class ONNXBeamSearch:
"""
基于ONNX模型的束搜索解码器
"""
def __init__(
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
):
"""
初始化
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
device: 推理设备cpu或cuda
"""
self.device = device
# 配置ONNX Runtime提供程序
if device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
# 创建ONNX Runtime会话
self.context_session = ort.InferenceSession(
context_encoder_path, providers=providers
)
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
# 获取输入输出信息
self.context_input_info = {
input.name: input for input in self.context_session.get_inputs()
}
self.context_output_info = {
output.name: output for output in self.context_session.get_outputs()
}
self.decoder_input_info = {
input.name: input for input in self.decoder_session.get_inputs()
}
self.decoder_output_info = {
output.name: output for output in self.decoder_session.get_outputs()
}
print("📦 ONNX模型加载完成")
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
def prepare_inputs(
self,
text_before: str,
text_after: str,
pinyin: str,
context_prompts: List[str] = None,
tokenizer=None,
query_engine=None,
max_seq_len: int = 128,
) -> Dict[str, np.ndarray]:
"""
准备模型输入
注意: 这是一个简化的版本实际应用中需要实现完整的预处理逻辑
这里使用随机数据模拟实际应使用tokenizer和查询引擎
Args:
text_before: 光标前文本
text_after: 光标后文本
pinyin: 拼音输入
context_prompts: 上下文提示
tokenizer: 分词器未实现
query_engine: 查询引擎未实现
max_seq_len: 最大序列长度
Returns:
输入字典包含input_ids, pinyin_ids, attention_mask
"""
# 在实际应用中,这里应该实现:
# 1. 使用tokenizer将文本转换为input_ids
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
# 3. 构建attention_mask
# 简化为随机数据(用于演示)
batch_size = 1
seq_len = min(max_seq_len, 64) # 简化长度
# 模拟输入
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
pinyin_ids = np.random.randint(
0, 30, (batch_size, 24), dtype=np.int64
) # 固定24长度
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
# 在实际应用中,应该这样处理拼音:
# from src.model.dataset import text_to_pinyin_ids
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
return {
"input_ids": input_ids,
"pinyin_ids": pinyin_ids,
"attention_mask": attention_mask,
}
def run_context_encoder(
self, inputs: Dict[str, np.ndarray]
) -> Tuple[np.ndarray, ...]:
"""
运行上下文编码器
Args:
inputs: 输入字典
Returns:
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
"""
# 准备ONNX输入确保顺序正确
onnx_inputs = {}
for input_name in self.context_input_info.keys():
if input_name in inputs:
onnx_inputs[input_name] = inputs[input_name]
else:
# 对于缺失的输入,使用默认值
shape = self.context_input_info[input_name].shape
dtype = self.context_input_info[input_name].type
# 简化处理:创建零数组
if "int" in dtype:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
else:
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
# 运行推理
outputs = self.context_session.run(None, onnx_inputs)
return tuple(outputs)
def run_decoder(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
history_slot_ids: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
) -> np.ndarray:
"""
运行解码器
Args:
context_H: 上下文编码 [batch, seq_len, 512]
pinyin_P: 拼音编码 [batch, 24, 512]
history_slot_ids: 历史槽位IDs [batch, 8]
context_mask: 上下文掩码 [batch, seq_len]
pinyin_mask: 拼音掩码 [batch, 24]
Returns:
logits [batch, vocab_size]
"""
inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(None, inputs)
return outputs[0]
def beam_search(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
beam_size: int = 5,
max_length: int = 8,
vocab_size: int = 10019,
temperature: float = 1.0,
length_penalty: float = 0.6,
) -> List[Tuple[List[int], float]]:
"""
束搜索算法
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
temperature: 温度参数控制随机性
length_penalty: 长度惩罚系数
Returns:
排序后的候选序列列表每个元素为(序列, 分数)
"""
# 初始束空序列对数概率为0
beams = [([], 0.0)] # (序列, 累计对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids
if len(seq) < 8:
# 填充到8个槽位
history = seq + [0] * (8 - len(seq))
else:
# 只保留最近8个字符
history = seq[-8:]
history_array = np.array([history], dtype=np.int64)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_array, context_mask, pinyin_mask
)
# 应用温度
logits = logits / temperature
# 转换为概率
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
# 获取top-k候选k = beam_size * 2 以增加多样性)
top_k = min(beam_size * 2, vocab_size)
top_indices = np.argsort(probs)[-top_k:][::-1]
top_probs = probs[top_indices]
# 扩展束
for idx, prob in zip(top_indices, top_probs):
new_seq = seq + [int(idx)]
# 更新分数:累计对数概率 + log(prob)
new_score = score + np.log(prob + 1e-10)
# 应用长度惩罚
length_penalized_score = new_score / (
(5 + len(new_seq)) ** length_penalty
)
new_beams.append((new_seq, length_penalized_score, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
# 检查是否所有序列都已结束结束符ID为0
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
if all_ended:
break
# 返回原始分数(未应用长度惩罚)
return beams
def interactive_beam_search(self):
"""交互式束搜索演示"""
print("\n" + "=" * 60)
print("束搜索交互演示")
print("=" * 60)
print("\n📝 输入法场景模拟:")
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 ''")
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
print("-" * 60)
# 模拟输入(实际应用中应从用户获取)
text_before = "今天天气很好"
text_after = "我们去公园玩"
pinyin = "hao" # 继续输入"hao"
slot_chars = [""] # 已确认的字符
context_prompts = ["张三", "李四"]
print(f"\n📋 输入参数:")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars}")
print(f" 上下文提示: {context_prompts}")
# 准备输入(简化版,使用随机数据)
print(f"\n🔧 准备模型输入...")
inputs = self.prepare_inputs(
text_before=text_before,
text_after=text_after,
pinyin=pinyin,
context_prompts=context_prompts,
)
# 运行上下文编码器
print(f"🧠 运行上下文编码器...")
context_outputs = self.run_context_encoder(inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
print(f"✅ 上下文编码完成")
print(f" context_H形状: {context_H.shape}")
print(f" pinyin_P形状: {pinyin_P.shape}")
# 将槽位历史转换为ID简化使用随机ID
# 实际应用中应使用query_engine将汉字转换为ID
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
if len(slot_ids) < 8:
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
# 运行束搜索
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
beams = self.beam_search(
context_H,
pinyin_P,
context_mask,
pinyin_mask,
beam_size=3,
max_length=4,
vocab_size=10019,
)
# 显示结果
print(f"\n🏆 束搜索结果:")
print("-" * 50)
for i, (seq, score) in enumerate(beams):
# 将ID序列转换为汉字简化显示ID
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
print(f"{i + 1}. 序列: [{seq_str}]")
print(f" 对数概率: {score:.4f}")
print(f" 概率: {np.exp(score):.6f}")
print()
print("📝 说明:")
print(" - 'END' 表示结束符 (ID: 0)")
print(" - 实际应用中应将ID转换为汉字")
print(" - 最高分数的序列作为最终预测")
return beams
def main():
parser = argparse.ArgumentParser(description="束搜索算法演示")
parser.add_argument(
"--context-encoder",
type=str,
default="./exported_models/context_encoder.onnx",
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
default="./exported_models/decoder.onnx",
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="推理设备(默认: cpu",
)
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3")
parser.add_argument(
"--max-length", type=int, default=4, help="最大生成长度(默认: 4"
)
parser.add_argument(
"--interactive",
action="store_true",
default=True,
help="交互模式(默认: True",
)
args = parser.parse_args()
# 检查模型文件是否存在
if not os.path.exists(args.context_encoder):
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
if not os.path.exists(args.decoder):
print(f"❌ 解码器文件不存在: {args.decoder}")
print(f" 请先运行 export_onnx.py 导出模型")
return
print("🚀 束搜索算法演示")
print("=" * 60)
print(f"上下文编码器: {args.context_encoder}")
print(f"解码器: {args.decoder}")
print(f"设备: {args.device}")
print(f"束大小: {args.beam_size}")
print(f"最大长度: {args.max_length}")
# 初始化束搜索器
beam_searcher = ONNXBeamSearch(
context_encoder_path=args.context_encoder,
decoder_path=args.decoder,
device=args.device,
)
if args.interactive:
# 运行交互演示
beam_searcher.interactive_beam_search()
print("\n" + "=" * 60)
print("🎉 演示完成")
print("\n💡 下一步:")
print(" 1. 实现完整的输入预处理tokenizer和拼音转换")
print(" 2. 集成查询引擎以将ID转换为汉字")
print(" 3. 根据实际场景调整束搜索参数")
print(" 4. 性能优化:批量处理、缓存等")
if __name__ == "__main__":
main()

View File

@ -479,118 +479,11 @@ train-model evaluate \
--batch-size 32
```
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
命令将显示"评估功能待实现"的提示信息。该功能计划用于:
- 加载训练好的模型检查点
- 在评估数据集上计算准确率、困惑度等指标
- 生成详细的性能报告
### 模型扩容两阶段训练
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
#### 训练策略
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
#### 基础用法
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "model:InputMethodEngine" \
--num-experts 40 \
--frozen-lr 2e-3 \
--full-lr 5e-5 \
--frozen-patience 8
```
#### 完整参数示例
```bash
train-model expand-and-train \
--train-data-path "path/to/train/dataset" \
--eval-data-path "path/to/eval/dataset" \
--output-dir "./expansion_output" \
--base-model-path "./pretrained/model.pt" \
--new-model-spec "custom_model:ExpandedModel" \
--vocab-size 10019 \
--dim 512 \
--num-experts 40 \
--frozen-patience 10 \
--frozen-lr 1e-3 \
--full-lr 1e-4 \
--frozen-scheduler cosine \
--full-scheduler cosine \
--batch-size 128 \
--num-epochs 20 \
--compile
```
#### 参数详解
**模型扩容参数**
- `--base-model-path`: 预训练基础模型检查点路径(必需)
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
- 自定义模型类必须是 `InputMethodEngine` 的子类
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel`
**两阶段训练参数**
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数触发切换到全量微调默认10
- `--frozen-lr`: 冻结阶段学习率默认1e-3
- `--full-lr`: 全量微调阶段学习率默认1e-4
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine``plateau`(默认:`cosine`
**其他参数**
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
- 继承现有的训练基础设施混合精度训练、TensorBoard日志、checkpoint保存等
#### 使用场景
1. **增加专家数量**20→40
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
- 新增参数新专家网络、gate层
2. **增加top_k值**2→3
- 冻结效果100% 参数可冻结(仅逻辑变化)
- 新增参数:无
3. **修改专家内部结构**如增加resblocks
- 冻结效果:~50% 参数可冻结linear_in/output可冻结
- 新增参数新增的resblocks层
4. **增加Transformer层数**4→5
- 冻结效果:~80% 参数可冻结前4层可冻结
- 新增参数新增的第5层
#### 自定义模型类示例
```python
# my_model.py
from model.model import InputMethodEngine
class MyExpandedModel(InputMethodEngine):
def __init__(self, num_experts=40, **kwargs):
# 调用父类构造函数覆盖num_experts参数
super().__init__(num_experts=num_experts, **kwargs)
# 可以在这里添加额外的层或修改现有层
# 使用命令
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
```
#### 注意事项
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
3. **性能保持**MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
4. **阶段切换**基于评估频率而非epoch建议适当调高 `--eval-frequency`
5. **模块导入**支持任意路径的模块通过Python标准导入机制加载
### 导出模型(开发中)
当前导出功能尚在开发中:

168
docs/WORD_BREAK_DESIGN.md Normal file
View File

@ -0,0 +1,168 @@
# 破词训练设计文档
## 背景
输入法用户在实际使用中通常是**逐词输入**的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
```
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
```
但为了增强模型的泛化能力,需要模拟用户**从词中间断开**的情况。例如用户可能只输入了"漂"就开始选字"亮"。
## 破词概念
### 术语定义
| 术语 | 说明 |
|------|------|
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
| 前缀 | 光标前的已确认文本 |
| 拼音 | 当前待选字的拼音(可能不完整) |
| 后缀 | 光标后的原文内容 |
### 场景示例
以词"漂亮"为例:
**整词模式90%概率):**
```
光标前: 那边的特别
拼音: piaoliang
预测: 漂 → 亮
```
**破词模式10%概率):**
```
光标前: 那边的特别漂
拼音: liang
预测: 亮
```
## 实现方案
### 分词策略
使用 jieba 分词器进行词语边界识别:
```python
import jieba
words = list(jieba.cut(text, HMM=False))
# "那边的特别漂亮的女孩是我的表姐。"
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
```
### 两阶段样本生成
每个词生成样本时分为两个阶段:
#### Phase 1前缀/整词阶段
- **整词90%**`prefix_positions = 整个词的所有字符`
- **破词前缀10%**`prefix_positions = 词的前 break_pos 个字符`
```python
if should_break:
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
else:
break_pos = word_len_chars # 整词
```
#### Phase 2破词续接阶段仅当破词时
当破词发生时,从断点位置开始继续采样:
```python
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 从断点开始采样后续字符
target_len = random_choice(1-8) # 采样长度
cont_positions = [cont_start, ...] # 后续字符位置
```
### 样本结构
每个字符生成一个训练样本,包含:
| 字段 | 说明 | 示例 |
|------|------|------|
| `part1` (prefix) | 光标前文本 | "那边的特别漂" |
| `part2` (pinyin) | 当前字拼音 | "liang" |
| `part3` (suffix) | 光标后文本 | "亮的女孩是我的表姐" |
| `part4` | 专有词提示 | "漂亮\|特别" |
| `label` | 目标汉字ID | 1234 |
| `history_slot_ids` | 历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
### 拼音增强策略
根据 `py_style_weight` 参数,拼音有以下三种形式:
| 形式 | 概率 | 示例 |
|------|------|------|
| 完整拼音 | 75% (9/12) | "piaoliang" |
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
| 仅首字母 | 8.3% (1/12) | "p" |
参数配置:`py_style_weight=(9, 2, 1)`
## 破词概率控制
### word_break_prob 参数
控制每个词从中间断开的概率,默认为 **10%**
```python
self.word_break_prob = 0.10 # 10%概率从词中间破开
```
### 破词位置分布
对于长度为 N 的词,破开位置 `break_pos` 的分布:
```python
break_pos = random.randint(1, N - 1)
```
- 2字词break_pos = 1100%在第1字后破开
- 3字词break_pos = 1 或 2各50%
- 4字词break_pos = 1, 2, 或 3各33%
## 数据分布预期
### 理想分布
| 类别 | 预期比例 |
|------|----------|
| 单字样本 | ~15% |
| 2字词整词 | ~30% |
| 3字词整词 | ~20% |
| 破词样本 | ~10% |
| 其他 | ~25% |
### 拼音不完整率
由于 `py_style_weight=(9, 2, 1)`
- 声母initials~16.7%
- 首字母:~8.3%
- **总计不完整**~25%
## 代码实现位置
主要实现文件:`src/model/dataset.py`
| 函数/类 | 行号 | 功能 |
|---------|------|------|
| `segment_text()` | ~30 | jieba分词 |
| `build_word_boundaries()` | ~35 | 建立词边界映射 |
| `PinyinInputDataset.__iter__()` | ~280 | 核心迭代逻辑 |
| `get_mask_pinyin()` | ~215 | 拼音加强处理 |
| `_add_word_samples()` | ~240 | 样本构建 |
## 注意事项
1. **破词仅针对多字词**:单字词(如"的"、“是”)不会破词
2. **破词保持语义完整**:破词后仍能根据上下文预测正确汉字
3. **历史槽位模拟逐步确认**:同一词内已确认的字会填入 `history_slot_ids`
4. **10% EOS标记**词尾有10%概率追加ID=0表示句子结束

33
eval.py
View File

@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
import argparse
import random
import re
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
from src.model.dataset import text_to_pinyin_ids
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
class TextEvaluator:
def __init__(
@ -171,34 +168,20 @@ class TextEvaluator:
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本转换为拼音列表
参考dataset.py中的generate_pinyin方法
将文本转换为拼音列表对整段文本调用 lazy_pinyin
利用 pypinyin 内部的分词能力处理多音字
参考 dataset.py 中的 generate_pinyin 方法
"""
if not text:
return []
text_len = len(text)
result: List[str] = [""] * text_len
pinyin_list = lazy_pinyin(text)
# 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
if len(pinyin_list) != len(text):
pinyin_list = [lazy_pinyin(c)[0] for c in text]
pinyin_list = lazy_pinyin(hanzi_segment)
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 填充非汉字字符
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
return pinyin_list
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]

112
export.record Normal file

File diff suppressed because one or more lines are too long

86
export_onnx.py Normal file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
输入法模型ONNX导出脚本
将模型导出为两个ONNX部分
1. context_encoder.onnx - 上下文编码器
2. decoder.onnx - 解码器
使用方法:
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
依赖安装:
pip install onnx onnxruntime
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from src.model.onnx_export import check_onnx_available, run_full_export
def main():
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="输出目录(默认: ./exported_models",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="导出设备(默认: cpu",
)
parser.add_argument(
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
)
args = parser.parse_args()
if not check_onnx_available():
sys.exit(1)
print(f"输出目录: {Path(args.output_dir).absolute()}")
context_encoder_path, decoder_path, config = run_full_export(
checkpoint_path=args.checkpoint,
output_dir=args.output_dir,
device=args.device,
skip_verification=args.skip_verification,
)
output_dir = Path(args.output_dir)
print()
print("=" * 60)
print("ONNX导出完成")
print("=" * 60)
print("生成的模型文件:")
print(f" - {context_encoder_path}")
print(f" - {decoder_path}")
print(f" - {output_dir / 'example_inputs.npz'}")
print(f" - {output_dir / 'example_inputs.pt'}")
print(f" - {output_dir / 'inference_example.py'}")
print()
print("使用方法:")
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
print()
print("注意:")
print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import os
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider']
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
beams = [([], 0.0)]
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = torch.tensor([history], dtype=torch.long)
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
probs = F.softmax(logits[0], dim=-1)
top_probs, top_indices = torch.topk(probs, beam_size)
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行 train-model export 导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("\u2705 ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
"""
初始化ONNX推理器
Args:
context_encoder_path: 上下文编码器ONNX模型路径
decoder_path: 解码器ONNX模型路径
"""
# 创建ONNX Runtime会话
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
# 获取输入输出名称
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
"""
准备模型输入与原始推理脚本保持一致
注意: 这里需要实现文本到token的转换
为了简化示例假设已经实现了相关函数
"""
# 这里应该调用实际的预处理函数
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
"""
运行上下文编码器
Args:
input_ids: [batch, seq_len]
pinyin_ids: [batch, 24]
attention_mask: [batch, seq_len]
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
# 准备输入
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
# 运行推理
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
# 解包输出
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
运行解码器
Args:
context_H: [batch, seq_len, 512]
pinyin_P: [batch, 24, 512]
history_slot_ids: [batch, 8]
context_mask: [batch, seq_len]
pinyin_mask: [batch, 24]
Returns:
logits: [batch, vocab_size]
"""
# 准备输入
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
# 运行推理
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
# 解包输出
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
"""
束搜索算法示例
Args:
context_H: 上下文编码
pinyin_P: 拼音编码
context_mask: 上下文掩码
pinyin_mask: 拼音掩码
beam_size: 束大小
max_length: 最大生成长度
vocab_size: 词汇表大小
Returns:
最佳序列列表
"""
# 初始束空序列分数为0
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
# 构建history_slot_ids已确认的字符ID
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:] # 只保留最近8个
history_tensor = torch.tensor([history], dtype=torch.long)
# 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
# 获取概率
probs = F.softmax(logits[0], dim=-1)
# 获取top-k候选
top_probs, top_indices = torch.topk(probs, beam_size)
# 扩展束
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
# 剪枝保留beam_size个最佳候选
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# 检查是否所有序列都已结束以结束符0结尾
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
"""
单步预测
Args:
input_ids: 输入token IDs
pinyin_ids: 拼音IDs
attention_mask: 注意力掩码
history_slot_ids: 历史槽位IDs
Returns:
预测logits
"""
# 1. 运行上下文编码器
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
# 2. 运行解码器
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
# 初始化推理器
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行export_onnx.py导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("✅ ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,11 @@
Frequency Analysis Results
==================================================
Min frequency: 1
Max frequency: 494748360
Mean frequency: 560007.90
Standard deviation: 5730144.34
10th percentile: 3
50th percentile: 93
90th percentile: 331538
IDs in range 5000-5500 min: 5594
IDs in range 5000-5500 max: 9569

20648
id_vs_freq.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,8 @@
"""
import argparse
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
@ -60,6 +62,98 @@ class InputMethodInference:
print(f"✅ 推理器初始化完成 (设备: {self.device})")
# 尝试启用readline以获得更好的行编辑功能
try:
import readline
# 设置readline使用UTF-8编码
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
print("📝 readline已启用支持更好的行编辑功能")
except ImportError:
print("📝 readline不可用使用标准输入")
def _safe_input(self, prompt: str, default: str = "") -> str:
"""
安全的输入函数尝试正确处理UTF-8字符和退格键
Args:
prompt: 提示文本
default: 默认值
Returns:
用户输入的字符串
"""
try:
# 显示提示和默认值
if default:
full_prompt = f"{prompt} [{default}]: "
else:
full_prompt = f"{prompt}: "
# 使用标准input
result = input(full_prompt)
# 如果用户直接回车且存在默认值,则返回默认值
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
# 用户按Ctrl+D或Ctrl+C
print()
return ""
except Exception as e:
# 其他错误
print(f"\n⚠️ 输入错误: {e}")
return ""
def _clean_pinyin_input(self, pinyin: str) -> str:
"""
清理拼音输入字符串处理退格键等特殊字符
拼音只允许: a-z, `, ', -
中文字符和其他字符会被忽略
Args:
pinyin: 原始拼音输入字符串
Returns:
清理后的拼音字符串
"""
if not pinyin:
return ""
result = []
for c in pinyin:
# 检查是否为合法拼音字符 (a-z, `, ', -)
# 注意: 中文字符的isalpha()也返回True所以需要额外检查
is_valid_pinyin_char = (
("a" <= c <= "z")
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
or c in ["`", "'", "-"]
)
if is_valid_pinyin_char:
# 合法拼音字符,转换为小写
result.append(c.lower())
elif c == " ":
# 空格忽略
continue
elif c == "\b" or c == "\x7f" or c == "\x08":
# 退格键、删除键:删除前一个字符
if result:
result.pop()
elif c == "\x1b":
# ESC键清空所有输入
result.clear()
else:
# 其他字符(包括中文字符)忽略
# 注意这里不添加到result所以退格键无法删除它们
# 但用户可能在拼音输入中误输入中文字符,应该忽略
pass
return "".join(result)
def load_model(self):
"""加载训练好的模型"""
# 创建模型实例(不编译)
@ -184,7 +278,7 @@ class InputMethodInference:
"""
# 1. 构建tokenizer输入
# 根据dataset.py格式为: "part4|part1" 和 part3
# 根据test.py和dataset.py格式为: "part4|part1" 和 part3
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
# part1: text_before
# part3: text_after
@ -192,13 +286,14 @@ class InputMethodInference:
# 处理上下文提示
context_text = "|".join(context_prompts) if context_prompts else ""
# 构建输入文本
# 构建输入文本 - 与test.py保持一致
# test.py: f"{part4}|{part1}" 作为第一个参数part3作为第二个参数
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
# 2. Tokenize - 与test.py保持一致
encoded = self.tokenizer(
input_text,
text_after,
@ -209,8 +304,11 @@ class InputMethodInference:
return_token_type_ids=True,
)
# 3. 处理拼音输入
pinyin_ids = text_to_pinyin_ids(pinyin)
# 3. 处理拼音输入 - 与test.py保持一致
# 首先清理拼音字符串,处理退格键等特殊字符
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
@ -396,7 +494,16 @@ class InputMethodInference:
print("\n" + "=" * 60)
print("输入法模型推理 - 交互模式")
print("=" * 60)
print("说明:")
# 检查终端编码
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
if encoding.lower() not in ["utf-8", "utf8"]:
print("⚠️ 警告: 终端编码不是UTF-8中文输入可能有问题")
print(" 建议设置: export LANG=en_US.UTF-8")
print(" 或设置: export LC_ALL=en_US.UTF-8")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
@ -411,7 +518,7 @@ class InputMethodInference:
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
print("格式: 用逗号分隔多个词汇,可为空")
print("示例: 张三,李四,北京大学")
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -434,7 +541,7 @@ class InputMethodInference:
print("第2步: 光标前文本")
print("说明: 光标前的连续文本内容")
print("示例: 今天天气很好")
text_before = input("请输入光标前文本: ").strip()
text_before = self._safe_input("请输入光标前文本")
if text_before.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -446,7 +553,7 @@ class InputMethodInference:
print("第3步: 光标后文本")
print("说明: 光标后的连续文本内容")
print("示例: 我们去公园玩")
text_after = input("请输入光标后文本: ").strip()
text_after = self._safe_input("请输入光标后文本")
if text_after.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -458,7 +565,7 @@ class InputMethodInference:
print("第4步: 拼音输入")
print("说明: 当前正在输入的拼音")
print("示例: tian, shang, hao")
pinyin = input("请输入拼音: ").strip()
pinyin = self._safe_input("请输入拼音")
if pinyin.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -471,7 +578,7 @@ class InputMethodInference:
print("说明: 用户已确认的输入历史,用逗号分隔")
print("示例: 上 (表示输入'shanghai'已确认''")
print(" 今天,天气 (表示已确认两个词)")
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]:
print("退出交互模式")
@ -524,7 +631,9 @@ class InputMethodInference:
# 询问是否继续
print("\n" + "-" * 40)
continue_input = input("是否继续推理?(y/n): ").strip().lower()
continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break
@ -539,7 +648,9 @@ class InputMethodInference:
traceback.print_exc()
# 询问是否继续
continue_input = input("\n是否继续?(y/n): ").strip().lower()
continue_input = (
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
print("退出交互模式")
break

763
onnx_inference.py Normal file
View File

@ -0,0 +1,763 @@
#!/usr/bin/env python3
"""
ONNX输入法模型推理脚本
使用ONNX Runtime进行推理测量每个阶段的执行时长
使用方法:
python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx
交互模式: 分步询问输入
1. 上下文提示: 模型不掌握的专有词汇姓名等可为空
2. 光标前文本: 光标前的连续文本
3. 光标后文本: 光标后的连续文本
4. 拼音: 当前输入的拼音
5. 槽位历史: 用户已确认的输入历史如输入shanghai已确认""
"""
import argparse
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from src.model.dataset import text_to_pinyin_ids
from src.model.query import QueryEngine
class ONNXInference:
"""ONNX输入法模型推理器"""
def __init__(
self,
context_encoder_path: str,
decoder_path: str,
vocab_size: int = 10019,
device: str = "cpu",
use_beam_search: bool = False,
beam_size: int = 5,
):
self.vocab_size = vocab_size
self.device = device
self.use_beam_search = use_beam_search
self.beam_size = beam_size
# 加载组件
print(f"正在加载上下文编码器: {context_encoder_path}")
load_start = time.perf_counter()
self.load_context_encoder(context_encoder_path)
self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)")
print(f"正在加载解码器: {decoder_path}")
load_start = time.perf_counter()
self.load_decoder(decoder_path)
self.decoder_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)")
# 加载tokenizer
print("正在加载tokenizer...")
load_start = time.perf_counter()
self.load_tokenizer()
self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)")
# 加载查询引擎
print("正在加载查询引擎...")
load_start = time.perf_counter()
self.load_query_engine()
self.query_engine_load_time = (time.perf_counter() - load_start) * 1000
print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)")
total_load_time = (
self.context_encoder_load_time
+ self.decoder_load_time
+ self.tokenizer_load_time
+ self.query_engine_load_time
)
print(f"\n✅ 推理器初始化完成 (设备: {device})")
print(f" 总加载时间: {total_load_time:.2f}ms")
# 尝试启用readline
try:
import readline
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
except ImportError:
pass
def load_context_encoder(self, model_path: str):
"""加载上下文编码器ONNX模型"""
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if self.device == "cuda"
else ["CPUExecutionProvider"]
)
self.context_encoder_session = ort.InferenceSession(
model_path, providers=providers
)
self.context_input_names = [
inp.name for inp in self.context_encoder_session.get_inputs()
]
self.context_output_names = [
out.name for out in self.context_encoder_session.get_outputs()
]
def load_decoder(self, model_path: str):
"""加载解码器ONNX模型"""
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if self.device == "cuda"
else ["CPUExecutionProvider"]
)
self.decoder_session = ort.InferenceSession(model_path, providers=providers)
self.decoder_input_names = [
inp.name for inp in self.decoder_session.get_inputs()
]
self.decoder_output_names = [
out.name for out in self.decoder_session.get_outputs()
]
def load_tokenizer(self):
"""加载tokenizer"""
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
except Exception:
print(" ⚠️ 无法加载自定义tokenizer使用bert-base-chinese")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎"""
try:
self.query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
except Exception:
self.query_engine = None
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
"""将汉字转换为ID"""
if char == "//":
return 0
if self.query_engine is None:
return ord(char) if len(char) == 1 else 0
try:
if pinyin is not None:
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
if info:
return info.id
results = self.query_engine.query_by_char(char, limit=1)
if results:
return results[0][0]
return 0
except:
return 0
def id_to_char(self, id: int) -> str:
"""将ID转换为汉字"""
if id == 0:
return "//"
if self.query_engine is None:
return chr(id) if id < 0x110000 else f"[ID:{id}]"
try:
info = self.query_engine.query_by_id(id)
return info.char if info else f"[ID:{id}]"
except:
return f"[ID:{id}]"
def _clean_pinyin_input(self, pinyin: str) -> str:
"""清理拼音输入字符串"""
if not pinyin:
return ""
result = []
for c in pinyin:
is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"]
if is_valid:
result.append(c.lower())
elif c == " ":
continue
elif c in ["\b", "\x7f", "\x08"]:
if result:
result.pop()
elif c == "\x1b":
result.clear()
return "".join(result)
def _safe_input(self, prompt: str, default: str = "") -> str:
"""安全的输入函数"""
try:
full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: "
result = input(full_prompt)
if not result and default:
return default
return result.strip()
except (EOFError, KeyboardInterrupt):
print()
return ""
except Exception as e:
print(f"\n⚠️ 输入错误: {e}")
return ""
def prepare_inputs(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
max_seq_len: int = 128,
) -> dict:
"""
准备模型输入
Returns:
dict: {
'preprocess_time': float, # 预处理时间(ms)
'input_ids': numpy array,
'attention_mask': numpy array,
'pinyin_ids': numpy array,
'history_slot_ids': numpy array,
}
"""
preprocess_start = time.perf_counter()
# 1. 构建tokenizer输入
context_text = "|".join(context_prompts) if context_prompts else ""
if context_text:
input_text = f"{context_text}|{text_before}"
else:
input_text = text_before
# 2. Tokenize
encoded = self.tokenizer(
input_text,
text_after,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
input_ids = encoded["input_ids"].numpy()
attention_mask = encoded["attention_mask"].numpy()
# 3. 处理拼音输入
cleaned_pinyin = self._clean_pinyin_input(pinyin)
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
if len(pinyin_ids) < 24:
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = np.array([pinyin_ids], dtype=np.int64)
# 4. 处理历史槽位
history_slot_ids = []
for char in slot_chars:
char_id = self.char_to_id(char)
history_slot_ids.append(char_id)
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
history_slot_ids = np.array([history_slot_ids], dtype=np.int64)
preprocess_time = (time.perf_counter() - preprocess_start) * 1000
return {
"preprocess_time": preprocess_time,
"input_ids": input_ids,
"attention_mask": attention_mask,
"pinyin_ids": pinyin_ids,
"history_slot_ids": history_slot_ids,
}
def run_context_encoder(
self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
运行上下文编码器
Returns:
context_H, pinyin_P, context_mask, pinyin_mask
"""
context_start = time.perf_counter()
inputs = {
"input_ids": input_ids,
"pinyin_ids": pinyin_ids,
"attention_mask": attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000
return context_H, pinyin_P, context_mask, pinyin_mask
def run_decoder(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
history_slot_ids: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
) -> np.ndarray:
"""
运行解码器
Returns:
logits: [batch, vocab_size]
"""
decoder_start = time.perf_counter()
inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000
return outputs[0]
def predict(
self,
context_prompts: List[str],
text_before: str,
text_after: str,
pinyin: str,
slot_chars: List[str],
top_k: int = 20,
use_beam_search: bool = False,
beam_size: int = 5,
max_length: int = 10,
) -> Tuple[List[Tuple[str, float, int]], dict]:
"""
执行推理
Args:
context_prompts: 上下文提示
text_before: 光标前文本
text_after: 光标后文本
pinyin: 当前输入的拼音
slot_chars: 槽位内的汉字列表
top_k: 返回top-k个预测结果
use_beam_search: 是否使用束搜索
beam_size: 束大小
max_length: 最大生成长度
Returns:
(predictions, timing_info)
predictions: List[Tuple[char, prob, id]]
timing_info: 各阶段耗时字典
"""
total_start = time.perf_counter()
# 阶段1: 预处理
prep_start = time.perf_counter()
inputs = self.prepare_inputs(
context_prompts, text_before, text_after, pinyin, slot_chars
)
preprocess_time = inputs["preprocess_time"]
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
pinyin_ids = inputs["pinyin_ids"]
history_slot_ids = inputs["history_slot_ids"]
prep_time = (time.perf_counter() - prep_start) * 1000
# 阶段2: 上下文编码
context_start = time.perf_counter()
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
context_encoder_time = self.last_context_encoder_time
if use_beam_search:
# 阶段3: 束搜索解码
decode_start = time.perf_counter()
predictions, beam_decode_time = self._beam_search_decode(
context_H,
pinyin_P,
context_mask,
pinyin_mask,
beam_size,
max_length,
top_k,
)
decoder_time = beam_decode_time
else:
# 阶段3: 单步解码
decode_start = time.perf_counter()
logits = self.run_decoder(
context_H,
pinyin_P,
history_slot_ids,
context_mask,
pinyin_mask,
)
# 阶段4: 后处理
postprocess_start = time.perf_counter()
probs = self._softmax(logits)
top_indices, top_probs = self._topk(probs, top_k)
predictions = []
for i in range(top_k):
idx = int(top_indices[0, i])
prob = float(top_probs[0, i])
char = self.id_to_char(idx)
predictions.append((char, prob, idx))
postprocess_time = (time.perf_counter() - postprocess_start) * 1000
decoder_time = self.last_decoder_time
total_time = (time.perf_counter() - total_start) * 1000
timing_info = {
"预处理": prep_time,
"上下文编码": context_encoder_time,
"解码": decoder_time,
"后处理": postprocess_time if not use_beam_search else 0,
"总耗时": total_time,
}
return predictions, timing_info
def _softmax(self, logits: np.ndarray) -> np.ndarray:
"""计算softmax"""
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""获取top-k"""
topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1]
topk_probs = np.take_along_axis(probs, topk_indices, axis=-1)
return topk_indices, topk_probs
def _beam_search_decode(
self,
context_H: np.ndarray,
pinyin_P: np.ndarray,
context_mask: np.ndarray,
pinyin_mask: np.ndarray,
beam_size: int,
max_length: int,
top_k: int,
) -> Tuple[List[Tuple[str, float, int]], float]:
"""束搜索解码"""
beams = [([], 0.0)] # (序列, 对数概率)
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = np.array([history], dtype=np.int64)
logits = self.run_decoder(
context_H,
pinyin_P,
history_tensor,
context_mask,
pinyin_mask,
)
probs = self._softmax(logits)[0]
topk_indices = np.argsort(probs)[-beam_size:][::-1]
topk_probs = probs[topk_indices]
for idx, prob in zip(topk_indices, topk_probs):
new_seq = seq + [int(idx)]
new_score = score + np.log(prob + 1e-10)
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
# 返回top-k个候选
predictions = []
for seq, score in beams[:top_k]:
if seq:
char = self.id_to_char(seq[-1])
prob = np.exp(score / max(len(seq), 1))
else:
char = self.id_to_char(0)
prob = 0.0
predictions.append((char, prob, seq[-1] if seq else 0))
decode_time = self.last_decoder_time # 只记录最后一次解码的时间
return predictions, decode_time
def interactive_mode(self):
"""交互式推理模式"""
print("\n" + "=" * 60)
print("ONNX输入法模型推理 - 交互模式")
print("=" * 60)
encoding = sys.stdout.encoding or "unknown"
print(f"终端编码: {encoding}")
print("\n说明:")
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
print(" - 光标前文本: 光标前的连续文本")
print(" - 光标后文本: 光标后的连续文本")
print(" - 拼音: 当前输入的拼音")
print(" - 槽位历史: 用户已确认的输入历史")
if self.use_beam_search:
print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})")
else:
print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)")
print("提示: 输入 'quit''exit''q' 可随时退出")
print("-" * 60)
while True:
try:
print("\n" + "=" * 60)
context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)")
if context_input.lower() in ["quit", "exit", "q"]:
break
context_prompts = [
item.strip() for item in context_input.split(",") if item.strip()
]
print("\n" + "-" * 40)
text_before = self._safe_input("第2步: 光标前文本")
if text_before.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
text_after = self._safe_input("第3步: 光标后文本")
if text_after.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
pinyin = self._safe_input("第4步: 拼音输入")
if pinyin.lower() in ["quit", "exit", "q"]:
break
print("\n" + "-" * 40)
slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)")
if slot_input.lower() in ["quit", "exit", "q"]:
break
slot_chars = [
char.strip() for char in slot_input.split(",") if char.strip()
]
print("\n" + "=" * 60)
print("📝 输入汇总:")
print(f" 上下文提示: {context_prompts if context_prompts else ''}")
print(f" 光标前文本: '{text_before}'")
print(f" 光标后文本: '{text_after}'")
print(f" 拼音: '{pinyin}'")
print(f" 槽位历史: {slot_chars if slot_chars else ''}")
print("\n🔮 推理中...")
predictions, timing_info = self.predict(
context_prompts,
text_before,
text_after,
pinyin,
slot_chars,
top_k=20,
use_beam_search=self.use_beam_search,
beam_size=self.beam_size,
)
# 显示时间统计
print(f"\n⏱️ 执行时间统计:")
print("-" * 40)
for stage, duration in timing_info.items():
if duration > 0:
print(f" {stage:<12}: {duration:>8.2f} ms")
print("-" * 40)
# 显示结果
print("\n🏆 Top-20 预测结果:")
print("-" * 50)
for i, (char, prob, idx) in enumerate(predictions):
if char == "//":
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
else:
print(
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
)
# 显示拼音参考
if pinyin and self.query_engine:
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
if pinyin_results:
for j, (pid, char, count) in enumerate(pinyin_results):
print(f" {char} (频次: {count:,})")
print("\n" + "-" * 40)
continue_input = (
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
)
if continue_input not in ["y", "yes", ""]:
break
except KeyboardInterrupt:
print("\n\n退出交互模式")
break
except Exception as e:
print(f"\n❌ 推理出错: {e}")
import traceback
traceback.print_exc()
def main():
parser = argparse.ArgumentParser(description="ONNX输入法模型推理")
parser.add_argument(
"--context-encoder",
"-c",
type=str,
required=True,
help="上下文编码器ONNX模型路径",
)
parser.add_argument(
"--decoder",
"-d",
type=str,
required=True,
help="解码器ONNX模型路径",
)
parser.add_argument(
"--vocab-size",
type=int,
default=10019,
help="词汇表大小 (默认: 10019)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="推理设备 (默认: cpu)",
)
parser.add_argument(
"--interactive",
action="store_true",
default=True,
help="交互模式 (默认: True)",
)
parser.add_argument("--test", action="store_true", help="运行测试推理")
parser.add_argument(
"--beam",
action="store_true",
help="使用束搜索解码 (默认: 单步解码)",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="束大小 (默认: 5)",
)
args = parser.parse_args()
# 检查文件是否存在
if not os.path.exists(args.context_encoder):
print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}")
sys.exit(1)
if not os.path.exists(args.decoder):
print(f"❌ 错误: 解码器文件不存在: {args.decoder}")
sys.exit(1)
# 初始化推理器
inference = ONNXInference(
context_encoder_path=args.context_encoder,
decoder_path=args.decoder,
vocab_size=args.vocab_size,
device=args.device,
use_beam_search=args.beam,
beam_size=args.beam_size,
)
# 测试推理
if args.test:
print("\n🧪 运行测试推理...")
print("测试场景: 输入'shanghai',已确认第一个字'',继续输入'tian'")
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
predictions, timing_info = inference.predict(
context_prompts=["张三", "李四"],
text_before="今天天气",
text_after="很好",
pinyin="tian",
slot_chars=[""],
use_beam_search=args.beam,
beam_size=args.beam_size,
)
print(f"\n⏱️ 执行时间统计:")
print("-" * 40)
for stage, duration in timing_info.items():
if duration > 0:
print(f" {stage:<12}: {duration:>8.2f} ms")
print("-" * 40)
print(f"\nTop-5 结果:")
for i, (char, prob, idx) in enumerate(predictions[:5]):
if char == "//":
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
else:
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
# 交互模式
if args.interactive:
inference.interactive_mode()
if __name__ == "__main__":
main()

View File

@ -14,6 +14,7 @@ dependencies = [
"onnxruntime>=1.24.2",
"pandas>=3.0.0",
"plotly>=5.0.0",
"jieba>=0.42.1",
"pypinyin>=0.55.0",
"requests>=2.32.5",
"rich>=14.3.1",
@ -23,11 +24,14 @@ dependencies = [
"transformers==5.1.0",
"typer>=0.21.1",
"waitress>=3.0.2",
"onnx>=1.21.0",
]
[project.scripts]
train-model = "model.trainer:app"
monitor-training = "model.monitor:app"
preprocess-model = "model.preprocess:main"
inspect-preprocessed = "model.inspect_preprocessed:main"
[tool.uv]
# 设置当前项目的默认索引源

20648
rank_freq.csv Normal file

File diff suppressed because it is too large Load Diff

0
samples.csv Normal file
View File

View File

@ -0,0 +1,266 @@
#!/usr/bin/env python3
"""
Analyze frequency distribution in pinyin_char_statistics.json
"""
import json
import sys
import os
import math
from collections import defaultdict
from pathlib import Path
def main():
# Path to the JSON file
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if not json_path.exists():
print(f"Error: File not found: {json_path}")
sys.exit(1)
print(f"Loading {json_path}...")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"Timestamp: {data.get('timestamp')}")
print(f"Total characters: {data.get('total_characters')}")
print(f"Total pinyins: {data.get('total_pinyins')}")
print(f"Valid input character count: {data.get('valid_input_character_count')}")
pairs = data.get("pairs", {})
print(f"Number of pairs: {len(pairs)}")
# Extract counts and IDs
counts = []
id_to_count = {}
char_to_count = {}
for key, pair in pairs.items():
try:
char_id = pair.get("id")
count = pair.get("count")
char = pair.get("char", "")
if char_id is not None and count is not None:
counts.append(count)
id_to_count[char_id] = count
if char:
char_to_count[char] = count
except (ValueError, TypeError) as e:
print(f"Warning: Could not parse pair {key}: {e}")
continue
if not counts:
print("No valid count data found.")
return
# Basic statistics
min_count = min(counts)
max_count = max(counts)
total_count = sum(counts)
mean_count = total_count / len(counts)
# Sort counts for percentiles
sorted_counts = sorted(counts)
n = len(sorted_counts)
# Percentiles
p10 = sorted_counts[int(0.1 * n)]
p25 = sorted_counts[int(0.25 * n)]
p50 = sorted_counts[int(0.5 * n)]
p75 = sorted_counts[int(0.75 * n)]
p90 = sorted_counts[int(0.9 * n)]
p99 = sorted_counts[int(0.99 * n)]
# Variance and std dev
variance = sum((x - mean_count) ** 2 for x in counts) / n
std_dev = math.sqrt(variance)
print("\n=== BASIC STATISTICS ===")
print(f"Min frequency: {min_count}")
print(f"Max frequency: {max_count}")
print(f"Mean frequency: {mean_count:.2f}")
print(f"Standard deviation: {std_dev:.2f}")
print(f"Total frequency sum: {total_count}")
print(f"Number of entries: {n}")
print("\n=== PERCENTILES ===")
print(f"10th percentile: {p10}")
print(f"25th percentile: {p25}")
print(f"50th percentile (median): {p50}")
print(f"75th percentile: {p75}")
print(f"90th percentile: {p90}")
print(f"99th percentile: {p99}")
# Find IDs with min and max counts
min_ids = [id for id, count in id_to_count.items() if count == min_count]
max_ids = [id for id, count in id_to_count.items() if count == max_count]
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
print(f"IDs with max frequency ({max_count}): {max_ids}")
# Check if IDs are assigned in frequency order
# Compute correlation between ID and count
ids = list(id_to_count.keys())
id_counts = [id_to_count[id] for id in ids]
# Sort by ID and check if counts are decreasing
sorted_by_id = sorted(ids)
counts_by_id = [id_to_count[id] for id in sorted_by_id]
# Calculate monotonicity: count of times count decreases as ID increases
decreases = 0
increases = 0
for i in range(1, len(counts_by_id)):
if counts_by_id[i] < counts_by_id[i - 1]:
decreases += 1
elif counts_by_id[i] > counts_by_id[i - 1]:
increases += 1
print(f"\n=== ID ORDER ANALYSIS ===")
print(f"Total pairs: {len(counts_by_id)}")
print(f"Decreases as ID increases: {decreases} times")
print(f"Increases as ID increases: {increases} times")
print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%")
# Check if IDs are roughly sorted by frequency
# Compute Spearman rank correlation (simplified)
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
# Average rank difference
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
max_rank_diff = max(rank_diffs)
print(
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
)
print(f"Maximum rank difference: {max_rank_diff}")
# Analyze specific ID range 5000-5500
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
range_counts = []
range_ids = []
for id in range(5000, 5501):
if id in id_to_count:
range_counts.append(id_to_count[id])
range_ids.append(id)
if range_counts:
range_min = min(range_counts)
range_max = max(range_counts)
range_mean = sum(range_counts) / len(range_counts)
range_sorted = sorted(range_counts)
range_n = len(range_counts)
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
print(f"IDs in range 5000-5500: {len(range_counts)}")
print(f"Min frequency in range: {range_min}")
print(f"Max frequency in range: {range_max}")
print(f"Mean frequency in range: {range_mean:.2f}")
print(f"10th percentile in range: {range_p10}")
print(f"50th percentile in range: {range_p50}")
print(f"90th percentile in range: {range_p90}")
# Find IDs with min frequency in this range
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
print(
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
)
else:
print("No IDs found in range 5000-5500")
# Histogram of frequencies (log bins)
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
if max_count > 0:
log_min = math.log10(min_count) if min_count > 0 else 0
log_max = math.log10(max_count)
num_bins = 20
bin_edges = [
10 ** (log_min + i * (log_max - log_min) / num_bins)
for i in range(num_bins + 1)
]
hist = [0] * num_bins
for count in counts:
if count > 0:
log_val = math.log10(count)
bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * num_bins),
num_bins - 1,
)
hist[bin_idx] += 1
print("Log-scale histogram (count range -> frequency count):")
for i in range(num_bins):
if hist[i] > 0:
lower = bin_edges[i]
upper = bin_edges[i + 1]
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
# Check for zero or near-zero frequencies
zero_count = sum(1 for c in counts if c == 0)
low_count = sum(1 for c in counts if 0 < c <= 10)
very_low_count = sum(1 for c in counts if 0 < c <= 100)
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
print(f"Entries with zero frequency: {zero_count}")
print(f"Entries with frequency <= 10: {low_count}")
print(f"Entries with frequency <= 100: {very_low_count}")
# Find the actual min frequency (excluding zeros if any)
non_zero_counts = [c for c in counts if c > 0]
if non_zero_counts:
actual_min = min(non_zero_counts)
print(f"Actual min frequency (non-zero): {actual_min}")
actual_min_ids = [
id for id, count in id_to_count.items() if count == actual_min
]
print(
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
)
# Summary for smoothing algorithm design
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
print(
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
)
print(f"Most entries ({p50}) have frequency around {p50}")
print(f"Top 10% of entries have frequency > {p90}")
print(f"Bottom 10% of entries have frequency < {p10}")
print(
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
)
# Save detailed data for further analysis
output_file = "frequency_analysis_results.txt"
with open(output_file, "w", encoding="utf-8") as f:
f.write("Frequency Analysis Results\n")
f.write("=" * 50 + "\n")
f.write(f"Min frequency: {min_count}\n")
f.write(f"Max frequency: {max_count}\n")
f.write(f"Mean frequency: {mean_count:.2f}\n")
f.write(f"Standard deviation: {std_dev:.2f}\n")
f.write(f"10th percentile: {p10}\n")
f.write(f"50th percentile: {p50}\n")
f.write(f"90th percentile: {p90}\n")
f.write(
f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
)
f.write(
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
)
print(f"\nDetailed results saved to {output_file}")
if __name__ == "__main__":
main()

135
scripts/analyze_range.py Normal file
View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
"""
Analyze specific ID ranges in pinyin_char_statistics.json
"""
import json
import sys
from pathlib import Path
def main():
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get("pairs", {})
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
# Analyze range 5000-5500 in detail
print("ID range 5000-5500 detailed analysis:")
print("ID\tCount\tChar\tPinyin")
range_data = []
for id in range(5000, 5501):
if id in id_to_count:
# Find the pair to get char and pinyin
for key, pair in pairs.items():
if pair.get("id") == id:
char = pair.get("char", "")
pinyin = pair.get("pinyin", "")
count = pair.get("count", 0)
range_data.append((id, count, char, pinyin))
if id % 100 == 0: # Print every 100th for overview
print(f"{id}\t{count}\t{char}\t{pinyin}")
break
# Print min and max in range
if range_data:
min_item = min(range_data, key=lambda x: x[1])
max_item = max(range_data, key=lambda x: x[1])
print(
f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
)
print(
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
)
# Check if frequencies are monotonic in this range
counts = [item[1] for item in range_data]
increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1))
decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1))
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
# Check for frequency plateaus
from collections import Counter
freq_count = Counter(counts)
most_common = freq_count.most_common(5)
print(f"Most common frequencies in range: {most_common}")
# Analyze the tail (IDs with frequency 1)
print("\n\nAnalysis of frequency=1 entries:")
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
print(f"Number of entries with frequency=1: {len(freq_one_ids)}")
if freq_one_ids:
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
print(f"First 10 IDs: {freq_one_ids[:10]}")
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
# Check if they're contiguous
sorted_ids = sorted(freq_one_ids)
contiguous = all(
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
)
print(f"Are they contiguous IDs? {contiguous}")
# Sample some characters
print("\nSample characters with frequency=1:")
sample_count = 0
for key, pair in pairs.items():
if pair.get("count") == 1 and sample_count < 10:
print(
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
)
sample_count += 1
# Check overall ID-frequency ordering
print("\n\nOverall ID-frequency ordering analysis:")
all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids]
# Count monotonic segments
non_increasing_segments = 0
current_segment_length = 1
for i in range(1, len(all_counts)):
if all_counts[i] <= all_counts[i - 1]:
current_segment_length += 1
else:
if current_segment_length > 1:
non_increasing_segments += 1
current_segment_length = 1
if current_segment_length > 1:
non_increasing_segments += 1
print(f"Total IDs: {len(all_ids)}")
print(f"Non-increasing segments: {non_increasing_segments}")
# Check for frequency plateaus overall
from collections import Counter
overall_freq_count = Counter(all_counts)
plateaus = [
(freq, count) for freq, count in overall_freq_count.items() if count > 1
]
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
for freq, count in plateaus_sorted:
print(f" {freq}: {count} IDs")
if __name__ == "__main__":
main()

View File

@ -1,10 +1,15 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader
from model.trainer import collate_fn, worker_init_fn
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
dataloader = DataLoader(
data,
@ -18,5 +23,5 @@ dataloader = DataLoader(
)
for i in dataloader:
print((i['labels'] == 1).sum())
print((i["labels"] == 1).sum())
break

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""
Comprehensive frequency distribution analysis
"""
import json
import sys
import math
from collections import Counter
from pathlib import Path
def main():
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get("pairs", {})
# Extract counts
counts = []
for key, pair in pairs.items():
count = pair.get("count")
if count is not None:
counts.append(count)
n = len(counts)
print(f"Total entries: {n}")
# Sort descending for rank-frequency analysis
counts_sorted_desc = sorted(counts, reverse=True)
# Basic statistics
min_count = min(counts)
max_count = max(counts)
mean_count = sum(counts) / n
# Percentiles
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
print("\n=== PERCENTILE DISTRIBUTION ===")
for p in percentiles:
idx = int(p * n)
value = counts_sorted_desc[idx]
print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})")
# Cumulative distribution
print("\n=== CUMULATIVE DISTRIBUTION ===")
thresholds = [
1,
2,
3,
5,
10,
20,
50,
100,
200,
500,
1000,
2000,
5000,
10000,
20000,
50000,
100000,
200000,
500000,
1000000,
5000000,
10000000,
50000000,
100000000,
500000000,
]
for thresh in thresholds:
if thresh > max_count:
break
below = sum(1 for c in counts if c <= thresh)
above = sum(1 for c in counts if c >= thresh)
print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)")
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
# Check min_count=109 parameter
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
below_109 = sum(1 for c in counts if c < 109)
at_or_above_109 = sum(1 for c in counts if c >= 109)
print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
print(
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
)
# If 109 is a threshold, what's the actual min among those >= 109?
counts_ge_109 = [c for c in counts if c >= 109]
if counts_ge_109:
actual_min_ge_109 = min(counts_ge_109)
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
# Rank-frequency analysis (Zipf's law)
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
for rank in range(1, 101):
freq = counts_sorted_desc[rank - 1]
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
# Frequency spectrum (how many distinct frequencies)
freq_counter = Counter(counts)
print(f"\n=== FREQUENCY SPECTRUM ===")
print(f"Distinct frequency values: {len(freq_counter)}")
# Most common frequencies
print("\nTop 20 most common frequencies (plateau sizes):")
for freq, freq_count in freq_counter.most_common(20):
print(f" Frequency {freq}: {freq_count} entries")
# Analyze ID ranges
print("\n=== ID RANGE ANALYSIS ===")
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
ranges = [
(0, 100, "Top 100 IDs"),
(100, 500, "IDs 100-500"),
(500, 1000, "IDs 500-1000"),
(1000, 2000, "IDs 1000-2000"),
(2000, 5000, "IDs 2000-5000"),
(5000, 5500, "IDs 5000-5500 (user mentioned)"),
(5500, 6000, "IDs 5500-6000"),
(10000, 10500, "IDs 10000-10500"),
(15000, 15500, "IDs 15000-15500"),
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
(19499, 20647, "IDs with freq=1"),
]
for start, end, label in ranges:
range_counts = [
id_to_count[id] for id in range(start, end) if id in id_to_count
]
if range_counts:
min_c = min(range_counts)
max_c = max(range_counts)
mean_c = sum(range_counts) / len(range_counts)
median_c = sorted(range_counts)[len(range_counts) // 2]
print(
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
)
# Check if IDs are perfectly sorted by frequency
print("\n=== ID ORDER VERIFICATION ===")
all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids]
# Check for any violations of non-increasing order
violations = 0
for i in range(1, len(all_counts)):
if all_counts[i] > all_counts[i - 1]:
violations += 1
if violations <= 5:
print(
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
)
print(f"Total violations of non-increasing order: {violations}")
# Check if equal frequencies are grouped together
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
current_freq = None
group_start = None
group_sizes = []
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
if count != current_freq:
if current_freq is not None:
group_sizes.append(
(current_freq, group_start, all_ids[i - 1], i - group_start)
)
current_freq = count
group_start = i
# Last group
if current_freq is not None:
group_sizes.append(
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
)
# Sort groups by size
group_sizes.sort(key=lambda x: x[3], reverse=True)
print("Top 10 largest frequency groups (plateaus):")
for freq, start_id_idx, end_id, size in group_sizes[:10]:
start_id = all_ids[start_id_idx]
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
# Summary for smoothing algorithm
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
print("1. IDs are perfectly sorted by frequency (non-increasing).")
print(
f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
)
print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
print(
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
)
print("7. Large frequency plateaus exist (many IDs share same frequency).")
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
# Save data for plotting
with open("rank_freq.csv", "w") as f:
f.write("rank,frequency\n")
for rank, freq in enumerate(counts_sorted_desc, 1):
f.write(f"{rank},{freq}\n")
print("\nRank-frequency data saved to rank_freq.csv")
if __name__ == "__main__":
main()

409
scripts/finetune_slots.py Normal file
View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
临时迁移训练脚本在预训练模型基础上重新训练支持冻结 context_encoder
src/model/trainer.py train 命令行为完全一致额外增加
- --pretrained-checkpoint: 加载预训练权重必需的迁移学习源
- --freeze-context-encoder: 冻结 context_encoder 默认开启
运行方式:
python scripts/finetune_slots.py \
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
--train-data-path /path/to/train_data \
--eval-data-path /path/to/eval_data \
--output-dir ./finetune_output \
--freeze-context-encoder
"""
import argparse
import json
import os
import random
import sys
from datetime import datetime
from pathlib import Path
import numpy as np
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.model import InputMethodEngine
from model.trainer import (
Trainer,
create_dataloader,
worker_init_fn,
)
from model.dataset import PinyinInputDataset
from model.preprocessed_dataset import (
PreProcessedDataset,
is_preprocessed_data,
)
from loguru import logger
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
def main():
parser = argparse.ArgumentParser(
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# === 数据参数 ===
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
parser.add_argument("--max-iter-length", type=int,
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
# === 迁移学习参数 ===
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
help="预训练模型检查点路径")
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
help="冻结 context_encoder 层 (默认开启)")
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
action="store_false",
help="不冻结 context_encoder")
# === 训练参数 ===
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
help="学习率")
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
help="最小学习率")
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
parser.add_argument("--label-smoothing", type=float, default=0.1,
help="标签平滑参数")
parser.add_argument("--grad-accum-steps", type=int, default=1,
help="梯度累积步数")
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
help="梯度裁剪范数")
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
# === 其他参数 ===
parser.add_argument("--mixed-precision", action="store_true", default=True)
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
action="store_false", help="禁用混合精度")
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
parser.add_argument("--tensorboard", action="store_true", default=True)
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
help="禁用 TensorBoard")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--compile", action="store_true", default=False,
help="使用 torch.compile 优化")
parser.add_argument("--moe-mode", default="all",
choices=["all", "sparse", "sparse_allow_graph"],
help="MoE 计算策略")
args = parser.parse_args()
# ================================================================
# 初始化
# ================================================================
torch.multiprocessing.set_sharing_strategy("file_system")
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
console = Console()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# ================================================================
# 模型常量 (与 trainer.py 保持一致)
# ================================================================
vocab_size = 10019
pinyin_vocab_size = 30
dim = 512
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 10
max_seq_len = 128
# ================================================================
# 打印配置
# ================================================================
console.print(Panel.fit(
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
config_table = Table(show_header=True, header_style="bold magenta")
config_table.add_column("Category", style="cyan")
config_table.add_column("Parameter", style="green")
config_table.add_column("Value", style="yellow")
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
config_table.add_row("迁移学习", "冻结 context_encoder",
str(args.freeze_context_encoder))
config_table.add_row("数据", "训练数据路径", args.train_data_path)
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
config_table.add_row("数据", "输出目录", args.output_dir)
config_table.add_row("数据", "批次大小", str(args.batch_size))
config_table.add_row("数据", "Worker数量", str(args.num_workers))
config_table.add_row("模型", "MoE策略", args.moe_mode)
config_table.add_row("模型", "编译优化", str(args.compile))
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
# ================================================================
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
# ================================================================
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(args.train_data_path,
max_cache_shards=2)
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
shuffle_train = not pre_shuffled
if args.max_iter_length > 0:
capped_samples = min(len(train_dataset), args.max_iter_length)
else:
capped_samples = len(train_dataset)
total_steps = (capped_samples // args.batch_size) * args.num_epochs
train_num_workers = min(args.num_workers, 1)
logger.info(
f"Preprocessed dataset: {len(train_dataset):,} samples, "
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
f"workers={train_num_workers}, steps={total_steps:,}")
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
else:
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=-1,
max_iter_length=args.max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
total_steps = int(args.max_iter_length *
args.num_epochs / args.batch_size)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.max_iter_length,
)
config_table.add_row("数据", "训练数据类型", "流式数据")
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(args.eval_data_path,
max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=0,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)
config_table.add_row("数据", "评估数据类型", "预处理数据")
else:
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=-1,
max_iter_length=args.batch_size * 64,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
max_iter_length=args.batch_size * 64,
)
config_table.add_row("数据", "评估数据类型", "流式数据")
config_table.add_row("数据", "总步数", str(total_steps))
console.print(config_table)
# ================================================================
# 创建模型并加载预训练权重
# ================================================================
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InputMethodEngine(
vocab_size=vocab_size,
pinyin_vocab_size=pinyin_vocab_size,
dim=dim,
num_slots=num_slots,
n_layers=n_layers,
n_heads=n_heads,
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=args.compile,
moe_mode=args.moe_mode,
)
model.to(device)
# 加载预训练权重
pretrained_path = Path(args.pretrained_checkpoint)
if not pretrained_path.exists():
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
sys.exit(1)
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
if "model_state_dict" in checkpoint:
pretrained_weights = checkpoint["model_state_dict"]
else:
pretrained_weights = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(
pretrained_weights, strict=False)
if missing_keys:
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
f"{missing_keys[:5]}...[/yellow]")
if unexpected_keys:
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
f"{unexpected_keys[:5]}...[/yellow]")
console.print(f"[green]✓ 预训练权重加载完成[/green]")
# ================================================================
# 冻结 context_encoder
# ================================================================
if args.freeze_context_encoder:
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
frozen_count = 0
trainable_count = 0
for name, param in model.named_parameters():
if name.startswith("context_encoder"):
param.requires_grad = False
frozen_count += param.numel()
else:
trainable_count += param.numel()
total_params = frozen_count + trainable_count
console.print(f"[green]✓ context_encoder 已冻结[/green]")
logger.info(
f"冻结参数: {frozen_count:,} / {total_params:,} "
f"({frozen_count / total_params * 100:.1f}%), "
f"可训练参数: {trainable_count:,} / {total_params:,} "
f"({trainable_count / total_params * 100:.1f}%)")
else:
logger.info("未冻结任何层,全模型参与训练")
# ================================================================
# 保存配置
# ================================================================
config = {
"pretrained_checkpoint": args.pretrained_checkpoint,
"freeze_context_encoder": args.freeze_context_encoder,
"train_data_path": args.train_data_path,
"eval_data_path": args.eval_data_path,
"output_dir": args.output_dir,
"batch_size": args.batch_size,
"num_epochs": args.num_epochs,
"learning_rate": args.learning_rate,
"min_learning_rate": args.min_learning_rate,
"weight_decay": args.weight_decay,
"warmup_ratio": args.warmup_ratio,
"label_smoothing": args.label_smoothing,
"grad_accum_steps": args.grad_accum_steps,
"clip_grad_norm": args.clip_grad_norm,
"eval_frequency": args.eval_frequency,
"save_frequency": args.save_frequency,
"mixed_precision": args.mixed_precision,
"num_workers": args.num_workers,
"use_tensorboard": args.tensorboard,
"seed": args.seed,
"compile": args.compile,
"moe_mode": args.moe_mode,
"total_steps": total_steps,
"vocab_size": vocab_size,
"pinyin_vocab_size": pinyin_vocab_size,
"dim": dim,
"num_slots": num_slots,
"n_layers": n_layers,
"n_heads": n_heads,
"num_experts": num_experts,
"max_seq_len": max_seq_len,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
}
config_file = output_path / "training_config.json"
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {config_file}")
# ================================================================
# 创建 Trainer 并开始训练
# ================================================================
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=total_steps,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
min_learning_rate=args.min_learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
label_smoothing=args.label_smoothing,
grad_accum_steps=args.grad_accum_steps,
clip_grad_norm=args.clip_grad_norm,
eval_frequency=args.eval_frequency,
save_frequency=args.save_frequency,
mixed_precision=args.mixed_precision,
device=device,
use_tensorboard=args.tensorboard,
status_file="training_status.json",
)
console.print("[green]✓ 训练器创建完成[/green]")
console.print("\n[bold cyan]开始训练...[/bold cyan]")
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
try:
trainer.train(
resume_from=None,
reset_training_state=False,
auto_resume=False, # 迁移学习从头开始,不自动恢复
)
except KeyboardInterrupt:
console.print("[bold green]训练被终止[/bold green]")
trainer.save_checkpoint("interrupted_model.pt")
console.print("[bold green]✓ 训练完成![/bold green]")
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
console.print(f"模型和日志保存在: {args.output_dir}")
if __name__ == "__main__":
main()

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Dict, Any
import sys
def modify_pinyin_statistics(file_path: Path) -> None:
"""
一次性修改拼音统计JSON文件
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
"""
# 1. 加载原数据
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
data: Dict[str, Any] = json.load(f)
except FileNotFoundError:
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
"id": 0,
"char": "",
"pinyin": "",
"count": original_zero_count + 1 # 原count + 1
"count": original_zero_count + 1, # 原count + 1
}
# 3.2 处理其他所有记录键和id都+1
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 这里保持原时间戳不变,因为是一次性修改
# 写回文件,保持可读格式
backup_path = file_path.with_suffix('.json.bak')
backup_path = file_path.with_suffix(".json.bak")
try:
# 先备份原文件
import shutil
shutil.copy2(file_path, backup_path)
print(f"已创建备份: {backup_path}")
# 写入新数据
with open(file_path, 'w', encoding='utf-8') as f:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"修改完成!")
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 使用示例
if __name__ == "__main__":
# 假设你的JSON文件在当前目录
json_file = Path("./src/model/assets/pinyin_char_statistics.json")
# JSON文件相对于项目根目录
json_file = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
# 执行修改
modify_pinyin_statistics(json_file)
# 验证修改:读取并显示前几条记录
print("\n验证前5条记录:")
with open(json_file, 'r', encoding='utf-8') as f:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
for i in range(5):

View File

@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""
Visualize frequency distribution with ASCII plots
"""
import json
import math
import sys
from pathlib import Path
def ascii_histogram(data, bins=20, width=60):
"""Create ASCII histogram"""
if not data:
return ""
min_val = min(data)
max_val = max(data)
# Use log bins for wide range
if max_val / min_val > 1000:
log_min = math.log10(min_val) if min_val > 0 else 0
log_max = math.log10(max_val)
bin_edges = [
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
]
hist = [0] * bins
for val in data:
if val > 0:
log_val = math.log10(val)
bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
)
hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
else:
bin_width = (max_val - min_val) / bins
bin_edges = [min_val + i * bin_width for i in range(bins + 1)]
hist = [0] * bins
for val in data:
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1)
hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)]
max_count = max(hist)
result = []
for i in range(bins):
if hist[i] == 0:
continue
bar = "#" * int(hist[i] / max_count * width)
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
return "\n".join(result)
def main():
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get("pairs", {})
counts = [
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
]
print("FREQUENCY DISTRIBUTION ANALYSIS")
print("=" * 60)
print("\n1. ASCII Histogram (log bins):")
print(ascii_histogram(counts, bins=20, width=60))
# Rank-frequency plot in ASCII
print("\n2. Rank-Frequency Relationship (Top 50):")
counts_sorted_desc = sorted(counts, reverse=True)
max_freq = counts_sorted_desc[0]
max_rank = 50
for rank in range(1, max_rank + 1):
freq = counts_sorted_desc[rank - 1]
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
bar = "#" * bar_length
print(f"Rank {rank:3}: {freq:12} {bar}")
# ID vs Frequency plot (sampled)
print("\n3. ID vs Frequency (sampled every 500 IDs):")
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
all_ids = sorted(id_to_count.keys())
max_id = all_ids[-1]
print("ID Frequency log10(freq)")
for id in range(0, max_id + 1, 500):
if id in id_to_count:
freq = id_to_count[id]
log_freq = math.log10(freq) if freq > 0 else 0
bar = "#" * int(log_freq / math.log10(max_freq) * 40)
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
# Zipf's law fit
print("\n4. Zipf's Law Analysis:")
print(" Rank * Frequency ≈ constant for Zipf's law")
print(" Top 10 ranks:")
for rank in range(1, 11):
freq = counts_sorted_desc[rank - 1]
product = rank * freq
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
# Check if product is roughly constant
products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
avg_product = sum(products) / len(products)
std_product = math.sqrt(
sum((p - avg_product) ** 2 for p in products) / len(products)
)
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
# Frequency spectrum
from collections import Counter
freq_counter = Counter(counts)
print("\n5. Frequency Spectrum (how many entries have each frequency):")
print(" Frequency Count Cumulative")
cum = 0
for freq in sorted(freq_counter.keys())[:20]:
count = freq_counter[freq]
cum += count
print(f" {freq:10} {count:6} {cum:6}")
# Summary statistics
print("\n6. Key Statistics:")
n = len(counts)
print(f" Total entries: {n}")
print(f" Min frequency: {min(counts)}")
print(f" Max frequency: {max(counts)}")
print(f" Ratio max/min: {max(counts) / min(counts):.2e}")
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
for p in percentiles:
idx = int(p * n)
value = counts_sorted_desc[idx]
print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})")
# Save data for external plotting
with open("id_vs_freq.csv", "w") as f:
f.write("id,frequency\n")
for id in sorted(id_to_count.keys()):
f.write(f"{id},{id_to_count[id]}\n")
print("\nData saved to id_vs_freq.csv for external plotting")
if __name__ == "__main__":
main()

189
src/analyze_data.py Normal file
View File

@ -0,0 +1,189 @@
#!/usr/bin/env python3
"""
数据分布查看脚本
主要关注 labels 0-10, 200-210, 2900-2910, 5100-5110, 9520-9530 的分布情况
使用方法:
uv run python src/analyze_data.py --data_path ./data/train.txt
注意:
data_path 可以是:
- 本地文本文件 (每行一个文本)
- HuggingFace 数据集路径
- 目录路径 (该目录下需要有 dataset_info.json 或类似配置文件)
"""
import os
import random
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from model.dataset import PinyinInputDataset
from model.query import QueryEngine
_id2char_cache = {}
def get_char_by_id(query_engine: QueryEngine, char_id: int) -> str:
if char_id == 0:
return "<EOS>"
if char_id not in _id2char_cache:
info = query_engine.query_by_id(char_id)
_id2char_cache[char_id] = info.char if info else f"<ID:{char_id}>"
return _id2char_cache[char_id]
def analyze_label_distribution(
dataset: PinyinInputDataset,
sample_size: int = 10000,
query_engine: QueryEngine = None,
):
"""分析 label 在指定区间的分布"""
target_ranges = [
(0, 10),
(200, 210),
(2900, 2910),
(5100, 5110),
(9520, 9530),
]
id_counts = defaultdict(int)
all_examples = []
dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
total_count = 0
sample_collected = 0
for batch in dataloader:
if sample_collected >= sample_size:
break
label = batch["label"].item()
prefix = batch["prefix"][0]
suffix = batch["suffix"][0]
pinyin = batch["pinyin"][0]
history = batch["history_slot_ids"][0].tolist()
part4 = prefix.split("^")[0] if "^" in prefix else ""
sample_collected += 1
total_count += 1
in_target_range = False
for start, end in target_ranges:
if start <= label <= end:
id_counts[label] += 1
in_target_range = True
if len(all_examples) < 200:
label_char = (
get_char_by_id(query_engine, label) if query_engine else f"<ID:{label}>"
)
history_chars = (
[get_char_by_id(query_engine, hid) for hid in history]
if query_engine
else history
)
all_examples.append(
{
"label": label,
"label_char": label_char,
"prefix": prefix,
"suffix": suffix,
"pinyin": pinyin,
"history": history,
"history_chars": history_chars,
"part4": part4,
}
)
print(f"\n{'=' * 60}")
print(f"采样总数: {total_count}")
print(f"{'=' * 60}\n")
print("Label ID 分布统计:")
print("-" * 50)
for start, end in target_ranges:
print(f"\n区间 [{start:5d} - {end:5d}]:")
for label_id in range(start, end + 1):
count = id_counts[label_id]
percentage = (count / total_count * 100) if total_count > 0 else 0
bar = "" * min(count, 50)
print(f" ID {label_id:5d}: {count:6d} ({percentage:.3f}%) {bar}")
print("\n")
print("=" * 80)
print("随机抽取 20 个样本详情:")
print("=" * 80)
random.shuffle(all_examples)
for idx, ex in enumerate(all_examples[:20], 1):
print(f"\n样本 {idx}:")
print(f" Label: {ex['label']} ({ex['label_char']})")
print(f" Part4: {ex['part4']}")
print(f" 光标前: {ex['prefix']}")
print(f" 光标后: {ex['suffix']}")
print(f" 拼音: {ex['pinyin']}")
print(f" 历史槽位: {ex['history']}")
print(f" 历史汉字: {ex['history_chars']}")
def main():
import argparse
default_data_path = os.path.expanduser("~/Data/corpus/CCI-Data/")
parser = argparse.ArgumentParser(
description="分析数据集 label 分布",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=f"""
默认数据集路径: {default_data_path}
示例:
uv run python src/analyze_data.py --data_path {default_data_path} --sample_size 10000
uv run python src/analyze_data.py --data_path ./data/eval.txt --sample_size 5000
""",
)
parser.add_argument(
"--data_path",
type=str,
default=default_data_path,
help="数据集路径 (本地文件或HuggingFace路径)",
)
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
parser.add_argument(
"--max_workers", type=int, default=-1, help="DataLoader workers"
)
args = parser.parse_args()
print(f"加载数据集: {args.data_path}")
print()
try:
dataset = PinyinInputDataset(
data_path=args.data_path,
max_workers=args.max_workers,
max_iter_length=args.sample_size,
)
except Exception as e:
print(f"\n错误: 无法加载数据集 '{args.data_path}'")
print(f"原因: {e}")
print("\n请确保:")
print(" 1. 数据路径存在且正确")
print(" 2. 如果是本地文本文件,每行应该是一个 JSON 对象或纯文本")
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
return
query_engine = QueryEngine()
query_engine.load()
analyze_label_distribution(
dataset, sample_size=args.sample_size, query_engine=query_engine
)
if __name__ == "__main__":
main()

View File

@ -8,7 +8,7 @@
"id": 0,
"char": "",
"pinyin": "",
"count": 11067734826
"count": 494748360
},
"1": {
"id": 1,

View File

@ -53,27 +53,7 @@ class PinyinLSTMEncoder(nn.Module):
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, input_dim] pinyin embeddings
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
Returns:
output: [batch, seq_len, input_dim] 每位置的拼音编码
"""
total_len = x.size(1)
if mask is not None:
lengths = mask.sum(dim=1).cpu().clamp(min=1)
packed = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False
)
packed_out, (hidden, cell) = self.lstm(packed)
output, _ = nn.utils.rnn.pad_packed_sequence(
packed_out, batch_first=True, total_length=total_len
)
else:
output, (hidden, cell) = self.lstm(x)
output, _ = self.lstm(x)
projected = self.proj(output)
return self.layer_norm(projected)
@ -341,15 +321,37 @@ class CrossAttentionFusion(nn.Module):
# 4. 专家混合层 (MoE Layer)
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
# ------------------------------------------------------------------
@torch.compiler.allow_in_graph
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
output = torch.zeros_like(x_flat)
for e in range(num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
return output
class MoELayer(nn.Module):
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
"""
moe_mode 支持三种策略:
- "all": 计算全部专家torch.compile 不断裂 (当前默认)
- "sparse": 只计算被路由到的专家 (产生 graph break)
- "sparse_allow_graph": 稀疏 MoE通过 allow_in_graph 避免 graph break
"""
def __init__(
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.dim = dim
self.moe_mode = moe_mode
# Import Expert from your existing components
# Assuming Expert class is defined as in components.py [2]
self.experts = nn.ModuleList(
[
Expert(
@ -362,48 +364,45 @@ class MoELayer(nn.Module):
]
)
# Gating Network [2]
self.gate = nn.Linear(dim, num_experts)
def forward(self, x):
"""
并行化 MoE 前向传播完全兼容 torch.compile AMP
Args:
x: [batch, seq_len, dim]
Returns:
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
num_tokens = B * L
x_flat = x.view(num_tokens, D)
# 展平输入以便处理
x_flat = x.view(num_tokens, D) # [B*L, D]
gates = self.gate(x_flat)
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
topk_weights = F.softmax(topk_weights, dim=-1)
# 1. 计算门控分数
gates = self.gate(x_flat) # [B*L, num_experts]
if self.moe_mode == "all":
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
)
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
out_flat = weighted_outputs.sum(dim=1)
# 2. 选择 Top-K 专家
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
elif self.moe_mode == "sparse":
out_flat = torch.zeros_like(x_flat)
for e in range(self.num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
out_flat.index_add_(
0,
idx,
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
)
# 归一化权重
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
elif self.moe_mode == "sparse_allow_graph":
out_flat = _sparse_moe_dispatch(
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
)
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
else:
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
# 4. 使用 gather 选择对应专家的输出
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
selected_outputs = torch.gather(
expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# 恢复原始形状
return out_flat.view(B, L, D)

View File

@ -1,8 +1,13 @@
import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba
import math
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import torch
@ -15,7 +20,6 @@ from torch.utils.data import IterableDataset
from .query import QueryEngine
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
@ -23,6 +27,26 @@ CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
jieba.setLogLevel(jieba.logging.INFO)
def segment_text(text: str) -> List[str]:
"""使用 jieba 分词,返回词列表"""
return list(jieba.cut(text, HMM=False))
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
result = []
pos = 0
for word in words:
start = pos
end = pos + len(word)
result.append((start, end))
pos = end
return result
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表
@ -43,15 +67,31 @@ class PinyinInputDataset(IterableDataset):
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.5,
retention_ratio: float = 0.8,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
merge_short_words_prob: float = 0.5,
merge_max_short_words: int = 3,
merge_max_total_chars: int = 6,
low_freq_repeat: float = 50.0,
high_freq_repeat: float = 0.1,
data_kwargs: Optional[Dict] = None,
target_labels: Optional[Set[int]] = None,
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 30_000_000
self.max_drop_prob = 0.8
self.repeat_end_freq = 10_000
self.max_repeat_expect = 50
# 频率调整参数 - 幂律平滑方案
self.min_freq = 109
self.low_freq_repeat = low_freq_repeat
self.high_freq_repeat = high_freq_repeat
self.word_break_prob = 0.10
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
self.merge_short_words_prob = merge_short_words_prob
self.merge_max_short_words = merge_max_short_words
self.merge_max_total_chars = merge_max_total_chars
self.data_kwargs = data_kwargs or {}
self.target_labels = target_labels
jieba.initialize()
self.tokenizer = AutoTokenizer.from_pretrained(
Path(str(files(__package__))) / "assets" / "tokenizer"
@ -61,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
load_kwargs = {"split": "train", "streaming": True}
load_kwargs.update(self.data_kwargs)
self.dataset = load_dataset(data_path, **load_kwargs)
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
@ -83,57 +125,48 @@ class PinyinInputDataset(IterableDataset):
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
# 计算幂律平滑参数
if self.max_freq > self.min_freq:
self.alpha = math.log(
self.low_freq_repeat / self.high_freq_repeat
) / math.log(self.max_freq / self.min_freq)
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
else:
self.alpha = 0.0
self.C = 1.0
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字)
if freq >= self.drop_start_freq:
# 线性丢弃概率计算
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
if max_freq == self.drop_start_freq:
drop_prob = 0.0
else:
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (max_freq - self.drop_start_freq)
)
if random.random() < drop_prob:
return 0
else:
return 1
"""削峰填谷 - 根据频率调整采样次数0表示丢弃
使用幂律平滑方案E(freq) = C × freq^(-α)
保持频率排序关系单个连续函数
"""
if freq <= 0:
return 0
# 2. 填谷处理(低频字)
elif freq <= self.repeat_end_freq:
# 线性重复期望计算
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0
else:
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 使用泊松分布实现随机重复
repeat_count = np.random.poisson(repeat_expect)
# 计算期望采样次数
expected = self.C * (freq ** (-self.alpha))
# 采样策略
if expected >= 1.0:
# 泊松分布重复
repeat_count = np.random.poisson(expected)
return max(1, repeat_count)
# 3. 中间频率字
else:
return 1
# 伯努利采样以概率expected返回1否则返回0
return 1 if random.random() < expected else 0
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本转换为拼音列表
将文本转换为拼音列表对整段文本调用 lazy_pinyin
利用 errors 回调确保一一对应对生僻字从 QueryEngine 回退
特性
1. 严格一一对应len(result) == len(text)
2. 高多音字准确率利用 pypinyin 内部的词语分词能力
3. 高性能预分配内存无多余对象创建
2. pypinyin 不认识的生僻字回退到 QueryEngine 最高频读音
3. 非汉字字符原样占位
Args:
text: 输入字符串
@ -144,58 +177,130 @@ class PinyinInputDataset(IterableDataset):
if not text:
return []
text_len = len(text)
# 2. 预分配结果列表,初始化占位符。
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
result: List[str] = [""] * text_len
def _fallback(chars):
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
result = []
for char in chars:
if self.query_engine.is_chinese_char(char):
ids = self.query_engine.query_by_char(char, limit=1)
if ids:
result.append(ids[0][1])
else:
result.append(char)
else:
result.append(char)
return result
# 3. 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
pinyin_list = lazy_pinyin(text, errors=_fallback)
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
# style=Style.NORMAL 获取不带声调的拼音
pinyin_list = lazy_pinyin(hanzi_segment)
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
if len(pinyin_list) != len(text):
logger.warning(
f"pinyin length mismatch: text_len={len(text)}, "
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
)
pinyin_list = []
for c in text:
result = lazy_pinyin(c, errors=_fallback)
pinyin_list.append(result[0] if result else c)
# 5. 健壮性兜底:
# 正常情况下pypinyin 返回的拼音数应等于汉字数。
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
return pinyin_list
# 6. 直接通过索引填充到预分配的位置
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 7. 填充非汉字字符
# 遍历原文,如果 result 对应位置为空,则填入原字符
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
style = random.random()
cumulative = 0.0
style_idx = 0
for i, w in enumerate(self.py_style_weight):
cumulative += w
if style < cumulative:
style_idx = i
break
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
break
else:
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
p=self.py_style_weight,
)
full_py = pinyin_list[i]
if style_idx == 0:
py = full_py
elif style_idx == 1:
py = to_initials(full_py)
if py == "":
py = pinyin_list[i][0]
mask_pinyin.append(py)
py = full_py[0]
else:
py = full_py[0]
mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
pinyin_ids = text_to_pinyin_ids(pinyin_str)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
return torch.tensor(pinyin_ids, dtype=torch.long)
def _build_single_sample(
self,
label: int,
history: list,
text: str,
word_start: int,
word_end: int,
part2: str,
pinyin_ids: torch.Tensor,
words: list,
) -> dict:
"""构造单条样本,每次调用都会重新随机采样上下文"""
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
part1 = text[word_start - part1_len : word_start]
# part3每次重新 roll
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + random.randint(1, 16)]
# part4每次重新 roll
part4 = ""
if random.random() > 0.7 and words:
num_words = random.randint(1, 3)
selected_words = random.sample(words, min(num_words, len(words)))
part4 = "|".join(selected_words)
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
)
# 确保 history 长度为 8
hist = list(history)
if len(hist) > 8:
hist = hist[-8:]
while len(hist) < 8:
hist.append(0)
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
@ -208,219 +313,295 @@ class PinyinInputDataset(IterableDataset):
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
# 安全检查如果worker_id >= num_workers则该worker不应该工作
# 这可能发生在self.max_workers小于实际worker数量时
if worker_id >= num_workers:
return # 产生空迭代器
return
# 使用局部变量存储分片数据集,避免竞争条件
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
try:
worker_dataset = self.dataset.shard(
num_shards=num_workers, index=worker_id
)
except (IndexError, ValueError):
worker_dataset = self.dataset
# 计算每个worker的配额
# 将 max_iter_length 转换为整数以确保整数除法
total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers
remainder = total_quota % num_workers
# 最后一个worker处理剩余的样本如果有余数
if worker_id == num_workers - 1:
worker_quota = base_quota + remainder
else:
worker_quota = base_quota
else:
# 单worker情况使用全部配额
worker_quota = int(self.max_iter_length)
num_workers = 1
worker_dataset = self.dataset # 不使用分片
worker_dataset = self.dataset
# 每个worker有自己的迭代计数器
current_iter_index = 0
batch_samples = []
for sample in worker_dataset:
# 检查是否达到最大迭代次数
if current_iter_index >= worker_quota:
break
text = sample.get(self.text_field, "")
if text:
pinyin_list = self.generate_pinyin(text)
for i in range(len(text)):
# 在开始处理每个字符前检查配额
if current_iter_index >= worker_quota:
break
if not text:
continue
labels = []
# 如果text[i]不在字符库中,则跳过
# 当i小于48时候则将part1取text[0:i]
# 当i大于48时候则将part1取text[i-48:i]
if not self.query_engine.is_chinese_char(text[i]):
continue
if i < 48:
part1 = text[0:i]
else:
part1 = text[i - 48 : i]
words = segment_text(text)
word_boundaries = build_word_boundaries(words)
pinyin_list = self.generate_pinyin(text)
# 方案C提前检查从位置i开始连续有多少个字符在词库中
max_valid_len = 0
for j in range(i, min(i + 8, len(text))):
if self.query_engine.is_chinese_char(text[j]):
max_valid_len += 1
else:
idx = 0
while idx < len(word_boundaries):
word_start, word_end = word_boundaries[idx]
char_positions = []
for i in range(word_start, word_end):
if self.query_engine.is_chinese_char(text[i]):
char_positions.append(i)
if not char_positions:
idx += 1
continue
word_len_chars = len(char_positions)
merge_end_idx = idx + 1
if word_len_chars <= 2:
accumulated_positions = list(char_positions)
accumulated_count = 1
next_idx = idx + 1
while next_idx < len(word_boundaries):
ns, ne = word_boundaries[next_idx]
next_positions = []
for i in range(ns, ne):
if self.query_engine.is_chinese_char(text[i]):
next_positions.append(i)
next_len = len(next_positions)
if next_len == 0 or next_len > 2:
break
if (
len(accumulated_positions) + next_len
> self.merge_max_total_chars
):
break
if accumulated_count + 1 > self.merge_max_short_words:
break
if random.random() > self.merge_short_words_prob:
break
# 如果没有可用字符,跳过
if max_valid_len == 0:
accumulated_positions.extend(next_positions)
accumulated_count += 1
next_idx += 1
if accumulated_count > 1:
char_positions = accumulated_positions
word_len_chars = len(char_positions)
merge_end_idx = next_idx
word_start = word_boundaries[idx][0]
word_end = word_boundaries[next_idx - 1][1]
should_break = (
word_len_chars > 1 and random.random() < self.word_break_prob
)
if should_break:
break_pos = random.randint(1, word_len_chars - 1)
else:
break_pos = word_len_chars
# ========== Phase 1: 前缀/整词 ==========
prefix_positions = char_positions[:break_pos]
prefix_text = "".join(text[i] for i in prefix_positions)
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
r = random.random()
if r < 0.9:
split_char = ""
elif r < 0.94:
split_char = "`"
elif r < 0.98:
split_char = "'"
else:
split_char = "-"
part2 = split_char.join(mask_pinyin)
pinyin_ids = self._compute_pinyin_ids(part2)
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(
text[i], pinyin_list[i]
).id
for i in prefix_positions
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
)
idx = merge_end_idx
continue
# 整词末尾 10% 概率追加 EOS破词前缀不加
if not should_break and random.random() <= 0.1:
labels.append(0)
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
processed_history = []
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
processed_history.append(label)
continue
# 首先取随机值pinyin_len1-8pinyin_len取值呈高斯分布最大概率取3
# 获取text[i + pinyin_len]字符如果无法获取所指向的后如果pinyin_len
# part2的长度为x取pinyin_list[i:i+pinyin_len]为part2
# 但是需要注意边界条件
target_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
# 根据实际可用长度调整
pinyin_len = min(target_len, max_valid_len)
repeats = max(1, int(base_repeats * weight))
py_end = min(i + pinyin_len, len(text))
pinyin_len, part2 = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
)
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=processed_history,
text=text,
word_start=word_start,
word_end=word_end,
part2=part2,
pinyin_ids=pinyin_ids,
words=words,
)
batch_samples.append(sample)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
processed_history.append(label)
part2 = split_char.join(part2)
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
# ========== Phase 2: 破词续接 ==========
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
cont_r = random.random()
cont_probs = self.cont_length_probs
cont_cumulative = 0.0
target_len = 4
for cont_len, cont_p in enumerate(cont_probs):
cont_cumulative += cont_p
if cont_r < cont_cumulative:
target_len = cont_len + 1
break
cont_positions = []
pos = cont_start
while len(cont_positions) < target_len and pos < len(text):
if self.query_engine.is_chinese_char(text[pos]):
cont_positions.append(pos)
else:
break
pos += 1
if not cont_positions:
continue
cont_text = "".join(text[i] for i in cont_positions)
cont_pinyin = [pinyin_list[i] for i in cont_positions]
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
r2 = random.random()
if r2 < 0.9:
split_char_cont = ""
elif r2 < 0.94:
split_char_cont = "`"
elif r2 < 0.98:
split_char_cont = "'"
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
split_char_cont = "-"
part2_cont = split_char_cont.join(mask_pinyin_cont)
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
# x为1-16中的任意整数取值平均分布
part3 = ""
if random.random() > 0.7:
part3 = text[
i + pinyin_len : i
+ pinyin_len
+ np.random.choice(range(1, 17))
]
# part4为文本0.30的概率为空
# 不为空则为1-5个连续字符串
# 连续字符串的取值方法为随机从字符库中取一个字符以及该字符后x个字符
# x为2-6中的任意整数取值平均分布
# 使用|将part4中的字符串连接起来
part4 = ""
if random.random() > 0.7:
# 生成1-5个连续字符串
num_strings = random.randint(1, 5)
string_list = []
for _ in range(num_strings):
# 随机选择起始位置
start_pos = random.randint(0, len(text) - 1)
# 随机选择x的值(2-6)
x = random.randint(2, 6)
# 获取连续字符串
end_pos = min(start_pos + x + 1, len(text))
string_list.append(text[start_pos:end_pos])
# 用|连接所有字符串
part4 = "|".join(string_list)
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(
text[i : i + pinyin_len],
pinyin_list[i : i + pinyin_len],
)
cont_labels = [
self.query_engine.get_char_info_by_char_pinyin(
text[i], pinyin_list[i]
).id
for i in cont_positions
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}"
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
)
idx = merge_end_idx
continue
# 续接末尾 10% 概率追加 EOS
if random.random() <= 0.1:
labels.append(0)
cont_labels.append(0)
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
samples = []
# 历史槽位长度权重:增加长历史采样比例
# 目标分布: H=0-2占45%, H=3-8占55%
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
cont_processed_history = []
cont_end = cont_positions[-1] + 1
for label_idx, label in enumerate(cont_labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
cont_processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
cont_processed_history.append(label)
continue
# 修复变量名冲突将内层循环变量i重命名为label_idx
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(label)
# 根据历史槽位长度调整采样次数
weight = (
history_weights[label_idx]
if label_idx < len(history_weights)
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
masked_labels = labels[:label_idx]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=cont_processed_history,
text=text,
word_start=cont_start,
word_end=cont_end,
part2=part2_cont,
pinyin_ids=pinyin_ids_cont,
words=words,
)
batch_samples.append(sample)
samples.extend(
[
{
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
]
* repeats
)
cont_processed_history.append(label)
# 添加到缓冲区
batch_samples.extend(samples)
idx = merge_end_idx
# 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size:
# 全量打乱缓冲区
indices = np.random.permutation(len(batch_samples))
# 计算实际保留大小(不超过缓冲区大小)
actual_retention = min(self.retention_size, len(batch_samples))
# 计算输出数量
output_count = len(batch_samples) - actual_retention
# 输出前output_count个样本
for i in range(output_count):
if current_iter_index >= worker_quota:
# 配额用完,清空缓冲区并返回
batch_samples = []
return
yield batch_samples[indices[i]]
current_iter_index += 1
# 保留后actual_retention个样本不清空缓冲区
retained_samples = [
batch_samples[idx] for idx in indices[output_count:]
]

View File

@ -0,0 +1,169 @@
"""
ONNX导出专用组件
为了支持ONNX导出对原始组件进行修改
1. 移除packed sequence操作
2. 处理动态形状问题
3. 确保所有操作符都ONNX兼容
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ExportPinyinLSTMEncoder(nn.Module):
"""
ONNX兼容的拼音LSTM编码器
简化版本不使用packed sequence
"""
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
self.num_layers = num_layers
self.dropout = dropout
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
"""
ONNX兼容的前向传播
不使用packed sequence改为使用masked计算
"""
# 简化直接使用LSTM不处理padding
# 在ONNX中处理可变长度序列较复杂
# 对于输入法场景拼音长度固定为24所以可以这样处理
output, (hidden, cell) = self.lstm(x)
projected = self.proj(output)
return self.layer_norm(projected)
class ExportContextEncoder(nn.Module):
"""
ONNX兼容的上下文编码器
"""
def __init__(
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
):
super().__init__()
self.dim = dim
self.max_len = max_len
# 使用原始text_emb但需要确保它支持ONNX
from modelscope import AutoModel
self.text_emb = AutoModel.from_pretrained(
"iic/nlp_structbert_backbone_lite_std"
).embeddings
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
self.pinyin_pooling = ExportPinyinLSTMEncoder(dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=n_heads,
dim_feedforward=dim * 4,
dropout=0.1,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.ln = nn.LayerNorm(dim)
def forward(self, text_ids, pinyin_ids, mask=None):
"""
ONNX兼容的前向传播
"""
# 文本嵌入
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
# 位置编码 - 使用预计算的pos_ids
seq_len = text_emb.size(1)
# 使用静态最大长度避免动态arange
if seq_len > self.max_len:
seq_len = self.max_len
# 创建位置ID确保在导出时是静态的
# 使用torch.full而不是torch.arange因为arange在动态形状下可能有问题
pos_ids = torch.arange(
seq_len, device=text_ids.device, dtype=torch.long
).unsqueeze(0)
# 如果实际序列长度小于最大长度,截取位置嵌入
if seq_len < self.max_len:
x = text_emb[:, :seq_len, :] + self.pos_emb(pos_ids)
else:
x = text_emb + self.pos_emb(pos_ids)
# 处理mask
if mask is not None:
# 确保mask是bool类型
src_mask = (mask == 0).to(torch.bool)
else:
src_mask = None
# Transformer
H = self.transformer(x, src_key_padding_mask=src_mask)
H = self.ln(H)
# 拼音编码
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
# 简化不传递mask给LSTM
P = self.pinyin_pooling(pinyin_emb) # [B, pinyin_len, dim]
return H, P
def create_export_context_encoder(original_context_encoder):
"""
从原始ContextEncoder创建导出版本
"""
# 获取原始配置
config = {
"vocab_size": getattr(original_context_encoder, "vocab_size", 10019),
"pinyin_vocab_size": original_context_encoder.pinyin_emb.num_embeddings,
"dim": original_context_encoder.dim,
"n_layers": original_context_encoder.transformer.num_layers,
"n_heads": original_context_encoder.transformer.layers[0].self_attn.num_heads,
"max_len": original_context_encoder.pos_emb.num_embeddings,
}
# 创建导出版本
export_encoder = ExportContextEncoder(**config)
# 复制权重
# 复制text_emb权重从原始AutoModel embeddings
# 注意这里假设原始text_emb的结构
state_dict = original_context_encoder.state_dict()
export_state_dict = export_encoder.state_dict()
# 复制匹配的权重
for key in export_state_dict:
if key in state_dict:
export_state_dict[key] = state_dict[key]
# 特殊处理复制position embeddings
if "pos_emb.weight" in export_state_dict and "pos_emb.weight" in state_dict:
# 确保大小匹配
orig_pos_emb = state_dict["pos_emb.weight"]
export_pos_emb = export_state_dict["pos_emb.weight"]
min_len = min(orig_pos_emb.size(0), export_pos_emb.size(0))
export_state_dict["pos_emb.weight"][:min_len] = orig_pos_emb[:min_len]
export_encoder.load_state_dict(export_state_dict)
return export_encoder

291
src/model/export_models.py Normal file
View File

@ -0,0 +1,291 @@
"""
ONNX导出模型定义
定义两个子模型用于ONNX导出
1. ContextEncoderExport: 输入文本和拼音输出上下文编码
2. DecoderExport: 输入上下文编码拼音编码和槽位历史输出logits
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# 修补modelscope以支持AutoModel导入
import sys
import modelscope
if not hasattr(modelscope, "AutoModel"):
from modelscope import Model
modelscope.AutoModel = Model
sys.modules["modelscope"].AutoModel = Model
from .components import (
ContextEncoder,
CrossAttentionFusion,
MoELayer,
SlotMemory,
PinyinLSTMEncoder,
)
class ContextEncoderExport(nn.Module):
"""
上下文编码器导出模型
输入: input_ids, pinyin_ids, attention_mask
输出: context_H, pinyin_P, context_mask, pinyin_mask
"""
def __init__(self, context_encoder: ContextEncoder):
super().__init__()
self.context_encoder = context_encoder
self.dim = context_encoder.dim
# 创建ONNX兼容的拼音LSTM编码器简化版不使用packed sequence
# 复制原始LSTM的参数
original_pooling = context_encoder.pinyin_pooling
self.pinyin_lstm = nn.LSTM(
input_size=original_pooling.input_dim,
hidden_size=original_pooling.hidden_dim,
num_layers=original_pooling.num_layers,
bidirectional=True,
batch_first=True,
dropout=original_pooling.dropout
if original_pooling.num_layers > 1
else 0.0,
)
self.pinyin_proj = nn.Linear(
original_pooling.hidden_dim * 2, original_pooling.input_dim
)
self.pinyin_ln = nn.LayerNorm(original_pooling.input_dim)
# 复制权重
self.pinyin_lstm.load_state_dict(original_pooling.lstm.state_dict())
self.pinyin_proj.load_state_dict(original_pooling.proj.state_dict())
self.pinyin_ln.load_state_dict(original_pooling.layer_norm.state_dict())
def forward(self, input_ids, pinyin_ids, attention_mask):
"""
Args:
input_ids: [batch_size, seq_len]
pinyin_ids: [batch_size, pinyin_len] (pinyin_len固定为24)
attention_mask: [batch_size, seq_len]
Returns:
context_H: [batch_size, seq_len, dim]
pinyin_P: [batch_size, pinyin_len, dim]
context_mask: [batch_size, seq_len] (bool -> int32)
pinyin_mask: [batch_size, pinyin_len] (bool -> int32)
"""
# 获取原始context_encoder的组件
text_emb = self.context_encoder.text_emb
pinyin_emb = self.context_encoder.pinyin_emb
pos_emb = self.context_encoder.pos_emb
transformer = self.context_encoder.transformer
ln = self.context_encoder.ln
# 文本嵌入
text_emb_out = text_emb(input_ids) # [B, seq_len, dim]
# 位置编码 - 修复torch.arange动态形状问题
seq_len = text_emb_out.size(1)
# 使用预计算的位置ID确保静态形状
if seq_len > pos_emb.num_embeddings:
seq_len = pos_emb.num_embeddings
# 创建位置ID - 使用torch.arange但确保是整数张量
pos_ids = torch.arange(
seq_len, device=input_ids.device, dtype=torch.long
).unsqueeze(0)
# 如果实际序列长度小于最大长度,截取
if seq_len < text_emb_out.size(1):
x = text_emb_out[:, :seq_len, :] + pos_emb(pos_ids)
else:
x = text_emb_out + pos_emb(pos_ids)
# 处理mask - 确保bool类型
if attention_mask is not None:
src_mask = (attention_mask == 0).to(torch.bool)
else:
src_mask = None
# Transformer
H = transformer(x, src_key_padding_mask=src_mask)
H = ln(H)
# 恢复原始序列长度(如果需要)
if seq_len < text_emb_out.size(1):
# 填充H到原始长度
original_seq_len = text_emb_out.size(1)
padding = torch.zeros(
H.size(0),
original_seq_len - seq_len,
H.size(2),
device=H.device,
dtype=H.dtype,
)
H = torch.cat([H, padding], dim=1)
# 拼音编码 - 使用ONNX兼容版本
pinyin_emb_out = pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
# 简化的LSTM不使用packed sequence
pinyin_lstm_out, _ = self.pinyin_lstm(pinyin_emb_out)
pinyin_proj_out = self.pinyin_proj(pinyin_lstm_out)
P = self.pinyin_ln(pinyin_proj_out)
# 生成mask转换为int32以便ONNX支持
context_mask = (attention_mask == 0).to(torch.int32) # 1表示padding0表示有效
pinyin_mask = (pinyin_ids == 0).to(torch.int32) # 1表示padding0表示有效
return H, P, context_mask, pinyin_mask
class DecoderExport(nn.Module):
"""
解码器导出模型
输入: context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
输出: logits
"""
def __init__(
self,
slot_memory: SlotMemory,
cross_attn: CrossAttentionFusion,
moe: MoELayer,
slot_attention: nn.Module,
classifier: nn.Module,
num_slots: int = 8,
dim: int = 512,
):
super().__init__()
self.slot_memory = slot_memory
self.cross_attn = cross_attn
self.moe = moe
self.slot_attention = slot_attention
self.classifier = classifier
self.num_slots = num_slots
self.dim = dim
def forward(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
"""
Args:
context_H: [batch_size, seq_len, dim]
pinyin_P: [batch_size, pinyin_len, dim]
history_slot_ids: [batch_size, num_slots]
context_mask: [batch_size, seq_len] (int32, 1表示padding)
pinyin_mask: [batch_size, pinyin_len] (int32, 1表示padding)
Returns:
logits: [batch_size, vocab_size]
"""
batch_size = context_H.size(0)
# 确保history_slot_ids形状正确
history_slot_ids = history_slot_ids.view(batch_size, self.num_slots)
# 1. 槽位记忆
S = self.slot_memory(history_slot_ids) # [batch_size, num_slots, dim]
# 2. 交叉注意力融合
# 转换mask: int32 -> bool (ONNX中需要bool类型)
context_mask_bool = context_mask.to(torch.bool)
pinyin_mask_bool = pinyin_mask.to(torch.bool)
fused = self.cross_attn(
S,
context_H,
pinyin_P,
context_mask=context_mask_bool,
pinyin_mask=pinyin_mask_bool,
) # [batch_size, num_slots, dim]
# 3. MoE层
moe_out = self.moe(fused) # [batch_size, num_slots, dim]
# 4. 槽位注意力池化
slot_scores = self.slot_attention(moe_out).squeeze(
-1
) # [batch_size, num_slots]
slot_weights = torch.softmax(slot_scores, dim=1) # [batch_size, num_slots]
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch_size, dim]
# 5. 分类头
logits = self.classifier(pooled) # [batch_size, vocab_size]
return logits
def create_export_models_from_checkpoint(checkpoint_path, device="cpu"):
"""
从checkpoint创建导出模型
Args:
checkpoint_path: 模型checkpoint路径
device: 加载设备
Returns:
context_encoder_export: ContextEncoderExport实例
decoder_export: DecoderExport实例
model_config: 模型配置字典
"""
# 加载原始模型配置
checkpoint = torch.load(checkpoint_path, map_location=device)
# 提取模型配置从checkpoint或使用默认值
if "config" in checkpoint:
config = checkpoint["config"]
else:
# 使用模型默认配置
config = {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
}
# 创建原始模型
from .model import InputMethodEngine
model = InputMethodEngine(
vocab_size=config.get("vocab_size", 10019),
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
dim=config.get("dim", 512),
num_slots=config.get("num_slots", 8),
n_layers=config.get("n_layers", 4),
n_heads=config.get("n_heads", 4),
num_experts=config.get("num_experts", 10),
max_seq_len=config.get("max_seq_len", 128),
compile=False,
)
# 加载权重
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# 创建导出模型
context_encoder_export = ContextEncoderExport(model.context_encoder)
decoder_export = DecoderExport(
slot_memory=model.slot_memory,
cross_attn=model.cross_attn,
moe=model.moe,
slot_attention=model.slot_attention,
classifier=model.classifier,
num_slots=model.num_slots,
dim=model.dim,
)
# 设置为评估模式
context_encoder_export.eval()
decoder_export.eval()
return context_encoder_export, decoder_export, config

View File

@ -0,0 +1,468 @@
#!/usr/bin/env python3
"""
预处理数据质量分析脚本
功能
1. 统计 labels 的分布出现次数比例最大/最小未出现标签数
2. 统计 history_slot_ids 的有效长度分布
3. 抽样还原为人类可读文本导出为 CSV 文件
用法
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
"""
import argparse
import csv
import json
from collections import Counter
from pathlib import Path
import numpy as np
from loguru import logger
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from .char_info import CharInfo
from .dataset import CHAR_TO_ID
from .preprocessed_dataset import PreProcessedDataset
from .query import QueryEngine
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
def decode_pinyin_ids(pinyin_ids: list) -> str:
"""将 pinyin_ids 还原为拼音字符串"""
chars = []
for pid in pinyin_ids:
if pid == 0:
break
chars.append(ID_TO_CHAR.get(pid, "?"))
return "".join(chars)
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
"""将 history_slot_ids 还原为文字"""
parts = []
for hid in history_ids:
if hid == 0:
parts.append("<PAD>")
else:
info = query_engine.query_by_id(hid)
if info is not None:
parts.append(f"{info.char}({info.pinyin})")
else:
parts.append(f"<ID:{hid}>")
return " | ".join(parts)
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
"""统计 labels 分布,带进度条"""
logger.info("正在统计 labels 分布...")
counter = Counter()
total = 0
num_shards = dataset._num_shards if dataset._is_sharded else 1
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
for shard_idx in pbar:
if dataset._is_sharded:
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
labels = shard_data["labels"].astype(np.int64)
else:
labels = dataset.labels[:].astype(np.int64)
unique, counts = np.unique(labels, return_counts=True)
for uid, cnt in zip(unique, counts):
counter[int(uid)] += cnt
total += len(labels)
if dataset._is_sharded:
del shard_data
return counter, total
def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0):
"""统计 history_slot_ids 的有效长度分布(非零元素个数)"""
logger.info("正在统计 history_slot_ids 长度分布...")
counter = Counter()
total = 0
num_shards = dataset._num_shards if dataset._is_sharded else 1
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard")
for shard_idx in pbar:
if dataset._is_sharded:
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
history_slots = shard_data["history_slot_ids"].astype(np.int64)
else:
history_slots = dataset.history_slot_ids[:].astype(np.int64)
lengths = np.count_nonzero(history_slots, axis=1)
unique, counts = np.unique(lengths, return_counts=True)
for uid, cnt in zip(unique, counts):
counter[int(uid)] += cnt
total += len(history_slots)
if dataset._is_sharded:
del shard_data
return counter, total
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
"""将一个样本还原为人类可读格式"""
input_ids = (
sample["input_ids"].tolist()
if hasattr(sample["input_ids"], "tolist")
else sample["input_ids"]
)
token_type_ids = (
sample["token_type_ids"].tolist()
if hasattr(sample["token_type_ids"], "tolist")
else sample["token_type_ids"]
)
labels = (
sample["labels"].item()
if hasattr(sample["labels"], "item")
else sample["labels"]
)
history_ids = (
sample["history_slot_ids"].tolist()
if hasattr(sample["history_slot_ids"], "tolist")
else sample["history_slot_ids"]
)
pinyin_ids = (
sample["pinyin_ids"].tolist()
if hasattr(sample["pinyin_ids"], "tolist")
else sample["pinyin_ids"]
)
# 还原 token 文本
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
if sep_positions:
sep_start = sep_positions[0]
sent_a_ids = [
tid
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
if tt == 0
]
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
else:
sent_a_ids = input_ids
sent_b_ids = []
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
suffix_text = (
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
)
pinyin_str = decode_pinyin_ids(pinyin_ids)
label_info = query_engine.query_by_id(labels)
history_str = decode_history(history_ids, query_engine)
history_slot_length = sum(1 for h in history_ids if h != 0)
return {
"context": context_text,
"suffix": suffix_text,
"pinyin": pinyin_str,
"label_id": labels,
"label_char": f"{label_info.char}({label_info.pinyin})"
if label_info
else f"<ID:{labels}>",
"label_count": label_info.count if label_info else 0,
"history": history_str,
"history_slot_length": history_slot_length,
"full_tokens": token_text,
}
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据质量分析")
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="预处理数据目录train/ 或 eval/",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="抽样的样本数量默认50取前 N 个样本)",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="CSV 输出文件路径(默认: <data-dir>/samples.csv",
)
parser.add_argument(
"--max-shards",
type=int,
default=0,
help="统计 labels 时最多读取的分片数0=全部)",
)
parser.add_argument(
"--top-k",
type=int,
default=30,
help="显示出现次数最多和最少的标签数量",
)
args = parser.parse_args()
if args.output is None:
args.output = str(Path(args.data_dir) / "samples.csv")
# 加载数据集
logger.info(f"加载数据集: {args.data_dir}")
dataset = PreProcessedDataset(args.data_dir)
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
if dataset._is_sharded:
console.print(
f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本"
)
console.print()
# 加载 QueryEngine
logger.info("加载 QueryEngine...")
query_engine = QueryEngine()
query_engine.load()
# 加载 Tokenizer
logger.info("加载 Tokenizer...")
from importlib.resources import files as pkg_files
from modelscope import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
)
# ====== 1. Labels 分布分析 ======
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
# 获取词表总大小_id_to_info 已包含 EOS id=0
vocab_size = len(query_engine._id_to_info)
appeared_ids = set(counter.keys())
all_ids = set(query_engine._id_to_info.keys())
missing_ids = all_ids - appeared_ids
console.print(f"\n总样本数: {total:,}")
console.print(f"词表大小: {vocab_size:,} (含 EOS)")
console.print(f"唯一标签数: {len(counter):,}")
console.print(
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
)
console.print(
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]"
)
most_common = counter.most_common(args.top_k)
least_common = (
counter.most_common()[: -args.top_k - 1 : -1]
if len(counter) > args.top_k
else counter.most_common()
)
# 最多标签表
table_top = Table(
title=f"出现次数最多的 {args.top_k} 个标签",
show_header=True,
header_style="bold magenta",
)
table_top.add_column("排名", style="cyan", width=6)
table_top.add_column("ID", style="green", width=8)
table_top.add_column("字符(拼音)", style="yellow", width=20)
table_top.add_column("频次", style="red", width=12)
table_top.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(most_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_top.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
)
console.print(table_top)
# 最少标签表
table_bottom = Table(
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
show_header=True,
header_style="bold magenta",
)
table_bottom.add_column("排名", style="cyan", width=6)
table_bottom.add_column("ID", style="green", width=8)
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
table_bottom.add_column("频次", style="red", width=12)
table_bottom.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(least_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_bottom.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
)
console.print(table_bottom)
# 频次分布概览
table_dist = Table(
title="频次分布概览", show_header=True, header_style="bold magenta"
)
table_dist.add_column("频次区间", style="cyan")
table_dist.add_column("标签数", style="green")
table_dist.add_column("占总标签数比例", style="yellow")
bins = [
(1, 10),
(11, 100),
(101, 1000),
(1001, 10000),
(10001, 100000),
(100001, 1000000),
(1000001, float("inf")),
]
for lo, hi in bins:
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
if count_in_bin > 0:
hi_str = str(int(hi)) if hi != float("inf") else ""
table_dist.add_row(
f"{lo}-{hi_str}",
f"{count_in_bin:,}",
f"{count_in_bin / len(counter) * 100:.1f}%",
)
# 未出现
if len(missing_ids) > 0:
table_dist.add_row(
"未出现",
f"{len(missing_ids):,}",
f"{len(missing_ids) / vocab_size * 100:.1f}%",
)
console.print(table_dist)
# ====== 2. 历史槽位长度分析 ======
console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]")
history_counter, history_total = analyze_history_slots(
dataset, max_shards=args.max_shards
)
if not history_counter:
console.print("[yellow] 无历史槽位数据[/yellow]")
else:
sorted_items = sorted(history_counter.items())
lengths_arr = [l for l, _ in sorted_items]
counts_arr = [c for _, c in sorted_items]
weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr))
avg_length = weighted_sum / history_total if history_total > 0 else 0
cumsum = 0
median_length = 0
for length, count in sorted_items:
cumsum += count
if cumsum >= history_total / 2:
median_length = length
break
console.print(f"\n总样本数: {history_total:,}")
console.print(f"最大历史槽位数: {max(lengths_arr)}")
console.print(f"最小历史槽位数: {min(lengths_arr)}")
console.print(f"平均历史槽位数: {avg_length:.2f}")
console.print(f"中位数历史槽位数: {median_length}")
hist_table = Table(
title="历史槽位有效长度分布",
show_header=True,
header_style="bold magenta",
)
hist_table.add_column("槽位数", style="cyan", width=10)
hist_table.add_column("样本数", style="green", width=12)
hist_table.add_column("占比", style="yellow", width=10)
for length, count in sorted_items:
pct = count / history_total * 100
hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%")
console.print(hist_table)
# ====== 3. 抽样还原 → CSV ======
num_samples = min(args.num_samples, len(dataset))
console.print(
f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
)
indices = list(range(num_samples))
csv_path = Path(args.output)
csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_headers = [
"index",
"pinyin",
"label_char",
"label_id",
"label_count",
"context",
"suffix",
"history",
"history_slot_length",
"full_tokens",
]
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(csv_headers)
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
sample = dataset[idx]
decoded = decode_sample(sample, tokenizer, query_engine)
writer.writerow(
[
idx,
decoded["pinyin"],
decoded["label_char"],
decoded["label_id"],
decoded["label_count"],
decoded["context"],
decoded["suffix"],
decoded["history"],
decoded["history_slot_length"],
decoded["full_tokens"],
]
)
console.print(
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
)
# 打印前 5 个样本的概要
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if i >= 5:
break
console.print(
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
f"上下文={row['context'][:50]}..."
)
console.print("\n[bold green]分析完成[/bold green]")
app = main
if __name__ == "__main__":
main()

87
src/model/merge_shards.py Normal file
View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片"""
import argparse
import gc
import json
from pathlib import Path
import numpy as np
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"):
with open(data_dir / "metadata.json") as f:
meta = json.load(f)
num_shards = meta["num_shards"]
if num_shards <= 1:
print(f"{data_dir}: already single shard, skip")
return
merged = {}
for field in FIELDS:
pieces = []
for i in tqdm(
range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard"
):
data = np.load(data_dir / f"shard_{i:06d}.npz")
pieces.append(data[field])
data.close()
merged[field] = np.concatenate(pieces, axis=0)
del pieces
gc.collect()
total = len(merged["labels"])
# 先删原分片(避免与新文件名冲突)
for i in range(num_shards):
(data_dir / f"shard_{i:06d}.npz").unlink()
# 再写合并文件
np.savez_compressed(data_dir / output_name, **merged)
del merged
gc.collect()
# 更新 metadata
meta["num_shards"] = 1
meta["shard_size"] = total
with open(data_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples")
def main():
parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件")
parser.add_argument(
"--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/"
)
parser.add_argument(
"--output-name", type=str, default="shard_000000.npz", help="合并后文件名"
)
parser.add_argument(
"--train-only", action="store_true", help="仅合并 train跳过 eval"
)
args = parser.parse_args()
root = Path(args.input_dir)
for split in ["train", "eval"]:
split_dir = root / split
if not split_dir.exists():
print(f"{split_dir}: not found, skip")
continue
if args.train_only and split == "eval":
continue
merge_split(split_dir, args.output_name)
if __name__ == "__main__":
main()

View File

@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module):
vocab_size: int = 10019,
pinyin_vocab_size: int = 28,
dim: int = 512,
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
n_layers: int = 4, # Transformer 层数
n_heads: int = 4, # 注意力头数
num_experts: int = 10, # MoE 专家数量
max_seq_len: int = 128, # 最大上下文长度
compile: bool = False, # 是否开启 torch.compile 优化
num_slots: int = 8,
n_layers: int = 4,
n_heads: int = 4,
num_experts: int = 10,
max_seq_len: int = 128,
compile: bool = False,
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
):
super().__init__()
self.dim = dim
@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
self.moe = MoELayer(
dim=dim,
num_experts=num_experts,
top_k=3,
num_resblocks=12,
moe_mode=moe_mode,
)
# 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1)
@ -80,24 +87,12 @@ class InputMethodEngine(nn.Module):
# 6. 分类头
self.classifier = nn.Linear(dim, vocab_size)
# 开启 torch.compile 优化 (如果请求)
# 在模型编译时添加优化选项
if compile:
from torch._inductor.select_algorithm import TritonTemplate
TritonTemplate.all_templates.clear()
self.forward = torch.compile(
self.forward,
mode="reduce-overhead",
fullgraph=False,
dynamic=False,
# options={
# "epilogue_fusion": True,
# "max_autotune": True,
# "triton.cudagraphs": True,
# "reorder_for_compute_comm_overlap": False,
# },
)
def forward(
@ -115,6 +110,10 @@ class InputMethodEngine(nn.Module):
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
is_zero = (history_slot_ids == 0)
cumsum_zero = is_zero.cumsum(dim=1)
slot_mask = (cumsum_zero <= 1).to(torch.get_default_dtype())
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
S = self.slot_memory(history_slot_ids)
@ -125,11 +124,15 @@ class InputMethodEngine(nn.Module):
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
)
fused = fused * slot_mask.unsqueeze(-1)
moe_out = self.moe(fused)
batch_size = input_ids.size(0)
slot_scores = self.slot_attention(moe_out).squeeze(-1)
slot_weights = torch.softmax(slot_scores, dim=1)
slot_weights = slot_weights * slot_mask
slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12)
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
logits = self.classifier(pooled)

366
src/model/onnx_export.py Normal file
View File

@ -0,0 +1,366 @@
"""
ONNX模型导出核心逻辑
InputMethodEngine 模型导出为两个 ONNX 文件
1. context_encoder.onnx - 上下文编码器可复用
2. decoder.onnx - 解码器
共用此模块的入口
- CLI: train-model export
- 脚本: python export_onnx.py薄壳
"""
import os
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from .export_models import create_export_models_from_checkpoint
def check_onnx_available() -> bool:
try:
import onnx # noqa: F401
import onnxruntime as ort # noqa: F401
return True
except ImportError:
print("错误: ONNX导出需要以下依赖:")
print(" pip install onnx onnxruntime")
print("请安装后重试")
return False
def export_context_encoder(
model,
output_path: str,
config: Dict,
skip_verification: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"pinyin_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
}
torch.onnx.export(
model,
(input_ids, pinyin_ids, attention_mask),
output_path,
input_names=["input_ids", "pinyin_ids", "attention_mask"],
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print(f" 上下文编码器导出完成 -> {output_path}")
if not skip_verification:
try:
import onnx
onnx.checker.check_model(onnx.load(output_path))
print(" ONNX 模型验证通过")
except Exception as e:
print(f" ONNX 模型验证警告: {e}")
return input_ids, pinyin_ids, attention_mask
def export_decoder(
model,
output_path: str,
config: Dict,
example_inputs: Optional[Tuple] = None,
skip_verification: bool = False,
) -> Tuple[torch.Tensor, ...]:
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
if example_inputs is not None:
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
batch_size = context_H.size(0)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
else:
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(
0, 100, (batch_size, num_slots), dtype=torch.long
)
dynamic_axes = {
"context_H": {0: "batch_size", 1: "seq_len"},
"pinyin_P": {0: "batch_size"},
"history_slot_ids": {0: "batch_size"},
"context_mask": {0: "batch_size", 1: "seq_len"},
"pinyin_mask": {0: "batch_size"},
"logits": {0: "batch_size"},
}
torch.onnx.export(
model,
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
output_path,
input_names=[
"context_H",
"pinyin_P",
"history_slot_ids",
"context_mask",
"pinyin_mask",
],
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=18,
do_constant_folding=True,
verbose=False,
)
print(f" 解码器导出完成 -> {output_path}")
if not skip_verification:
try:
import onnx
onnx.checker.check_model(onnx.load(output_path))
print(" ONNX 模型验证通过")
except Exception as e:
print(f" ONNX 模型验证警告: {e}")
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
def save_example_inputs(output_dir: str, example_inputs_dict: Dict) -> None:
npz_path = os.path.join(output_dir, "example_inputs.npz")
np_data = {}
for key, tensor in example_inputs_dict.items():
if isinstance(tensor, torch.Tensor):
np_data[key] = tensor.cpu().numpy()
elif isinstance(tensor, tuple):
for i, t in enumerate(tensor):
if isinstance(t, torch.Tensor):
np_data[f"{key}_{i}"] = t.cpu().numpy()
np.savez(npz_path, **np_data)
print(f" 示例输入已保存到: {npz_path}")
torch_path = os.path.join(output_dir, "example_inputs.pt")
torch.save(example_inputs_dict, torch_path)
print(f" PyTorch 示例输入已保存到: {torch_path}")
_INFERENCE_EXAMPLE_TEMPLATE = '''#!/usr/bin/env python3
"""
ONNX模型推理示例
展示如何使用导出的两个ONNX模型进行推理
包括束搜索beam search算法
"""
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F
from typing import List, Tuple
class ONNXInference:
"""ONNX模型推理器"""
def __init__(self, context_encoder_path, decoder_path):
self.context_encoder_session = ort.InferenceSession(
context_encoder_path,
providers=['CPUExecutionProvider']
)
self.decoder_session = ort.InferenceSession(
decoder_path,
providers=['CPUExecutionProvider']
)
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
print(f"上下文编码器输入: {self.context_input_names}")
print(f"上下文编码器输出: {self.context_output_names}")
print(f"解码器输入: {self.decoder_input_names}")
print(f"解码器输出: {self.decoder_output_names}")
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
raise NotImplementedError("请实现实际的输入预处理")
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
inputs = {
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
}
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
context_H, pinyin_P, context_mask, pinyin_mask = outputs
return (
torch.from_numpy(context_H),
torch.from_numpy(pinyin_P),
torch.from_numpy(context_mask),
torch.from_numpy(pinyin_mask),
)
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
inputs = {
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
}
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
logits = outputs[0]
return torch.from_numpy(logits)
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
beam_size=5, max_length=10, vocab_size=10019):
beams = [([], 0.0)]
for step in range(max_length):
new_beams = []
for seq, score in beams:
if len(seq) < 8:
history = seq + [0] * (8 - len(seq))
else:
history = seq[-8:]
history_tensor = torch.tensor([history], dtype=torch.long)
logits = self.run_decoder(
context_H, pinyin_P, history_tensor,
context_mask, pinyin_mask
)
probs = F.softmax(logits[0], dim=-1)
top_probs, top_indices = torch.topk(probs, beam_size)
for prob, idx in zip(top_probs, top_indices):
new_seq = seq + [idx.item()]
new_score = score + torch.log(prob).item()
new_beams.append((new_seq, new_score))
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
if all_ended:
break
return beams
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
input_ids, pinyin_ids, attention_mask
)
logits = self.run_decoder(
context_H, pinyin_P, history_slot_ids,
context_mask, pinyin_mask
)
return logits
def main():
"""示例主函数"""
print("ONNX模型推理示例")
print("=" * 60)
context_encoder_path = "context_encoder.onnx"
decoder_path = "decoder.onnx"
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
print("错误: 找不到ONNX模型文件")
print("请先运行 train-model export 导出模型")
return
inference = ONNXInference(context_encoder_path, decoder_path)
print("\\u2705 ONNX推理器初始化完成")
print("请参考此示例实现完整的输入法推理流程")
if __name__ == "__main__":
main()
'''
def create_inference_example(output_dir: str, config: Dict) -> None:
example_path = os.path.join(output_dir, "inference_example.py")
with open(example_path, "w", encoding="utf-8") as f:
f.write(_INFERENCE_EXAMPLE_TEMPLATE)
print(f" 推理示例脚本已保存到: {example_path}")
def run_full_export(
checkpoint_path: str,
output_dir: str,
device: str = "cpu",
skip_verification: bool = False,
) -> Tuple[str, str, Dict]:
"""
完整的 ONNX 导出流程
Returns:
(context_encoder_path, decoder_path, config)
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"加载 checkpoint: {checkpoint_path}")
context_encoder_export, decoder_export, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
print(f"模型配置: vocab_size={config.get('vocab_size')}, "
f"dim={config.get('dim')}, "
f"num_slots={config.get('num_slots')}, "
f"num_experts={config.get('num_experts')}, "
f"moe_mode={config.get('moe_mode', 'all')}")
print("正在导出模型...")
context_encoder_path = str(output_path / "context_encoder.onnx")
example_inputs = export_context_encoder(
context_encoder_export, context_encoder_path, config,
skip_verification=skip_verification,
)
with torch.no_grad():
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
*example_inputs
)
decoder_path = str(output_path / "decoder.onnx")
export_decoder(
decoder_export, decoder_path, config,
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
skip_verification=skip_verification,
)
example_inputs_dict = {
"input_ids": example_inputs[0],
"pinyin_ids": example_inputs[1],
"attention_mask": example_inputs[2],
"context_H": context_H,
"pinyin_P": pinyin_P,
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
save_example_inputs(output_dir, example_inputs_dict)
create_inference_example(output_dir, config)
return context_encoder_path, decoder_path, config

433
src/model/preprocess.py Normal file
View File

@ -0,0 +1,433 @@
#!/usr/bin/env python3
"""
预处理脚本 PinyinInputDataset 的输出转换为分片压缩 .npz 文件
采用分片流式写入内存峰值固定为 shard_size 级别不随总样本数增长
每个分片使用 np.savez_compressed 保存zlib 压缩GPU 服务器无需解压到硬盘
用法
python -m model.preprocess \
--train-data-path "some/hf_dataset" \
--eval-data-path "some/hf_dataset" \
--output-dir ./preprocessed \
--num-train-samples 5000000 \
--num-eval-samples 8192
生成目录结构
output_dir/
train/
metadata.json
shard_000.npz (5M样本, 6个字段, zlib压缩)
shard_001.npz
...
eval/
metadata.json
shard_000.npz
...
"""
import argparse
import gc
import json
import struct
import time
import zipfile
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .trainer import preprocess_collate_fn, worker_init_fn
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
result = {}
for f in FIELDS:
tensor = batch[f][:take]
arr = tensor.numpy().astype(np.int16)
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
arr = arr.squeeze(-1)
result[f] = arr
return result
def collect_samples(
dataloader: DataLoader,
num_samples: int,
output_dir: Path,
split_name: str,
max_seq_length: int = 128,
shard_size: int = 5_000_000,
) -> int:
"""
分片流式收集样本每累积 shard_size 个样本保存为一个压缩 .npz 分片
内存峰值 = shard_size × 每样本字节数约578字节 @ shard_size=5M 约2.9GB
"""
split_dir = output_dir / split_name
split_dir.mkdir(parents=True, exist_ok=True)
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
shard_count = 0
shard_idx = 0
total = 0
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
for batch in dataloader:
batch_size = batch["input_ids"].size(0)
remaining = num_samples - total
if remaining <= 0:
break
take = min(batch_size, remaining)
extracted = _extract_batch(batch, take)
for f in FIELDS:
shard_buffers[f].append(extracted[f])
shard_count += take
total += take
pbar.update(take)
if shard_count >= shard_size:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
shard_buffers = {f: [] for f in FIELDS}
shard_count = 0
del merged
gc.collect()
if total >= num_samples:
break
# 写入最后一个不满的分片
if shard_count > 0:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
pbar.close()
actual_count = min(total, num_samples)
num_shards = shard_idx
shard_sizes = [shard_size] * (num_shards - 1)
if num_shards > 0:
shard_sizes.append(actual_count - sum(shard_sizes))
metadata = {
"num_samples": actual_count,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": num_shards,
"shard_sizes": shard_sizes,
}
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
json.dump(metadata, fp, indent=2, ensure_ascii=False)
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
logger.info(
f"{split_name}: {actual_count} samples in {num_shards} shards, "
f"{total_size / (1024**3):.2f} GB (compressed)"
)
return actual_count
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
parser.add_argument(
"--train-data-path",
type=str,
required=True,
help="训练数据集路径HuggingFace格式",
)
parser.add_argument(
"--eval-data-path",
type=str,
required=True,
help="评估数据集路径HuggingFace格式",
)
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
parser.add_argument(
"--num-train-samples", type=int, required=True, help="训练集样本数量"
)
parser.add_argument(
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
)
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
parser.add_argument(
"--num-workers", type=int, default=2, help="DataLoader worker数量"
)
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小样本数控制内存峰值默认500万约2.9GB/分片未压缩)",
)
parser.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔)",
)
parser.add_argument(
"--shuffle-buffer-size",
type=int,
default=2000000,
help="数据集shuffle缓冲区大小",
)
parser.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_max_iter = args.num_train_samples * 5
eval_max_iter = args.num_eval_samples * 5
shard_mem_gb = args.shard_size * 578 / (1024**3)
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
console.print(f"输出目录: {output_dir}")
console.print(f"数据类型: int16")
console.print(
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
)
console.print()
num_train_workers = args.num_workers
num_eval_workers = max(1, args.num_workers // 2)
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=num_train_workers,
max_iter_length=train_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=num_train_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_train_workers > 0 else False,
)
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=num_eval_workers,
max_iter_length=eval_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.batch_size,
num_workers=num_eval_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_eval_workers > 0 else False,
)
logger.info("开始收集训练数据...")
train_count = collect_samples(
train_dataloader,
args.num_train_samples,
output_dir,
"train",
args.max_seq_length,
args.shard_size,
)
if train_count < args.num_train_samples:
logger.warning(
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
)
logger.info("开始收集评估数据...")
eval_count = collect_samples(
eval_dataloader,
args.num_eval_samples,
output_dir,
"eval",
args.max_seq_length,
args.shard_size,
)
if eval_count < args.num_eval_samples:
logger.warning(
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
)
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
console.print(f"训练集: {train_count:,} 样本")
console.print(f"评估集: {eval_count:,} 样本")
console.print(f"输出目录: {output_dir}")
for split in ["train", "eval"]:
split_dir = output_dir / split
if split_dir.exists():
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
"""
将过大的 .npz 分片拆分到目标大小
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)
建议在内存充裕的 CPU 机器上运行拆分后将小分片拷贝至 GPU 服务器
"""
import gc
import time
from tqdm import tqdm
input_path = Path(input_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metadata_path = input_path / "metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
with open(metadata_path) as f:
metadata = json.load(f)
shard_files = sorted(input_path.glob("shard_*.npz"))
console = Console()
console.print(
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
)
global_shard_idx = 0
all_shard_sizes: List[int] = []
total_samples = 0
for sf in tqdm(shard_files, desc="Shards"):
t0 = time.time()
data = np.load(str(sf))
n_samples = data[list(data.keys())[0]].shape[0]
for start in range(0, n_samples, target_size):
end = min(start + target_size, n_samples)
chunk_data: Dict[str, np.ndarray] = {}
for field in data.keys():
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
np.savez_compressed(out_file, **chunk_data)
chunk_size = end - start
all_shard_sizes.append(chunk_size)
total_samples += chunk_size
del chunk_data
gc.collect()
global_shard_idx += 1
data.close()
del data
gc.collect()
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
metadata["num_samples"] = total_samples
metadata["num_shards"] = global_shard_idx
metadata["shard_sizes"] = all_shard_sizes
metadata.pop("shard_size", None)
json.dump(metadata, f, indent=2, ensure_ascii=False)
console.print(
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
)
def _dispatch():
import sys
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
import argparse as ap
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
p.add_argument(
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
)
p.add_argument("--output-dir", required=True, help="输出目录")
p.add_argument(
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
)
args = p.parse_args(sys.argv[2:])
resplit_shards(args.input_dir, args.output_dir, args.target_size)
else:
main()
app = _dispatch
if __name__ == "__main__":
_dispatch()

View File

@ -0,0 +1,243 @@
"""
预处理数据集加载器
支持两种格式
1. 分片压缩格式.npz从压缩文件中按需加载分片LRU 缓存最多持有 max_cache_shards 个分片
2. 单体格式.npy向后兼容使用 mmap 零拷贝加载
GPU 服务器仅需存放压缩后的 .npz 文件无需解压到硬盘
"""
import time
import gc
import json
import struct
import zipfile
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from loguru import logger
from torch.utils.data import Dataset
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _read_shard_size(npz_path: Path) -> int:
with zipfile.ZipFile(npz_path, "r") as z:
first = sorted(z.namelist())[0]
with z.open(first) as f:
magic = f.read(6)
if magic != b"\x93NUMPY":
raise ValueError(f"Not a numpy file: {npz_path}")
ver = struct.unpack("<BB", f.read(2))
if ver[0] >= 2:
header_len = struct.unpack("<I", f.read(4))[0]
else:
header_len = struct.unpack("<H", f.read(2))[0]
header = eval(f.read(header_len))
return header["shape"][0]
def _build_offsets(sizes: List[int]) -> List[int]:
offsets = [0]
for s in sizes:
offsets.append(offsets[-1] + s)
return offsets
def is_preprocessed_data(path: str) -> bool:
"""判断路径是否为预处理数据目录"""
p = Path(path)
return p.is_dir() and (p / "metadata.json").exists()
class _ShardCache:
"""LRU 缓存,管理按需加载的 .npz 分片,最多持有 max_size 个分片"""
def __init__(self, max_size: int = 2):
self.max_size = max_size
self._cache: OrderedDict[int, Dict[str, np.ndarray]] = OrderedDict()
def get(self, shard_idx: int, loader_fn) -> Dict[str, np.ndarray]:
if shard_idx in self._cache:
self._cache.move_to_end(shard_idx)
return self._cache[shard_idx]
data = loader_fn(shard_idx)
self._cache[shard_idx] = data
self._cache.move_to_end(shard_idx)
while len(self._cache) > self.max_size:
evicted_key, evicted_data = self._cache.popitem(last=False)
del evicted_data
gc.collect()
return data
def clear(self):
self._cache.clear()
gc.collect()
class PreProcessedDataset(Dataset):
"""
预处理数据集加载器自动检测数据格式
- 分片压缩格式metadata.json shard_size 字段
.npz 分片按需加载LRU 缓存控制内存
- 单体格式向后兼容
mmap 零拷贝加载 .npy 文件
所有数据以 int16 存储读取时转为 torch.long (int64)
"""
def __init__(self, data_dir: str, max_cache_shards: int = 1):
t_start = time.perf_counter()
self.data_dir = Path(data_dir)
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.max_seq_length = self.metadata["max_seq_length"]
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
if shard_files:
self._num_shards = len(shard_files)
if (
"shard_sizes" in self.metadata
and len(self.metadata["shard_sizes"]) == self._num_shards
):
self._shard_sizes = self.metadata["shard_sizes"]
logger.info(
f"Using shard_sizes from metadata.json ({self._num_shards} shards)"
)
else:
t_scan = time.perf_counter()
self._shard_sizes = [_read_shard_size(sf) for sf in shard_files]
logger.info(
f"Scanned {self._num_shards} shard headers in "
f"{time.perf_counter() - t_scan:.2f}s"
)
self._shard_offsets = _build_offsets(self._shard_sizes)
self.num_samples = self._shard_offsets[-1]
self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards)
logger.info(
f"Loaded sharded dataset: {self.num_samples:,} samples, "
f"{self._num_shards} shards, "
f"init in {time.perf_counter() - t_start:.2f}s"
)
else:
self.num_samples = self.metadata["num_samples"]
self._is_sharded = False
self._load_single_files()
logger.info(
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), "
f"init in {time.perf_counter() - t_start:.2f}s"
)
def _load_single_files(self):
"""向后兼容:加载单体 .npy 文件mmap 模式)"""
self.input_ids = np.load(self.data_dir / "input_ids.npy", mmap_mode="r")
self.token_type_ids = np.load(
self.data_dir / "token_type_ids.npy", mmap_mode="r"
)
self.attention_mask = np.load(
self.data_dir / "attention_mask.npy", mmap_mode="r"
)
self.labels = np.load(self.data_dir / "labels.npy", mmap_mode="r")
self.history_slot_ids = np.load(
self.data_dir / "history_slot_ids.npy", mmap_mode="r"
)
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
"""加载一个 .npz 分片到内存(保持原始 int16不转换"""
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
return dict(np.load(shard_path))
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> dict:
if not 0 <= idx < self.num_samples:
raise IndexError(
f"Index {idx} out of range for dataset with {self.num_samples} samples"
)
if not hasattr(self, "_first_access_logged"):
self._first_access_logged = True
logger.info("First __getitem__ call (initial shard load may be slow)")
if self._is_sharded:
lo, hi = 0, len(self._shard_offsets) - 1
while lo < hi:
mid = (lo + hi) // 2
if self._shard_offsets[mid] <= idx:
lo = mid + 1
else:
hi = mid
shard_idx = lo - 1
local_idx = idx - self._shard_offsets[shard_idx]
shard_data = self._cache.get(shard_idx, self._load_shard)
result = {
"input_ids": torch.from_numpy(
shard_data["input_ids"][local_idx].astype(np.int64)
),
"token_type_ids": torch.from_numpy(
shard_data["token_type_ids"][local_idx].astype(np.int64)
),
"attention_mask": torch.from_numpy(
shard_data["attention_mask"][local_idx].astype(np.int64)
),
"labels": torch.tensor(
shard_data["labels"][local_idx], dtype=torch.long
),
"history_slot_ids": torch.from_numpy(
shard_data["history_slot_ids"][local_idx].astype(np.int64)
),
"pinyin_ids": torch.from_numpy(
shard_data["pinyin_ids"][local_idx].astype(np.int64)
),
}
else:
result = {
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
"token_type_ids": torch.from_numpy(
self.token_type_ids[idx].astype(np.int64)
),
"attention_mask": torch.from_numpy(
self.attention_mask[idx].astype(np.int64)
),
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
"history_slot_ids": torch.from_numpy(
self.history_slot_ids[idx].astype(np.int64)
),
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
}
return result
def preprocessed_collate_fn(batch):
"""
预处理数据的 collate 函数
不含 string 字段prefix/suffix/pinyin仅处理 tensor 字段
"""
return {
"input_ids": torch.stack([item["input_ids"] for item in batch]),
"token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
"labels": torch.stack([item["labels"] for item in batch]),
"history_slot_ids": torch.stack([item["history_slot_ids"] for item in batch]),
"pinyin_ids": torch.stack([item["pinyin_ids"] for item in batch]),
}

318
src/model/shuffle_npz.py Normal file
View File

@ -0,0 +1,318 @@
#!/usr/bin/env python3
"""
打乱并平衡预处理 .npz 分片
两阶段处理
Phase 1: 逐分片内部打乱 写入临时 .npymmap 友好
Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片
用法
python -m src.model.shuffle_npz \
--input-dir /path/to/subsampled \
--output-dir /path/to/shuffled \
--shard-size 1000000 \
--seed 42
"""
import argparse
import gc
import json
import shutil
from pathlib import Path
from typing import Dict, List
import numpy as np
from loguru import logger
from rich.console import Console
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _phase1_shuffle_to_npy(
input_dir: Path,
split: str,
temp_dir: Path,
shard_sizes: List[int],
seed: int,
):
"""
Phase 1: 逐分片内部打乱写入 .npy 文件
每个分片的每个字段写入一个独立 .npy路径格式
temp_dir/<field>/shard_<idx>.npy
峰值内存 ~10GB单字段 int16 5GB + permuted copy 5GB
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
rng = np.random.RandomState(seed)
for field in FIELDS:
(temp_dir / field).mkdir(parents=True, exist_ok=True)
pbar = tqdm(
total=num_shards,
desc=f"Phase 1: shuffling {split}",
unit="shard",
)
for src_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{src_idx:06d}.npz"
data = np.load(shard_path)
n = shard_sizes[src_idx]
perm = rng.permutation(n)
for field in FIELDS:
arr = data[field].copy()
shuffled = arr[perm]
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
del arr, shuffled
gc.collect()
data.close()
pbar.update(1)
pbar.close()
def _phase2_rebalance_to_npz(
temp_dir: Path,
output_dir: Path,
split: str,
shard_sizes: List[int],
target_shard_size: int,
max_seq_length: int,
seed: int,
) -> List[int]:
"""
Phase 2: mmap .npy 文件中按比例分配样本到平衡的 .npz 输出分片
每个输出分片从所有源分片各取 proportional chunk concat shuffle save
内存峰值 ~3GB一个输出缓冲 + mmap pages
"""
output_split_dir = output_dir / split
output_split_dir.mkdir(parents=True, exist_ok=True)
total_samples = sum(shard_sizes)
num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size
rng = np.random.RandomState(seed + 2)
logger.info(
f"Phase 2: distributing {total_samples:,} samples into "
f"{num_output_shards} output shards (~{target_shard_size:,} each)"
)
# 打开所有源 .npy 的 mmap读模式零 RAM 开销)
src_mmaps: List[Dict[str, np.ndarray]] = []
for src_idx in range(len(shard_sizes)):
shard_mmap = {}
for field in FIELDS:
npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy"
shard_mmap[field] = np.load(npy_path, mmap_mode="r")
src_mmaps.append(shard_mmap)
output_shard_sizes: List[int] = []
pbar = tqdm(
total=num_output_shards,
desc=f"Phase 2: writing {split}",
unit="shard",
)
for out_j in range(num_output_shards):
buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
for src_i in range(len(shard_sizes)):
s = shard_sizes[src_i]
start = (out_j * s) // num_output_shards
end = ((out_j + 1) * s) // num_output_shards
if start >= end:
continue
for field in FIELDS:
chunk = src_mmaps[src_i][field][start:end].copy()
buffers[field].append(chunk)
output = {}
for field in FIELDS:
output[field] = (
np.concatenate(buffers[field])
if len(buffers[field]) > 1
else buffers[field][0]
)
out_count = len(output[FIELDS[0]])
if out_count > 1:
perm = rng.permutation(out_count)
for field in FIELDS:
output[field] = output[field][perm]
np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output)
output_shard_sizes.append(out_count)
del output, buffers
gc.collect()
pbar.update(1)
pbar.close()
# 关闭所有 mmap
src_mmaps.clear()
gc.collect()
return output_shard_sizes
def _copy_eval(input_dir: Path, output_dir: Path):
"""复制 eval 目录(数据量小,无需打乱)。"""
src_eval = input_dir / "eval"
dst_eval = output_dir / "eval"
if not src_eval.exists():
return
logger.info(f"Copying eval data from {src_eval}")
if dst_eval.exists():
shutil.rmtree(dst_eval)
shutil.copytree(src_eval, dst_eval)
def main():
console = Console()
parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片")
parser.add_argument(
"--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/"
)
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
parser.add_argument(
"--shard-size",
type=int,
default=1_000_000,
help="输出分片大小(样本数),默认 100 万",
)
parser.add_argument("--seed", type=int, default=42, help="随机种子")
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_meta_path = input_dir / "train" / "metadata.json"
if not train_meta_path.exists():
console.print(f"[red]错误: {train_meta_path} 不存在[/red]")
return
with open(train_meta_path) as f:
train_meta = json.load(f)
max_seq_length = train_meta["max_seq_length"]
shard_sizes: List[int] = train_meta["shard_sizes"]
total_samples = sum(shard_sizes)
console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]")
console.print(f"输入目录: {input_dir}")
console.print(f"输出目录: {output_dir}")
console.print(f"总样本数: {total_samples:,}")
console.print(f"源分片数: {len(shard_sizes)}")
console.print(
f"目标分片大小: {args.shard_size:,} "
f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)"
)
console.print(f"随机种子: {args.seed}")
console.print()
# 临时目录
temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp")
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
temp_train_dir = temp_dir / "train"
temp_train_dir.mkdir(parents=True, exist_ok=True)
try:
# ── Phase 1 ──
console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]")
_phase1_shuffle_to_npy(
input_dir=input_dir,
split="train",
temp_dir=temp_train_dir,
shard_sizes=shard_sizes,
seed=args.seed,
)
# ── Phase 2 ──
console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]")
output_shard_sizes = _phase2_rebalance_to_npz(
temp_dir=temp_train_dir,
output_dir=output_dir,
split="train",
shard_sizes=shard_sizes,
target_shard_size=args.shard_size,
max_seq_length=max_seq_length,
seed=args.seed,
)
# ── 写 metadata ──
train_output_meta = {
"num_samples": sum(output_shard_sizes),
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": args.shard_size,
"num_shards": len(output_shard_sizes),
"shard_sizes": output_shard_sizes,
"pre_shuffled": True,
"seed": args.seed,
}
out_train_dir = output_dir / "train"
with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(train_output_meta, f, indent=2, ensure_ascii=False)
# ── 复制 eval ──
_copy_eval(input_dir, output_dir)
finally:
# 清理临时文件
console.print("[dim]清理临时文件...[/dim]")
if temp_dir.exists():
shutil.rmtree(temp_dir)
# ── 总结 ──
console.print()
console.print("[bold green]=== 完成 ===[/bold green]")
console.print(
f"train/: {sum(output_shard_sizes):,} 样本, "
f"{len(output_shard_sizes)} 分片, "
f"pre_shuffled=True"
)
for sdir_name in ["train", "eval"]:
sdir = output_dir / sdir_name
if sdir.exists():
total_size = sum(
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
)
meta_path = sdir / "metadata.json"
if meta_path.exists():
with open(meta_path) as mf:
meta = json.load(mf)
else:
meta = {}
console.print(
f" {sdir_name}/: {total_size / (1024**3):.2f} GB, "
f"{meta.get('num_shards', '?')} shards"
)
if __name__ == "__main__":
main()

464
src/model/subsample.py Normal file
View File

@ -0,0 +1,464 @@
#!/usr/bin/env python3
"""
子采样脚本从预处理的 .npz 分片中抽取子集
策略
1. 第1遍扫描只读 labels 字段统计每个 label ID 的样本数记录分片大小
2. 中间计算按每 ID 最多 N 个样本封顶硬封顶不足目标总量则从超额池补足
3. 第2遍扫描读取全部字段按精确保留配额抽取训练样本同时抽取评估集
用法
python -m model.subsample \
--input-dir ./preprocessed \
--output-dir ./subsampled \
--cap-per-label 300000 \
--target-total 100000000 \
--num-eval 2560
"""
import argparse
import gc
import json
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
from loguru import logger
from rich.console import Console
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _global_to_shard(
positions: np.ndarray, shard_sizes: List[int]
) -> Dict[int, List[int]]:
"""将全局位置映射为 {shard_idx: [local_indices]}"""
cumsum = np.cumsum([0] + shard_sizes)
mapping: Dict[int, List[int]] = {}
for pos in positions:
shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1)
local_idx = pos - cumsum[shard_idx]
mapping.setdefault(shard_idx, []).append(local_idx)
return mapping
def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]:
"""
第1遍扫描只读 labels统计每个 label ID 的总样本数记录各分片大小
返回 (label_counts, shard_sizes)
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
label_counts: Dict[int, int] = {}
shard_sizes: List[int] = []
pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard")
for shard_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz"
data = np.load(shard_path)
labels = data["labels"]
n = len(labels)
shard_sizes.append(n)
unique, counts = np.unique(labels, return_counts=True)
for uid, cnt in zip(unique, counts):
uid = int(uid)
label_counts[uid] = label_counts.get(uid, 0) + int(cnt)
data.close()
pbar.update(1)
pbar.close()
return label_counts, shard_sizes
def compute_quotas(
label_counts: Dict[int, int],
cap: int = 300_000,
target_total: int = 100_000_000,
) -> Dict[int, int]:
"""
计算每个 label 的精确保留配额硬封顶
策略
- ID 保留 min(count, cap)
- 若封顶后总量 >= target_total直接用封顶策略
- 若封顶后总量 < target_total从超额池count > cap 的部分等比抽取补足
"""
capped_total = sum(min(cnt, cap) for cnt in label_counts.values())
if capped_total >= target_total:
return {label: min(cnt, cap) for label, cnt in label_counts.items()}
# 封顶不足,从超额池补足
deficit = target_total - capped_total
excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()}
excess_total = sum(excess_per_label.values())
quotas: Dict[int, int] = {}
if excess_total > 0:
ratio = min(1.0, deficit / excess_total)
for label, cnt in label_counts.items():
base = min(cnt, cap)
extra = int(excess_per_label[label] * ratio)
quotas[label] = base + extra
else:
quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()}
return quotas
def pass2_subsample(
input_dir: Path,
output_dir: Path,
split: str,
quotas: Dict[int, int],
eval_map: Dict[int, List[int]],
shard_sizes: List[int],
shard_size: int = 5_000_000,
shuffle: bool = True,
seed: int = 42,
) -> Tuple[int, int]:
"""
第2遍扫描读取全部字段抽取评估样本 + 按精确保留配额子采样训练样本
quotas: {label_id: exact_number_to_keep}
返回 (train_kept, eval_kept)
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
max_seq_length = metadata["max_seq_length"]
output_train_dir = output_dir / "train"
output_eval_dir = output_dir / "eval"
output_train_dir.mkdir(parents=True, exist_ok=True)
output_eval_dir.mkdir(parents=True, exist_ok=True)
# 清理旧输出分片(避免残留文件污染)
for old_shard in sorted(output_train_dir.glob("shard_*.npz")):
old_shard.unlink()
logger.debug(f"Removed old shard: {old_shard.name}")
rng = np.random.RandomState(seed)
shuffle_rng = np.random.RandomState(seed + 1)
remaining = dict(quotas) # {label_id: remaining_to_keep}
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
train_buf_count = 0
train_shard_idx = 0
train_shard_sizes: List[int] = []
total_train_kept = 0
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
total_eval_kept = 0
pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard")
for src_shard_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz"
data = np.load(shard_path)
n = shard_sizes[src_shard_idx]
eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64)
is_eval = np.zeros(n, dtype=bool)
if len(eval_local) > 0:
is_eval[eval_local] = True
# ── 第一步:只加载 labels计算掩码 ──
labels = data["labels"]
# 抽取评估样本的 labels
if len(eval_local) > 0:
eval_buffers["labels"].append(labels[eval_local].copy())
# 计算训练保留位置
train_candidates = labels[~is_eval]
train_n = len(train_candidates)
n_kept = 0
keep_original_indices = np.array([], dtype=np.int64)
if train_n > 0:
sort_idx = np.argsort(train_candidates, kind="stable")
sorted_labels = train_candidates[sort_idx]
unique_vals, starts = np.unique(sorted_labels, return_index=True)
ends = np.append(starts[1:], train_n)
train_keep_mask = np.zeros(train_n, dtype=bool)
for i in range(len(unique_vals)):
label_val = int(unique_vals[i])
start, end = starts[i], ends[i]
cnt_in_shard = end - start
need = remaining.get(label_val, 0)
select = min(cnt_in_shard, need)
if select >= cnt_in_shard:
train_keep_mask[sort_idx[start:end]] = True
elif select > 0:
chosen = rng.choice(cnt_in_shard, size=select, replace=False)
train_keep_mask[sort_idx[start + chosen]] = True
remaining[label_val] = max(0, remaining.get(label_val, 0) - select)
original_train_indices = np.where(~is_eval)[0]
keep_original_indices = original_train_indices[train_keep_mask]
n_kept = len(keep_original_indices)
# 释放标签相关的临时数组
del train_candidates, sort_idx, sorted_labels
del unique_vals, starts, ends, train_keep_mask
del original_train_indices
gc.collect()
del labels, is_eval
gc.collect()
# ── 第二步:逐字段加载,抽取评估 + 训练 ──
# 评估:跳过 labels已在第一步抽取
if len(eval_local) > 0:
for f in FIELDS:
if f == "labels":
continue
arr = data[f]
eval_buffers[f].append(arr[eval_local].copy())
del arr
gc.collect()
total_eval_kept += len(eval_local)
# 训练:逐字段加载,立即删除
if n_kept > 0:
for f in FIELDS:
arr = data[f]
train_buffers[f].append(arr[keep_original_indices].copy())
del arr
gc.collect()
train_buf_count += n_kept
total_train_kept += n_kept
del keep_original_indices
data.close()
gc.collect()
if train_buf_count >= shard_size:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
train_shard_sizes.append(train_buf_count)
logger.debug(
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
)
train_shard_idx += 1
train_buffers = {f: [] for f in FIELDS}
train_buf_count = 0
del merged
gc.collect()
pbar.update(1)
pbar.close()
# 剩余缓冲
if train_buf_count > 0:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
if shuffle and train_buf_count > 1:
perm = shuffle_rng.permutation(train_buf_count)
for f in FIELDS:
merged[f] = merged[f][perm]
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
train_shard_sizes.append(train_buf_count)
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
train_shard_idx += 1
# 评估数据
if total_eval_kept > 0:
eval_merged = {}
for f in FIELDS:
eval_merged[f] = np.concatenate(eval_buffers[f], axis=0)
np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged)
else:
logger.warning("No eval samples extracted!")
# 写 metadata
train_metadata = {
"num_samples": total_train_kept,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": train_shard_idx,
"shard_sizes": train_shard_sizes,
"pre_shuffled": shuffle,
"seed": seed,
}
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
eval_metadata = {
"num_samples": total_eval_kept,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": total_eval_kept,
"num_shards": 1,
}
with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(eval_metadata, f, indent=2, ensure_ascii=False)
return total_train_kept, total_eval_kept
def main():
console = Console()
parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集")
parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录")
parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录")
parser.add_argument(
"--cap-per-label",
type=int,
default=300_000,
help="每个 label ID 最大保留样本数",
)
parser.add_argument(
"--target-total",
type=int,
default=100_000_000,
help="目标训练集样本总数",
)
parser.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="输出分片大小(样本数)",
)
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
parser.add_argument(
"--no-shuffle",
action="store_false",
dest="shuffle",
help="禁用输出分片内部打乱",
)
parser.add_argument(
"--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)"
)
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]")
console.print(f"输入目录: {input_dir}")
console.print(f"输出目录: {output_dir}")
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
console.print(f"目标训练集: {args.target_total:,}")
console.print(f"评估集: {args.num_eval}")
console.print(f"输出分片打乱: {'' if args.shuffle else ''}")
console.print(f"随机种子: {args.seed}")
console.print()
# ── 第 1 遍:统计 ──
console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]")
label_counts, shard_sizes = pass1_count(input_dir, "train")
total_samples = sum(shard_sizes)
num_labels = len(label_counts)
capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values())
console.print(f"分片数: {len(shard_sizes)}")
console.print(f"总样本数: {total_samples:,}")
console.print(f"标签种类: {num_labels}")
console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}")
console.print()
# ── 计算保留配额 ──
quotas = compute_quotas(
label_counts,
cap=args.cap_per_label,
target_total=args.target_total,
)
expected_train = sum(quotas.values())
console.print(f"期望训练集大小: {expected_train:,}")
# 标签分布统计
stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label)
stats_n_below = num_labels - stats_n_capped
stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label)
console.print(
f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, "
f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}"
)
console.print()
# ── 抽取评估集位置 ──
eval_rng = np.random.RandomState(args.seed + 100)
eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False)
eval_positions.sort()
eval_map = _global_to_shard(eval_positions, shard_sizes)
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
console.print()
# ── 第 2 遍:子采样 ──
console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]")
train_count, eval_count = pass2_subsample(
input_dir,
output_dir,
"train",
quotas,
eval_map,
shard_sizes,
args.shard_size,
shuffle=args.shuffle,
seed=args.seed,
)
# ── 输出总结 ──
console.print()
console.print("[bold green]=== 完成 ===[/bold green]")
console.print(f"训练集: {train_count:,} 样本")
console.print(f"评估集: {eval_count:,} 样本")
for split in ["train", "eval"]:
sdir = output_dir / split
if sdir.exists():
total_size = sum(
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
)
meta_path = sdir / "metadata.json"
if meta_path.exists():
with open(meta_path) as mf:
meta = json.load(mf)
else:
meta = {}
console.print(
f" {split}/: {total_size / (1024**3):.2f} GB, "
f"{meta.get('num_shards', '?')} shards"
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,443 @@
#!/usr/bin/env python3
"""
缺失字符补充工具
步骤 1: find-missing 扫描已预处理数据找出从未出现的 label ID输出 JSON
步骤 2: generate-template 根据 JSON 生成 JSONL 占位文件供用户手动填入包含缺失字的真实文本
步骤 3: preprocess-supplement 将填好的 JSONL 文本预处理为 .npz 分片输出到独立目录
用法
python -m model.supplement_missing find-missing \
--preprocessed-dir ./preprocessed/train \
--output missing_chars.json
python -m model.supplement_missing generate-template \
--missing-chars missing_chars.json \
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \
--missing-chars missing_chars.json \
--supplement-texts supplement_texts.jsonl \
--output-dir ./preprocessed/supplement \
--num-samples 100000
"""
import argparse
import json
from pathlib import Path
from typing import Set
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from rich.table import Table
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .preprocess import collect_samples
from .query import QueryEngine
from .trainer import preprocess_collate_fn, worker_init_fn
def scan_labels(preprocessed_dir: Path) -> Set[int]:
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
appeared: Set[int] = set()
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
if not shard_files:
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
return appeared
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
data = np.load(shard_path)
labels = data["labels"].astype(np.int64)
if labels.ndim > 1 and labels.shape[-1] == 1:
labels = labels.squeeze(-1)
unique_ids = np.unique(labels)
appeared.update(int(uid) for uid in unique_ids)
del data
return appeared
def cmd_find_missing(args):
console = Console()
preprocessed_dir = Path(args.preprocessed_dir)
if not preprocessed_dir.exists():
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
return
metadata_path = preprocessed_dir / "metadata.json"
if not metadata_path.exists():
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
return
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
console.print(
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
)
console.print("[bold cyan]扫描 labels...[/bold cyan]")
appeared = scan_labels(preprocessed_dir)
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
query_engine = QueryEngine()
query_engine.load()
all_ids = set(query_engine._id_to_info.keys())
missing_ids = all_ids - appeared
missing_chars = []
for mid in sorted(missing_ids):
if mid == 0:
continue
info = query_engine.query_by_id(mid)
if info is not None:
missing_chars.append(
{
"id": info.id,
"char": info.char,
"pinyin": info.pinyin,
"count": info.count,
}
)
result = {
"missing_count": len(missing_chars),
"missing_chars": missing_chars,
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
console.print(f"已出现标签: {len(appeared):,}")
console.print(
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
)
if missing_chars:
table = Table(
title=f"缺失字符 (共 {len(missing_chars)} 个)",
show_header=True,
header_style="bold magenta",
)
table.add_column("ID", style="cyan", width=8)
table.add_column("字符", style="yellow", width=6)
table.add_column("拼音", style="green", width=12)
table.add_column("语料频次", style="red", width=12)
for entry in missing_chars:
table.add_row(
str(entry["id"]),
entry["char"],
entry["pinyin"],
f"{entry['count']:,}",
)
console.print(table)
console.print(f"\n已输出到: {output_path}")
def cmd_generate_template(args):
console = Console()
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
return
num_entries = args.num_entries
total_lines = len(missing_chars) * num_entries
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for entry in missing_chars:
for i in range(num_entries):
line = json.dumps(
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
ensure_ascii=False,
)
f.write(line + "\n")
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
console.print(
f"{total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
)
def cmd_preprocess_supplement(args):
console = Console()
# 加载缺失字符
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
return
target_labels = {entry["id"] for entry in missing_chars}
target_labels.add(0) # 包含 EOS
# 解析参数
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
max_iter = args.num_samples * 5
num_workers = args.num_workers
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
console.print(f"补充文本: {args.supplement_texts}")
console.print(f"缺失字符数: {len(missing_chars)}")
console.print(f"目标样本: {args.num_samples:,}")
console.print(f"输出目录: {output_dir}")
console.print(f"Worker 数: {num_workers}")
console.print()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
dataset = PinyinInputDataset(
data_path="json",
max_workers=num_workers,
max_iter_length=max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
data_kwargs={
"data_files": args.supplement_texts,
"streaming": False,
},
target_labels=target_labels,
)
dataloader_kwargs = {
"batch_size": args.batch_size,
"num_workers": num_workers,
"pin_memory": False,
"worker_init_fn": worker_init_fn,
"collate_fn": preprocess_collate_fn(args.max_seq_length),
}
if num_workers > 0:
dataloader_kwargs["prefetch_factor"] = 2
dataloader_kwargs["persistent_workers"] = True
dataloader = DataLoader(dataset, **dataloader_kwargs)
logger.info("开始收集补充数据...")
count = collect_samples(
dataloader,
args.num_samples,
output_dir,
"supplement",
args.max_seq_length,
args.shard_size,
)
if count < args.num_samples:
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
console.print(f"生成样本: {count:,}")
console.print(f"输出目录: {output_dir}")
total_size = sum(
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
console.print()
console.print(
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
)
def main():
parser = argparse.ArgumentParser(
description="缺失字符补充工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
子命令
find-missing 扫描已预处理数据找出从未出现的 label ID
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片独立目录
示例
python -m model.supplement_missing find-missing \\
--preprocessed-dir ./preprocessed/train \\
--output missing_chars.json
python -m model.supplement_missing generate-template \\
--missing-chars missing_chars.json \\
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \\
--missing-chars missing_chars.json \\
--supplement-texts supplement_texts.jsonl \\
--output-dir ./preprocessed/supplement \\
--num-samples 100000
""",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
# find-missing
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
p_find.add_argument(
"--preprocessed-dir",
type=str,
required=True,
help="预处理数据目录(包含 shard_*.npz 和 metadata.json",
)
p_find.add_argument(
"--output",
type=str,
default="missing_chars.json",
help="输出 JSON 文件路径(默认: missing_chars.json",
)
# generate-template
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
p_gen.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_gen.add_argument(
"--output",
type=str,
default="supplement_texts.jsonl",
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl",
)
p_gen.add_argument(
"--num-entries",
type=int,
default=3,
help="每个缺失字符生成的模板条数(默认: 3",
)
# preprocess-supplement
p_pre = subparsers.add_parser(
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
)
p_pre.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_pre.add_argument(
"--supplement-texts",
type=str,
required=True,
help="已填写的补充文本 JSONL 文件路径",
)
p_pre.add_argument(
"--output-dir",
type=str,
required=True,
help="输出目录(独立目录,不会覆盖已有数据)",
)
p_pre.add_argument(
"--num-samples",
type=int,
required=True,
help="目标样本数量",
)
p_pre.add_argument(
"--batch-size",
type=int,
default=128,
help="批大小(默认: 128",
)
p_pre.add_argument(
"--num-workers",
type=int,
default=0,
help="DataLoader worker 数量。本地 JSONL 小文件建议 0默认: 0",
)
p_pre.add_argument(
"--max-seq-length",
type=int,
default=128,
help="最大序列长度(默认: 128",
)
p_pre.add_argument(
"--seed",
type=int,
default=42,
help="随机种子(默认: 42",
)
p_pre.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小(样本数),控制内存峰值(默认: 500万",
)
p_pre.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔,默认: 9,2,1",
)
p_pre.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
return
if args.command == "find-missing":
cmd_find_missing(args)
elif args.command == "generate-template":
cmd_generate_template(args)
elif args.command == "preprocess-supplement":
cmd_preprocess_supplement(args)
app = main
if __name__ == "__main__":
main()

View File

@ -28,6 +28,11 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .dataset import PinyinInputDataset
from .preprocessed_dataset import (
PreProcessedDataset,
is_preprocessed_data,
preprocessed_collate_fn,
)
# 导入模型和数据
from .model import InputMethodEngine
@ -97,7 +102,12 @@ class Trainer:
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = list([i for i in eval_dataloader])
if isinstance(eval_dataloader, DataLoader) and not isinstance(
eval_dataloader.dataset, torch.utils.data.IterableDataset
):
self.eval_dataloader = eval_dataloader
else:
self.eval_dataloader = list([i for i in eval_dataloader])
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
@ -620,14 +630,14 @@ class Trainer:
except Exception as e:
logger.error(f"Failed to write training status: {e}")
def _create_progress_bar(self) -> Progress:
"""创建Rich进度条"""
def _create_progress(self) -> Progress:
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("|"),
TimeElapsedColumn(),
TextColumn("|"),
TimeRemainingColumn(),
console=self.console,
expand=True,
@ -694,20 +704,30 @@ class Trainer:
accumulated_accuracy = 0.0
accumulation_counter = 0
# 创建进度条
with self._create_progress_bar() as progress:
# 计算每个 epoch 的步数
steps_per_epoch = max(1, self.total_steps // self.num_epochs)
# 创建双进度条epoch 级 + epoch 内 batch 级
with self._create_progress() as progress:
epoch_task = progress.add_task(
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.total_steps,
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
total=self.num_epochs,
completed=self.current_epoch,
)
batch_task = progress.add_task(
"[green]Batch",
total=steps_per_epoch,
)
# 训练循环
for epoch in range(self.current_epoch, self.num_epochs):
self.current_epoch = epoch
progress.update(
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
progress.reset(
batch_task,
total=steps_per_epoch,
description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}",
)
epoch_step = 0
for batch_idx, batch in enumerate(self.train_dataloader):
# 更新学习率
current_lr = self._update_learning_rate()
@ -733,16 +753,18 @@ class Trainer:
self.scaler.update()
self.optimizer.zero_grad()
# 更新进度条
# 更新 batch 进度条 (每 step 推进一次)
progress.update(
epoch_task,
batch_task,
advance=1,
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
f"Step {global_step}/{self.total_steps} | "
f"Loss: {loss:.4f} | "
f"LR: {current_lr:.2e}",
description=f"[green]Epoch {epoch + 1} "
f"Step {epoch_step + 1}/{steps_per_epoch}"
f" | Loss: {loss:.4f}"
f" | LR: {current_lr:.2e}",
)
epoch_step += 1
# 定期评估和记录
if (global_step + 1) % self.eval_frequency == 0:
# 计算平均指标
@ -770,7 +792,6 @@ class Trainer:
# 更新最佳模型
if eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = eval_metrics["eval_loss"]
# 只保存best_model不创建额外的checkpoint文件
self.save_checkpoint("best_model.pt", is_best=True)
# 记录到TensorBoard
@ -791,7 +812,7 @@ class Trainer:
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
)
progress.console.log(log_text)
progress.log(log_text)
# 重置累积指标
accumulated_loss = 0.0
@ -808,12 +829,30 @@ class Trainer:
# 检查是否达到总步数
if global_step >= self.total_steps:
progress.update(epoch_task, completed=self.total_steps)
break
# 进度条不重置,显示整体训练进度
# epoch 内步数检查
if epoch_step >= steps_per_epoch:
break
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
# epoch 完成
progress.update(epoch_task, advance=1)
# 每个 epoch 结束后评估并保存 best model
epoch_eval_metrics = self.evaluate()
if epoch_eval_metrics:
epoch_log_metrics = {
"epoch/eval_loss": epoch_eval_metrics["eval_loss"],
"epoch/eval_accuracy": epoch_eval_metrics["eval_accuracy"],
}
if self.writer is not None:
for key, value in epoch_log_metrics.items():
self.writer.add_scalar(key, value, global_step)
if epoch_eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = epoch_eval_metrics["eval_loss"]
self.save_checkpoint("best_model.pt", is_best=True)
# 每个 epoch 结束后保存检查点
self.save_epoch_checkpoint(epoch + 1)
# 检查是否达到总步数
@ -823,6 +862,34 @@ class Trainer:
# 训练完成
logger.info("Training completed!")
# 最终评估并保存 best model
logger.info("Running final evaluation...")
final_eval_metrics = self.evaluate()
if final_eval_metrics:
final_log_metrics = {
"train/loss": accumulated_loss / accumulation_counter
if accumulation_counter > 0
else 0.0,
"train/accuracy": accumulated_accuracy / accumulation_counter
if accumulation_counter > 0
else 0.0,
"train/learning_rate": self._get_current_lr(),
"eval/loss": final_eval_metrics["eval_loss"],
"eval/accuracy": final_eval_metrics["eval_accuracy"],
}
self._log_to_tensorboard(final_log_metrics, global_step)
logger.info(
f"Final eval loss: {final_eval_metrics['eval_loss']:.4f}, "
f"Final eval accuracy: {final_eval_metrics['eval_accuracy']:.4f}"
)
if final_eval_metrics["eval_loss"] < self.best_eval_loss:
self.best_eval_loss = final_eval_metrics["eval_loss"]
self.save_checkpoint("best_model.pt", is_best=True)
# 保存最终模型
self.save_checkpoint("final_model.pt")
logger.info("Final model saved.")
# 显示保存的 epoch checkpoint 信息
if self.epoch_checkpoints:
sorted_checkpoints = self.get_epoch_checkpoints()
@ -965,25 +1032,62 @@ def worker_init_fn(worker_id: int) -> None:
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
自定义批处理函数将多个样本组合成一个batch
支持动态填充根据batch内最大序列长度进行padding而非固定max_length
max_seq_length > 0 pad到指定长度用于预处理
Args:
batch: 样本列表每个样本是一个字典
max_seq_length: 目标序列长度0表示动态padding
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
input_ids_list = [item["input_ids"] for item in batch]
token_type_ids_list = [item["token_type_ids"] for item in batch]
attention_mask_list = [item["attention_mask"] for item in batch]
if max_seq_length > 0:
target_len = max_seq_length
else:
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
@ -1001,9 +1105,18 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
}
def preprocess_collate_fn(max_seq_length: int):
"""创建用于预处理的collate_fn始终pad到max_seq_length"""
def _collate(batch):
return collate_fn(batch, max_seq_length=max_seq_length)
return _collate
# Typer CLI应用
def create_dataloader(
dataset: PinyinInputDataset,
dataset,
batch_size: int,
num_workers: int = 2,
pin_memory: bool = True,
@ -1011,29 +1124,35 @@ def create_dataloader(
max_iter_length: Optional[int] = None,
) -> Any:
"""
创建数据加载器优先使用DataLoader2如果不可用则回退到DataLoader
专门针对流式数据集优化
创建数据加载器自动识别数据集类型
Args:
dataset: PinyinInputDataset实例
batch_size: 批次大小
num_workers: worker数量对于流式数据集建议为2
pin_memory: 是否固定内存
shuffle: 是否打乱流式数据集内部处理打乱
max_iter_length: 最大迭代长度用于计算总步数
Returns:
数据加载器实例
- PinyinInputDatasetIterableDataset使用流式加载
- PreProcessedDatasetmap-style使用标准加载支持 shuffle
"""
if isinstance(dataset, PreProcessedDataset):
logger.info(f"📊 使用预处理数据集,样本数: {len(dataset)}")
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=preprocessed_collate_fn,
drop_last=True,
persistent_workers=True if num_workers > 0 else False,
)
logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}")
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
drop_last=True,
prefetch_factor=2,
persistent_workers=True,
shuffle=shuffle,
)
@ -1098,6 +1217,11 @@ def train(
"--compile/--no-compile",
help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
),
moe_mode: str = typer.Option(
"all",
"--moe-mode",
help="MoE 计算策略: all全量计算, sparse稀疏计算, sparse_allow_graph稀疏+allow_in_graph",
),
):
"""
训练输入法模型
@ -1152,6 +1276,7 @@ def train(
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("模型", "MoE策略", moe_mode)
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
@ -1164,12 +1289,97 @@ def train(
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 检测数据类型并创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
is_train_preprocessed = is_preprocessed_data(train_data_path)
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2)
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
# 预打乱数据不需要 DataLoader 的 RandomSampler避免跨分片解压抖动
shuffle_train = not pre_shuffled
if max_iter_length > 0:
max_samples_per_epoch = max_iter_length
capped_samples = min(len(train_dataset), max_samples_per_epoch)
else:
capped_samples = len(train_dataset)
total_steps = (capped_samples // batch_size) * num_epochs
train_num_workers = min(num_workers, 1)
logger.info(
f"Preprocessed dataset: {len(train_dataset):,} samples, "
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
f"workers={train_num_workers}, steps={total_steps:,}"
)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=train_num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=shuffle_train,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
config_table.add_row("数据", "预打乱", str(pre_shuffled))
else:
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1,
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
total_steps = int(max_iter_length * num_epochs / batch_size)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
config_table.add_row("数据", "训练数据类型", "流式数据")
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=0,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)
config_table.add_row("数据", "评估数据类型", "预处理数据")
else:
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
config_table.add_row("数据", "评估数据类型", "流式数据")
config_table.add_row("数据", "总步数", str(total_steps))
console.print(config_table)
# 保存配置
config = {
"train_data_path": train_data_path,
@ -1202,6 +1412,10 @@ def train(
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
"moe_mode": moe_mode,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
"total_steps": total_steps,
}
config_file = output_path / "training_config.json"
@ -1210,52 +1424,6 @@ def train(
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
# 注意PinyinInputDataset是IterableDataset所以不能使用shuffle参数
# 多worker配置每个worker处理数据集的一个分片由dataset.__iter__中的shard处理
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
# 评估数据集(使用相同的设置,但可以调整参数)
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
console.print("[bold cyan]正在创建模型...[/bold cyan]")
model = InputMethodEngine(
vocab_size=vocab_size,
@ -1267,6 +1435,7 @@ def train(
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=compile,
moe_mode=moe_mode,
)
console.print(
@ -1279,7 +1448,7 @@ def train(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size),
total_steps=total_steps,
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,
@ -1339,21 +1508,46 @@ def evaluate(
@app.command()
def export(
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
output_path: str = typer.Option(
"./exported_model.onnx", "--output", "-o", help="输出路径"
output_dir: str = typer.Option(
"./exported_models", "--output", "-o", help="输出目录"
),
device: str = typer.Option("cpu", "--device", help="导出设备cpu/cuda"),
skip_verification: bool = typer.Option(
False, "--skip-verification", help="跳过ONNX模型验证"
),
):
"""
导出模型为ONNX格式
导出模型为ONNX格式生成 context_encoder.onnx decoder.onnx
"""
from .onnx_export import check_onnx_available, run_full_export
console = Console()
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
console.print("[bold cyan]━━━ ONNX 模型导出 ━━━[/bold cyan]")
console.print(f" 检查点: {checkpoint_path}")
console.print(f" 输出目录: {output_dir}")
console.print(f" 设备: {device}")
# 这里应该实现导出逻辑
# 1. 加载检查点
# 2. 导出为ONNX
if not check_onnx_available():
raise typer.Exit(1)
console.print("[yellow]导出功能待实现[/yellow]")
try:
context_encoder_path, decoder_path, config = run_full_export(
checkpoint_path=checkpoint_path,
output_dir=output_dir,
device=device,
skip_verification=skip_verification,
)
except Exception as e:
console.print(f"[bold red]导出失败: {e}[/bold red]")
raise typer.Exit(1)
console.print("\n[bold green]✓ 导出完成![/bold green]")
console.print(f" context_encoder.onnx -> {context_encoder_path}")
console.print(f" decoder.onnx -> {decoder_path}")
console.print(
f"\n MoE 层使用 'all' 模式(全量计算 {config.get('num_experts', '?')} 个专家),"
f"稀疏化优化可后续迭代"
)
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
import os
import sys
sys.path.append("src")
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
import time
import torch
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
)
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
@ -47,8 +48,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
part2 = "piaoliang"
part1 = "从中国回家后,他觉得世界上最好的城市就是"
part2 = "shanghai"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
@ -56,9 +57,9 @@ if len_py < 24:
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = "可行|特别|伤害"
part4 = ""
encoded = tokenizer(
f"{part4}|{part1}",
@ -83,8 +84,11 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
checkpoint = torch.load(
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
input_ids = sample["input_ids"]
token_type_ids = sample["token_type_ids"]
@ -97,12 +101,13 @@ for k, v in sample.items():
print(f"{k}: {v}")
start = time.time()
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)
print(sort_res[0:5])
with torch.no_grad():
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)
print(sort_res[0:5])
query_engine = QueryEngine()
query_engine.load()

View File

@ -392,7 +392,13 @@ def check_compile_issues():
issues = []
# 检查 components.py 中的潜在问题
with open("src/model/components.py", "r") as f:
components_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"src",
"model",
"components.py",
)
with open(components_path, "r") as f:
content = f.read()
# 检查 float('-inf')

View File

@ -4,9 +4,13 @@
解决设备转换和权重加载问题
"""
import os
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
@ -196,7 +200,7 @@ def test_id_mapping():
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"

127
tests/test_dataset.py Normal file
View File

@ -0,0 +1,127 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.model import InputMethodEngine
from model.query import QueryEngine
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
from model.dataset import PinyinInputDataset
def worker_init_fn(worker_id: int) -> None:
"""
初始化每个DataLoader worker的随机种子确保可复现性
Args:
worker_id: worker的ID
"""
worker_seed = torch.initial_seed() % (2**32)
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
支持动态padding根据batch内最大序列长度进行padding
"""
input_ids_list = [item["input_ids"] for item in batch]
token_type_ids_list = [item["token_type_ids"] for item in batch]
attention_mask_list = [item["attention_mask"] for item in batch]
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"labels": labels,
"history_slot_ids": history_slot_ids,
"prefix": prefixes,
"suffix": suffixes,
"pinyin": pinyins,
"pinyin_ids": pinyin_ids,
}
train_dataset = PinyinInputDataset(
data_path="/home/songsenand/Data/corpus/CCI-Data/",
max_workers=-1, # 自动选择worker数量
max_iter_length=1000000,
text_field="text",
py_style_weight=(90, 2, 1),
shuffle_buffer_size=20000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
dataloader = DataLoader(
train_dataset,
batch_size=512,
num_workers=16,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题
persistent_workers=True,
)
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
pass

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from rich.console import Console
from torch.utils.data import DataLoader, Dataset
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.model import InputMethodEngine
from src.model.trainer import Trainer

409
verify_onnx.py Normal file
View File

@ -0,0 +1,409 @@
#!/usr/bin/env python3
"""
ONNX模型验证脚本
验证导出的ONNX模型与原始PyTorch模型输出的一致性
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
import onnxruntime as ort
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from src.model.export_models import create_export_models_from_checkpoint
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
"""
比较PyTorch和ONNX输出
Args:
pytorch_output: PyTorch张量
onnx_output: ONNX Runtime输出numpy数组
name: 输出名称用于错误信息
rtol: 相对容差
atol: 绝对容差
Returns:
bool: 是否匹配
"""
# 转换PyTorch输出为numpy
if isinstance(pytorch_output, torch.Tensor):
pytorch_np = pytorch_output.detach().cpu().numpy()
else:
pytorch_np = np.array(pytorch_output)
# 确保形状一致
if pytorch_np.shape != onnx_output.shape:
print(
f"{name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
)
return False
# 计算差异
diff = np.abs(pytorch_np - onnx_output)
max_diff = np.max(diff)
mean_diff = np.mean(diff)
# 检查是否在容差范围内
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
if is_close:
print(f"{name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
else:
print(f"{name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print(
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
)
return is_close
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证上下文编码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证上下文编码器: {onnx_path}")
# 加载PyTorch模型
context_encoder_export, _, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2 # 使用batch_size=2测试动态批处理
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
# 随机屏蔽一些位置
attention_mask[:, seq_len // 2 :] = 0
# PyTorch推理
with torch.no_grad():
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
# ONNX推理
onnx_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
all_match = True
for i, name in enumerate(output_names):
if i < len(pytorch_outputs) and i < len(onnx_outputs):
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
all_match = all_match and match
return all_match
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
"""
验证解码器
Args:
checkpoint_path: PyTorch checkpoint路径
onnx_path: ONNX模型路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 验证解码器: {onnx_path}")
# 加载PyTorch模型
_, decoder_export, config = create_export_models_from_checkpoint(
checkpoint_path, device
)
# 创建ONNX Runtime会话
session = ort.InferenceSession(
onnx_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 2
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
dim = config.get("dim", 512)
num_slots = config.get("num_slots", 8)
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
# PyTorch推理
with torch.no_grad():
pytorch_output = decoder_export(
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
)
# ONNX推理
onnx_inputs = {
"context_H": context_H.numpy(),
"pinyin_P": pinyin_P.numpy(),
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask.numpy(),
"pinyin_mask": pinyin_mask.numpy(),
}
onnx_outputs = session.run(None, onnx_inputs)
# 比较输出
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
def verify_end_to_end(
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
):
"""
端到端验证比较完整推理流程
Args:
checkpoint_path: PyTorch checkpoint路径
context_encoder_path: 上下文编码器ONNX路径
decoder_path: 解码器ONNX路径
device: 设备
Returns:
bool: 验证是否通过
"""
print(f"\n🔍 端到端验证")
# 加载原始PyTorch模型
from src.model.model import InputMethodEngine
checkpoint = torch.load(checkpoint_path, map_location=device)
if "config" in checkpoint:
config = checkpoint["config"]
else:
config = {
"vocab_size": 10019,
"pinyin_vocab_size": 30,
"dim": 512,
"num_slots": 8,
"n_layers": 4,
"n_heads": 4,
"num_experts": 10,
"max_seq_len": 128,
}
model = InputMethodEngine(
vocab_size=config.get("vocab_size", 10019),
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
dim=config.get("dim", 512),
num_slots=config.get("num_slots", 8),
n_layers=config.get("n_layers", 4),
n_heads=config.get("n_heads", 4),
num_experts=config.get("num_experts", 10),
max_seq_len=config.get("max_seq_len", 128),
compile=False,
)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# 创建ONNX Runtime会话
context_session = ort.InferenceSession(
context_encoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
decoder_session = ort.InferenceSession(
decoder_path,
providers=[
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
],
)
# 创建测试输入
batch_size = 1
seq_len = config.get("max_seq_len", 128)
pinyin_len = 24
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
# PyTorch完整推理
with torch.no_grad():
pytorch_logits = model(
input_ids=input_ids,
token_type_ids=torch.zeros_like(input_ids), # 简化处理
attention_mask=attention_mask,
pinyin_ids=pinyin_ids,
history_slot_ids=history_slot_ids,
)
# ONNX推理流程
# 1. 上下文编码器
context_inputs = {
"input_ids": input_ids.numpy(),
"pinyin_ids": pinyin_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
context_outputs = context_session.run(None, context_inputs)
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
# 2. 解码器
decoder_inputs = {
"context_H": context_H,
"pinyin_P": pinyin_P,
"history_slot_ids": history_slot_ids.numpy(),
"context_mask": context_mask,
"pinyin_mask": pinyin_mask,
}
onnx_outputs = decoder_session.run(None, decoder_inputs)
# 比较输出
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
def main():
parser = argparse.ArgumentParser(description="ONNX模型验证")
parser.add_argument(
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
)
parser.add_argument(
"--context-encoder",
type=str,
help="上下文编码器ONNX路径默认: ./exported_models/context_encoder.onnx",
)
parser.add_argument(
"--decoder",
type=str,
help="解码器ONNX路径默认: ./exported_models/decoder.onnx",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default="./exported_models",
help="导出目录(如果未指定单个模型路径)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="验证设备(默认: cpu",
)
parser.add_argument(
"--skip-context", action="store_true", help="跳过上下文编码器验证"
)
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
args = parser.parse_args()
# 确定模型路径
if args.context_encoder:
context_encoder_path = args.context_encoder
else:
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
if args.decoder:
decoder_path = args.decoder
else:
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
print("🔬 ONNX模型验证")
print("=" * 60)
print(f"Checkpoint: {args.checkpoint}")
print(f"上下文编码器: {context_encoder_path}")
print(f"解码器: {decoder_path}")
print(f"设备: {args.device}")
print()
all_pass = True
# 验证上下文编码器
if not args.skip_context and os.path.exists(context_encoder_path):
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
print("✅ 上下文编码器验证通过")
else:
print("❌ 上下文编码器验证失败")
all_pass = False
elif not args.skip_context:
print("⚠️ 上下文编码器文件不存在,跳过验证")
# 验证解码器
if not args.skip_decoder and os.path.exists(decoder_path):
if verify_decoder(args.checkpoint, decoder_path, args.device):
print("✅ 解码器验证通过")
else:
print("❌ 解码器验证失败")
all_pass = False
elif not args.skip_decoder:
print("⚠️ 解码器文件不存在,跳过验证")
# 端到端验证
if (
not args.skip_end_to_end
and os.path.exists(context_encoder_path)
and os.path.exists(decoder_path)
):
if verify_end_to_end(
args.checkpoint, context_encoder_path, decoder_path, args.device
):
print("✅ 端到端验证通过")
else:
print("❌ 端到端验证失败")
all_pass = False
print("\n" + "=" * 60)
if all_pass:
print("🎉 所有验证通过ONNX模型与PyTorch模型输出一致")
else:
print("❌ 部分验证失败,请检查模型导出过程")
sys.exit(1)
if __name__ == "__main__":
main()