Compare commits
15 Commits
33f56f709b
...
88955bcfdd
| Author | SHA1 | Date |
|---|---|---|
|
|
88955bcfdd | |
|
|
53f244de2f | |
|
|
71ef54e3d4 | |
|
|
722912f296 | |
|
|
0862b5b8fc | |
|
|
27beb7f0b1 | |
|
|
d0f1534086 | |
|
|
483e4d4f98 | |
|
|
432132a108 | |
|
|
e8eab1f260 | |
|
|
8b41bcdc6f | |
|
|
4ded2d656f | |
|
|
1b7da9ddd4 | |
|
|
710cfe7fc2 | |
|
|
3175ace9c5 |
|
|
@ -177,3 +177,9 @@ cython_debug/
|
|||
uv.lock
|
||||
|
||||
data/*
|
||||
|
||||
|
||||
**/*.onnx
|
||||
**/*.data
|
||||
**/*.npz
|
||||
**/*.pt
|
||||
|
|
|
|||
|
|
@ -0,0 +1,229 @@
|
|||
# 预处理数据预打乱方案
|
||||
|
||||
## 目标
|
||||
|
||||
用 CPU 机时预打乱数据,让训练时直接用 `shuffle=False` 顺序读取,消除跨分片缓存抖动和 CPU 利用率低的问题。
|
||||
|
||||
## 改动清单
|
||||
|
||||
### 1. `src/model/subsample.py` — 默认开启输出打乱
|
||||
|
||||
**函数签名变更:**
|
||||
```python
|
||||
def pass2_subsample(
|
||||
...,
|
||||
shuffle: bool = True, # 新增
|
||||
seed: int = 42, # 新增
|
||||
) -> Tuple[int, int]:
|
||||
```
|
||||
|
||||
**改动点:**
|
||||
- `rng = np.random.RandomState()` → `rng = np.random.RandomState(seed)`
|
||||
- 新增 `shuffle_rng = np.random.RandomState(seed + 1)` 用于输出打乱
|
||||
- 在两次 `np.savez_compressed` 写入前(line ~251 和 line ~273),插入:
|
||||
```python
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
```
|
||||
- main() 中新增参数:
|
||||
```python
|
||||
parser.add_argument("--no-shuffle", action="store_false", dest="shuffle",
|
||||
help="禁用输出分片内部打乱")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子(用于选择+打乱)")
|
||||
```
|
||||
- 调用 `pass2_subsample` 时传入 `shuffle=args.shuffle, seed=args.seed`
|
||||
|
||||
**metadata.json 新增字段:**
|
||||
```json
|
||||
"pre_shuffled": true,
|
||||
"seed": 42
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 新建 `src/model/shuffle_npz.py` — 平衡+打乱已有数据
|
||||
|
||||
处理流程:
|
||||
|
||||
```
|
||||
Phase 1: npz → 逐字段打乱 → .npy(mmap 友好)
|
||||
输入: 19个不平衡 .npz 分片(100M样本,~10GB 压缩)
|
||||
输出: 114个临时 .npy 文件(6字段 × 19分片,~144GB 未压缩)
|
||||
内存峰值: ~10GB(单字段 int16 最大 ~5GB + permuted copy ~5GB)
|
||||
耗时: ~15-20分钟(解压+写入)
|
||||
|
||||
Phase 2: .npy → 平衡分配 → .npz
|
||||
输入: Phase 1 的 .npy 文件(mmap 模式)
|
||||
输出: 100个平衡 .npz 分片(每片100万样本,~100MB/片压缩后)
|
||||
每个输出分片 = 从19个源各取比例份额 → concatenate → shuffle → save
|
||||
内存峰值: ~3GB(1个输出缓冲 + mmap pages)
|
||||
耗时: ~10-15分钟(mmap读取+压缩写入)
|
||||
|
||||
总计: ~30-40分钟,峰值内存 ~10GB
|
||||
```
|
||||
|
||||
**磁盘需求:**
|
||||
- 临时 .npy 文件(Phase 1→2 中间产物):~144GB
|
||||
- 最终输出 .npz:~10GB
|
||||
- 临时文件在 Phase 2 完成后自动删除
|
||||
|
||||
**用法:**
|
||||
```bash
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /home/songsenand/DataSet/SubPro \
|
||||
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
**关键实现:**
|
||||
|
||||
Phase 1 — 逐字段加载+打乱:
|
||||
```python
|
||||
for src_idx in range(num_shards):
|
||||
data = np.load(shard_path) # lazy NpzFile
|
||||
n = shard_sizes[src_idx]
|
||||
perm = rng.permutation(n)
|
||||
|
||||
for field in FIELDS:
|
||||
arr = data[field].copy() # ~5GB peak (input_ids)
|
||||
shuffled = arr[perm] # ~5GB temp
|
||||
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||
del arr, shuffled
|
||||
gc.collect()
|
||||
|
||||
data.close()
|
||||
```
|
||||
|
||||
Phase 2 — 平衡分配:
|
||||
```python
|
||||
# 打开所有源 mmap(读模式,零内存)
|
||||
src_mmaps = [
|
||||
{f: np.load(temp_dir / f / f"shard_{i:06d}.npy", mmap_mode='r') for f in FIELDS}
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
for out_j in range(num_output_shards):
|
||||
buffers = {f: [] for f in FIELDS}
|
||||
|
||||
for src_i in range(num_shards):
|
||||
s = shard_sizes[src_i]
|
||||
start = (out_j * s) // num_output_shards
|
||||
end = ((out_j + 1) * s) // num_output_shards
|
||||
if start >= end:
|
||||
continue
|
||||
for f in FIELDS:
|
||||
chunk = src_mmaps[src_i][f][start:end].copy()
|
||||
buffers[f].append(chunk)
|
||||
|
||||
output = {f: np.concatenate(buffers[f]) for f in FIELDS}
|
||||
# 额外打乱(跨源混合)
|
||||
perm = rng.permutation(len(output[FIELDS[0]]))
|
||||
for f in FIELDS:
|
||||
output[f] = output[f][perm]
|
||||
|
||||
np.savez_compressed(output_dir / f"shard_{out_j:06d}.npz", **output)
|
||||
del output, buffers, perm
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
**输出 metadata.json:**
|
||||
```json
|
||||
{
|
||||
"num_samples": 99998406,
|
||||
"max_seq_length": 128,
|
||||
"dtype": "int16",
|
||||
"fields": [...],
|
||||
"shard_size": 1000000,
|
||||
"num_shards": 100,
|
||||
"shard_sizes": [1000000, ..., 998406],
|
||||
"pre_shuffled": true,
|
||||
"seed": 42
|
||||
}
|
||||
```
|
||||
|
||||
**eval 目录处理:** 如果 `--input-dir/eval/` 存在,直接复制到 `--output-dir/eval/`(eval 数据量小,不需要打乱)
|
||||
|
||||
---
|
||||
|
||||
### 3. `src/model/trainer.py` — 预处理数据禁用 shuffle
|
||||
|
||||
**改动点(train 函数,line ~1258-1272):**
|
||||
|
||||
```python
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
||||
# pre_shuffled 数据不需要 DataLoader 的 RandomSampler
|
||||
shuffle_train = not train_dataset.metadata.get("pre_shuffled", False)
|
||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||
# 支持 max_iter_length 限制总步数
|
||||
if max_iter_length > 0:
|
||||
max_steps_per_epoch = max_iter_length // batch_size
|
||||
total_steps = min(total_steps, max_steps_per_epoch * num_epochs)
|
||||
train_num_workers = min(num_workers, 1)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=train_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=shuffle_train, # 预打乱数据不 shuffle
|
||||
)
|
||||
```
|
||||
|
||||
**eval DataLoader(line ~1295-1303):**
|
||||
|
||||
```python
|
||||
if is_eval_preprocessed:
|
||||
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=0, # eval 数据小,单进程足够
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False, # eval 不需要打乱
|
||||
)
|
||||
```
|
||||
|
||||
**`create_dataloader` 函数(line ~1076-1114):** 无需改动,`shuffle` 参数已透传。
|
||||
|
||||
---
|
||||
|
||||
### 4. `src/model/preprocessed_dataset.py`
|
||||
|
||||
现有代码无需修改。`PreProcessedDataset` 已经可以正确处理 `shuffle=False` 的情况(PyTorch 的 `SequentialSampler` 会按 0..N-1 顺序读取)。
|
||||
|
||||
`metadata["pre_shuffled"]` 字段由 subsample.py 和 shuffle_npz.py 在写入 metadata.json 时添加,训练代码读取判断即可。
|
||||
|
||||
---
|
||||
|
||||
## 执行顺序
|
||||
|
||||
```bash
|
||||
# Step 1: 打乱并平衡已有的 100M 数据集
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /home/songsenand/DataSet/SubPro \
|
||||
--output-dir /home/songsenand/DataSet/SubPro_Shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
|
||||
# Step 2: 用新数据训练(数据已打乱,顺序读取即可)
|
||||
uv run train-model train \
|
||||
--train-data-path /home/songsenand/DataSet/SubPro_Shuffled/train \
|
||||
--eval-data-path /home/songsenand/DataSet/SubPro_Shuffled/eval \
|
||||
-b 16 \
|
||||
-o ~/tmp \
|
||||
--eval-frequency 20 \
|
||||
--save-frequency 40
|
||||
```
|
||||
|
||||
## 预期效果
|
||||
|
||||
| 改动前 | 改动后 |
|
||||
|---|---|
|
||||
| 每 batch 跨 13 个 shard | 顺序读,1 个 shard 在缓存中 |
|
||||
| 每 batch 数据加载 2-3 分钟 | ~0.1-0.5 秒(纯 mmap/memory) |
|
||||
| CPU 利用率 10% | 正常(训练计算是瓶颈) |
|
||||
| 内存 40GB+ | <20GB(单 shard 1M 样本 ≈ 1.4GB) |
|
||||
135
README.md
135
README.md
|
|
@ -72,6 +72,36 @@
|
|||
* **目标槽位序列**:真实用户输入的文字 ID 序列,作为模型的监督信号 [1]。
|
||||
* **标签处理**:在每一个槽位步(Step),模型需要预测该步对应的真实文字 ID [1]。
|
||||
|
||||
#### 4.1.1 历史槽位(history_slot_ids)的设计策略
|
||||
|
||||
`history_slot_ids` 的语义是**当前拼音组内已经确认的字符**,用于模拟真实的输入法输入场景。
|
||||
|
||||
**设计原则:**
|
||||
|
||||
1. **拼音组隔离**:每个拼音组(jieba 分词的一个词或合并后的短词组)对应一次独立的拼音输入会话。拼音组开始时 `history_slot_ids` 为空(全 0),组内逐字累积。
|
||||
```
|
||||
输入: 特种兵 → 拼音组 "tezhongbing"
|
||||
Step 1: pinyin="tezhongbing", history=[] → 预测 特
|
||||
Step 2: pinyin="tezhongbing", history=[特] → 预测 种
|
||||
Step 3: pinyin="tezhongbing", history=[特, 种] → 预测 兵
|
||||
输入: 十分倾佩 → 拼音组 "shifen" + 拼音组 "qing" + 拼音组 "pei"
|
||||
拼音组 "qing": pinyin="qing", history=[] → 预测 倾
|
||||
拼音组 "pei": pinyin="pei", history=[] → 预测 佩
|
||||
```
|
||||
|
||||
2. **上下文承载前缀字**:当拼音组切换时,前一个拼音组已确认的字符(如"特")已写入文本,通过 BERT 编码的 `input_ids`(而非 `history_slot_ids`)传递给模型。这模拟了真实输入法环境——用户打完"te"确认"特"后,文本中已有"特",再输入"zhongbing"时,模型从上下文(`input_ids`)中能看到"特"。
|
||||
|
||||
3. **破词续接**:当一个词被按一定概率(~10%)拆分时(`should_break=True`),Phase 1 处理前缀(如"特"),Phase 2 处理续接(如"种兵")。Phase 2 的 `cont_processed_history = []` 是**正确的设计**——前缀已确认为文本的一部分,续接是一个新的拼音会话,history 从空开始。
|
||||
```
|
||||
"特种兵" 断在"特"后:
|
||||
Phase 1: pinyin="te", history=[] → 预测 特 → 确认为文本上下文
|
||||
Phase 2: pinyin="zhongbing", history=[] → 预测 种
|
||||
history=[种] → 预测 兵
|
||||
```
|
||||
Phase 2 中"特"已通过 `part1`(光标前文本)进入 BERT 上下文中,不在 `history_slot_ids` 中。
|
||||
|
||||
4. **短词合并**:相邻单字词(≤2 字符)有 50% 概率合并为一个拼音组。合并后组内字符共享 history 累积;未合并时各自为独立拼音组(history 为空)。两种方式都对应真实输入法场景——用户可能一次性输入多个字,也可能逐字输入。
|
||||
|
||||
### 4.2 损失函数与优化
|
||||
* **损失函数**:使用 **CrossEntropyLoss** 计算每一步预测结果与真实标签之间的差异 [1]。
|
||||
* **掩码机制**:仅计算非填充位置(Non-padding positions)的损失,忽略无效的时间步 [1]。
|
||||
|
|
@ -790,112 +820,7 @@ train-model evaluate \
|
|||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 6.8 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 6.9 导出模型(开发中)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,436 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
束搜索算法演示
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行束搜索推理,
|
||||
模拟输入法场景:给定上下文、拼音和已确认字符,生成候选汉字序列。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
|
||||
class ONNXBeamSearch:
|
||||
"""
|
||||
基于ONNX模型的束搜索解码器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, context_encoder_path: str, decoder_path: str, device: str = "cpu"
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
device: 推理设备(cpu或cuda)
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
# 配置ONNX Runtime提供程序
|
||||
if device == "cuda":
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_session = ort.InferenceSession(
|
||||
context_encoder_path, providers=providers
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(decoder_path, providers=providers)
|
||||
|
||||
# 获取输入输出信息
|
||||
self.context_input_info = {
|
||||
input.name: input for input in self.context_session.get_inputs()
|
||||
}
|
||||
self.context_output_info = {
|
||||
output.name: output for output in self.context_session.get_outputs()
|
||||
}
|
||||
self.decoder_input_info = {
|
||||
input.name: input for input in self.decoder_session.get_inputs()
|
||||
}
|
||||
self.decoder_output_info = {
|
||||
output.name: output for output in self.decoder_session.get_outputs()
|
||||
}
|
||||
|
||||
print("📦 ONNX模型加载完成")
|
||||
print(f" 上下文编码器输入: {list(self.context_input_info.keys())}")
|
||||
print(f" 上下文编码器输出: {list(self.context_output_info.keys())}")
|
||||
print(f" 解码器输入: {list(self.decoder_input_info.keys())}")
|
||||
print(f" 解码器输出: {list(self.decoder_output_info.keys())}")
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
context_prompts: List[str] = None,
|
||||
tokenizer=None,
|
||||
query_engine=None,
|
||||
max_seq_len: int = 128,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
准备模型输入
|
||||
|
||||
注意: 这是一个简化的版本,实际应用中需要实现完整的预处理逻辑
|
||||
这里使用随机数据模拟,实际应使用tokenizer和查询引擎
|
||||
|
||||
Args:
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 拼音输入
|
||||
context_prompts: 上下文提示
|
||||
tokenizer: 分词器(未实现)
|
||||
query_engine: 查询引擎(未实现)
|
||||
max_seq_len: 最大序列长度
|
||||
|
||||
Returns:
|
||||
输入字典,包含input_ids, pinyin_ids, attention_mask
|
||||
"""
|
||||
# 在实际应用中,这里应该实现:
|
||||
# 1. 使用tokenizer将文本转换为input_ids
|
||||
# 2. 使用text_to_pinyin_ids将拼音转换为pinyin_ids
|
||||
# 3. 构建attention_mask
|
||||
|
||||
# 简化为随机数据(用于演示)
|
||||
batch_size = 1
|
||||
seq_len = min(max_seq_len, 64) # 简化长度
|
||||
|
||||
# 模拟输入
|
||||
input_ids = np.random.randint(0, 1000, (batch_size, seq_len), dtype=np.int64)
|
||||
pinyin_ids = np.random.randint(
|
||||
0, 30, (batch_size, 24), dtype=np.int64
|
||||
) # 固定24长度
|
||||
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
|
||||
|
||||
# 在实际应用中,应该这样处理拼音:
|
||||
# from src.model.dataset import text_to_pinyin_ids
|
||||
# pinyin_ids_list = text_to_pinyin_ids(pinyin)
|
||||
# pinyin_ids_array = np.array(pinyin_ids_list, dtype=np.int64).reshape(1, -1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def run_context_encoder(
|
||||
self, inputs: Dict[str, np.ndarray]
|
||||
) -> Tuple[np.ndarray, ...]:
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
inputs: 输入字典
|
||||
|
||||
Returns:
|
||||
上下文编码器输出: (context_H, pinyin_P, context_mask, pinyin_mask)
|
||||
"""
|
||||
# 准备ONNX输入(确保顺序正确)
|
||||
onnx_inputs = {}
|
||||
for input_name in self.context_input_info.keys():
|
||||
if input_name in inputs:
|
||||
onnx_inputs[input_name] = inputs[input_name]
|
||||
else:
|
||||
# 对于缺失的输入,使用默认值
|
||||
shape = self.context_input_info[input_name].shape
|
||||
dtype = self.context_input_info[input_name].type
|
||||
# 简化处理:创建零数组
|
||||
if "int" in dtype:
|
||||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.int64)
|
||||
else:
|
||||
onnx_inputs[input_name] = np.zeros(shape, dtype=np.float32)
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_session.run(None, onnx_inputs)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
history_slot_ids: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码 [batch, seq_len, 512]
|
||||
pinyin_P: 拼音编码 [batch, 24, 512]
|
||||
history_slot_ids: 历史槽位IDs [batch, 8]
|
||||
context_mask: 上下文掩码 [batch, seq_len]
|
||||
pinyin_mask: 拼音掩码 [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits [batch, vocab_size]
|
||||
"""
|
||||
inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(None, inputs)
|
||||
|
||||
return outputs[0]
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
beam_size: int = 5,
|
||||
max_length: int = 8,
|
||||
vocab_size: int = 10019,
|
||||
temperature: float = 1.0,
|
||||
length_penalty: float = 0.6,
|
||||
) -> List[Tuple[List[int], float]]:
|
||||
"""
|
||||
束搜索算法
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
temperature: 温度参数(控制随机性)
|
||||
length_penalty: 长度惩罚系数
|
||||
|
||||
Returns:
|
||||
排序后的候选序列列表,每个元素为(序列, 分数)
|
||||
"""
|
||||
# 初始束:空序列,对数概率为0
|
||||
beams = [([], 0.0)] # (序列, 累计对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids
|
||||
if len(seq) < 8:
|
||||
# 填充到8个槽位
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
# 只保留最近8个字符
|
||||
history = seq[-8:]
|
||||
|
||||
history_array = np.array([history], dtype=np.int64)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_array, context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 应用温度
|
||||
logits = logits / temperature
|
||||
|
||||
# 转换为概率
|
||||
probs = F.softmax(torch.from_numpy(logits[0]), dim=-1).numpy()
|
||||
|
||||
# 获取top-k候选(k = beam_size * 2 以增加多样性)
|
||||
top_k = min(beam_size * 2, vocab_size)
|
||||
top_indices = np.argsort(probs)[-top_k:][::-1]
|
||||
top_probs = probs[top_indices]
|
||||
|
||||
# 扩展束
|
||||
for idx, prob in zip(top_indices, top_probs):
|
||||
new_seq = seq + [int(idx)]
|
||||
# 更新分数:累计对数概率 + log(prob)
|
||||
new_score = score + np.log(prob + 1e-10)
|
||||
|
||||
# 应用长度惩罚
|
||||
length_penalized_score = new_score / (
|
||||
(5 + len(new_seq)) ** length_penalty
|
||||
)
|
||||
|
||||
new_beams.append((new_seq, length_penalized_score, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True) # 按惩罚后分数排序
|
||||
beams = [(seq, orig_score) for seq, _, orig_score in new_beams[:beam_size]]
|
||||
|
||||
# 检查是否所有序列都已结束(结束符ID为0)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if len(seq) > 0)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
# 返回原始分数(未应用长度惩罚)
|
||||
return beams
|
||||
|
||||
def interactive_beam_search(self):
|
||||
"""交互式束搜索演示"""
|
||||
print("\n" + "=" * 60)
|
||||
print("束搜索交互演示")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n📝 输入法场景模拟:")
|
||||
print(" 假设用户正在输入拼音 'nihao',已确认第一个字 '你'")
|
||||
print(" 上下文: 光标前文本 '今天天气很好',光标后文本 '我们去公园玩'")
|
||||
print(" 上下文提示: '张三,李四'(模型不掌握的专有名词)")
|
||||
print("-" * 60)
|
||||
|
||||
# 模拟输入(实际应用中应从用户获取)
|
||||
text_before = "今天天气很好"
|
||||
text_after = "我们去公园玩"
|
||||
pinyin = "hao" # 继续输入"hao"
|
||||
slot_chars = ["你"] # 已确认的字符
|
||||
context_prompts = ["张三", "李四"]
|
||||
|
||||
print(f"\n📋 输入参数:")
|
||||
print(f" 光标前文本: '{text_before}'")
|
||||
print(f" 光标后文本: '{text_after}'")
|
||||
print(f" 拼音: '{pinyin}'")
|
||||
print(f" 槽位历史: {slot_chars}")
|
||||
print(f" 上下文提示: {context_prompts}")
|
||||
|
||||
# 准备输入(简化版,使用随机数据)
|
||||
print(f"\n🔧 准备模型输入...")
|
||||
inputs = self.prepare_inputs(
|
||||
text_before=text_before,
|
||||
text_after=text_after,
|
||||
pinyin=pinyin,
|
||||
context_prompts=context_prompts,
|
||||
)
|
||||
|
||||
# 运行上下文编码器
|
||||
print(f"🧠 运行上下文编码器...")
|
||||
context_outputs = self.run_context_encoder(inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||||
|
||||
print(f"✅ 上下文编码完成")
|
||||
print(f" context_H形状: {context_H.shape}")
|
||||
print(f" pinyin_P形状: {pinyin_P.shape}")
|
||||
|
||||
# 将槽位历史转换为ID(简化:使用随机ID)
|
||||
# 实际应用中应使用query_engine将汉字转换为ID
|
||||
slot_ids = [42] if slot_chars else [0] # 假设'你'的ID是42
|
||||
if len(slot_ids) < 8:
|
||||
slot_ids = slot_ids + [0] * (8 - len(slot_ids))
|
||||
|
||||
# 运行束搜索
|
||||
print(f"\n🔍 运行束搜索 (beam_size=3, max_length=4)...")
|
||||
beams = self.beam_search(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
beam_size=3,
|
||||
max_length=4,
|
||||
vocab_size=10019,
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
print(f"\n🏆 束搜索结果:")
|
||||
print("-" * 50)
|
||||
for i, (seq, score) in enumerate(beams):
|
||||
# 将ID序列转换为汉字(简化:显示ID)
|
||||
seq_str = " ".join([f"ID:{id}" if id != 0 else "END" for id in seq])
|
||||
print(f"{i + 1}. 序列: [{seq_str}]")
|
||||
print(f" 对数概率: {score:.4f}")
|
||||
print(f" 概率: {np.exp(score):.6f}")
|
||||
print()
|
||||
|
||||
print("📝 说明:")
|
||||
print(" - 'END' 表示结束符 (ID: 0)")
|
||||
print(" - 实际应用中应将ID转换为汉字")
|
||||
print(" - 最高分数的序列作为最终预测")
|
||||
|
||||
return beams
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="束搜索算法演示")
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
type=str,
|
||||
default="./exported_models/context_encoder.onnx",
|
||||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
default="./exported_models/decoder.onnx",
|
||||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="推理设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument("--beam-size", type=int, default=3, help="束大小(默认: 3)")
|
||||
parser.add_argument(
|
||||
"--max-length", type=int, default=4, help="最大生成长度(默认: 4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="交互模式(默认: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if not os.path.exists(args.context_encoder):
|
||||
print(f"❌ 上下文编码器文件不存在: {args.context_encoder}")
|
||||
print(f" 请先运行 export_onnx.py 导出模型")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.decoder):
|
||||
print(f"❌ 解码器文件不存在: {args.decoder}")
|
||||
print(f" 请先运行 export_onnx.py 导出模型")
|
||||
return
|
||||
|
||||
print("🚀 束搜索算法演示")
|
||||
print("=" * 60)
|
||||
print(f"上下文编码器: {args.context_encoder}")
|
||||
print(f"解码器: {args.decoder}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"束大小: {args.beam_size}")
|
||||
print(f"最大长度: {args.max_length}")
|
||||
|
||||
# 初始化束搜索器
|
||||
beam_searcher = ONNXBeamSearch(
|
||||
context_encoder_path=args.context_encoder,
|
||||
decoder_path=args.decoder,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.interactive:
|
||||
# 运行交互演示
|
||||
beam_searcher.interactive_beam_search()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 演示完成")
|
||||
print("\n💡 下一步:")
|
||||
print(" 1. 实现完整的输入预处理(tokenizer和拼音转换)")
|
||||
print(" 2. 集成查询引擎以将ID转换为汉字")
|
||||
print(" 3. 根据实际场景调整束搜索参数")
|
||||
print(" 4. 性能优化:批量处理、缓存等")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
docs/TRAINING.md
107
docs/TRAINING.md
|
|
@ -484,113 +484,6 @@ train-model evaluate \
|
|||
- 在评估数据集上计算准确率、困惑度等指标
|
||||
- 生成详细的性能报告
|
||||
|
||||
### 模型扩容两阶段训练
|
||||
|
||||
当需要增加模型容量(如增加专家数量、修改层结构等)时,可以使用 `expand-and-train` 命令进行两阶段训练:先冻结匹配层训练新增参数,然后全量微调。
|
||||
|
||||
#### 训练策略
|
||||
|
||||
1. **冻结阶段**:只训练形状不匹配的新增参数(如新增的专家、扩容的层等)
|
||||
2. **全量微调阶段**:当验证损失连续 `--frozen-patience` 次不下降时,自动解冻所有层进行全量训练
|
||||
|
||||
#### 基础用法
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "model:InputMethodEngine" \
|
||||
--num-experts 40 \
|
||||
--frozen-lr 2e-3 \
|
||||
--full-lr 5e-5 \
|
||||
--frozen-patience 8
|
||||
```
|
||||
|
||||
#### 完整参数示例
|
||||
|
||||
```bash
|
||||
train-model expand-and-train \
|
||||
--train-data-path "path/to/train/dataset" \
|
||||
--eval-data-path "path/to/eval/dataset" \
|
||||
--output-dir "./expansion_output" \
|
||||
--base-model-path "./pretrained/model.pt" \
|
||||
--new-model-spec "custom_model:ExpandedModel" \
|
||||
--vocab-size 10019 \
|
||||
--dim 512 \
|
||||
--num-experts 40 \
|
||||
--frozen-patience 10 \
|
||||
--frozen-lr 1e-3 \
|
||||
--full-lr 1e-4 \
|
||||
--frozen-scheduler cosine \
|
||||
--full-scheduler cosine \
|
||||
--batch-size 128 \
|
||||
--num-epochs 20 \
|
||||
--compile
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
**模型扩容参数**
|
||||
- `--base-model-path`: 预训练基础模型检查点路径(必需)
|
||||
- `--new-model-spec`: 新模型规格,格式:`模块名:类名`,如 `model:InputMethodEngine`(必需)
|
||||
- 支持任意路径的模块导入,模块文件需包含自定义的模型类
|
||||
- 自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
- 示例:`my_model:MyExpandedModel` 对应 `my_model.py` 中的 `MyExpandedModel` 类
|
||||
|
||||
**两阶段训练参数**
|
||||
- `--frozen-patience`: 冻结阶段验证损失连续不下降的评估次数,触发切换到全量微调(默认:10)
|
||||
- `--frozen-lr`: 冻结阶段学习率(默认:1e-3)
|
||||
- `--full-lr`: 全量微调阶段学习率(默认:1e-4)
|
||||
- `--frozen-scheduler`: 冻结阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
- `--full-scheduler`: 全量微调阶段学习率调度器,可选 `cosine` 或 `plateau`(默认:`cosine`)
|
||||
|
||||
**其他参数**
|
||||
- 支持所有 `train` 子命令的通用参数(数据参数、模型参数、训练参数等)
|
||||
- 继承现有的训练基础设施:混合精度训练、TensorBoard日志、checkpoint保存等
|
||||
|
||||
#### 使用场景
|
||||
|
||||
1. **增加专家数量**(20→40)
|
||||
- 冻结效果:~70% 参数可冻结(已有专家权重、注意力层等)
|
||||
- 新增参数:新专家网络、gate层
|
||||
|
||||
2. **增加top_k值**(2→3)
|
||||
- 冻结效果:100% 参数可冻结(仅逻辑变化)
|
||||
- 新增参数:无
|
||||
|
||||
3. **修改专家内部结构**(如增加resblocks)
|
||||
- 冻结效果:~50% 参数可冻结(linear_in/output可冻结)
|
||||
- 新增参数:新增的resblocks层
|
||||
|
||||
4. **增加Transformer层数**(4→5)
|
||||
- 冻结效果:~80% 参数可冻结(前4层可冻结)
|
||||
- 新增参数:新增的第5层
|
||||
|
||||
#### 自定义模型类示例
|
||||
|
||||
```python
|
||||
# my_model.py
|
||||
from model.model import InputMethodEngine
|
||||
|
||||
class MyExpandedModel(InputMethodEngine):
|
||||
def __init__(self, num_experts=40, **kwargs):
|
||||
# 调用父类构造函数,覆盖num_experts参数
|
||||
super().__init__(num_experts=num_experts, **kwargs)
|
||||
# 可以在这里添加额外的层或修改现有层
|
||||
|
||||
# 使用命令
|
||||
# train-model expand-and-train --new-model-spec "my_model:MyExpandedModel" ...
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. **模型类要求**:自定义模型类必须是 `InputMethodEngine` 的子类
|
||||
2. **冻结条件**:只有权重形状完全匹配的层才会被冻结
|
||||
3. **性能保持**:MoE层保持"计算所有专家+Top-K选择"方案,确保 `torch.compile` 下的最佳性能
|
||||
4. **阶段切换**:基于评估频率而非epoch,建议适当调高 `--eval-frequency`
|
||||
5. **模块导入**:支持任意路径的模块,通过Python标准导入机制加载
|
||||
|
||||
### 导出模型(开发中)
|
||||
|
||||
当前导出功能尚在开发中:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
# 破词训练设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
输入法用户在实际使用中通常是**逐词输入**的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
|
||||
|
||||
```
|
||||
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
|
||||
```
|
||||
|
||||
但为了增强模型的泛化能力,需要模拟用户**从词中间断开**的情况。例如用户可能只输入了"漂"就开始选字"亮"。
|
||||
|
||||
## 破词概念
|
||||
|
||||
### 术语定义
|
||||
|
||||
| 术语 | 说明 |
|
||||
|------|------|
|
||||
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
|
||||
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
|
||||
| 前缀 | 光标前的已确认文本 |
|
||||
| 拼音 | 当前待选字的拼音(可能不完整) |
|
||||
| 后缀 | 光标后的原文内容 |
|
||||
|
||||
### 场景示例
|
||||
|
||||
以词"漂亮"为例:
|
||||
|
||||
**整词模式(90%概率):**
|
||||
```
|
||||
光标前: 那边的特别
|
||||
拼音: piaoliang
|
||||
预测: 漂 → 亮
|
||||
```
|
||||
|
||||
**破词模式(10%概率):**
|
||||
```
|
||||
光标前: 那边的特别漂
|
||||
拼音: liang
|
||||
预测: 亮
|
||||
```
|
||||
|
||||
## 实现方案
|
||||
|
||||
### 分词策略
|
||||
|
||||
使用 jieba 分词器进行词语边界识别:
|
||||
|
||||
```python
|
||||
import jieba
|
||||
words = list(jieba.cut(text, HMM=False))
|
||||
# "那边的特别漂亮的女孩是我的表姐。"
|
||||
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
|
||||
```
|
||||
|
||||
### 两阶段样本生成
|
||||
|
||||
每个词生成样本时分为两个阶段:
|
||||
|
||||
#### Phase 1:前缀/整词阶段
|
||||
|
||||
- **整词(90%)**:`prefix_positions = 整个词的所有字符`
|
||||
- **破词前缀(10%)**:`prefix_positions = 词的前 break_pos 个字符`
|
||||
|
||||
```python
|
||||
if should_break:
|
||||
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
|
||||
else:
|
||||
break_pos = word_len_chars # 整词
|
||||
```
|
||||
|
||||
#### Phase 2:破词续接阶段(仅当破词时)
|
||||
|
||||
当破词发生时,从断点位置开始继续采样:
|
||||
|
||||
```python
|
||||
if should_break and break_pos < word_len_chars:
|
||||
cont_start = char_positions[break_pos]
|
||||
# 从断点开始采样后续字符
|
||||
target_len = random_choice(1-8) # 采样长度
|
||||
cont_positions = [cont_start, ...] # 后续字符位置
|
||||
```
|
||||
|
||||
### 样本结构
|
||||
|
||||
每个字符生成一个训练样本,包含:
|
||||
|
||||
| 字段 | 说明 | 示例 |
|
||||
|------|------|------|
|
||||
| `part1` (prefix) | 光标前文本 | "那边的特别漂" |
|
||||
| `part2` (pinyin) | 当前字拼音 | "liang" |
|
||||
| `part3` (suffix) | 光标后文本 | "亮的女孩是我的表姐" |
|
||||
| `part4` | 专有词提示 | "漂亮\|特别" |
|
||||
| `label` | 目标汉字ID | 1234 |
|
||||
| `history_slot_ids` | 历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
|
||||
|
||||
### 拼音增强策略
|
||||
|
||||
根据 `py_style_weight` 参数,拼音有以下三种形式:
|
||||
|
||||
| 形式 | 概率 | 示例 |
|
||||
|------|------|------|
|
||||
| 完整拼音 | 75% (9/12) | "piaoliang" |
|
||||
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
|
||||
| 仅首字母 | 8.3% (1/12) | "p" |
|
||||
|
||||
参数配置:`py_style_weight=(9, 2, 1)`
|
||||
|
||||
## 破词概率控制
|
||||
|
||||
### word_break_prob 参数
|
||||
|
||||
控制每个词从中间断开的概率,默认为 **10%**:
|
||||
|
||||
```python
|
||||
self.word_break_prob = 0.10 # 10%概率从词中间破开
|
||||
```
|
||||
|
||||
### 破词位置分布
|
||||
|
||||
对于长度为 N 的词,破开位置 `break_pos` 的分布:
|
||||
|
||||
```python
|
||||
break_pos = random.randint(1, N - 1)
|
||||
```
|
||||
|
||||
- 2字词:break_pos = 1(100%在第1字后破开)
|
||||
- 3字词:break_pos = 1 或 2(各50%)
|
||||
- 4字词:break_pos = 1, 2, 或 3(各33%)
|
||||
|
||||
## 数据分布预期
|
||||
|
||||
### 理想分布
|
||||
|
||||
| 类别 | 预期比例 |
|
||||
|------|----------|
|
||||
| 单字样本 | ~15% |
|
||||
| 2字词整词 | ~30% |
|
||||
| 3字词整词 | ~20% |
|
||||
| 破词样本 | ~10% |
|
||||
| 其他 | ~25% |
|
||||
|
||||
### 拼音不完整率
|
||||
|
||||
由于 `py_style_weight=(9, 2, 1)`:
|
||||
|
||||
- 声母(initials):~16.7%
|
||||
- 首字母:~8.3%
|
||||
- **总计不完整**:~25%
|
||||
|
||||
## 代码实现位置
|
||||
|
||||
主要实现文件:`src/model/dataset.py`
|
||||
|
||||
| 函数/类 | 行号 | 功能 |
|
||||
|---------|------|------|
|
||||
| `segment_text()` | ~30 | jieba分词 |
|
||||
| `build_word_boundaries()` | ~35 | 建立词边界映射 |
|
||||
| `PinyinInputDataset.__iter__()` | ~280 | 核心迭代逻辑 |
|
||||
| `get_mask_pinyin()` | ~215 | 拼音加强处理 |
|
||||
| `_add_word_samples()` | ~240 | 样本构建 |
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **破词仅针对多字词**:单字词(如"的"、“是”)不会破词
|
||||
2. **破词保持语义完整**:破词后仍能根据上下文预测正确汉字
|
||||
3. **历史槽位模拟逐步确认**:同一词内已确认的字会填入 `history_slot_ids`
|
||||
4. **10% EOS标记**:词尾有10%概率追加ID=0表示句子结束
|
||||
31
eval.py
31
eval.py
|
|
@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
|
|||
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
|
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
|
|||
from src.model.query import QueryEngine
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
class TextEvaluator:
|
||||
def __init__(
|
||||
|
|
@ -171,34 +168,20 @@ class TextEvaluator:
|
|||
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
"""
|
||||
流式处理单条文本,转换为拼音列表。
|
||||
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||
利用 pypinyin 内部的分词能力处理多音字。
|
||||
参考 dataset.py 中的 generate_pinyin 方法。
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
text_len = len(text)
|
||||
result: List[str] = [""] * text_len
|
||||
pinyin_list = lazy_pinyin(text)
|
||||
|
||||
# 遍历所有连续汉字片段
|
||||
for match in _HANZI_RE.finditer(text):
|
||||
start_idx = match.start()
|
||||
hanzi_segment = match.group()
|
||||
# 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
|
||||
if len(pinyin_list) != len(text):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in text]
|
||||
|
||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||
|
||||
if len(pinyin_list) != len(hanzi_segment):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||
|
||||
for i, py in enumerate(pinyin_list):
|
||||
result[start_idx + i] = py
|
||||
|
||||
# 填充非汉字字符
|
||||
for i, char in enumerate(text):
|
||||
if not result[i]:
|
||||
result[i] = char
|
||||
|
||||
return result
|
||||
return pinyin_list
|
||||
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
输入法模型ONNX导出脚本
|
||||
|
||||
将模型导出为两个ONNX部分:
|
||||
1. context_encoder.onnx - 上下文编码器
|
||||
2. decoder.onnx - 解码器
|
||||
|
||||
使用方法:
|
||||
python export_onnx.py --checkpoint ./output/checkpoints/best_model.pt --output-dir ./exported_models
|
||||
|
||||
依赖安装:
|
||||
pip install onnx onnxruntime
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.model.onnx_export import check_onnx_available, run_full_export
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="输入法模型ONNX导出")
|
||||
parser.add_argument(
|
||||
"--checkpoint", "-c", type=str, required=True, help="模型checkpoint路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
type=str,
|
||||
default="./exported_models",
|
||||
help="输出目录(默认: ./exported_models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="导出设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-verification", action="store_true", help="跳过ONNX模型验证"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not check_onnx_available():
|
||||
sys.exit(1)
|
||||
|
||||
print(f"输出目录: {Path(args.output_dir).absolute()}")
|
||||
|
||||
context_encoder_path, decoder_path, config = run_full_export(
|
||||
checkpoint_path=args.checkpoint,
|
||||
output_dir=args.output_dir,
|
||||
device=args.device,
|
||||
skip_verification=args.skip_verification,
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("ONNX导出完成!")
|
||||
print("=" * 60)
|
||||
print("生成的模型文件:")
|
||||
print(f" - {context_encoder_path}")
|
||||
print(f" - {decoder_path}")
|
||||
print(f" - {output_dir / 'example_inputs.npz'}")
|
||||
print(f" - {output_dir / 'example_inputs.pt'}")
|
||||
print(f" - {output_dir / 'inference_example.py'}")
|
||||
print()
|
||||
print("使用方法:")
|
||||
print(f" 1. 检查模型: python -m onnx.checker {context_encoder_path}")
|
||||
print(f" 2. 运行推理示例: cd {output_dir} && python inference_example.py")
|
||||
print(f" 3. 集成到您的应用: 参考 inference_example.py 中的 ONNXInference 类")
|
||||
print()
|
||||
print("注意:")
|
||||
print(" - 请确保安装了 onnxruntime: pip install onnxruntime")
|
||||
print(" - GPU推理需要 onnxruntime-gpu: pip install onnxruntime-gpu")
|
||||
print(" - MoE 层当前使用 'all' 模式(全量计算),稀疏化优化可后续迭代")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
beams = [([], 0.0)]
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
for seq, score in beams:
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:]
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行 train-model export 导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
print("\u2705 ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
"""
|
||||
初始化ONNX推理器
|
||||
|
||||
Args:
|
||||
context_encoder_path: 上下文编码器ONNX模型路径
|
||||
decoder_path: 解码器ONNX模型路径
|
||||
"""
|
||||
# 创建ONNX Runtime会话
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider'] # 或 'CUDAExecutionProvider'
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# 获取输入输出名称
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
"""
|
||||
准备模型输入(与原始推理脚本保持一致)
|
||||
|
||||
注意: 这里需要实现文本到token的转换
|
||||
为了简化示例,假设已经实现了相关函数
|
||||
"""
|
||||
# 这里应该调用实际的预处理函数
|
||||
# 返回: input_ids, pinyin_ids, attention_mask, history_slot_ids
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pinyin_ids: [batch, 24]
|
||||
attention_mask: [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Args:
|
||||
context_H: [batch, seq_len, 512]
|
||||
pinyin_P: [batch, 24, 512]
|
||||
history_slot_ids: [batch, 8]
|
||||
context_mask: [batch, seq_len]
|
||||
pinyin_mask: [batch, 24]
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
# 准备输入
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
|
||||
# 运行推理
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
# 解包输出
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
"""
|
||||
束搜索算法示例
|
||||
|
||||
Args:
|
||||
context_H: 上下文编码
|
||||
pinyin_P: 拼音编码
|
||||
context_mask: 上下文掩码
|
||||
pinyin_mask: 拼音掩码
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
vocab_size: 词汇表大小
|
||||
|
||||
Returns:
|
||||
最佳序列列表
|
||||
"""
|
||||
# 初始束:空序列,分数为0
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
# 构建history_slot_ids:已确认的字符ID
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:] # 只保留最近8个
|
||||
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
|
||||
# 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# 获取概率
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
|
||||
# 获取top-k候选
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
|
||||
# 扩展束
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
# 剪枝:保留beam_size个最佳候选
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
# 检查是否所有序列都已结束(以结束符0结尾)
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
"""
|
||||
单步预测
|
||||
|
||||
Args:
|
||||
input_ids: 输入token IDs
|
||||
pinyin_ids: 拼音IDs
|
||||
attention_mask: 注意力掩码
|
||||
history_slot_ids: 历史槽位IDs
|
||||
|
||||
Returns:
|
||||
预测logits
|
||||
"""
|
||||
# 1. 运行上下文编码器
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
|
||||
# 2. 运行解码器
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化推理器
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行export_onnx.py导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
|
||||
print("✅ ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
Frequency Analysis Results
|
||||
==================================================
|
||||
Min frequency: 1
|
||||
Max frequency: 494748360
|
||||
Mean frequency: 560007.90
|
||||
Standard deviation: 5730144.34
|
||||
10th percentile: 3
|
||||
50th percentile: 93
|
||||
90th percentile: 331538
|
||||
IDs in range 5000-5500 min: 5594
|
||||
IDs in range 5000-5500 max: 9569
|
||||
File diff suppressed because it is too large
Load Diff
137
inference.py
137
inference.py
|
|
@ -22,6 +22,8 @@
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
|
@ -60,6 +62,98 @@ class InputMethodInference:
|
|||
|
||||
print(f"✅ 推理器初始化完成 (设备: {self.device})")
|
||||
|
||||
# 尝试启用readline以获得更好的行编辑功能
|
||||
try:
|
||||
import readline
|
||||
|
||||
# 设置readline使用UTF-8编码
|
||||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||
print("📝 readline已启用,支持更好的行编辑功能")
|
||||
except ImportError:
|
||||
print("📝 readline不可用,使用标准输入")
|
||||
|
||||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||
"""
|
||||
安全的输入函数,尝试正确处理UTF-8字符和退格键
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
用户输入的字符串
|
||||
"""
|
||||
try:
|
||||
# 显示提示和默认值
|
||||
if default:
|
||||
full_prompt = f"{prompt} [{default}]: "
|
||||
else:
|
||||
full_prompt = f"{prompt}: "
|
||||
|
||||
# 使用标准input
|
||||
result = input(full_prompt)
|
||||
|
||||
# 如果用户直接回车且存在默认值,则返回默认值
|
||||
if not result and default:
|
||||
return default
|
||||
|
||||
return result.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
# 用户按Ctrl+D或Ctrl+C
|
||||
print()
|
||||
return ""
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
print(f"\n⚠️ 输入错误: {e}")
|
||||
return ""
|
||||
|
||||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||
"""
|
||||
清理拼音输入字符串,处理退格键等特殊字符
|
||||
|
||||
拼音只允许: a-z, `, ', -
|
||||
中文字符和其他字符会被忽略
|
||||
|
||||
Args:
|
||||
pinyin: 原始拼音输入字符串
|
||||
|
||||
Returns:
|
||||
清理后的拼音字符串
|
||||
"""
|
||||
if not pinyin:
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for c in pinyin:
|
||||
# 检查是否为合法拼音字符 (a-z, `, ', -)
|
||||
# 注意: 中文字符的isalpha()也返回True,所以需要额外检查
|
||||
is_valid_pinyin_char = (
|
||||
("a" <= c <= "z")
|
||||
or ("A" <= c <= "Z") # 允许大写字母,转换为小写
|
||||
or c in ["`", "'", "-"]
|
||||
)
|
||||
|
||||
if is_valid_pinyin_char:
|
||||
# 合法拼音字符,转换为小写
|
||||
result.append(c.lower())
|
||||
elif c == " ":
|
||||
# 空格忽略
|
||||
continue
|
||||
elif c == "\b" or c == "\x7f" or c == "\x08":
|
||||
# 退格键、删除键:删除前一个字符
|
||||
if result:
|
||||
result.pop()
|
||||
elif c == "\x1b":
|
||||
# ESC键:清空所有输入
|
||||
result.clear()
|
||||
else:
|
||||
# 其他字符(包括中文字符)忽略
|
||||
# 注意:这里不添加到result,所以退格键无法删除它们
|
||||
# 但用户可能在拼音输入中误输入中文字符,应该忽略
|
||||
pass
|
||||
|
||||
return "".join(result)
|
||||
|
||||
def load_model(self):
|
||||
"""加载训练好的模型"""
|
||||
# 创建模型实例(不编译)
|
||||
|
|
@ -184,7 +278,7 @@ class InputMethodInference:
|
|||
"""
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
# 根据dataset.py,格式为: "part4|part1" 和 part3
|
||||
# 根据test.py和dataset.py,格式为: "part4|part1" 和 part3
|
||||
# part4: 上下文提示(专有词汇、姓名等,模型不掌握)
|
||||
# part1: text_before
|
||||
# part3: text_after
|
||||
|
|
@ -192,13 +286,14 @@ class InputMethodInference:
|
|||
# 处理上下文提示
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
# 构建输入文本
|
||||
# 构建输入文本 - 与test.py保持一致
|
||||
# test.py: f"{part4}|{part1}" 作为第一个参数,part3作为第二个参数
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize
|
||||
# 2. Tokenize - 与test.py保持一致
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
|
|
@ -209,8 +304,11 @@ class InputMethodInference:
|
|||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 3. 处理拼音输入
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin)
|
||||
# 3. 处理拼音输入 - 与test.py保持一致
|
||||
# 首先清理拼音字符串,处理退格键等特殊字符
|
||||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||
|
||||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
|
|
@ -396,7 +494,16 @@ class InputMethodInference:
|
|||
print("\n" + "=" * 60)
|
||||
print("输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
print("说明:")
|
||||
|
||||
# 检查终端编码
|
||||
encoding = sys.stdout.encoding or "unknown"
|
||||
print(f"终端编码: {encoding}")
|
||||
if encoding.lower() not in ["utf-8", "utf8"]:
|
||||
print("⚠️ 警告: 终端编码不是UTF-8,中文输入可能有问题")
|
||||
print(" 建议设置: export LANG=en_US.UTF-8")
|
||||
print(" 或设置: export LC_ALL=en_US.UTF-8")
|
||||
|
||||
print("\n说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
|
|
@ -411,7 +518,7 @@ class InputMethodInference:
|
|||
print("第1步: 上下文提示(模型不掌握的专有词汇、姓名等)")
|
||||
print("格式: 用逗号分隔多个词汇,可为空")
|
||||
print("示例: 张三,李四,北京大学")
|
||||
context_input = input("请输入上下文提示(直接回车跳过): ").strip()
|
||||
context_input = self._safe_input("请输入上下文提示(直接回车跳过)")
|
||||
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -434,7 +541,7 @@ class InputMethodInference:
|
|||
print("第2步: 光标前文本")
|
||||
print("说明: 光标前的连续文本内容")
|
||||
print("示例: 今天天气很好")
|
||||
text_before = input("请输入光标前文本: ").strip()
|
||||
text_before = self._safe_input("请输入光标前文本")
|
||||
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -446,7 +553,7 @@ class InputMethodInference:
|
|||
print("第3步: 光标后文本")
|
||||
print("说明: 光标后的连续文本内容")
|
||||
print("示例: 我们去公园玩")
|
||||
text_after = input("请输入光标后文本: ").strip()
|
||||
text_after = self._safe_input("请输入光标后文本")
|
||||
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -458,7 +565,7 @@ class InputMethodInference:
|
|||
print("第4步: 拼音输入")
|
||||
print("说明: 当前正在输入的拼音")
|
||||
print("示例: tian, shang, hao")
|
||||
pinyin = input("请输入拼音: ").strip()
|
||||
pinyin = self._safe_input("请输入拼音")
|
||||
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -471,7 +578,7 @@ class InputMethodInference:
|
|||
print("说明: 用户已确认的输入历史,用逗号分隔")
|
||||
print("示例: 上 (表示输入'shanghai'已确认'上')")
|
||||
print(" 今天,天气 (表示已确认两个词)")
|
||||
slot_input = input("请输入槽位历史(直接回车表示无): ").strip()
|
||||
slot_input = self._safe_input("请输入槽位历史(直接回车表示无)")
|
||||
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
print("退出交互模式")
|
||||
|
|
@ -524,7 +631,9 @@ class InputMethodInference:
|
|||
|
||||
# 询问是否继续
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = input("是否继续推理?(y/n): ").strip().lower()
|
||||
continue_input = (
|
||||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
@ -539,7 +648,9 @@ class InputMethodInference:
|
|||
traceback.print_exc()
|
||||
|
||||
# 询问是否继续
|
||||
continue_input = input("\n是否继续?(y/n): ").strip().lower()
|
||||
continue_input = (
|
||||
self._safe_input("\n是否继续?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
print("退出交互模式")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -0,0 +1,763 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX输入法模型推理脚本
|
||||
|
||||
使用ONNX Runtime进行推理,测量每个阶段的执行时长
|
||||
|
||||
使用方法:
|
||||
python onnx_inference.py --context-encoder exported_models/context_encoder.onnx --decoder exported_models/decoder.onnx
|
||||
|
||||
交互模式: 分步询问输入
|
||||
1. 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)
|
||||
2. 光标前文本: 光标前的连续文本
|
||||
3. 光标后文本: 光标后的连续文本
|
||||
4. 拼音: 当前输入的拼音
|
||||
5. 槽位历史: 用户已确认的输入历史(如输入shanghai已确认"上")
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
from src.model.query import QueryEngine
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX输入法模型推理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_encoder_path: str,
|
||||
decoder_path: str,
|
||||
vocab_size: int = 10019,
|
||||
device: str = "cpu",
|
||||
use_beam_search: bool = False,
|
||||
beam_size: int = 5,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.use_beam_search = use_beam_search
|
||||
self.beam_size = beam_size
|
||||
|
||||
# 加载组件
|
||||
print(f"正在加载上下文编码器: {context_encoder_path}")
|
||||
load_start = time.perf_counter()
|
||||
self.load_context_encoder(context_encoder_path)
|
||||
self.context_encoder_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 上下文编码器加载完成 ({self.context_encoder_load_time:.2f}ms)")
|
||||
|
||||
print(f"正在加载解码器: {decoder_path}")
|
||||
load_start = time.perf_counter()
|
||||
self.load_decoder(decoder_path)
|
||||
self.decoder_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 解码器加载完成 ({self.decoder_load_time:.2f}ms)")
|
||||
|
||||
# 加载tokenizer
|
||||
print("正在加载tokenizer...")
|
||||
load_start = time.perf_counter()
|
||||
self.load_tokenizer()
|
||||
self.tokenizer_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ Tokenizer加载完成 ({self.tokenizer_load_time:.2f}ms)")
|
||||
|
||||
# 加载查询引擎
|
||||
print("正在加载查询引擎...")
|
||||
load_start = time.perf_counter()
|
||||
self.load_query_engine()
|
||||
self.query_engine_load_time = (time.perf_counter() - load_start) * 1000
|
||||
print(f" ✅ 查询引擎加载完成 ({self.query_engine_load_time:.2f}ms)")
|
||||
|
||||
total_load_time = (
|
||||
self.context_encoder_load_time
|
||||
+ self.decoder_load_time
|
||||
+ self.tokenizer_load_time
|
||||
+ self.query_engine_load_time
|
||||
)
|
||||
print(f"\n✅ 推理器初始化完成 (设备: {device})")
|
||||
print(f" 总加载时间: {total_load_time:.2f}ms")
|
||||
|
||||
# 尝试启用readline
|
||||
try:
|
||||
import readline
|
||||
|
||||
readline.set_completer_delims(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def load_context_encoder(self, model_path: str):
|
||||
"""加载上下文编码器ONNX模型"""
|
||||
providers = (
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if self.device == "cuda"
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
model_path, providers=providers
|
||||
)
|
||||
|
||||
self.context_input_names = [
|
||||
inp.name for inp in self.context_encoder_session.get_inputs()
|
||||
]
|
||||
self.context_output_names = [
|
||||
out.name for out in self.context_encoder_session.get_outputs()
|
||||
]
|
||||
|
||||
def load_decoder(self, model_path: str):
|
||||
"""加载解码器ONNX模型"""
|
||||
providers = (
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if self.device == "cuda"
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(model_path, providers=providers)
|
||||
|
||||
self.decoder_input_names = [
|
||||
inp.name for inp in self.decoder_session.get_inputs()
|
||||
]
|
||||
self.decoder_output_names = [
|
||||
out.name for out in self.decoder_session.get_outputs()
|
||||
]
|
||||
|
||||
def load_tokenizer(self):
|
||||
"""加载tokenizer"""
|
||||
try:
|
||||
tokenizer_path = (
|
||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||
except Exception:
|
||||
print(" ⚠️ 无法加载自定义tokenizer,使用bert-base-chinese")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||
|
||||
def load_query_engine(self):
|
||||
"""加载查询引擎"""
|
||||
try:
|
||||
self.query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
if stats_path.exists():
|
||||
self.query_engine.load(stats_path)
|
||||
except Exception:
|
||||
self.query_engine = None
|
||||
|
||||
def char_to_id(self, char: str, pinyin: Optional[str] = None) -> int:
|
||||
"""将汉字转换为ID"""
|
||||
if char == "//":
|
||||
return 0
|
||||
|
||||
if self.query_engine is None:
|
||||
return ord(char) if len(char) == 1 else 0
|
||||
|
||||
try:
|
||||
if pinyin is not None:
|
||||
info = self.query_engine.get_char_info_by_char_pinyin(char, pinyin)
|
||||
if info:
|
||||
return info.id
|
||||
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
return results[0][0]
|
||||
return 0
|
||||
except:
|
||||
return 0
|
||||
|
||||
def id_to_char(self, id: int) -> str:
|
||||
"""将ID转换为汉字"""
|
||||
if id == 0:
|
||||
return "//"
|
||||
|
||||
if self.query_engine is None:
|
||||
return chr(id) if id < 0x110000 else f"[ID:{id}]"
|
||||
|
||||
try:
|
||||
info = self.query_engine.query_by_id(id)
|
||||
return info.char if info else f"[ID:{id}]"
|
||||
except:
|
||||
return f"[ID:{id}]"
|
||||
|
||||
def _clean_pinyin_input(self, pinyin: str) -> str:
|
||||
"""清理拼音输入字符串"""
|
||||
if not pinyin:
|
||||
return ""
|
||||
|
||||
result = []
|
||||
for c in pinyin:
|
||||
is_valid = ("a" <= c <= "z") or ("A" <= c <= "Z") or c in ["`", "'", "-"]
|
||||
if is_valid:
|
||||
result.append(c.lower())
|
||||
elif c == " ":
|
||||
continue
|
||||
elif c in ["\b", "\x7f", "\x08"]:
|
||||
if result:
|
||||
result.pop()
|
||||
elif c == "\x1b":
|
||||
result.clear()
|
||||
return "".join(result)
|
||||
|
||||
def _safe_input(self, prompt: str, default: str = "") -> str:
|
||||
"""安全的输入函数"""
|
||||
try:
|
||||
full_prompt = f"{prompt} [{default}]: " if default else f"{prompt}: "
|
||||
result = input(full_prompt)
|
||||
if not result and default:
|
||||
return default
|
||||
return result.strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ 输入错误: {e}")
|
||||
return ""
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
max_seq_len: int = 128,
|
||||
) -> dict:
|
||||
"""
|
||||
准备模型输入
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'preprocess_time': float, # 预处理时间(ms)
|
||||
'input_ids': numpy array,
|
||||
'attention_mask': numpy array,
|
||||
'pinyin_ids': numpy array,
|
||||
'history_slot_ids': numpy array,
|
||||
}
|
||||
"""
|
||||
preprocess_start = time.perf_counter()
|
||||
|
||||
# 1. 构建tokenizer输入
|
||||
context_text = "|".join(context_prompts) if context_prompts else ""
|
||||
|
||||
if context_text:
|
||||
input_text = f"{context_text}|{text_before}"
|
||||
else:
|
||||
input_text = text_before
|
||||
|
||||
# 2. Tokenize
|
||||
encoded = self.tokenizer(
|
||||
input_text,
|
||||
text_after,
|
||||
max_length=max_seq_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"].numpy()
|
||||
attention_mask = encoded["attention_mask"].numpy()
|
||||
|
||||
# 3. 处理拼音输入
|
||||
cleaned_pinyin = self._clean_pinyin_input(pinyin)
|
||||
pinyin_ids = text_to_pinyin_ids(cleaned_pinyin)
|
||||
|
||||
if len(pinyin_ids) < 24:
|
||||
pinyin_ids.extend([0] * (24 - len(pinyin_ids)))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
|
||||
pinyin_ids = np.array([pinyin_ids], dtype=np.int64)
|
||||
|
||||
# 4. 处理历史槽位
|
||||
history_slot_ids = []
|
||||
for char in slot_chars:
|
||||
char_id = self.char_to_id(char)
|
||||
history_slot_ids.append(char_id)
|
||||
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
history_slot_ids = np.array([history_slot_ids], dtype=np.int64)
|
||||
|
||||
preprocess_time = (time.perf_counter() - preprocess_start) * 1000
|
||||
|
||||
return {
|
||||
"preprocess_time": preprocess_time,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
}
|
||||
|
||||
def run_context_encoder(
|
||||
self, input_ids: np.ndarray, pinyin_ids: np.ndarray, attention_mask: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
运行上下文编码器
|
||||
|
||||
Returns:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
context_start = time.perf_counter()
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
|
||||
self.last_context_encoder_time = (time.perf_counter() - context_start) * 1000
|
||||
|
||||
return context_H, pinyin_P, context_mask, pinyin_mask
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
history_slot_ids: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
运行解码器
|
||||
|
||||
Returns:
|
||||
logits: [batch, vocab_size]
|
||||
"""
|
||||
decoder_start = time.perf_counter()
|
||||
|
||||
inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
|
||||
self.last_decoder_time = (time.perf_counter() - decoder_start) * 1000
|
||||
|
||||
return outputs[0]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
context_prompts: List[str],
|
||||
text_before: str,
|
||||
text_after: str,
|
||||
pinyin: str,
|
||||
slot_chars: List[str],
|
||||
top_k: int = 20,
|
||||
use_beam_search: bool = False,
|
||||
beam_size: int = 5,
|
||||
max_length: int = 10,
|
||||
) -> Tuple[List[Tuple[str, float, int]], dict]:
|
||||
"""
|
||||
执行推理
|
||||
|
||||
Args:
|
||||
context_prompts: 上下文提示
|
||||
text_before: 光标前文本
|
||||
text_after: 光标后文本
|
||||
pinyin: 当前输入的拼音
|
||||
slot_chars: 槽位内的汉字列表
|
||||
top_k: 返回top-k个预测结果
|
||||
use_beam_search: 是否使用束搜索
|
||||
beam_size: 束大小
|
||||
max_length: 最大生成长度
|
||||
|
||||
Returns:
|
||||
(predictions, timing_info)
|
||||
predictions: List[Tuple[char, prob, id]]
|
||||
timing_info: 各阶段耗时字典
|
||||
"""
|
||||
total_start = time.perf_counter()
|
||||
|
||||
# 阶段1: 预处理
|
||||
prep_start = time.perf_counter()
|
||||
inputs = self.prepare_inputs(
|
||||
context_prompts, text_before, text_after, pinyin, slot_chars
|
||||
)
|
||||
preprocess_time = inputs["preprocess_time"]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
pinyin_ids = inputs["pinyin_ids"]
|
||||
history_slot_ids = inputs["history_slot_ids"]
|
||||
prep_time = (time.perf_counter() - prep_start) * 1000
|
||||
|
||||
# 阶段2: 上下文编码
|
||||
context_start = time.perf_counter()
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
context_encoder_time = self.last_context_encoder_time
|
||||
|
||||
if use_beam_search:
|
||||
# 阶段3: 束搜索解码
|
||||
decode_start = time.perf_counter()
|
||||
predictions, beam_decode_time = self._beam_search_decode(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
beam_size,
|
||||
max_length,
|
||||
top_k,
|
||||
)
|
||||
decoder_time = beam_decode_time
|
||||
else:
|
||||
# 阶段3: 单步解码
|
||||
decode_start = time.perf_counter()
|
||||
logits = self.run_decoder(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
history_slot_ids,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
)
|
||||
|
||||
# 阶段4: 后处理
|
||||
postprocess_start = time.perf_counter()
|
||||
probs = self._softmax(logits)
|
||||
top_indices, top_probs = self._topk(probs, top_k)
|
||||
|
||||
predictions = []
|
||||
for i in range(top_k):
|
||||
idx = int(top_indices[0, i])
|
||||
prob = float(top_probs[0, i])
|
||||
char = self.id_to_char(idx)
|
||||
predictions.append((char, prob, idx))
|
||||
|
||||
postprocess_time = (time.perf_counter() - postprocess_start) * 1000
|
||||
decoder_time = self.last_decoder_time
|
||||
|
||||
total_time = (time.perf_counter() - total_start) * 1000
|
||||
|
||||
timing_info = {
|
||||
"预处理": prep_time,
|
||||
"上下文编码": context_encoder_time,
|
||||
"解码": decoder_time,
|
||||
"后处理": postprocess_time if not use_beam_search else 0,
|
||||
"总耗时": total_time,
|
||||
}
|
||||
|
||||
return predictions, timing_info
|
||||
|
||||
def _softmax(self, logits: np.ndarray) -> np.ndarray:
|
||||
"""计算softmax"""
|
||||
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
||||
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
|
||||
|
||||
def _topk(self, probs: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""获取top-k"""
|
||||
topk_indices = np.argsort(probs, axis=-1)[:, -k:][:, ::-1]
|
||||
topk_probs = np.take_along_axis(probs, topk_indices, axis=-1)
|
||||
return topk_indices, topk_probs
|
||||
|
||||
def _beam_search_decode(
|
||||
self,
|
||||
context_H: np.ndarray,
|
||||
pinyin_P: np.ndarray,
|
||||
context_mask: np.ndarray,
|
||||
pinyin_mask: np.ndarray,
|
||||
beam_size: int,
|
||||
max_length: int,
|
||||
top_k: int,
|
||||
) -> Tuple[List[Tuple[str, float, int]], float]:
|
||||
"""束搜索解码"""
|
||||
beams = [([], 0.0)] # (序列, 对数概率)
|
||||
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
|
||||
for seq, score in beams:
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:]
|
||||
|
||||
history_tensor = np.array([history], dtype=np.int64)
|
||||
|
||||
logits = self.run_decoder(
|
||||
context_H,
|
||||
pinyin_P,
|
||||
history_tensor,
|
||||
context_mask,
|
||||
pinyin_mask,
|
||||
)
|
||||
|
||||
probs = self._softmax(logits)[0]
|
||||
topk_indices = np.argsort(probs)[-beam_size:][::-1]
|
||||
topk_probs = probs[topk_indices]
|
||||
|
||||
for idx, prob in zip(topk_indices, topk_probs):
|
||||
new_seq = seq + [int(idx)]
|
||||
new_score = score + np.log(prob + 1e-10)
|
||||
new_beams.append((new_seq, new_score))
|
||||
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
|
||||
# 返回top-k个候选
|
||||
predictions = []
|
||||
for seq, score in beams[:top_k]:
|
||||
if seq:
|
||||
char = self.id_to_char(seq[-1])
|
||||
prob = np.exp(score / max(len(seq), 1))
|
||||
else:
|
||||
char = self.id_to_char(0)
|
||||
prob = 0.0
|
||||
predictions.append((char, prob, seq[-1] if seq else 0))
|
||||
|
||||
decode_time = self.last_decoder_time # 只记录最后一次解码的时间
|
||||
|
||||
return predictions, decode_time
|
||||
|
||||
def interactive_mode(self):
|
||||
"""交互式推理模式"""
|
||||
print("\n" + "=" * 60)
|
||||
print("ONNX输入法模型推理 - 交互模式")
|
||||
print("=" * 60)
|
||||
|
||||
encoding = sys.stdout.encoding or "unknown"
|
||||
print(f"终端编码: {encoding}")
|
||||
|
||||
print("\n说明:")
|
||||
print(" - 上下文提示: 模型不掌握的专有词汇、姓名等(可为空)")
|
||||
print(" - 光标前文本: 光标前的连续文本")
|
||||
print(" - 光标后文本: 光标后的连续文本")
|
||||
print(" - 拼音: 当前输入的拼音")
|
||||
print(" - 槽位历史: 用户已确认的输入历史")
|
||||
if self.use_beam_search:
|
||||
print(f" - 解码模式: 束搜索 (beam_size={self.beam_size})")
|
||||
else:
|
||||
print(" - 解码模式: 单步解码 (使用 --beam 启用束搜索)")
|
||||
print("提示: 输入 'quit' 或 'exit' 或 'q' 可随时退出")
|
||||
print("-" * 60)
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("\n" + "=" * 60)
|
||||
context_input = self._safe_input("第1步: 上下文提示(直接回车跳过)")
|
||||
if context_input.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
context_prompts = [
|
||||
item.strip() for item in context_input.split(",") if item.strip()
|
||||
]
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
text_before = self._safe_input("第2步: 光标前文本")
|
||||
if text_before.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
text_after = self._safe_input("第3步: 光标后文本")
|
||||
if text_after.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
pinyin = self._safe_input("第4步: 拼音输入")
|
||||
if pinyin.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
slot_input = self._safe_input("第5步: 槽位历史(直接回车表示无)")
|
||||
if slot_input.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
|
||||
slot_chars = [
|
||||
char.strip() for char in slot_input.split(",") if char.strip()
|
||||
]
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📝 输入汇总:")
|
||||
print(f" 上下文提示: {context_prompts if context_prompts else '无'}")
|
||||
print(f" 光标前文本: '{text_before}'")
|
||||
print(f" 光标后文本: '{text_after}'")
|
||||
print(f" 拼音: '{pinyin}'")
|
||||
print(f" 槽位历史: {slot_chars if slot_chars else '无'}")
|
||||
|
||||
print("\n🔮 推理中...")
|
||||
predictions, timing_info = self.predict(
|
||||
context_prompts,
|
||||
text_before,
|
||||
text_after,
|
||||
pinyin,
|
||||
slot_chars,
|
||||
top_k=20,
|
||||
use_beam_search=self.use_beam_search,
|
||||
beam_size=self.beam_size,
|
||||
)
|
||||
|
||||
# 显示时间统计
|
||||
print(f"\n⏱️ 执行时间统计:")
|
||||
print("-" * 40)
|
||||
for stage, duration in timing_info.items():
|
||||
if duration > 0:
|
||||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||||
print("-" * 40)
|
||||
|
||||
# 显示结果
|
||||
print("\n🏆 Top-20 预测结果:")
|
||||
print("-" * 50)
|
||||
for i, (char, prob, idx) in enumerate(predictions):
|
||||
if char == "//":
|
||||
print(f"{i + 1:2d}. {'//':<4} (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(
|
||||
f"{i + 1:2d}. {char:<4} (ID: {idx:>5}) - 概率: {prob:.4f}"
|
||||
)
|
||||
|
||||
# 显示拼音参考
|
||||
if pinyin and self.query_engine:
|
||||
print(f"\n📖 拼音 '{pinyin}' 的常见汉字:")
|
||||
pinyin_results = self.query_engine.query_by_pinyin(pinyin, limit=10)
|
||||
if pinyin_results:
|
||||
for j, (pid, char, count) in enumerate(pinyin_results):
|
||||
print(f" {char} (频次: {count:,})")
|
||||
|
||||
print("\n" + "-" * 40)
|
||||
continue_input = (
|
||||
self._safe_input("是否继续推理?(y/n)", "y").strip().lower()
|
||||
)
|
||||
if continue_input not in ["y", "yes", ""]:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n退出交互模式")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ 推理出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ONNX输入法模型推理")
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
"-c",
|
||||
type=str,
|
||||
required=True,
|
||||
help="上下文编码器ONNX模型路径",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
"-d",
|
||||
type=str,
|
||||
required=True,
|
||||
help="解码器ONNX模型路径",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=10019,
|
||||
help="词汇表大小 (默认: 10019)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="推理设备 (默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="交互模式 (默认: True)",
|
||||
)
|
||||
parser.add_argument("--test", action="store_true", help="运行测试推理")
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
action="store_true",
|
||||
help="使用束搜索解码 (默认: 单步解码)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="束大小 (默认: 5)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(args.context_encoder):
|
||||
print(f"❌ 错误: 上下文编码器文件不存在: {args.context_encoder}")
|
||||
sys.exit(1)
|
||||
if not os.path.exists(args.decoder):
|
||||
print(f"❌ 错误: 解码器文件不存在: {args.decoder}")
|
||||
sys.exit(1)
|
||||
|
||||
# 初始化推理器
|
||||
inference = ONNXInference(
|
||||
context_encoder_path=args.context_encoder,
|
||||
decoder_path=args.decoder,
|
||||
vocab_size=args.vocab_size,
|
||||
device=args.device,
|
||||
use_beam_search=args.beam,
|
||||
beam_size=args.beam_size,
|
||||
)
|
||||
|
||||
# 测试推理
|
||||
if args.test:
|
||||
print("\n🧪 运行测试推理...")
|
||||
print("测试场景: 输入'shanghai',已确认第一个字'上',继续输入'tian'")
|
||||
print("上下文提示: 张三、李四(模型不掌握的专有名词)")
|
||||
|
||||
predictions, timing_info = inference.predict(
|
||||
context_prompts=["张三", "李四"],
|
||||
text_before="今天天气",
|
||||
text_after="很好",
|
||||
pinyin="tian",
|
||||
slot_chars=["上"],
|
||||
use_beam_search=args.beam,
|
||||
beam_size=args.beam_size,
|
||||
)
|
||||
|
||||
print(f"\n⏱️ 执行时间统计:")
|
||||
print("-" * 40)
|
||||
for stage, duration in timing_info.items():
|
||||
if duration > 0:
|
||||
print(f" {stage:<12}: {duration:>8.2f} ms")
|
||||
print("-" * 40)
|
||||
|
||||
print(f"\nTop-5 结果:")
|
||||
for i, (char, prob, idx) in enumerate(predictions[:5]):
|
||||
if char == "//":
|
||||
print(f" {i + 1}. // (结束符) - 概率: {prob:.4f}")
|
||||
else:
|
||||
print(f" {i + 1}. {char} (ID: {idx}) - 概率: {prob:.4f}")
|
||||
|
||||
# 交互模式
|
||||
if args.interactive:
|
||||
inference.interactive_mode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -14,6 +14,7 @@ dependencies = [
|
|||
"onnxruntime>=1.24.2",
|
||||
"pandas>=3.0.0",
|
||||
"plotly>=5.0.0",
|
||||
"jieba>=0.42.1",
|
||||
"pypinyin>=0.55.0",
|
||||
"requests>=2.32.5",
|
||||
"rich>=14.3.1",
|
||||
|
|
@ -23,11 +24,14 @@ dependencies = [
|
|||
"transformers==5.1.0",
|
||||
"typer>=0.21.1",
|
||||
"waitress>=3.0.2",
|
||||
"onnx>=1.21.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
train-model = "model.trainer:app"
|
||||
monitor-training = "model.monitor:app"
|
||||
preprocess-model = "model.preprocess:main"
|
||||
inspect-preprocessed = "model.inspect_preprocessed:main"
|
||||
|
||||
[tool.uv]
|
||||
# 设置当前项目的默认索引源
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
|
|
@ -0,0 +1,266 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze frequency distribution in pinyin_char_statistics.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
# Path to the JSON file
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
if not json_path.exists():
|
||||
print(f"Error: File not found: {json_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading {json_path}...")
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Timestamp: {data.get('timestamp')}")
|
||||
print(f"Total characters: {data.get('total_characters')}")
|
||||
print(f"Total pinyins: {data.get('total_pinyins')}")
|
||||
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
||||
|
||||
pairs = data.get("pairs", {})
|
||||
print(f"Number of pairs: {len(pairs)}")
|
||||
|
||||
# Extract counts and IDs
|
||||
counts = []
|
||||
id_to_count = {}
|
||||
char_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
try:
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
char = pair.get("char", "")
|
||||
if char_id is not None and count is not None:
|
||||
counts.append(count)
|
||||
id_to_count[char_id] = count
|
||||
if char:
|
||||
char_to_count[char] = count
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Warning: Could not parse pair {key}: {e}")
|
||||
continue
|
||||
|
||||
if not counts:
|
||||
print("No valid count data found.")
|
||||
return
|
||||
|
||||
# Basic statistics
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
total_count = sum(counts)
|
||||
mean_count = total_count / len(counts)
|
||||
|
||||
# Sort counts for percentiles
|
||||
sorted_counts = sorted(counts)
|
||||
n = len(sorted_counts)
|
||||
|
||||
# Percentiles
|
||||
p10 = sorted_counts[int(0.1 * n)]
|
||||
p25 = sorted_counts[int(0.25 * n)]
|
||||
p50 = sorted_counts[int(0.5 * n)]
|
||||
p75 = sorted_counts[int(0.75 * n)]
|
||||
p90 = sorted_counts[int(0.9 * n)]
|
||||
p99 = sorted_counts[int(0.99 * n)]
|
||||
|
||||
# Variance and std dev
|
||||
variance = sum((x - mean_count) ** 2 for x in counts) / n
|
||||
std_dev = math.sqrt(variance)
|
||||
|
||||
print("\n=== BASIC STATISTICS ===")
|
||||
print(f"Min frequency: {min_count}")
|
||||
print(f"Max frequency: {max_count}")
|
||||
print(f"Mean frequency: {mean_count:.2f}")
|
||||
print(f"Standard deviation: {std_dev:.2f}")
|
||||
print(f"Total frequency sum: {total_count}")
|
||||
print(f"Number of entries: {n}")
|
||||
|
||||
print("\n=== PERCENTILES ===")
|
||||
print(f"10th percentile: {p10}")
|
||||
print(f"25th percentile: {p25}")
|
||||
print(f"50th percentile (median): {p50}")
|
||||
print(f"75th percentile: {p75}")
|
||||
print(f"90th percentile: {p90}")
|
||||
print(f"99th percentile: {p99}")
|
||||
|
||||
# Find IDs with min and max counts
|
||||
min_ids = [id for id, count in id_to_count.items() if count == min_count]
|
||||
max_ids = [id for id, count in id_to_count.items() if count == max_count]
|
||||
|
||||
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
|
||||
print(f"IDs with max frequency ({max_count}): {max_ids}")
|
||||
|
||||
# Check if IDs are assigned in frequency order
|
||||
# Compute correlation between ID and count
|
||||
ids = list(id_to_count.keys())
|
||||
id_counts = [id_to_count[id] for id in ids]
|
||||
|
||||
# Sort by ID and check if counts are decreasing
|
||||
sorted_by_id = sorted(ids)
|
||||
counts_by_id = [id_to_count[id] for id in sorted_by_id]
|
||||
|
||||
# Calculate monotonicity: count of times count decreases as ID increases
|
||||
decreases = 0
|
||||
increases = 0
|
||||
for i in range(1, len(counts_by_id)):
|
||||
if counts_by_id[i] < counts_by_id[i - 1]:
|
||||
decreases += 1
|
||||
elif counts_by_id[i] > counts_by_id[i - 1]:
|
||||
increases += 1
|
||||
|
||||
print(f"\n=== ID ORDER ANALYSIS ===")
|
||||
print(f"Total pairs: {len(counts_by_id)}")
|
||||
print(f"Decreases as ID increases: {decreases} times")
|
||||
print(f"Increases as ID increases: {increases} times")
|
||||
print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%")
|
||||
|
||||
# Check if IDs are roughly sorted by frequency
|
||||
# Compute Spearman rank correlation (simplified)
|
||||
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
|
||||
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
|
||||
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
|
||||
|
||||
# Average rank difference
|
||||
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
|
||||
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
||||
max_rank_diff = max(rank_diffs)
|
||||
|
||||
print(
|
||||
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
|
||||
)
|
||||
print(f"Maximum rank difference: {max_rank_diff}")
|
||||
|
||||
# Analyze specific ID range 5000-5500
|
||||
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
|
||||
range_counts = []
|
||||
range_ids = []
|
||||
for id in range(5000, 5501):
|
||||
if id in id_to_count:
|
||||
range_counts.append(id_to_count[id])
|
||||
range_ids.append(id)
|
||||
|
||||
if range_counts:
|
||||
range_min = min(range_counts)
|
||||
range_max = max(range_counts)
|
||||
range_mean = sum(range_counts) / len(range_counts)
|
||||
range_sorted = sorted(range_counts)
|
||||
range_n = len(range_counts)
|
||||
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
|
||||
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
|
||||
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
|
||||
|
||||
print(f"IDs in range 5000-5500: {len(range_counts)}")
|
||||
print(f"Min frequency in range: {range_min}")
|
||||
print(f"Max frequency in range: {range_max}")
|
||||
print(f"Mean frequency in range: {range_mean:.2f}")
|
||||
print(f"10th percentile in range: {range_p10}")
|
||||
print(f"50th percentile in range: {range_p50}")
|
||||
print(f"90th percentile in range: {range_p90}")
|
||||
|
||||
# Find IDs with min frequency in this range
|
||||
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
||||
print(
|
||||
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
|
||||
)
|
||||
else:
|
||||
print("No IDs found in range 5000-5500")
|
||||
|
||||
# Histogram of frequencies (log bins)
|
||||
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
|
||||
if max_count > 0:
|
||||
log_min = math.log10(min_count) if min_count > 0 else 0
|
||||
log_max = math.log10(max_count)
|
||||
num_bins = 20
|
||||
bin_edges = [
|
||||
10 ** (log_min + i * (log_max - log_min) / num_bins)
|
||||
for i in range(num_bins + 1)
|
||||
]
|
||||
|
||||
hist = [0] * num_bins
|
||||
for count in counts:
|
||||
if count > 0:
|
||||
log_val = math.log10(count)
|
||||
bin_idx = min(
|
||||
int((log_val - log_min) / (log_max - log_min) * num_bins),
|
||||
num_bins - 1,
|
||||
)
|
||||
hist[bin_idx] += 1
|
||||
|
||||
print("Log-scale histogram (count range -> frequency count):")
|
||||
for i in range(num_bins):
|
||||
if hist[i] > 0:
|
||||
lower = bin_edges[i]
|
||||
upper = bin_edges[i + 1]
|
||||
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
||||
|
||||
# Check for zero or near-zero frequencies
|
||||
zero_count = sum(1 for c in counts if c == 0)
|
||||
low_count = sum(1 for c in counts if 0 < c <= 10)
|
||||
very_low_count = sum(1 for c in counts if 0 < c <= 100)
|
||||
|
||||
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
|
||||
print(f"Entries with zero frequency: {zero_count}")
|
||||
print(f"Entries with frequency <= 10: {low_count}")
|
||||
print(f"Entries with frequency <= 100: {very_low_count}")
|
||||
|
||||
# Find the actual min frequency (excluding zeros if any)
|
||||
non_zero_counts = [c for c in counts if c > 0]
|
||||
if non_zero_counts:
|
||||
actual_min = min(non_zero_counts)
|
||||
print(f"Actual min frequency (non-zero): {actual_min}")
|
||||
actual_min_ids = [
|
||||
id for id, count in id_to_count.items() if count == actual_min
|
||||
]
|
||||
print(
|
||||
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
|
||||
)
|
||||
|
||||
# Summary for smoothing algorithm design
|
||||
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
||||
print(
|
||||
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
|
||||
)
|
||||
print(f"Most entries ({p50}) have frequency around {p50}")
|
||||
print(f"Top 10% of entries have frequency > {p90}")
|
||||
print(f"Bottom 10% of entries have frequency < {p10}")
|
||||
print(
|
||||
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
|
||||
)
|
||||
|
||||
# Save detailed data for further analysis
|
||||
output_file = "frequency_analysis_results.txt"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("Frequency Analysis Results\n")
|
||||
f.write("=" * 50 + "\n")
|
||||
f.write(f"Min frequency: {min_count}\n")
|
||||
f.write(f"Max frequency: {max_count}\n")
|
||||
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
||||
f.write(f"Standard deviation: {std_dev:.2f}\n")
|
||||
f.write(f"10th percentile: {p10}\n")
|
||||
f.write(f"50th percentile: {p50}\n")
|
||||
f.write(f"90th percentile: {p90}\n")
|
||||
f.write(
|
||||
f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
|
||||
)
|
||||
f.write(
|
||||
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
|
||||
)
|
||||
|
||||
print(f"\nDetailed results saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze specific ID ranges in pinyin_char_statistics.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get("pairs", {})
|
||||
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
# Analyze range 5000-5500 in detail
|
||||
print("ID range 5000-5500 detailed analysis:")
|
||||
print("ID\tCount\tChar\tPinyin")
|
||||
|
||||
range_data = []
|
||||
for id in range(5000, 5501):
|
||||
if id in id_to_count:
|
||||
# Find the pair to get char and pinyin
|
||||
for key, pair in pairs.items():
|
||||
if pair.get("id") == id:
|
||||
char = pair.get("char", "")
|
||||
pinyin = pair.get("pinyin", "")
|
||||
count = pair.get("count", 0)
|
||||
range_data.append((id, count, char, pinyin))
|
||||
if id % 100 == 0: # Print every 100th for overview
|
||||
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
||||
break
|
||||
|
||||
# Print min and max in range
|
||||
if range_data:
|
||||
min_item = min(range_data, key=lambda x: x[1])
|
||||
max_item = max(range_data, key=lambda x: x[1])
|
||||
print(
|
||||
f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
|
||||
)
|
||||
print(
|
||||
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
|
||||
)
|
||||
|
||||
# Check if frequencies are monotonic in this range
|
||||
counts = [item[1] for item in range_data]
|
||||
increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1))
|
||||
decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1))
|
||||
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
||||
|
||||
# Check for frequency plateaus
|
||||
from collections import Counter
|
||||
|
||||
freq_count = Counter(counts)
|
||||
most_common = freq_count.most_common(5)
|
||||
print(f"Most common frequencies in range: {most_common}")
|
||||
|
||||
# Analyze the tail (IDs with frequency 1)
|
||||
print("\n\nAnalysis of frequency=1 entries:")
|
||||
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
|
||||
print(f"Number of entries with frequency=1: {len(freq_one_ids)}")
|
||||
if freq_one_ids:
|
||||
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
|
||||
print(f"First 10 IDs: {freq_one_ids[:10]}")
|
||||
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
|
||||
|
||||
# Check if they're contiguous
|
||||
sorted_ids = sorted(freq_one_ids)
|
||||
contiguous = all(
|
||||
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
|
||||
)
|
||||
print(f"Are they contiguous IDs? {contiguous}")
|
||||
|
||||
# Sample some characters
|
||||
print("\nSample characters with frequency=1:")
|
||||
sample_count = 0
|
||||
for key, pair in pairs.items():
|
||||
if pair.get("count") == 1 and sample_count < 10:
|
||||
print(
|
||||
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
|
||||
)
|
||||
sample_count += 1
|
||||
|
||||
# Check overall ID-frequency ordering
|
||||
print("\n\nOverall ID-frequency ordering analysis:")
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
all_counts = [id_to_count[id] for id in all_ids]
|
||||
|
||||
# Count monotonic segments
|
||||
non_increasing_segments = 0
|
||||
current_segment_length = 1
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] <= all_counts[i - 1]:
|
||||
current_segment_length += 1
|
||||
else:
|
||||
if current_segment_length > 1:
|
||||
non_increasing_segments += 1
|
||||
current_segment_length = 1
|
||||
if current_segment_length > 1:
|
||||
non_increasing_segments += 1
|
||||
|
||||
print(f"Total IDs: {len(all_ids)}")
|
||||
print(f"Non-increasing segments: {non_increasing_segments}")
|
||||
|
||||
# Check for frequency plateaus overall
|
||||
from collections import Counter
|
||||
|
||||
overall_freq_count = Counter(all_counts)
|
||||
plateaus = [
|
||||
(freq, count) for freq, count in overall_freq_count.items() if count > 1
|
||||
]
|
||||
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
||||
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
||||
for freq, count in plateaus_sorted:
|
||||
print(f" {freq}: {count} IDs")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,10 +1,15 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from model.trainer import collate_fn, worker_init_fn
|
||||
|
||||
|
||||
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
|
||||
|
||||
dataloader = DataLoader(
|
||||
data,
|
||||
|
|
@ -18,5 +23,5 @@ dataloader = DataLoader(
|
|||
)
|
||||
|
||||
for i in dataloader:
|
||||
print((i['labels'] == 1).sum())
|
||||
print((i["labels"] == 1).sum())
|
||||
break
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive frequency distribution analysis
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import math
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get("pairs", {})
|
||||
|
||||
# Extract counts
|
||||
counts = []
|
||||
for key, pair in pairs.items():
|
||||
count = pair.get("count")
|
||||
if count is not None:
|
||||
counts.append(count)
|
||||
|
||||
n = len(counts)
|
||||
print(f"Total entries: {n}")
|
||||
|
||||
# Sort descending for rank-frequency analysis
|
||||
counts_sorted_desc = sorted(counts, reverse=True)
|
||||
|
||||
# Basic statistics
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
mean_count = sum(counts) / n
|
||||
|
||||
# Percentiles
|
||||
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
||||
print("\n=== PERCENTILE DISTRIBUTION ===")
|
||||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})")
|
||||
|
||||
# Cumulative distribution
|
||||
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
||||
thresholds = [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
5,
|
||||
10,
|
||||
20,
|
||||
50,
|
||||
100,
|
||||
200,
|
||||
500,
|
||||
1000,
|
||||
2000,
|
||||
5000,
|
||||
10000,
|
||||
20000,
|
||||
50000,
|
||||
100000,
|
||||
200000,
|
||||
500000,
|
||||
1000000,
|
||||
5000000,
|
||||
10000000,
|
||||
50000000,
|
||||
100000000,
|
||||
500000000,
|
||||
]
|
||||
for thresh in thresholds:
|
||||
if thresh > max_count:
|
||||
break
|
||||
below = sum(1 for c in counts if c <= thresh)
|
||||
above = sum(1 for c in counts if c >= thresh)
|
||||
print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)")
|
||||
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
||||
|
||||
# Check min_count=109 parameter
|
||||
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
||||
below_109 = sum(1 for c in counts if c < 109)
|
||||
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
||||
print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
|
||||
print(
|
||||
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
|
||||
)
|
||||
|
||||
# If 109 is a threshold, what's the actual min among those >= 109?
|
||||
counts_ge_109 = [c for c in counts if c >= 109]
|
||||
if counts_ge_109:
|
||||
actual_min_ge_109 = min(counts_ge_109)
|
||||
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
|
||||
|
||||
# Rank-frequency analysis (Zipf's law)
|
||||
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
||||
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
||||
for rank in range(1, 101):
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
||||
|
||||
# Frequency spectrum (how many distinct frequencies)
|
||||
freq_counter = Counter(counts)
|
||||
print(f"\n=== FREQUENCY SPECTRUM ===")
|
||||
print(f"Distinct frequency values: {len(freq_counter)}")
|
||||
|
||||
# Most common frequencies
|
||||
print("\nTop 20 most common frequencies (plateau sizes):")
|
||||
for freq, freq_count in freq_counter.most_common(20):
|
||||
print(f" Frequency {freq}: {freq_count} entries")
|
||||
|
||||
# Analyze ID ranges
|
||||
print("\n=== ID RANGE ANALYSIS ===")
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
ranges = [
|
||||
(0, 100, "Top 100 IDs"),
|
||||
(100, 500, "IDs 100-500"),
|
||||
(500, 1000, "IDs 500-1000"),
|
||||
(1000, 2000, "IDs 1000-2000"),
|
||||
(2000, 5000, "IDs 2000-5000"),
|
||||
(5000, 5500, "IDs 5000-5500 (user mentioned)"),
|
||||
(5500, 6000, "IDs 5500-6000"),
|
||||
(10000, 10500, "IDs 10000-10500"),
|
||||
(15000, 15500, "IDs 15000-15500"),
|
||||
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
|
||||
(19499, 20647, "IDs with freq=1"),
|
||||
]
|
||||
|
||||
for start, end, label in ranges:
|
||||
range_counts = [
|
||||
id_to_count[id] for id in range(start, end) if id in id_to_count
|
||||
]
|
||||
if range_counts:
|
||||
min_c = min(range_counts)
|
||||
max_c = max(range_counts)
|
||||
mean_c = sum(range_counts) / len(range_counts)
|
||||
median_c = sorted(range_counts)[len(range_counts) // 2]
|
||||
print(
|
||||
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
|
||||
)
|
||||
|
||||
# Check if IDs are perfectly sorted by frequency
|
||||
print("\n=== ID ORDER VERIFICATION ===")
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
all_counts = [id_to_count[id] for id in all_ids]
|
||||
|
||||
# Check for any violations of non-increasing order
|
||||
violations = 0
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] > all_counts[i - 1]:
|
||||
violations += 1
|
||||
if violations <= 5:
|
||||
print(
|
||||
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
|
||||
)
|
||||
|
||||
print(f"Total violations of non-increasing order: {violations}")
|
||||
|
||||
# Check if equal frequencies are grouped together
|
||||
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
|
||||
current_freq = None
|
||||
group_start = None
|
||||
group_sizes = []
|
||||
|
||||
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
||||
if count != current_freq:
|
||||
if current_freq is not None:
|
||||
group_sizes.append(
|
||||
(current_freq, group_start, all_ids[i - 1], i - group_start)
|
||||
)
|
||||
current_freq = count
|
||||
group_start = i
|
||||
|
||||
# Last group
|
||||
if current_freq is not None:
|
||||
group_sizes.append(
|
||||
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
|
||||
)
|
||||
|
||||
# Sort groups by size
|
||||
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
||||
print("Top 10 largest frequency groups (plateaus):")
|
||||
for freq, start_id_idx, end_id, size in group_sizes[:10]:
|
||||
start_id = all_ids[start_id_idx]
|
||||
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
|
||||
|
||||
# Summary for smoothing algorithm
|
||||
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
||||
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
||||
print(
|
||||
f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
|
||||
)
|
||||
print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
|
||||
print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
|
||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
|
||||
print(
|
||||
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
|
||||
)
|
||||
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
||||
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
||||
|
||||
# Save data for plotting
|
||||
with open("rank_freq.csv", "w") as f:
|
||||
f.write("rank,frequency\n")
|
||||
for rank, freq in enumerate(counts_sorted_desc, 1):
|
||||
f.write(f"{rank},{freq}\n")
|
||||
print("\nRank-frequency data saved to rank_freq.csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,409 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
临时迁移训练脚本:在预训练模型基础上重新训练,支持冻结 context_encoder。
|
||||
|
||||
与 src/model/trainer.py 的 train 命令行为完全一致,额外增加:
|
||||
- --pretrained-checkpoint: 加载预训练权重(必需的迁移学习源)
|
||||
- --freeze-context-encoder: 冻结 context_encoder 层(默认开启)
|
||||
|
||||
运行方式:
|
||||
python scripts/finetune_slots.py \
|
||||
--pretrained-checkpoint ./output/checkpoints/best_model.pt \
|
||||
--train-data-path /path/to/train_data \
|
||||
--eval-data-path /path/to/eval_data \
|
||||
--output-dir ./finetune_output \
|
||||
--freeze-context-encoder
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from model.model import InputMethodEngine
|
||||
from model.trainer import (
|
||||
Trainer,
|
||||
create_dataloader,
|
||||
worker_init_fn,
|
||||
)
|
||||
from model.dataset import PinyinInputDataset
|
||||
from model.preprocessed_dataset import (
|
||||
PreProcessedDataset,
|
||||
is_preprocessed_data,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="迁移学习训练:加载预训练模型,冻结指定层后重新训练",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
# === 数据参数 ===
|
||||
parser.add_argument("--train-data-path", "-t", required=True, help="训练数据集路径")
|
||||
parser.add_argument("--eval-data-path", "-e", required=True, help="评估数据集路径")
|
||||
parser.add_argument("--output-dir", "-o", default="./finetune_output", help="输出目录")
|
||||
parser.add_argument("--max-iter-length", type=int,
|
||||
default=1024 * 1024 * 128, help="每个 epoch 最大样本数")
|
||||
|
||||
# === 迁移学习参数 ===
|
||||
parser.add_argument("--pretrained-checkpoint", "-c", required=True,
|
||||
help="预训练模型检查点路径")
|
||||
parser.add_argument("--freeze-context-encoder", action="store_true", default=True,
|
||||
help="冻结 context_encoder 层 (默认开启)")
|
||||
parser.add_argument("--no-freeze-context-encoder", dest="freeze_context_encoder",
|
||||
action="store_false",
|
||||
help="不冻结 context_encoder")
|
||||
|
||||
# === 训练参数 ===
|
||||
parser.add_argument("--batch-size", "-b", type=int, default=128, help="批次大小")
|
||||
parser.add_argument("--num-epochs", type=int, default=10, help="训练轮数")
|
||||
parser.add_argument("--learning-rate", "-lr", type=float, default=2e-4,
|
||||
help="学习率")
|
||||
parser.add_argument("--min-learning-rate", type=float, default=1e-9,
|
||||
help="最小学习率")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.05, help="权重衰减")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="热身步数比例")
|
||||
parser.add_argument("--label-smoothing", type=float, default=0.1,
|
||||
help="标签平滑参数")
|
||||
parser.add_argument("--grad-accum-steps", type=int, default=1,
|
||||
help="梯度累积步数")
|
||||
parser.add_argument("--clip-grad-norm", type=float, default=1.0,
|
||||
help="梯度裁剪范数")
|
||||
parser.add_argument("--eval-frequency", type=int, default=500, help="评估频率")
|
||||
parser.add_argument("--save-frequency", type=int, default=1000, help="保存频率")
|
||||
|
||||
# === 其他参数 ===
|
||||
parser.add_argument("--mixed-precision", action="store_true", default=True)
|
||||
parser.add_argument("--no-mixed-precision", dest="mixed_precision",
|
||||
action="store_false", help="禁用混合精度")
|
||||
parser.add_argument("--num-workers", type=int, default=2, help="数据加载worker数")
|
||||
parser.add_argument("--tensorboard", action="store_true", default=True)
|
||||
parser.add_argument("--no-tensorboard", dest="tensorboard", action="store_false",
|
||||
help="禁用 TensorBoard")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument("--compile", action="store_true", default=False,
|
||||
help="使用 torch.compile 优化")
|
||||
parser.add_argument("--moe-mode", default="all",
|
||||
choices=["all", "sparse", "sparse_allow_graph"],
|
||||
help="MoE 计算策略")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# ================================================================
|
||||
# 初始化
|
||||
# ================================================================
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
console = Console()
|
||||
output_path = Path(args.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ================================================================
|
||||
# 模型常量 (与 trainer.py 保持一致)
|
||||
# ================================================================
|
||||
vocab_size = 10019
|
||||
pinyin_vocab_size = 30
|
||||
dim = 512
|
||||
num_slots = 8
|
||||
n_layers = 4
|
||||
n_heads = 4
|
||||
num_experts = 10
|
||||
max_seq_len = 128
|
||||
|
||||
# ================================================================
|
||||
# 打印配置
|
||||
# ================================================================
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]迁移学习训练配置[/bold cyan]", border_style="cyan"))
|
||||
config_table = Table(show_header=True, header_style="bold magenta")
|
||||
config_table.add_column("Category", style="cyan")
|
||||
config_table.add_column("Parameter", style="green")
|
||||
config_table.add_column("Value", style="yellow")
|
||||
|
||||
config_table.add_row("迁移学习", "预训练检查点", args.pretrained_checkpoint)
|
||||
config_table.add_row("迁移学习", "冻结 context_encoder",
|
||||
str(args.freeze_context_encoder))
|
||||
config_table.add_row("数据", "训练数据路径", args.train_data_path)
|
||||
config_table.add_row("数据", "评估数据路径", args.eval_data_path)
|
||||
config_table.add_row("数据", "输出目录", args.output_dir)
|
||||
config_table.add_row("数据", "批次大小", str(args.batch_size))
|
||||
config_table.add_row("数据", "Worker数量", str(args.num_workers))
|
||||
config_table.add_row("模型", "MoE策略", args.moe_mode)
|
||||
config_table.add_row("模型", "编译优化", str(args.compile))
|
||||
config_table.add_row("训练", "训练轮数", str(args.num_epochs))
|
||||
config_table.add_row("训练", "学习率", f"{args.learning_rate:.2e}")
|
||||
config_table.add_row("训练", "最小学习率", f"{args.min_learning_rate:.2e}")
|
||||
config_table.add_row("训练", "权重衰减", str(args.weight_decay))
|
||||
config_table.add_row("训练", "热身比例", str(args.warmup_ratio))
|
||||
config_table.add_row("训练", "标签平滑", str(args.label_smoothing))
|
||||
config_table.add_row("训练", "梯度累积", str(args.grad_accum_steps))
|
||||
config_table.add_row("训练", "梯度裁剪", str(args.clip_grad_norm))
|
||||
config_table.add_row("训练", "混合精度", str(args.mixed_precision))
|
||||
|
||||
# ================================================================
|
||||
# 创建数据加载器 (逻辑与 trainer.py CLI 完全一致)
|
||||
# ================================================================
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
is_train_preprocessed = is_preprocessed_data(args.train_data_path)
|
||||
is_eval_preprocessed = is_preprocessed_data(args.eval_data_path)
|
||||
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(args.train_data_path,
|
||||
max_cache_shards=2)
|
||||
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
||||
shuffle_train = not pre_shuffled
|
||||
if args.max_iter_length > 0:
|
||||
capped_samples = min(len(train_dataset), args.max_iter_length)
|
||||
else:
|
||||
capped_samples = len(train_dataset)
|
||||
total_steps = (capped_samples // args.batch_size) * args.num_epochs
|
||||
train_num_workers = min(args.num_workers, 1)
|
||||
logger.info(
|
||||
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
||||
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
||||
f"workers={train_num_workers}, steps={total_steps:,}")
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=train_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=shuffle_train,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||
else:
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=args.train_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=args.max_iter_length,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
||||
5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
total_steps = int(args.max_iter_length *
|
||||
args.num_epochs / args.batch_size)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=args.max_iter_length,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "流式数据")
|
||||
|
||||
if is_eval_preprocessed:
|
||||
eval_dataset = PreProcessedDataset(args.eval_data_path,
|
||||
max_cache_shards=1)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=0,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
||||
else:
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=args.eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=args.batch_size * 64,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40,
|
||||
5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=2,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=args.batch_size * 64,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "流式数据")
|
||||
|
||||
config_table.add_row("数据", "总步数", str(total_steps))
|
||||
console.print(config_table)
|
||||
|
||||
# ================================================================
|
||||
# 创建模型并加载预训练权重
|
||||
# ================================================================
|
||||
console.print("[bold cyan]正在创建模型并加载预训练权重...[/bold cyan]")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = InputMethodEngine(
|
||||
vocab_size=vocab_size,
|
||||
pinyin_vocab_size=pinyin_vocab_size,
|
||||
dim=dim,
|
||||
num_slots=num_slots,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads,
|
||||
num_experts=num_experts,
|
||||
max_seq_len=max_seq_len,
|
||||
compile=args.compile,
|
||||
moe_mode=args.moe_mode,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
# 加载预训练权重
|
||||
pretrained_path = Path(args.pretrained_checkpoint)
|
||||
if not pretrained_path.exists():
|
||||
console.print(f"[red]❌ 预训练检查点不存在: {args.pretrained_checkpoint}[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device)
|
||||
if "model_state_dict" in checkpoint:
|
||||
pretrained_weights = checkpoint["model_state_dict"]
|
||||
else:
|
||||
pretrained_weights = checkpoint
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
pretrained_weights, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
console.print(f"[yellow]⚠ 缺失的键 ({len(missing_keys)}): "
|
||||
f"{missing_keys[:5]}...[/yellow]")
|
||||
if unexpected_keys:
|
||||
console.print(f"[yellow]⚠ 多余的键 ({len(unexpected_keys)}): "
|
||||
f"{unexpected_keys[:5]}...[/yellow]")
|
||||
|
||||
console.print(f"[green]✓ 预训练权重加载完成[/green]")
|
||||
|
||||
# ================================================================
|
||||
# 冻结 context_encoder
|
||||
# ================================================================
|
||||
if args.freeze_context_encoder:
|
||||
console.print("[bold cyan]正在冻结 context_encoder...[/bold cyan]")
|
||||
frozen_count = 0
|
||||
trainable_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("context_encoder"):
|
||||
param.requires_grad = False
|
||||
frozen_count += param.numel()
|
||||
else:
|
||||
trainable_count += param.numel()
|
||||
|
||||
total_params = frozen_count + trainable_count
|
||||
console.print(f"[green]✓ context_encoder 已冻结[/green]")
|
||||
logger.info(
|
||||
f"冻结参数: {frozen_count:,} / {total_params:,} "
|
||||
f"({frozen_count / total_params * 100:.1f}%), "
|
||||
f"可训练参数: {trainable_count:,} / {total_params:,} "
|
||||
f"({trainable_count / total_params * 100:.1f}%)")
|
||||
else:
|
||||
logger.info("未冻结任何层,全模型参与训练")
|
||||
|
||||
# ================================================================
|
||||
# 保存配置
|
||||
# ================================================================
|
||||
config = {
|
||||
"pretrained_checkpoint": args.pretrained_checkpoint,
|
||||
"freeze_context_encoder": args.freeze_context_encoder,
|
||||
"train_data_path": args.train_data_path,
|
||||
"eval_data_path": args.eval_data_path,
|
||||
"output_dir": args.output_dir,
|
||||
"batch_size": args.batch_size,
|
||||
"num_epochs": args.num_epochs,
|
||||
"learning_rate": args.learning_rate,
|
||||
"min_learning_rate": args.min_learning_rate,
|
||||
"weight_decay": args.weight_decay,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"label_smoothing": args.label_smoothing,
|
||||
"grad_accum_steps": args.grad_accum_steps,
|
||||
"clip_grad_norm": args.clip_grad_norm,
|
||||
"eval_frequency": args.eval_frequency,
|
||||
"save_frequency": args.save_frequency,
|
||||
"mixed_precision": args.mixed_precision,
|
||||
"num_workers": args.num_workers,
|
||||
"use_tensorboard": args.tensorboard,
|
||||
"seed": args.seed,
|
||||
"compile": args.compile,
|
||||
"moe_mode": args.moe_mode,
|
||||
"total_steps": total_steps,
|
||||
"vocab_size": vocab_size,
|
||||
"pinyin_vocab_size": pinyin_vocab_size,
|
||||
"dim": dim,
|
||||
"num_slots": num_slots,
|
||||
"n_layers": n_layers,
|
||||
"n_heads": n_heads,
|
||||
"num_experts": num_experts,
|
||||
"max_seq_len": max_seq_len,
|
||||
"is_train_preprocessed": is_train_preprocessed,
|
||||
"is_eval_preprocessed": is_eval_preprocessed,
|
||||
}
|
||||
config_file = output_path / "training_config.json"
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Configuration saved to {config_file}")
|
||||
|
||||
# ================================================================
|
||||
# 创建 Trainer 并开始训练
|
||||
# ================================================================
|
||||
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=total_steps,
|
||||
output_dir=args.output_dir,
|
||||
num_epochs=args.num_epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
min_learning_rate=args.min_learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
label_smoothing=args.label_smoothing,
|
||||
grad_accum_steps=args.grad_accum_steps,
|
||||
clip_grad_norm=args.clip_grad_norm,
|
||||
eval_frequency=args.eval_frequency,
|
||||
save_frequency=args.save_frequency,
|
||||
mixed_precision=args.mixed_precision,
|
||||
device=device,
|
||||
use_tensorboard=args.tensorboard,
|
||||
status_file="training_status.json",
|
||||
)
|
||||
console.print("[green]✓ 训练器创建完成[/green]")
|
||||
|
||||
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
||||
console.print(f"开始时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
try:
|
||||
trainer.train(
|
||||
resume_from=None,
|
||||
reset_training_state=False,
|
||||
auto_resume=False, # 迁移学习从头开始,不自动恢复
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("[bold green]训练被终止[/bold green]")
|
||||
trainer.save_checkpoint("interrupted_model.pt")
|
||||
|
||||
console.print("[bold green]✓ 训练完成![/bold green]")
|
||||
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
console.print(f"模型和日志保存在: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||
from typing import Dict, Any
|
||||
import sys
|
||||
|
||||
|
||||
def modify_pinyin_statistics(file_path: Path) -> None:
|
||||
"""
|
||||
一次性修改拼音统计JSON文件。
|
||||
|
|
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
"""
|
||||
# 1. 加载原数据
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data: Dict[str, Any] = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
|
||||
|
|
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
"id": 0,
|
||||
"char": "",
|
||||
"pinyin": "",
|
||||
"count": original_zero_count + 1 # 原count + 1
|
||||
"count": original_zero_count + 1, # 原count + 1
|
||||
}
|
||||
|
||||
# 3.2 处理其他所有记录,键和id都+1
|
||||
|
|
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
# 这里保持原时间戳不变,因为是一次性修改
|
||||
|
||||
# 写回文件,保持可读格式
|
||||
backup_path = file_path.with_suffix('.json.bak')
|
||||
backup_path = file_path.with_suffix(".json.bak")
|
||||
try:
|
||||
# 先备份原文件
|
||||
import shutil
|
||||
|
||||
shutil.copy2(file_path, backup_path)
|
||||
print(f"已创建备份: {backup_path}")
|
||||
|
||||
# 写入新数据
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"修改完成!")
|
||||
|
|
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 假设你的JSON文件在当前目录
|
||||
json_file = Path("./src/model/assets/pinyin_char_statistics.json")
|
||||
# JSON文件相对于项目根目录
|
||||
json_file = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
|
||||
# 执行修改
|
||||
modify_pinyin_statistics(json_file)
|
||||
|
||||
# 验证修改:读取并显示前几条记录
|
||||
print("\n验证前5条记录:")
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for i in range(5):
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Visualize frequency distribution with ASCII plots
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ascii_histogram(data, bins=20, width=60):
|
||||
"""Create ASCII histogram"""
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
min_val = min(data)
|
||||
max_val = max(data)
|
||||
|
||||
# Use log bins for wide range
|
||||
if max_val / min_val > 1000:
|
||||
log_min = math.log10(min_val) if min_val > 0 else 0
|
||||
log_max = math.log10(max_val)
|
||||
bin_edges = [
|
||||
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
|
||||
]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
if val > 0:
|
||||
log_val = math.log10(val)
|
||||
bin_idx = min(
|
||||
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
|
||||
)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
|
||||
else:
|
||||
bin_width = (max_val - min_val) / bins
|
||||
bin_edges = [min_val + i * bin_width for i in range(bins + 1)]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)]
|
||||
|
||||
max_count = max(hist)
|
||||
result = []
|
||||
for i in range(bins):
|
||||
if hist[i] == 0:
|
||||
continue
|
||||
bar = "#" * int(hist[i] / max_count * width)
|
||||
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def main():
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get("pairs", {})
|
||||
counts = [
|
||||
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
|
||||
]
|
||||
|
||||
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. ASCII Histogram (log bins):")
|
||||
print(ascii_histogram(counts, bins=20, width=60))
|
||||
|
||||
# Rank-frequency plot in ASCII
|
||||
print("\n2. Rank-Frequency Relationship (Top 50):")
|
||||
counts_sorted_desc = sorted(counts, reverse=True)
|
||||
max_freq = counts_sorted_desc[0]
|
||||
max_rank = 50
|
||||
|
||||
for rank in range(1, max_rank + 1):
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
||||
bar = "#" * bar_length
|
||||
print(f"Rank {rank:3}: {freq:12} {bar}")
|
||||
|
||||
# ID vs Frequency plot (sampled)
|
||||
print("\n3. ID vs Frequency (sampled every 500 IDs):")
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
max_id = all_ids[-1]
|
||||
|
||||
print("ID Frequency log10(freq)")
|
||||
for id in range(0, max_id + 1, 500):
|
||||
if id in id_to_count:
|
||||
freq = id_to_count[id]
|
||||
log_freq = math.log10(freq) if freq > 0 else 0
|
||||
bar = "#" * int(log_freq / math.log10(max_freq) * 40)
|
||||
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
||||
|
||||
# Zipf's law fit
|
||||
print("\n4. Zipf's Law Analysis:")
|
||||
print(" Rank * Frequency ≈ constant for Zipf's law")
|
||||
print(" Top 10 ranks:")
|
||||
for rank in range(1, 11):
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
product = rank * freq
|
||||
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
||||
|
||||
# Check if product is roughly constant
|
||||
products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
|
||||
avg_product = sum(products) / len(products)
|
||||
std_product = math.sqrt(
|
||||
sum((p - avg_product) ** 2 for p in products) / len(products)
|
||||
)
|
||||
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
||||
print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
|
||||
|
||||
# Frequency spectrum
|
||||
from collections import Counter
|
||||
|
||||
freq_counter = Counter(counts)
|
||||
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
||||
print(" Frequency Count Cumulative")
|
||||
cum = 0
|
||||
for freq in sorted(freq_counter.keys())[:20]:
|
||||
count = freq_counter[freq]
|
||||
cum += count
|
||||
print(f" {freq:10} {count:6} {cum:6}")
|
||||
|
||||
# Summary statistics
|
||||
print("\n6. Key Statistics:")
|
||||
n = len(counts)
|
||||
print(f" Total entries: {n}")
|
||||
print(f" Min frequency: {min(counts)}")
|
||||
print(f" Max frequency: {max(counts)}")
|
||||
print(f" Ratio max/min: {max(counts) / min(counts):.2e}")
|
||||
|
||||
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
||||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
||||
|
||||
# Save data for external plotting
|
||||
with open("id_vs_freq.csv", "w") as f:
|
||||
f.write("id,frequency\n")
|
||||
for id in sorted(id_to_count.keys()):
|
||||
f.write(f"{id},{id_to_count[id]}\n")
|
||||
print("\nData saved to id_vs_freq.csv for external plotting")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据分布查看脚本
|
||||
主要关注 labels 为:0-10, 200-210, 2900-2910, 5100-5110, 9520-9530 的分布情况
|
||||
|
||||
使用方法:
|
||||
uv run python src/analyze_data.py --data_path ./data/train.txt
|
||||
|
||||
注意:
|
||||
data_path 可以是:
|
||||
- 本地文本文件 (每行一个文本)
|
||||
- HuggingFace 数据集路径
|
||||
- 目录路径 (该目录下需要有 dataset_info.json 或类似配置文件)
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from model.query import QueryEngine
|
||||
|
||||
|
||||
_id2char_cache = {}
|
||||
|
||||
|
||||
def get_char_by_id(query_engine: QueryEngine, char_id: int) -> str:
|
||||
if char_id == 0:
|
||||
return "<EOS>"
|
||||
if char_id not in _id2char_cache:
|
||||
info = query_engine.query_by_id(char_id)
|
||||
_id2char_cache[char_id] = info.char if info else f"<ID:{char_id}>"
|
||||
return _id2char_cache[char_id]
|
||||
|
||||
|
||||
def analyze_label_distribution(
|
||||
dataset: PinyinInputDataset,
|
||||
sample_size: int = 10000,
|
||||
query_engine: QueryEngine = None,
|
||||
):
|
||||
"""分析 label 在指定区间的分布"""
|
||||
target_ranges = [
|
||||
(0, 10),
|
||||
(200, 210),
|
||||
(2900, 2910),
|
||||
(5100, 5110),
|
||||
(9520, 9530),
|
||||
]
|
||||
|
||||
id_counts = defaultdict(int)
|
||||
all_examples = []
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=1, num_workers=16)
|
||||
|
||||
total_count = 0
|
||||
sample_collected = 0
|
||||
|
||||
for batch in dataloader:
|
||||
if sample_collected >= sample_size:
|
||||
break
|
||||
|
||||
label = batch["label"].item()
|
||||
prefix = batch["prefix"][0]
|
||||
suffix = batch["suffix"][0]
|
||||
pinyin = batch["pinyin"][0]
|
||||
history = batch["history_slot_ids"][0].tolist()
|
||||
part4 = prefix.split("^")[0] if "^" in prefix else ""
|
||||
|
||||
sample_collected += 1
|
||||
total_count += 1
|
||||
|
||||
in_target_range = False
|
||||
for start, end in target_ranges:
|
||||
if start <= label <= end:
|
||||
id_counts[label] += 1
|
||||
in_target_range = True
|
||||
|
||||
if len(all_examples) < 200:
|
||||
label_char = (
|
||||
get_char_by_id(query_engine, label) if query_engine else f"<ID:{label}>"
|
||||
)
|
||||
history_chars = (
|
||||
[get_char_by_id(query_engine, hid) for hid in history]
|
||||
if query_engine
|
||||
else history
|
||||
)
|
||||
all_examples.append(
|
||||
{
|
||||
"label": label,
|
||||
"label_char": label_char,
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"pinyin": pinyin,
|
||||
"history": history,
|
||||
"history_chars": history_chars,
|
||||
"part4": part4,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"采样总数: {total_count}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
print("Label ID 分布统计:")
|
||||
print("-" * 50)
|
||||
for start, end in target_ranges:
|
||||
print(f"\n区间 [{start:5d} - {end:5d}]:")
|
||||
for label_id in range(start, end + 1):
|
||||
count = id_counts[label_id]
|
||||
percentage = (count / total_count * 100) if total_count > 0 else 0
|
||||
bar = "█" * min(count, 50)
|
||||
print(f" ID {label_id:5d}: {count:6d} ({percentage:.3f}%) {bar}")
|
||||
|
||||
print("\n")
|
||||
print("=" * 80)
|
||||
print("随机抽取 20 个样本详情:")
|
||||
print("=" * 80)
|
||||
|
||||
random.shuffle(all_examples)
|
||||
for idx, ex in enumerate(all_examples[:20], 1):
|
||||
print(f"\n样本 {idx}:")
|
||||
print(f" Label: {ex['label']} ({ex['label_char']})")
|
||||
print(f" Part4: {ex['part4']}")
|
||||
print(f" 光标前: {ex['prefix']}")
|
||||
print(f" 光标后: {ex['suffix']}")
|
||||
print(f" 拼音: {ex['pinyin']}")
|
||||
print(f" 历史槽位: {ex['history']}")
|
||||
print(f" 历史汉字: {ex['history_chars']}")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
default_data_path = os.path.expanduser("~/Data/corpus/CCI-Data/")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="分析数据集 label 分布",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=f"""
|
||||
默认数据集路径: {default_data_path}
|
||||
|
||||
示例:
|
||||
uv run python src/analyze_data.py --data_path {default_data_path} --sample_size 10000
|
||||
uv run python src/analyze_data.py --data_path ./data/eval.txt --sample_size 5000
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
default=default_data_path,
|
||||
help="数据集路径 (本地文件或HuggingFace路径)",
|
||||
)
|
||||
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
|
||||
parser.add_argument(
|
||||
"--max_workers", type=int, default=-1, help="DataLoader workers"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"加载数据集: {args.data_path}")
|
||||
print()
|
||||
|
||||
try:
|
||||
dataset = PinyinInputDataset(
|
||||
data_path=args.data_path,
|
||||
max_workers=args.max_workers,
|
||||
max_iter_length=args.sample_size,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\n错误: 无法加载数据集 '{args.data_path}'")
|
||||
print(f"原因: {e}")
|
||||
print("\n请确保:")
|
||||
print(" 1. 数据路径存在且正确")
|
||||
print(" 2. 如果是本地文本文件,每行应该是一个 JSON 对象或纯文本")
|
||||
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
|
||||
return
|
||||
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
analyze_label_distribution(
|
||||
dataset, sample_size=args.sample_size, query_engine=query_engine
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"id": 0,
|
||||
"char": "",
|
||||
"pinyin": "",
|
||||
"count": 11067734826
|
||||
"count": 494748360
|
||||
},
|
||||
"1": {
|
||||
"id": 1,
|
||||
|
|
|
|||
|
|
@ -53,27 +53,7 @@ class PinyinLSTMEncoder(nn.Module):
|
|||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: [batch, seq_len, input_dim] pinyin embeddings
|
||||
mask: [batch, seq_len] optional padding mask (True for valid, False for padding)
|
||||
Returns:
|
||||
output: [batch, seq_len, input_dim] 每位置的拼音编码
|
||||
"""
|
||||
total_len = x.size(1)
|
||||
|
||||
if mask is not None:
|
||||
lengths = mask.sum(dim=1).cpu().clamp(min=1)
|
||||
packed = nn.utils.rnn.pack_padded_sequence(
|
||||
x, lengths, batch_first=True, enforce_sorted=False
|
||||
)
|
||||
packed_out, (hidden, cell) = self.lstm(packed)
|
||||
output, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
packed_out, batch_first=True, total_length=total_len
|
||||
)
|
||||
else:
|
||||
output, (hidden, cell) = self.lstm(x)
|
||||
|
||||
output, _ = self.lstm(x)
|
||||
projected = self.proj(output)
|
||||
return self.layer_norm(projected)
|
||||
|
||||
|
|
@ -341,15 +321,37 @@ class CrossAttentionFusion(nn.Module):
|
|||
# 4. 专家混合层 (MoE Layer)
|
||||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.compiler.allow_in_graph
|
||||
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
|
||||
output = torch.zeros_like(x_flat)
|
||||
for e in range(num_experts):
|
||||
mask = topk_indices == e
|
||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, idx, (w * experts[e](x_flat[idx])).to(output.dtype))
|
||||
return output
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
|
||||
"""
|
||||
moe_mode 支持三种策略:
|
||||
- "all": 计算全部专家,torch.compile 不断裂 (当前默认)
|
||||
- "sparse": 只计算被路由到的专家 (产生 graph break)
|
||||
- "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.dim = dim
|
||||
self.moe_mode = moe_mode
|
||||
|
||||
# Import Expert from your existing components
|
||||
# Assuming Expert class is defined as in components.py [2]
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
|
|
@ -362,48 +364,45 @@ class MoELayer(nn.Module):
|
|||
]
|
||||
)
|
||||
|
||||
# Gating Network [2]
|
||||
self.gate = nn.Linear(dim, num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||
|
||||
Args:
|
||||
x: [batch, seq_len, dim]
|
||||
Returns:
|
||||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
num_tokens = B * L
|
||||
x_flat = x.view(num_tokens, D)
|
||||
|
||||
# 展平输入以便处理
|
||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||
gates = self.gate(x_flat)
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
|
||||
topk_weights = F.softmax(topk_weights, dim=-1)
|
||||
|
||||
# 1. 计算门控分数
|
||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||
|
||||
# 2. 选择 Top-K 专家
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||
|
||||
# 归一化权重
|
||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||
|
||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||
if self.moe_mode == "all":
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
) # [B*L, num_experts, D]
|
||||
)
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
|
||||
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
|
||||
out_flat = weighted_outputs.sum(dim=1)
|
||||
|
||||
# 4. 使用 gather 选择对应专家的输出
|
||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||
selected_outputs = torch.gather(
|
||||
expert_outputs, 1, indices_expanded
|
||||
) # [B*L, K, D]
|
||||
# 5. 加权求和
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||
elif self.moe_mode == "sparse":
|
||||
out_flat = torch.zeros_like(x_flat)
|
||||
for e in range(self.num_experts):
|
||||
mask = topk_indices == e
|
||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
out_flat.index_add_(
|
||||
0,
|
||||
idx,
|
||||
(w * self.experts[e](x_flat[idx])).to(out_flat.dtype),
|
||||
)
|
||||
|
||||
elif self.moe_mode == "sparse_allow_graph":
|
||||
out_flat = _sparse_moe_dispatch(
|
||||
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
|
||||
|
||||
# 恢复原始形状
|
||||
return out_flat.view(B, L, D)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
||||
|
||||
import jieba
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -15,7 +20,6 @@ from torch.utils.data import IterableDataset
|
|||
|
||||
from .query import QueryEngine
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||
|
|
@ -23,6 +27,26 @@ CHAR_TO_ID["'"] = 28 # 显式添加引号
|
|||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||||
|
||||
|
||||
jieba.setLogLevel(jieba.logging.INFO)
|
||||
|
||||
|
||||
def segment_text(text: str) -> List[str]:
|
||||
"""使用 jieba 分词,返回词列表"""
|
||||
return list(jieba.cut(text, HMM=False))
|
||||
|
||||
|
||||
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
|
||||
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
|
||||
result = []
|
||||
pos = 0
|
||||
for word in words:
|
||||
start = pos
|
||||
end = pos + len(word)
|
||||
result.append((start, end))
|
||||
pos = end
|
||||
return result
|
||||
|
||||
|
||||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||
"""
|
||||
将拼音字符串转换为 ID 列表。
|
||||
|
|
@ -43,15 +67,31 @@ class PinyinInputDataset(IterableDataset):
|
|||
text_field: str = "text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 100000,
|
||||
retention_ratio: float = 0.5,
|
||||
retention_ratio: float = 0.8,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
merge_short_words_prob: float = 0.5,
|
||||
merge_max_short_words: int = 3,
|
||||
merge_max_total_chars: int = 6,
|
||||
low_freq_repeat: float = 50.0,
|
||||
high_freq_repeat: float = 0.1,
|
||||
data_kwargs: Optional[Dict] = None,
|
||||
target_labels: Optional[Set[int]] = None,
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 30_000_000
|
||||
self.max_drop_prob = 0.8
|
||||
self.repeat_end_freq = 10_000
|
||||
self.max_repeat_expect = 50
|
||||
# 频率调整参数 - 幂律平滑方案
|
||||
self.min_freq = 109
|
||||
self.low_freq_repeat = low_freq_repeat
|
||||
self.high_freq_repeat = high_freq_repeat
|
||||
self.word_break_prob = 0.10
|
||||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
self.merge_short_words_prob = merge_short_words_prob
|
||||
self.merge_max_short_words = merge_max_short_words
|
||||
self.merge_max_total_chars = merge_max_total_chars
|
||||
|
||||
self.data_kwargs = data_kwargs or {}
|
||||
self.target_labels = target_labels
|
||||
|
||||
jieba.initialize()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(files(__package__))) / "assets" / "tokenizer"
|
||||
|
|
@ -61,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.max_iter_length = max_iter_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.text_field = text_field
|
||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
load_kwargs = {"split": "train", "streaming": True}
|
||||
load_kwargs.update(self.data_kwargs)
|
||||
self.dataset = load_dataset(data_path, **load_kwargs)
|
||||
self.max_workers = max_workers
|
||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
|
|
@ -83,57 +125,48 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
# 提取每个样本的目标字符及其频率
|
||||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
||||
|
||||
# 计算幂律平滑参数
|
||||
if self.max_freq > self.min_freq:
|
||||
self.alpha = math.log(
|
||||
self.low_freq_repeat / self.high_freq_repeat
|
||||
) / math.log(self.max_freq / self.min_freq)
|
||||
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
|
||||
else:
|
||||
self.alpha = 0.0
|
||||
self.C = 1.0
|
||||
|
||||
def adjust_frequency(self, freq: int) -> int:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||
# 1. 削峰处理(高频字)
|
||||
if freq >= self.drop_start_freq:
|
||||
# 线性丢弃概率计算
|
||||
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
|
||||
if max_freq == self.drop_start_freq:
|
||||
drop_prob = 0.0
|
||||
else:
|
||||
drop_prob = (
|
||||
self.max_drop_prob
|
||||
* (freq - self.drop_start_freq)
|
||||
/ (max_freq - self.drop_start_freq)
|
||||
)
|
||||
if random.random() < drop_prob:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃
|
||||
使用幂律平滑方案:E(freq) = C × freq^(-α)
|
||||
保持频率排序关系,单个连续函数
|
||||
"""
|
||||
if freq <= 0:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
# 2. 填谷处理(低频字)
|
||||
elif freq <= self.repeat_end_freq:
|
||||
# 线性重复期望计算
|
||||
if freq <= self.min_freq:
|
||||
repeat_expect = self.max_repeat_expect
|
||||
else:
|
||||
if self.repeat_end_freq == self.min_freq:
|
||||
repeat_expect = 0
|
||||
else:
|
||||
repeat_expect = (
|
||||
self.max_repeat_expect
|
||||
* (self.repeat_end_freq - freq)
|
||||
/ (self.repeat_end_freq - self.min_freq)
|
||||
)
|
||||
# 使用泊松分布实现随机重复
|
||||
repeat_count = np.random.poisson(repeat_expect)
|
||||
# 计算期望采样次数
|
||||
expected = self.C * (freq ** (-self.alpha))
|
||||
|
||||
# 采样策略
|
||||
if expected >= 1.0:
|
||||
# 泊松分布重复
|
||||
repeat_count = np.random.poisson(expected)
|
||||
return max(1, repeat_count)
|
||||
|
||||
# 3. 中间频率字
|
||||
else:
|
||||
return 1
|
||||
# 伯努利采样:以概率expected返回1,否则返回0
|
||||
return 1 if random.random() < expected else 0
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
"""
|
||||
流式处理单条文本,转换为拼音列表。
|
||||
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||
利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。
|
||||
|
||||
特性:
|
||||
1. 严格一一对应:len(result) == len(text)
|
||||
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
|
||||
3. 高性能:预分配内存,无多余对象创建
|
||||
2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音
|
||||
3. 非汉字字符原样占位
|
||||
|
||||
Args:
|
||||
text: 输入字符串
|
||||
|
|
@ -144,58 +177,130 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not text:
|
||||
return []
|
||||
|
||||
text_len = len(text)
|
||||
# 2. 预分配结果列表,初始化占位符。
|
||||
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
|
||||
result: List[str] = [""] * text_len
|
||||
|
||||
# 3. 遍历所有连续汉字片段
|
||||
for match in _HANZI_RE.finditer(text):
|
||||
start_idx = match.start()
|
||||
hanzi_segment = match.group()
|
||||
|
||||
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
|
||||
# style=Style.NORMAL 获取不带声调的拼音
|
||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||
|
||||
# 5. 健壮性兜底:
|
||||
# 正常情况下,pypinyin 返回的拼音数应等于汉字数。
|
||||
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
|
||||
if len(pinyin_list) != len(hanzi_segment):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||
|
||||
# 6. 直接通过索引填充到预分配的位置
|
||||
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
|
||||
for i, py in enumerate(pinyin_list):
|
||||
result[start_idx + i] = py
|
||||
|
||||
# 7. 填充非汉字字符
|
||||
# 遍历原文,如果 result 对应位置为空,则填入原字符
|
||||
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
|
||||
for i, char in enumerate(text):
|
||||
if not result[i]:
|
||||
result[i] = char
|
||||
|
||||
def _fallback(chars):
|
||||
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
|
||||
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
|
||||
result = []
|
||||
for char in chars:
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
ids = self.query_engine.query_by_char(char, limit=1)
|
||||
if ids:
|
||||
result.append(ids[0][1])
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
result.append(char)
|
||||
return result
|
||||
|
||||
# 生成需要预测汉字对应的拼音,并进行加强
|
||||
pinyin_list = lazy_pinyin(text, errors=_fallback)
|
||||
|
||||
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
|
||||
if len(pinyin_list) != len(text):
|
||||
logger.warning(
|
||||
f"pinyin length mismatch: text_len={len(text)}, "
|
||||
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
|
||||
)
|
||||
pinyin_list = []
|
||||
for c in text:
|
||||
result = lazy_pinyin(c, errors=_fallback)
|
||||
pinyin_list.append(result[0] if result else c)
|
||||
|
||||
return pinyin_list
|
||||
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
) -> Tuple[int, List[str]]:
|
||||
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
|
||||
style = random.random()
|
||||
cumulative = 0.0
|
||||
style_idx = 0
|
||||
for i, w in enumerate(self.py_style_weight):
|
||||
cumulative += w
|
||||
if style < cumulative:
|
||||
style_idx = i
|
||||
break
|
||||
|
||||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
break
|
||||
else:
|
||||
py = np.random.choice(
|
||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||
p=self.py_style_weight,
|
||||
)
|
||||
full_py = pinyin_list[i]
|
||||
if style_idx == 0:
|
||||
py = full_py
|
||||
elif style_idx == 1:
|
||||
py = to_initials(full_py)
|
||||
if py == "":
|
||||
py = pinyin_list[i][0]
|
||||
py = full_py[0]
|
||||
else:
|
||||
py = full_py[0]
|
||||
mask_pinyin.append(py)
|
||||
return len(mask_pinyin), mask_pinyin
|
||||
|
||||
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
|
||||
pinyin_ids = text_to_pinyin_ids(pinyin_str)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
return torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
def _build_single_sample(
|
||||
self,
|
||||
label: int,
|
||||
history: list,
|
||||
text: str,
|
||||
word_start: int,
|
||||
word_end: int,
|
||||
part2: str,
|
||||
pinyin_ids: torch.Tensor,
|
||||
words: list,
|
||||
) -> dict:
|
||||
"""构造单条样本,每次调用都会重新随机采样上下文"""
|
||||
|
||||
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
|
||||
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
|
||||
part1 = text[word_start - part1_len : word_start]
|
||||
|
||||
# part3:每次重新 roll
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||||
|
||||
# part4:每次重新 roll
|
||||
part4 = ""
|
||||
if random.random() > 0.7 and words:
|
||||
num_words = random.randint(1, 3)
|
||||
selected_words = random.sample(words, min(num_words, len(words)))
|
||||
part4 = "|".join(selected_words)
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
truncation=True,
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
# 确保 history 长度为 8
|
||||
hist = list(history)
|
||||
if len(hist) > 8:
|
||||
hist = hist[-8:]
|
||||
while len(hist) < 8:
|
||||
hist.append(0)
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
|
||||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
|
|
@ -208,219 +313,295 @@ class PinyinInputDataset(IterableDataset):
|
|||
random.seed(seed % (2**32))
|
||||
np.random.seed(seed % (2**32))
|
||||
|
||||
# 安全检查:如果worker_id >= num_workers,则该worker不应该工作
|
||||
# 这可能发生在self.max_workers小于实际worker数量时
|
||||
if worker_id >= num_workers:
|
||||
return # 产生空迭代器
|
||||
return
|
||||
|
||||
# 使用局部变量存储分片数据集,避免竞争条件
|
||||
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||
try:
|
||||
worker_dataset = self.dataset.shard(
|
||||
num_shards=num_workers, index=worker_id
|
||||
)
|
||||
except (IndexError, ValueError):
|
||||
worker_dataset = self.dataset
|
||||
|
||||
# 计算每个worker的配额
|
||||
# 将 max_iter_length 转换为整数以确保整数除法
|
||||
total_quota = int(self.max_iter_length)
|
||||
base_quota = total_quota // num_workers
|
||||
remainder = total_quota % num_workers
|
||||
|
||||
# 最后一个worker处理剩余的样本(如果有余数)
|
||||
if worker_id == num_workers - 1:
|
||||
worker_quota = base_quota + remainder
|
||||
else:
|
||||
worker_quota = base_quota
|
||||
else:
|
||||
# 单worker情况,使用全部配额
|
||||
worker_quota = int(self.max_iter_length)
|
||||
num_workers = 1
|
||||
worker_dataset = self.dataset # 不使用分片
|
||||
worker_dataset = self.dataset
|
||||
|
||||
# 每个worker有自己的迭代计数器
|
||||
current_iter_index = 0
|
||||
|
||||
batch_samples = []
|
||||
for sample in worker_dataset:
|
||||
# 检查是否达到最大迭代次数
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
text = sample.get(self.text_field, "")
|
||||
if text:
|
||||
if not text:
|
||||
continue
|
||||
|
||||
words = segment_text(text)
|
||||
word_boundaries = build_word_boundaries(words)
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
for i in range(len(text)):
|
||||
# 在开始处理每个字符前检查配额
|
||||
if current_iter_index >= worker_quota:
|
||||
break
|
||||
|
||||
labels = []
|
||||
# 如果text[i]不在字符库中,则跳过
|
||||
# 当i小于48时候,则将part1取text[0:i]
|
||||
# 当i大于48时候,则将part1取text[i-48:i]
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
continue
|
||||
if i < 48:
|
||||
part1 = text[0:i]
|
||||
else:
|
||||
part1 = text[i - 48 : i]
|
||||
idx = 0
|
||||
while idx < len(word_boundaries):
|
||||
word_start, word_end = word_boundaries[idx]
|
||||
|
||||
# 方案C:提前检查从位置i开始连续有多少个字符在词库中
|
||||
max_valid_len = 0
|
||||
for j in range(i, min(i + 8, len(text))):
|
||||
if self.query_engine.is_chinese_char(text[j]):
|
||||
max_valid_len += 1
|
||||
else:
|
||||
break
|
||||
char_positions = []
|
||||
for i in range(word_start, word_end):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
char_positions.append(i)
|
||||
|
||||
# 如果没有可用字符,跳过
|
||||
if max_valid_len == 0:
|
||||
if not char_positions:
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||
# 但是需要注意边界条件
|
||||
target_len = np.random.choice(
|
||||
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
)
|
||||
# 根据实际可用长度调整
|
||||
pinyin_len = min(target_len, max_valid_len)
|
||||
word_len_chars = len(char_positions)
|
||||
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
pinyin_len, part2 = self.get_mask_pinyin(
|
||||
text[i:py_end], pinyin_list[i:py_end]
|
||||
merge_end_idx = idx + 1
|
||||
if word_len_chars <= 2:
|
||||
accumulated_positions = list(char_positions)
|
||||
accumulated_count = 1
|
||||
next_idx = idx + 1
|
||||
|
||||
while next_idx < len(word_boundaries):
|
||||
ns, ne = word_boundaries[next_idx]
|
||||
next_positions = []
|
||||
for i in range(ns, ne):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
next_positions.append(i)
|
||||
next_len = len(next_positions)
|
||||
|
||||
if next_len == 0 or next_len > 2:
|
||||
break
|
||||
if (
|
||||
len(accumulated_positions) + next_len
|
||||
> self.merge_max_total_chars
|
||||
):
|
||||
break
|
||||
if accumulated_count + 1 > self.merge_max_short_words:
|
||||
break
|
||||
if random.random() > self.merge_short_words_prob:
|
||||
break
|
||||
|
||||
accumulated_positions.extend(next_positions)
|
||||
accumulated_count += 1
|
||||
next_idx += 1
|
||||
|
||||
if accumulated_count > 1:
|
||||
char_positions = accumulated_positions
|
||||
word_len_chars = len(char_positions)
|
||||
merge_end_idx = next_idx
|
||||
word_start = word_boundaries[idx][0]
|
||||
word_end = word_boundaries[next_idx - 1][1]
|
||||
|
||||
should_break = (
|
||||
word_len_chars > 1 and random.random() < self.word_break_prob
|
||||
)
|
||||
|
||||
split_char = np.random.choice(
|
||||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||
)
|
||||
|
||||
part2 = split_char.join(part2)
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
pinyin_ids.extend([0] * (24 - len_py))
|
||||
if should_break:
|
||||
break_pos = random.randint(1, word_len_chars - 1)
|
||||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
break_pos = word_len_chars
|
||||
|
||||
# part3为文本,大概率(0.70)为空
|
||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||
# x为1-16中的任意整数,取值平均分布
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[
|
||||
i + pinyin_len : i
|
||||
+ pinyin_len
|
||||
+ np.random.choice(range(1, 17))
|
||||
]
|
||||
# ========== Phase 1: 前缀/整词 ==========
|
||||
prefix_positions = char_positions[:break_pos]
|
||||
prefix_text = "".join(text[i] for i in prefix_positions)
|
||||
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
||||
|
||||
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
||||
r = random.random()
|
||||
if r < 0.9:
|
||||
split_char = ""
|
||||
elif r < 0.94:
|
||||
split_char = "`"
|
||||
elif r < 0.98:
|
||||
split_char = "'"
|
||||
else:
|
||||
split_char = "-"
|
||||
part2 = split_char.join(mask_pinyin)
|
||||
pinyin_ids = self._compute_pinyin_ids(part2)
|
||||
|
||||
# part4为文本,0.30的概率为空
|
||||
# 不为空则为1-5个连续字符串
|
||||
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
||||
# x为2-6中的任意整数,取值平均分布
|
||||
# 使用|将part4中的字符串连接起来
|
||||
part4 = ""
|
||||
if random.random() > 0.7:
|
||||
# 生成1-5个连续字符串
|
||||
num_strings = random.randint(1, 5)
|
||||
string_list = []
|
||||
for _ in range(num_strings):
|
||||
# 随机选择起始位置
|
||||
start_pos = random.randint(0, len(text) - 1)
|
||||
# 随机选择x的值(2-6)
|
||||
x = random.randint(2, 6)
|
||||
# 获取连续字符串
|
||||
end_pos = min(start_pos + x + 1, len(text))
|
||||
string_list.append(text[start_pos:end_pos])
|
||||
# 用|连接所有字符串
|
||||
part4 = "|".join(string_list)
|
||||
try:
|
||||
labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||
for c, p in zip(
|
||||
text[i : i + pinyin_len],
|
||||
pinyin_list[i : i + pinyin_len],
|
||||
)
|
||||
self.query_engine.get_char_info_by_char_pinyin(
|
||||
text[i], pinyin_list[i]
|
||||
).id
|
||||
for i in prefix_positions
|
||||
]
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}"
|
||||
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
||||
)
|
||||
idx = merge_end_idx
|
||||
continue
|
||||
if random.random() <= 0.1:
|
||||
|
||||
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
||||
if not should_break and random.random() <= 0.1:
|
||||
labels.append(0)
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
samples = []
|
||||
# 历史槽位长度权重:增加长历史采样比例
|
||||
# 目标分布: H=0-2占45%, H=3-8占55%
|
||||
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
|
||||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||
processed_history = []
|
||||
for label_idx, label in enumerate(labels):
|
||||
base_repeats = self.adjust_frequency(label)
|
||||
# 根据历史槽位长度调整采样次数
|
||||
base_repeats = self.adjust_frequency(
|
||||
self.sample_freqs.get(label, 0)
|
||||
)
|
||||
if base_repeats == 0:
|
||||
processed_history.append(label)
|
||||
continue
|
||||
if (
|
||||
self.target_labels is not None
|
||||
and label not in self.target_labels
|
||||
):
|
||||
processed_history.append(label)
|
||||
continue
|
||||
|
||||
weight = (
|
||||
history_weights[label_idx]
|
||||
if label_idx < len(history_weights)
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
|
||||
masked_labels = labels[:label_idx]
|
||||
len_l = len(masked_labels)
|
||||
masked_labels.extend([0] * (8 - len_l))
|
||||
|
||||
samples.extend(
|
||||
[
|
||||
{
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(
|
||||
masked_labels, dtype=torch.long
|
||||
),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
]
|
||||
* repeats
|
||||
for _ in range(repeats):
|
||||
sample = self._build_single_sample(
|
||||
label=label,
|
||||
history=processed_history,
|
||||
text=text,
|
||||
word_start=word_start,
|
||||
word_end=word_end,
|
||||
part2=part2,
|
||||
pinyin_ids=pinyin_ids,
|
||||
words=words,
|
||||
)
|
||||
batch_samples.append(sample)
|
||||
|
||||
# 添加到缓冲区
|
||||
batch_samples.extend(samples)
|
||||
processed_history.append(label)
|
||||
|
||||
# ========== Phase 2: 破词续接 ==========
|
||||
if should_break and break_pos < word_len_chars:
|
||||
cont_start = char_positions[break_pos]
|
||||
|
||||
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
||||
cont_r = random.random()
|
||||
cont_probs = self.cont_length_probs
|
||||
cont_cumulative = 0.0
|
||||
target_len = 4
|
||||
for cont_len, cont_p in enumerate(cont_probs):
|
||||
cont_cumulative += cont_p
|
||||
if cont_r < cont_cumulative:
|
||||
target_len = cont_len + 1
|
||||
break
|
||||
cont_positions = []
|
||||
pos = cont_start
|
||||
while len(cont_positions) < target_len and pos < len(text):
|
||||
if self.query_engine.is_chinese_char(text[pos]):
|
||||
cont_positions.append(pos)
|
||||
else:
|
||||
break
|
||||
pos += 1
|
||||
|
||||
if not cont_positions:
|
||||
continue
|
||||
|
||||
cont_text = "".join(text[i] for i in cont_positions)
|
||||
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
||||
|
||||
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
||||
r2 = random.random()
|
||||
if r2 < 0.9:
|
||||
split_char_cont = ""
|
||||
elif r2 < 0.94:
|
||||
split_char_cont = "`"
|
||||
elif r2 < 0.98:
|
||||
split_char_cont = "'"
|
||||
else:
|
||||
split_char_cont = "-"
|
||||
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
||||
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
||||
|
||||
try:
|
||||
cont_labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(
|
||||
text[i], pinyin_list[i]
|
||||
).id
|
||||
for i in cont_positions
|
||||
]
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
||||
)
|
||||
idx = merge_end_idx
|
||||
continue
|
||||
|
||||
# 续接末尾 10% 概率追加 EOS
|
||||
if random.random() <= 0.1:
|
||||
cont_labels.append(0)
|
||||
|
||||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||
cont_processed_history = []
|
||||
cont_end = cont_positions[-1] + 1
|
||||
for label_idx, label in enumerate(cont_labels):
|
||||
base_repeats = self.adjust_frequency(
|
||||
self.sample_freqs.get(label, 0)
|
||||
)
|
||||
if base_repeats == 0:
|
||||
cont_processed_history.append(label)
|
||||
continue
|
||||
if (
|
||||
self.target_labels is not None
|
||||
and label not in self.target_labels
|
||||
):
|
||||
cont_processed_history.append(label)
|
||||
continue
|
||||
|
||||
weight = (
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
for _ in range(repeats):
|
||||
sample = self._build_single_sample(
|
||||
label=label,
|
||||
history=cont_processed_history,
|
||||
text=text,
|
||||
word_start=cont_start,
|
||||
word_end=cont_end,
|
||||
part2=part2_cont,
|
||||
pinyin_ids=pinyin_ids_cont,
|
||||
words=words,
|
||||
)
|
||||
batch_samples.append(sample)
|
||||
|
||||
cont_processed_history.append(label)
|
||||
|
||||
idx = merge_end_idx
|
||||
|
||||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
# 全量打乱缓冲区
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
|
||||
# 计算实际保留大小(不超过缓冲区大小)
|
||||
actual_retention = min(self.retention_size, len(batch_samples))
|
||||
|
||||
# 计算输出数量
|
||||
output_count = len(batch_samples) - actual_retention
|
||||
|
||||
# 输出前output_count个样本
|
||||
for i in range(output_count):
|
||||
if current_iter_index >= worker_quota:
|
||||
# 配额用完,清空缓冲区并返回
|
||||
batch_samples = []
|
||||
return
|
||||
yield batch_samples[indices[i]]
|
||||
current_iter_index += 1
|
||||
|
||||
# 保留后actual_retention个样本(不清空缓冲区)
|
||||
retained_samples = [
|
||||
batch_samples[idx] for idx in indices[output_count:]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
ONNX导出专用组件
|
||||
|
||||
为了支持ONNX导出,对原始组件进行修改:
|
||||
1. 移除packed sequence操作
|
||||
2. 处理动态形状问题
|
||||
3. 确保所有操作符都ONNX兼容
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ExportPinyinLSTMEncoder(nn.Module):
|
||||
"""
|
||||
ONNX兼容的拼音LSTM编码器
|
||||
简化版本,不使用packed sequence
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_dim,
|
||||
hidden_size=self.hidden_dim,
|
||||
num_layers=num_layers,
|
||||
bidirectional=True,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
)
|
||||
|
||||
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
ONNX兼容的前向传播
|
||||
不使用packed sequence,改为使用masked计算
|
||||
"""
|
||||
# 简化:直接使用LSTM,不处理padding
|
||||
# 在ONNX中,处理可变长度序列较复杂
|
||||
# 对于输入法场景,拼音长度固定为24,所以可以这样处理
|
||||
output, (hidden, cell) = self.lstm(x)
|
||||
projected = self.proj(output)
|
||||
return self.layer_norm(projected)
|
||||
|
||||
|
||||
class ExportContextEncoder(nn.Module):
|
||||
"""
|
||||
ONNX兼容的上下文编码器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, vocab_size, pinyin_vocab_size, dim=512, n_layers=4, n_heads=4, max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
|
||||
# 使用原始text_emb,但需要确保它支持ONNX
|
||||
from modelscope import AutoModel
|
||||
|
||||
self.text_emb = AutoModel.from_pretrained(
|
||||
"iic/nlp_structbert_backbone_lite_std"
|
||||
).embeddings
|
||||
|
||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||
self.pos_emb = nn.Embedding(max_len, dim)
|
||||
self.pinyin_pooling = ExportPinyinLSTMEncoder(dim)
|
||||
|
||||
# Transformer Encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=dim,
|
||||
nhead=n_heads,
|
||||
dim_feedforward=dim * 4,
|
||||
dropout=0.1,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
||||
self.ln = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, text_ids, pinyin_ids, mask=None):
|
||||
"""
|
||||
ONNX兼容的前向传播
|
||||
"""
|
||||
# 文本嵌入
|
||||
text_emb = self.text_emb(text_ids) # [B, seq_len, dim]
|
||||
|
||||
# 位置编码 - 使用预计算的pos_ids
|
||||
seq_len = text_emb.size(1)
|
||||
# 使用静态最大长度,避免动态arange
|
||||
if seq_len > self.max_len:
|
||||
seq_len = self.max_len
|
||||
|
||||
# 创建位置ID,确保在导出时是静态的
|
||||
# 使用torch.full而不是torch.arange,因为arange在动态形状下可能有问题
|
||||
pos_ids = torch.arange(
|
||||
seq_len, device=text_ids.device, dtype=torch.long
|
||||
).unsqueeze(0)
|
||||
|
||||
# 如果实际序列长度小于最大长度,截取位置嵌入
|
||||
if seq_len < self.max_len:
|
||||
x = text_emb[:, :seq_len, :] + self.pos_emb(pos_ids)
|
||||
else:
|
||||
x = text_emb + self.pos_emb(pos_ids)
|
||||
|
||||
# 处理mask
|
||||
if mask is not None:
|
||||
# 确保mask是bool类型
|
||||
src_mask = (mask == 0).to(torch.bool)
|
||||
else:
|
||||
src_mask = None
|
||||
|
||||
# Transformer
|
||||
H = self.transformer(x, src_key_padding_mask=src_mask)
|
||||
H = self.ln(H)
|
||||
|
||||
# 拼音编码
|
||||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||||
# 简化:不传递mask给LSTM
|
||||
P = self.pinyin_pooling(pinyin_emb) # [B, pinyin_len, dim]
|
||||
|
||||
return H, P
|
||||
|
||||
|
||||
def create_export_context_encoder(original_context_encoder):
|
||||
"""
|
||||
从原始ContextEncoder创建导出版本
|
||||
"""
|
||||
# 获取原始配置
|
||||
config = {
|
||||
"vocab_size": getattr(original_context_encoder, "vocab_size", 10019),
|
||||
"pinyin_vocab_size": original_context_encoder.pinyin_emb.num_embeddings,
|
||||
"dim": original_context_encoder.dim,
|
||||
"n_layers": original_context_encoder.transformer.num_layers,
|
||||
"n_heads": original_context_encoder.transformer.layers[0].self_attn.num_heads,
|
||||
"max_len": original_context_encoder.pos_emb.num_embeddings,
|
||||
}
|
||||
|
||||
# 创建导出版本
|
||||
export_encoder = ExportContextEncoder(**config)
|
||||
|
||||
# 复制权重
|
||||
# 复制text_emb权重(从原始AutoModel embeddings)
|
||||
# 注意:这里假设原始text_emb的结构
|
||||
state_dict = original_context_encoder.state_dict()
|
||||
export_state_dict = export_encoder.state_dict()
|
||||
|
||||
# 复制匹配的权重
|
||||
for key in export_state_dict:
|
||||
if key in state_dict:
|
||||
export_state_dict[key] = state_dict[key]
|
||||
|
||||
# 特殊处理:复制position embeddings
|
||||
if "pos_emb.weight" in export_state_dict and "pos_emb.weight" in state_dict:
|
||||
# 确保大小匹配
|
||||
orig_pos_emb = state_dict["pos_emb.weight"]
|
||||
export_pos_emb = export_state_dict["pos_emb.weight"]
|
||||
min_len = min(orig_pos_emb.size(0), export_pos_emb.size(0))
|
||||
export_state_dict["pos_emb.weight"][:min_len] = orig_pos_emb[:min_len]
|
||||
|
||||
export_encoder.load_state_dict(export_state_dict)
|
||||
|
||||
return export_encoder
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
"""
|
||||
ONNX导出模型定义
|
||||
|
||||
定义两个子模型用于ONNX导出:
|
||||
1. ContextEncoderExport: 输入文本和拼音,输出上下文编码
|
||||
2. DecoderExport: 输入上下文编码、拼音编码和槽位历史,输出logits
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# 修补modelscope以支持AutoModel导入
|
||||
import sys
|
||||
import modelscope
|
||||
|
||||
if not hasattr(modelscope, "AutoModel"):
|
||||
from modelscope import Model
|
||||
|
||||
modelscope.AutoModel = Model
|
||||
sys.modules["modelscope"].AutoModel = Model
|
||||
|
||||
from .components import (
|
||||
ContextEncoder,
|
||||
CrossAttentionFusion,
|
||||
MoELayer,
|
||||
SlotMemory,
|
||||
PinyinLSTMEncoder,
|
||||
)
|
||||
|
||||
|
||||
class ContextEncoderExport(nn.Module):
|
||||
"""
|
||||
上下文编码器导出模型
|
||||
输入: input_ids, pinyin_ids, attention_mask
|
||||
输出: context_H, pinyin_P, context_mask, pinyin_mask
|
||||
"""
|
||||
|
||||
def __init__(self, context_encoder: ContextEncoder):
|
||||
super().__init__()
|
||||
self.context_encoder = context_encoder
|
||||
self.dim = context_encoder.dim
|
||||
|
||||
# 创建ONNX兼容的拼音LSTM编码器(简化版,不使用packed sequence)
|
||||
# 复制原始LSTM的参数
|
||||
original_pooling = context_encoder.pinyin_pooling
|
||||
self.pinyin_lstm = nn.LSTM(
|
||||
input_size=original_pooling.input_dim,
|
||||
hidden_size=original_pooling.hidden_dim,
|
||||
num_layers=original_pooling.num_layers,
|
||||
bidirectional=True,
|
||||
batch_first=True,
|
||||
dropout=original_pooling.dropout
|
||||
if original_pooling.num_layers > 1
|
||||
else 0.0,
|
||||
)
|
||||
self.pinyin_proj = nn.Linear(
|
||||
original_pooling.hidden_dim * 2, original_pooling.input_dim
|
||||
)
|
||||
self.pinyin_ln = nn.LayerNorm(original_pooling.input_dim)
|
||||
|
||||
# 复制权重
|
||||
self.pinyin_lstm.load_state_dict(original_pooling.lstm.state_dict())
|
||||
self.pinyin_proj.load_state_dict(original_pooling.proj.state_dict())
|
||||
self.pinyin_ln.load_state_dict(original_pooling.layer_norm.state_dict())
|
||||
|
||||
def forward(self, input_ids, pinyin_ids, attention_mask):
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch_size, seq_len]
|
||||
pinyin_ids: [batch_size, pinyin_len] (pinyin_len固定为24)
|
||||
attention_mask: [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
context_H: [batch_size, seq_len, dim]
|
||||
pinyin_P: [batch_size, pinyin_len, dim]
|
||||
context_mask: [batch_size, seq_len] (bool -> int32)
|
||||
pinyin_mask: [batch_size, pinyin_len] (bool -> int32)
|
||||
"""
|
||||
# 获取原始context_encoder的组件
|
||||
text_emb = self.context_encoder.text_emb
|
||||
pinyin_emb = self.context_encoder.pinyin_emb
|
||||
pos_emb = self.context_encoder.pos_emb
|
||||
transformer = self.context_encoder.transformer
|
||||
ln = self.context_encoder.ln
|
||||
|
||||
# 文本嵌入
|
||||
text_emb_out = text_emb(input_ids) # [B, seq_len, dim]
|
||||
|
||||
# 位置编码 - 修复torch.arange动态形状问题
|
||||
seq_len = text_emb_out.size(1)
|
||||
# 使用预计算的位置ID,确保静态形状
|
||||
if seq_len > pos_emb.num_embeddings:
|
||||
seq_len = pos_emb.num_embeddings
|
||||
|
||||
# 创建位置ID - 使用torch.arange但确保是整数张量
|
||||
pos_ids = torch.arange(
|
||||
seq_len, device=input_ids.device, dtype=torch.long
|
||||
).unsqueeze(0)
|
||||
|
||||
# 如果实际序列长度小于最大长度,截取
|
||||
if seq_len < text_emb_out.size(1):
|
||||
x = text_emb_out[:, :seq_len, :] + pos_emb(pos_ids)
|
||||
else:
|
||||
x = text_emb_out + pos_emb(pos_ids)
|
||||
|
||||
# 处理mask - 确保bool类型
|
||||
if attention_mask is not None:
|
||||
src_mask = (attention_mask == 0).to(torch.bool)
|
||||
else:
|
||||
src_mask = None
|
||||
|
||||
# Transformer
|
||||
H = transformer(x, src_key_padding_mask=src_mask)
|
||||
H = ln(H)
|
||||
|
||||
# 恢复原始序列长度(如果需要)
|
||||
if seq_len < text_emb_out.size(1):
|
||||
# 填充H到原始长度
|
||||
original_seq_len = text_emb_out.size(1)
|
||||
padding = torch.zeros(
|
||||
H.size(0),
|
||||
original_seq_len - seq_len,
|
||||
H.size(2),
|
||||
device=H.device,
|
||||
dtype=H.dtype,
|
||||
)
|
||||
H = torch.cat([H, padding], dim=1)
|
||||
|
||||
# 拼音编码 - 使用ONNX兼容版本
|
||||
pinyin_emb_out = pinyin_emb(pinyin_ids) # [B, pinyin_len, dim]
|
||||
# 简化的LSTM,不使用packed sequence
|
||||
pinyin_lstm_out, _ = self.pinyin_lstm(pinyin_emb_out)
|
||||
pinyin_proj_out = self.pinyin_proj(pinyin_lstm_out)
|
||||
P = self.pinyin_ln(pinyin_proj_out)
|
||||
|
||||
# 生成mask(转换为int32以便ONNX支持)
|
||||
context_mask = (attention_mask == 0).to(torch.int32) # 1表示padding,0表示有效
|
||||
pinyin_mask = (pinyin_ids == 0).to(torch.int32) # 1表示padding,0表示有效
|
||||
|
||||
return H, P, context_mask, pinyin_mask
|
||||
|
||||
|
||||
class DecoderExport(nn.Module):
|
||||
"""
|
||||
解码器导出模型
|
||||
输入: context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
输出: logits
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
slot_memory: SlotMemory,
|
||||
cross_attn: CrossAttentionFusion,
|
||||
moe: MoELayer,
|
||||
slot_attention: nn.Module,
|
||||
classifier: nn.Module,
|
||||
num_slots: int = 8,
|
||||
dim: int = 512,
|
||||
):
|
||||
super().__init__()
|
||||
self.slot_memory = slot_memory
|
||||
self.cross_attn = cross_attn
|
||||
self.moe = moe
|
||||
self.slot_attention = slot_attention
|
||||
self.classifier = classifier
|
||||
self.num_slots = num_slots
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
"""
|
||||
Args:
|
||||
context_H: [batch_size, seq_len, dim]
|
||||
pinyin_P: [batch_size, pinyin_len, dim]
|
||||
history_slot_ids: [batch_size, num_slots]
|
||||
context_mask: [batch_size, seq_len] (int32, 1表示padding)
|
||||
pinyin_mask: [batch_size, pinyin_len] (int32, 1表示padding)
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, vocab_size]
|
||||
"""
|
||||
batch_size = context_H.size(0)
|
||||
|
||||
# 确保history_slot_ids形状正确
|
||||
history_slot_ids = history_slot_ids.view(batch_size, self.num_slots)
|
||||
|
||||
# 1. 槽位记忆
|
||||
S = self.slot_memory(history_slot_ids) # [batch_size, num_slots, dim]
|
||||
|
||||
# 2. 交叉注意力融合
|
||||
# 转换mask: int32 -> bool (ONNX中需要bool类型)
|
||||
context_mask_bool = context_mask.to(torch.bool)
|
||||
pinyin_mask_bool = pinyin_mask.to(torch.bool)
|
||||
|
||||
fused = self.cross_attn(
|
||||
S,
|
||||
context_H,
|
||||
pinyin_P,
|
||||
context_mask=context_mask_bool,
|
||||
pinyin_mask=pinyin_mask_bool,
|
||||
) # [batch_size, num_slots, dim]
|
||||
|
||||
# 3. MoE层
|
||||
moe_out = self.moe(fused) # [batch_size, num_slots, dim]
|
||||
|
||||
# 4. 槽位注意力池化
|
||||
slot_scores = self.slot_attention(moe_out).squeeze(
|
||||
-1
|
||||
) # [batch_size, num_slots]
|
||||
slot_weights = torch.softmax(slot_scores, dim=1) # [batch_size, num_slots]
|
||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch_size, dim]
|
||||
|
||||
# 5. 分类头
|
||||
logits = self.classifier(pooled) # [batch_size, vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def create_export_models_from_checkpoint(checkpoint_path, device="cpu"):
|
||||
"""
|
||||
从checkpoint创建导出模型
|
||||
|
||||
Args:
|
||||
checkpoint_path: 模型checkpoint路径
|
||||
device: 加载设备
|
||||
|
||||
Returns:
|
||||
context_encoder_export: ContextEncoderExport实例
|
||||
decoder_export: DecoderExport实例
|
||||
model_config: 模型配置字典
|
||||
"""
|
||||
# 加载原始模型配置
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# 提取模型配置(从checkpoint或使用默认值)
|
||||
if "config" in checkpoint:
|
||||
config = checkpoint["config"]
|
||||
else:
|
||||
# 使用模型默认配置
|
||||
config = {
|
||||
"vocab_size": 10019,
|
||||
"pinyin_vocab_size": 30,
|
||||
"dim": 512,
|
||||
"num_slots": 8,
|
||||
"n_layers": 4,
|
||||
"n_heads": 4,
|
||||
"num_experts": 10,
|
||||
"max_seq_len": 128,
|
||||
}
|
||||
|
||||
# 创建原始模型
|
||||
from .model import InputMethodEngine
|
||||
|
||||
model = InputMethodEngine(
|
||||
vocab_size=config.get("vocab_size", 10019),
|
||||
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
|
||||
dim=config.get("dim", 512),
|
||||
num_slots=config.get("num_slots", 8),
|
||||
n_layers=config.get("n_layers", 4),
|
||||
n_heads=config.get("n_heads", 4),
|
||||
num_experts=config.get("num_experts", 10),
|
||||
max_seq_len=config.get("max_seq_len", 128),
|
||||
compile=False,
|
||||
)
|
||||
|
||||
# 加载权重
|
||||
if "model_state_dict" in checkpoint:
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建导出模型
|
||||
context_encoder_export = ContextEncoderExport(model.context_encoder)
|
||||
decoder_export = DecoderExport(
|
||||
slot_memory=model.slot_memory,
|
||||
cross_attn=model.cross_attn,
|
||||
moe=model.moe,
|
||||
slot_attention=model.slot_attention,
|
||||
classifier=model.classifier,
|
||||
num_slots=model.num_slots,
|
||||
dim=model.dim,
|
||||
)
|
||||
|
||||
# 设置为评估模式
|
||||
context_encoder_export.eval()
|
||||
decoder_export.eval()
|
||||
|
||||
return context_encoder_export, decoder_export, config
|
||||
|
|
@ -0,0 +1,468 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
预处理数据质量分析脚本
|
||||
|
||||
功能:
|
||||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
||||
2. 统计 history_slot_ids 的有效长度分布
|
||||
3. 抽样还原为人类可读文本,导出为 CSV 文件
|
||||
|
||||
用法:
|
||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from tqdm import tqdm
|
||||
|
||||
from .char_info import CharInfo
|
||||
from .dataset import CHAR_TO_ID
|
||||
from .preprocessed_dataset import PreProcessedDataset
|
||||
from .query import QueryEngine
|
||||
|
||||
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
|
||||
|
||||
|
||||
def decode_pinyin_ids(pinyin_ids: list) -> str:
|
||||
"""将 pinyin_ids 还原为拼音字符串"""
|
||||
chars = []
|
||||
for pid in pinyin_ids:
|
||||
if pid == 0:
|
||||
break
|
||||
chars.append(ID_TO_CHAR.get(pid, "?"))
|
||||
return "".join(chars)
|
||||
|
||||
|
||||
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
|
||||
"""将 history_slot_ids 还原为文字"""
|
||||
parts = []
|
||||
for hid in history_ids:
|
||||
if hid == 0:
|
||||
parts.append("<PAD>")
|
||||
else:
|
||||
info = query_engine.query_by_id(hid)
|
||||
if info is not None:
|
||||
parts.append(f"{info.char}({info.pinyin})")
|
||||
else:
|
||||
parts.append(f"<ID:{hid}>")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||
"""统计 labels 分布,带进度条"""
|
||||
logger.info("正在统计 labels 分布...")
|
||||
counter = Counter()
|
||||
total = 0
|
||||
|
||||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||||
|
||||
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
|
||||
|
||||
for shard_idx in pbar:
|
||||
if dataset._is_sharded:
|
||||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||||
labels = shard_data["labels"].astype(np.int64)
|
||||
else:
|
||||
labels = dataset.labels[:].astype(np.int64)
|
||||
|
||||
unique, counts = np.unique(labels, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
counter[int(uid)] += cnt
|
||||
total += len(labels)
|
||||
|
||||
if dataset._is_sharded:
|
||||
del shard_data
|
||||
|
||||
return counter, total
|
||||
|
||||
|
||||
def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||
"""统计 history_slot_ids 的有效长度分布(非零元素个数)"""
|
||||
logger.info("正在统计 history_slot_ids 长度分布...")
|
||||
counter = Counter()
|
||||
total = 0
|
||||
|
||||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||||
|
||||
pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard")
|
||||
|
||||
for shard_idx in pbar:
|
||||
if dataset._is_sharded:
|
||||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||||
history_slots = shard_data["history_slot_ids"].astype(np.int64)
|
||||
else:
|
||||
history_slots = dataset.history_slot_ids[:].astype(np.int64)
|
||||
|
||||
lengths = np.count_nonzero(history_slots, axis=1)
|
||||
unique, counts = np.unique(lengths, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
counter[int(uid)] += cnt
|
||||
total += len(history_slots)
|
||||
|
||||
if dataset._is_sharded:
|
||||
del shard_data
|
||||
|
||||
return counter, total
|
||||
|
||||
|
||||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||
"""将一个样本还原为人类可读格式"""
|
||||
input_ids = (
|
||||
sample["input_ids"].tolist()
|
||||
if hasattr(sample["input_ids"], "tolist")
|
||||
else sample["input_ids"]
|
||||
)
|
||||
token_type_ids = (
|
||||
sample["token_type_ids"].tolist()
|
||||
if hasattr(sample["token_type_ids"], "tolist")
|
||||
else sample["token_type_ids"]
|
||||
)
|
||||
labels = (
|
||||
sample["labels"].item()
|
||||
if hasattr(sample["labels"], "item")
|
||||
else sample["labels"]
|
||||
)
|
||||
history_ids = (
|
||||
sample["history_slot_ids"].tolist()
|
||||
if hasattr(sample["history_slot_ids"], "tolist")
|
||||
else sample["history_slot_ids"]
|
||||
)
|
||||
pinyin_ids = (
|
||||
sample["pinyin_ids"].tolist()
|
||||
if hasattr(sample["pinyin_ids"], "tolist")
|
||||
else sample["pinyin_ids"]
|
||||
)
|
||||
|
||||
# 还原 token 文本
|
||||
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
|
||||
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
|
||||
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
|
||||
if sep_positions:
|
||||
sep_start = sep_positions[0]
|
||||
sent_a_ids = [
|
||||
tid
|
||||
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
|
||||
if tt == 0
|
||||
]
|
||||
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
|
||||
else:
|
||||
sent_a_ids = input_ids
|
||||
sent_b_ids = []
|
||||
|
||||
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
|
||||
suffix_text = (
|
||||
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
|
||||
)
|
||||
|
||||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
||||
label_info = query_engine.query_by_id(labels)
|
||||
history_str = decode_history(history_ids, query_engine)
|
||||
history_slot_length = sum(1 for h in history_ids if h != 0)
|
||||
|
||||
return {
|
||||
"context": context_text,
|
||||
"suffix": suffix_text,
|
||||
"pinyin": pinyin_str,
|
||||
"label_id": labels,
|
||||
"label_char": f"{label_info.char}({label_info.pinyin})"
|
||||
if label_info
|
||||
else f"<ID:{labels}>",
|
||||
"label_count": label_info.count if label_info else 0,
|
||||
"history": history_str,
|
||||
"history_slot_length": history_slot_length,
|
||||
"full_tokens": token_text,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="预处理数据质量分析")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="预处理数据目录(train/ 或 eval/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="抽样的样本数量(默认50,取前 N 个样本)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CSV 输出文件路径(默认: <data-dir>/samples.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-shards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="统计 labels 时最多读取的分片数(0=全部)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="显示出现次数最多和最少的标签数量",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output is None:
|
||||
args.output = str(Path(args.data_dir) / "samples.csv")
|
||||
|
||||
# 加载数据集
|
||||
logger.info(f"加载数据集: {args.data_dir}")
|
||||
dataset = PreProcessedDataset(args.data_dir)
|
||||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
||||
if dataset._is_sharded:
|
||||
console.print(
|
||||
f" 分片数: {dataset._num_shards}, 每分片: {min(dataset._shard_sizes):,} - {max(dataset._shard_sizes):,} 样本"
|
||||
)
|
||||
console.print()
|
||||
|
||||
# 加载 QueryEngine
|
||||
logger.info("加载 QueryEngine...")
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
# 加载 Tokenizer
|
||||
logger.info("加载 Tokenizer...")
|
||||
from importlib.resources import files as pkg_files
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
# ====== 1. Labels 分布分析 ======
|
||||
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
|
||||
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
|
||||
|
||||
# 获取词表总大小(_id_to_info 已包含 EOS id=0)
|
||||
vocab_size = len(query_engine._id_to_info)
|
||||
|
||||
appeared_ids = set(counter.keys())
|
||||
all_ids = set(query_engine._id_to_info.keys())
|
||||
missing_ids = all_ids - appeared_ids
|
||||
|
||||
console.print(f"\n总样本数: {total:,}")
|
||||
console.print(f"词表大小: {vocab_size:,} (含 EOS)")
|
||||
console.print(f"唯一标签数: {len(counter):,}")
|
||||
console.print(
|
||||
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
|
||||
)
|
||||
console.print(
|
||||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]"
|
||||
)
|
||||
|
||||
most_common = counter.most_common(args.top_k)
|
||||
least_common = (
|
||||
counter.most_common()[: -args.top_k - 1 : -1]
|
||||
if len(counter) > args.top_k
|
||||
else counter.most_common()
|
||||
)
|
||||
|
||||
# 最多标签表
|
||||
table_top = Table(
|
||||
title=f"出现次数最多的 {args.top_k} 个标签",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table_top.add_column("排名", style="cyan", width=6)
|
||||
table_top.add_column("ID", style="green", width=8)
|
||||
table_top.add_column("字符(拼音)", style="yellow", width=20)
|
||||
table_top.add_column("频次", style="red", width=12)
|
||||
table_top.add_column("占比", style="blue", width=10)
|
||||
|
||||
for rank, (label_id, count) in enumerate(most_common, 1):
|
||||
info = query_engine.query_by_id(label_id)
|
||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||||
pct = count / total * 100
|
||||
table_top.add_row(
|
||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
|
||||
)
|
||||
console.print(table_top)
|
||||
|
||||
# 最少标签表
|
||||
table_bottom = Table(
|
||||
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table_bottom.add_column("排名", style="cyan", width=6)
|
||||
table_bottom.add_column("ID", style="green", width=8)
|
||||
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
|
||||
table_bottom.add_column("频次", style="red", width=12)
|
||||
table_bottom.add_column("占比", style="blue", width=10)
|
||||
|
||||
for rank, (label_id, count) in enumerate(least_common, 1):
|
||||
info = query_engine.query_by_id(label_id)
|
||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||||
pct = count / total * 100
|
||||
table_bottom.add_row(
|
||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
|
||||
)
|
||||
console.print(table_bottom)
|
||||
|
||||
# 频次分布概览
|
||||
table_dist = Table(
|
||||
title="频次分布概览", show_header=True, header_style="bold magenta"
|
||||
)
|
||||
table_dist.add_column("频次区间", style="cyan")
|
||||
table_dist.add_column("标签数", style="green")
|
||||
table_dist.add_column("占总标签数比例", style="yellow")
|
||||
|
||||
bins = [
|
||||
(1, 10),
|
||||
(11, 100),
|
||||
(101, 1000),
|
||||
(1001, 10000),
|
||||
(10001, 100000),
|
||||
(100001, 1000000),
|
||||
(1000001, float("inf")),
|
||||
]
|
||||
for lo, hi in bins:
|
||||
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
|
||||
if count_in_bin > 0:
|
||||
hi_str = str(int(hi)) if hi != float("inf") else "∞"
|
||||
table_dist.add_row(
|
||||
f"{lo}-{hi_str}",
|
||||
f"{count_in_bin:,}",
|
||||
f"{count_in_bin / len(counter) * 100:.1f}%",
|
||||
)
|
||||
# 未出现
|
||||
if len(missing_ids) > 0:
|
||||
table_dist.add_row(
|
||||
"未出现",
|
||||
f"{len(missing_ids):,}",
|
||||
f"{len(missing_ids) / vocab_size * 100:.1f}%",
|
||||
)
|
||||
console.print(table_dist)
|
||||
|
||||
# ====== 2. 历史槽位长度分析 ======
|
||||
console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]")
|
||||
history_counter, history_total = analyze_history_slots(
|
||||
dataset, max_shards=args.max_shards
|
||||
)
|
||||
|
||||
if not history_counter:
|
||||
console.print("[yellow] 无历史槽位数据[/yellow]")
|
||||
else:
|
||||
sorted_items = sorted(history_counter.items())
|
||||
lengths_arr = [l for l, _ in sorted_items]
|
||||
counts_arr = [c for _, c in sorted_items]
|
||||
weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr))
|
||||
avg_length = weighted_sum / history_total if history_total > 0 else 0
|
||||
|
||||
cumsum = 0
|
||||
median_length = 0
|
||||
for length, count in sorted_items:
|
||||
cumsum += count
|
||||
if cumsum >= history_total / 2:
|
||||
median_length = length
|
||||
break
|
||||
|
||||
console.print(f"\n总样本数: {history_total:,}")
|
||||
console.print(f"最大历史槽位数: {max(lengths_arr)}")
|
||||
console.print(f"最小历史槽位数: {min(lengths_arr)}")
|
||||
console.print(f"平均历史槽位数: {avg_length:.2f}")
|
||||
console.print(f"中位数历史槽位数: {median_length}")
|
||||
|
||||
hist_table = Table(
|
||||
title="历史槽位有效长度分布",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
hist_table.add_column("槽位数", style="cyan", width=10)
|
||||
hist_table.add_column("样本数", style="green", width=12)
|
||||
hist_table.add_column("占比", style="yellow", width=10)
|
||||
|
||||
for length, count in sorted_items:
|
||||
pct = count / history_total * 100
|
||||
hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%")
|
||||
console.print(hist_table)
|
||||
|
||||
# ====== 3. 抽样还原 → CSV ======
|
||||
num_samples = min(args.num_samples, len(dataset))
|
||||
console.print(
|
||||
f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||||
)
|
||||
|
||||
indices = list(range(num_samples))
|
||||
|
||||
csv_path = Path(args.output)
|
||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_headers = [
|
||||
"index",
|
||||
"pinyin",
|
||||
"label_char",
|
||||
"label_id",
|
||||
"label_count",
|
||||
"context",
|
||||
"suffix",
|
||||
"history",
|
||||
"history_slot_length",
|
||||
"full_tokens",
|
||||
]
|
||||
|
||||
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(csv_headers)
|
||||
|
||||
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
|
||||
sample = dataset[idx]
|
||||
decoded = decode_sample(sample, tokenizer, query_engine)
|
||||
writer.writerow(
|
||||
[
|
||||
idx,
|
||||
decoded["pinyin"],
|
||||
decoded["label_char"],
|
||||
decoded["label_id"],
|
||||
decoded["label_count"],
|
||||
decoded["context"],
|
||||
decoded["suffix"],
|
||||
decoded["history"],
|
||||
decoded["history_slot_length"],
|
||||
decoded["full_tokens"],
|
||||
]
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
|
||||
)
|
||||
|
||||
# 打印前 5 个样本的概要
|
||||
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
|
||||
with open(csv_path, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if i >= 5:
|
||||
break
|
||||
console.print(
|
||||
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
|
||||
f"上下文={row['context'][:50]}..."
|
||||
)
|
||||
|
||||
console.print("\n[bold green]分析完成[/bold green]")
|
||||
|
||||
|
||||
app = main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python3
|
||||
"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"):
|
||||
with open(data_dir / "metadata.json") as f:
|
||||
meta = json.load(f)
|
||||
num_shards = meta["num_shards"]
|
||||
if num_shards <= 1:
|
||||
print(f"{data_dir}: already single shard, skip")
|
||||
return
|
||||
|
||||
merged = {}
|
||||
for field in FIELDS:
|
||||
pieces = []
|
||||
for i in tqdm(
|
||||
range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard"
|
||||
):
|
||||
data = np.load(data_dir / f"shard_{i:06d}.npz")
|
||||
pieces.append(data[field])
|
||||
data.close()
|
||||
merged[field] = np.concatenate(pieces, axis=0)
|
||||
del pieces
|
||||
gc.collect()
|
||||
|
||||
total = len(merged["labels"])
|
||||
# 先删原分片(避免与新文件名冲突)
|
||||
for i in range(num_shards):
|
||||
(data_dir / f"shard_{i:06d}.npz").unlink()
|
||||
# 再写合并文件
|
||||
np.savez_compressed(data_dir / output_name, **merged)
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
# 更新 metadata
|
||||
meta["num_shards"] = 1
|
||||
meta["shard_size"] = total
|
||||
with open(data_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件")
|
||||
parser.add_argument(
|
||||
"--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name", type=str, default="shard_000000.npz", help="合并后文件名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-only", action="store_true", help="仅合并 train,跳过 eval"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.input_dir)
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
split_dir = root / split
|
||||
if not split_dir.exists():
|
||||
print(f"{split_dir}: not found, skip")
|
||||
continue
|
||||
if args.train_only and split == "eval":
|
||||
continue
|
||||
merge_split(split_dir, args.output_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module):
|
|||
vocab_size: int = 10019,
|
||||
pinyin_vocab_size: int = 28,
|
||||
dim: int = 512,
|
||||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||
n_layers: int = 4, # Transformer 层数
|
||||
n_heads: int = 4, # 注意力头数
|
||||
num_experts: int = 10, # MoE 专家数量
|
||||
max_seq_len: int = 128, # 最大上下文长度
|
||||
compile: bool = False, # 是否开启 torch.compile 优化
|
||||
num_slots: int = 8,
|
||||
n_layers: int = 4,
|
||||
n_heads: int = 4,
|
||||
num_experts: int = 10,
|
||||
max_seq_len: int = 128,
|
||||
compile: bool = False,
|
||||
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module):
|
|||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
||||
self.moe = MoELayer(
|
||||
dim=dim,
|
||||
num_experts=num_experts,
|
||||
top_k=3,
|
||||
num_resblocks=12,
|
||||
moe_mode=moe_mode,
|
||||
)
|
||||
|
||||
# 5. 槽位注意力池化
|
||||
self.slot_attention = nn.Linear(dim, 1)
|
||||
|
|
@ -80,24 +87,12 @@ class InputMethodEngine(nn.Module):
|
|||
# 6. 分类头
|
||||
self.classifier = nn.Linear(dim, vocab_size)
|
||||
|
||||
# 开启 torch.compile 优化 (如果请求)
|
||||
# 在模型编译时添加优化选项
|
||||
if compile:
|
||||
from torch._inductor.select_algorithm import TritonTemplate
|
||||
|
||||
TritonTemplate.all_templates.clear()
|
||||
|
||||
self.forward = torch.compile(
|
||||
self.forward,
|
||||
mode="reduce-overhead",
|
||||
fullgraph=False,
|
||||
dynamic=False,
|
||||
# options={
|
||||
# "epilogue_fusion": True,
|
||||
# "max_autotune": True,
|
||||
# "triton.cudagraphs": True,
|
||||
# "reorder_for_compute_comm_overlap": False,
|
||||
# },
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
@ -115,6 +110,10 @@ class InputMethodEngine(nn.Module):
|
|||
|
||||
history_slot_ids = history_slot_ids.view(-1, self.num_slots)
|
||||
|
||||
is_zero = (history_slot_ids == 0)
|
||||
cumsum_zero = is_zero.cumsum(dim=1)
|
||||
slot_mask = (cumsum_zero <= 1).to(torch.get_default_dtype())
|
||||
|
||||
H, P = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||
|
||||
S = self.slot_memory(history_slot_ids)
|
||||
|
|
@ -125,11 +124,15 @@ class InputMethodEngine(nn.Module):
|
|||
S, H, P, context_mask=context_mask, pinyin_mask=pinyin_mask
|
||||
)
|
||||
|
||||
fused = fused * slot_mask.unsqueeze(-1)
|
||||
|
||||
moe_out = self.moe(fused)
|
||||
|
||||
batch_size = input_ids.size(0)
|
||||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||
slot_weights = torch.softmax(slot_scores, dim=1)
|
||||
slot_weights = slot_weights * slot_mask
|
||||
slot_weights = slot_weights / (slot_weights.sum(dim=1, keepdim=True) + 1e-12)
|
||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1)
|
||||
|
||||
logits = self.classifier(pooled)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,366 @@
|
|||
"""
|
||||
ONNX模型导出核心逻辑
|
||||
|
||||
将 InputMethodEngine 模型导出为两个 ONNX 文件:
|
||||
1. context_encoder.onnx - 上下文编码器(可复用)
|
||||
2. decoder.onnx - 解码器
|
||||
|
||||
共用此模块的入口:
|
||||
- CLI: train-model export
|
||||
- 脚本: python export_onnx.py(薄壳)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .export_models import create_export_models_from_checkpoint
|
||||
|
||||
|
||||
def check_onnx_available() -> bool:
|
||||
try:
|
||||
import onnx # noqa: F401
|
||||
import onnxruntime as ort # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
print("错误: ONNX导出需要以下依赖:")
|
||||
print(" pip install onnx onnxruntime")
|
||||
print("请安装后重试")
|
||||
return False
|
||||
|
||||
|
||||
def export_context_encoder(
|
||||
model,
|
||||
output_path: str,
|
||||
config: Dict,
|
||||
skip_verification: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
|
||||
dynamic_axes = {
|
||||
"input_ids": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_ids": {0: "batch_size"},
|
||||
"attention_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_P": {0: "batch_size"},
|
||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_mask": {0: "batch_size"},
|
||||
}
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids, pinyin_ids, attention_mask),
|
||||
output_path,
|
||||
input_names=["input_ids", "pinyin_ids", "attention_mask"],
|
||||
output_names=["context_H", "pinyin_P", "context_mask", "pinyin_mask"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print(f" 上下文编码器导出完成 -> {output_path}")
|
||||
|
||||
if not skip_verification:
|
||||
try:
|
||||
import onnx
|
||||
onnx.checker.check_model(onnx.load(output_path))
|
||||
print(" ONNX 模型验证通过")
|
||||
except Exception as e:
|
||||
print(f" ONNX 模型验证警告: {e}")
|
||||
|
||||
return input_ids, pinyin_ids, attention_mask
|
||||
|
||||
|
||||
def export_decoder(
|
||||
model,
|
||||
output_path: str,
|
||||
config: Dict,
|
||||
example_inputs: Optional[Tuple] = None,
|
||||
skip_verification: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
dim = config.get("dim", 512)
|
||||
num_slots = config.get("num_slots", 8)
|
||||
|
||||
if example_inputs is not None:
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = example_inputs
|
||||
batch_size = context_H.size(0)
|
||||
history_slot_ids = torch.randint(
|
||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||
)
|
||||
else:
|
||||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||||
history_slot_ids = torch.randint(
|
||||
0, 100, (batch_size, num_slots), dtype=torch.long
|
||||
)
|
||||
|
||||
dynamic_axes = {
|
||||
"context_H": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_P": {0: "batch_size"},
|
||||
"history_slot_ids": {0: "batch_size"},
|
||||
"context_mask": {0: "batch_size", 1: "seq_len"},
|
||||
"pinyin_mask": {0: "batch_size"},
|
||||
"logits": {0: "batch_size"},
|
||||
}
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask),
|
||||
output_path,
|
||||
input_names=[
|
||||
"context_H",
|
||||
"pinyin_P",
|
||||
"history_slot_ids",
|
||||
"context_mask",
|
||||
"pinyin_mask",
|
||||
],
|
||||
output_names=["logits"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print(f" 解码器导出完成 -> {output_path}")
|
||||
|
||||
if not skip_verification:
|
||||
try:
|
||||
import onnx
|
||||
onnx.checker.check_model(onnx.load(output_path))
|
||||
print(" ONNX 模型验证通过")
|
||||
except Exception as e:
|
||||
print(f" ONNX 模型验证警告: {e}")
|
||||
|
||||
return context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
|
||||
|
||||
def save_example_inputs(output_dir: str, example_inputs_dict: Dict) -> None:
|
||||
npz_path = os.path.join(output_dir, "example_inputs.npz")
|
||||
np_data = {}
|
||||
for key, tensor in example_inputs_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
np_data[key] = tensor.cpu().numpy()
|
||||
elif isinstance(tensor, tuple):
|
||||
for i, t in enumerate(tensor):
|
||||
if isinstance(t, torch.Tensor):
|
||||
np_data[f"{key}_{i}"] = t.cpu().numpy()
|
||||
np.savez(npz_path, **np_data)
|
||||
print(f" 示例输入已保存到: {npz_path}")
|
||||
|
||||
torch_path = os.path.join(output_dir, "example_inputs.pt")
|
||||
torch.save(example_inputs_dict, torch_path)
|
||||
print(f" PyTorch 示例输入已保存到: {torch_path}")
|
||||
|
||||
|
||||
_INFERENCE_EXAMPLE_TEMPLATE = '''#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型推理示例
|
||||
|
||||
展示如何使用导出的两个ONNX模型进行推理,
|
||||
包括束搜索(beam search)算法。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ONNXInference:
|
||||
"""ONNX模型推理器"""
|
||||
|
||||
def __init__(self, context_encoder_path, decoder_path):
|
||||
self.context_encoder_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
self.decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
self.context_input_names = [input.name for input in self.context_encoder_session.get_inputs()]
|
||||
self.context_output_names = [output.name for output in self.context_encoder_session.get_outputs()]
|
||||
self.decoder_input_names = [input.name for input in self.decoder_session.get_inputs()]
|
||||
self.decoder_output_names = [output.name for output in self.decoder_session.get_outputs()]
|
||||
|
||||
print(f"上下文编码器输入: {self.context_input_names}")
|
||||
print(f"上下文编码器输出: {self.context_output_names}")
|
||||
print(f"解码器输入: {self.decoder_input_names}")
|
||||
print(f"解码器输出: {self.decoder_output_names}")
|
||||
|
||||
def prepare_inputs(self, text_before, text_after, pinyin, slot_chars, tokenizer, query_engine, max_seq_len=128):
|
||||
raise NotImplementedError("请实现实际的输入预处理")
|
||||
|
||||
def run_context_encoder(self, input_ids, pinyin_ids, attention_mask):
|
||||
inputs = {
|
||||
"input_ids": input_ids.numpy() if isinstance(input_ids, torch.Tensor) else input_ids,
|
||||
"pinyin_ids": pinyin_ids.numpy() if isinstance(pinyin_ids, torch.Tensor) else pinyin_ids,
|
||||
"attention_mask": attention_mask.numpy() if isinstance(attention_mask, torch.Tensor) else attention_mask,
|
||||
}
|
||||
outputs = self.context_encoder_session.run(self.context_output_names, inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = outputs
|
||||
return (
|
||||
torch.from_numpy(context_H),
|
||||
torch.from_numpy(pinyin_P),
|
||||
torch.from_numpy(context_mask),
|
||||
torch.from_numpy(pinyin_mask),
|
||||
)
|
||||
|
||||
def run_decoder(self, context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask):
|
||||
inputs = {
|
||||
"context_H": context_H.numpy() if isinstance(context_H, torch.Tensor) else context_H,
|
||||
"pinyin_P": pinyin_P.numpy() if isinstance(pinyin_P, torch.Tensor) else pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy() if isinstance(history_slot_ids, torch.Tensor) else history_slot_ids,
|
||||
"context_mask": context_mask.numpy() if isinstance(context_mask, torch.Tensor) else context_mask,
|
||||
"pinyin_mask": pinyin_mask.numpy() if isinstance(pinyin_mask, torch.Tensor) else pinyin_mask,
|
||||
}
|
||||
outputs = self.decoder_session.run(self.decoder_output_names, inputs)
|
||||
logits = outputs[0]
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
def beam_search(self, context_H, pinyin_P, context_mask, pinyin_mask,
|
||||
beam_size=5, max_length=10, vocab_size=10019):
|
||||
beams = [([], 0.0)]
|
||||
for step in range(max_length):
|
||||
new_beams = []
|
||||
for seq, score in beams:
|
||||
if len(seq) < 8:
|
||||
history = seq + [0] * (8 - len(seq))
|
||||
else:
|
||||
history = seq[-8:]
|
||||
history_tensor = torch.tensor([history], dtype=torch.long)
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_tensor,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
probs = F.softmax(logits[0], dim=-1)
|
||||
top_probs, top_indices = torch.topk(probs, beam_size)
|
||||
for prob, idx in zip(top_probs, top_indices):
|
||||
new_seq = seq + [idx.item()]
|
||||
new_score = score + torch.log(prob).item()
|
||||
new_beams.append((new_seq, new_score))
|
||||
new_beams.sort(key=lambda x: x[1], reverse=True)
|
||||
beams = new_beams[:beam_size]
|
||||
all_ended = all(seq[-1] == 0 for seq, _ in beams if seq)
|
||||
if all_ended:
|
||||
break
|
||||
return beams
|
||||
|
||||
def predict_single(self, input_ids, pinyin_ids, attention_mask, history_slot_ids):
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = self.run_context_encoder(
|
||||
input_ids, pinyin_ids, attention_mask
|
||||
)
|
||||
logits = self.run_decoder(
|
||||
context_H, pinyin_P, history_slot_ids,
|
||||
context_mask, pinyin_mask
|
||||
)
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""示例主函数"""
|
||||
print("ONNX模型推理示例")
|
||||
print("=" * 60)
|
||||
|
||||
context_encoder_path = "context_encoder.onnx"
|
||||
decoder_path = "decoder.onnx"
|
||||
|
||||
if not os.path.exists(context_encoder_path) or not os.path.exists(decoder_path):
|
||||
print("错误: 找不到ONNX模型文件")
|
||||
print("请先运行 train-model export 导出模型")
|
||||
return
|
||||
|
||||
inference = ONNXInference(context_encoder_path, decoder_path)
|
||||
print("\\u2705 ONNX推理器初始化完成")
|
||||
print("请参考此示例实现完整的输入法推理流程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
|
||||
def create_inference_example(output_dir: str, config: Dict) -> None:
|
||||
example_path = os.path.join(output_dir, "inference_example.py")
|
||||
with open(example_path, "w", encoding="utf-8") as f:
|
||||
f.write(_INFERENCE_EXAMPLE_TEMPLATE)
|
||||
print(f" 推理示例脚本已保存到: {example_path}")
|
||||
|
||||
|
||||
def run_full_export(
|
||||
checkpoint_path: str,
|
||||
output_dir: str,
|
||||
device: str = "cpu",
|
||||
skip_verification: bool = False,
|
||||
) -> Tuple[str, str, Dict]:
|
||||
"""
|
||||
完整的 ONNX 导出流程
|
||||
|
||||
Returns:
|
||||
(context_encoder_path, decoder_path, config)
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"加载 checkpoint: {checkpoint_path}")
|
||||
context_encoder_export, decoder_export, config = create_export_models_from_checkpoint(
|
||||
checkpoint_path, device
|
||||
)
|
||||
|
||||
print(f"模型配置: vocab_size={config.get('vocab_size')}, "
|
||||
f"dim={config.get('dim')}, "
|
||||
f"num_slots={config.get('num_slots')}, "
|
||||
f"num_experts={config.get('num_experts')}, "
|
||||
f"moe_mode={config.get('moe_mode', 'all')}")
|
||||
|
||||
print("正在导出模型...")
|
||||
|
||||
context_encoder_path = str(output_path / "context_encoder.onnx")
|
||||
example_inputs = export_context_encoder(
|
||||
context_encoder_export, context_encoder_path, config,
|
||||
skip_verification=skip_verification,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_encoder_export(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
decoder_path = str(output_path / "decoder.onnx")
|
||||
export_decoder(
|
||||
decoder_export, decoder_path, config,
|
||||
example_inputs=(context_H, pinyin_P, context_mask, pinyin_mask),
|
||||
skip_verification=skip_verification,
|
||||
)
|
||||
|
||||
example_inputs_dict = {
|
||||
"input_ids": example_inputs[0],
|
||||
"pinyin_ids": example_inputs[1],
|
||||
"attention_mask": example_inputs[2],
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
save_example_inputs(output_dir, example_inputs_dict)
|
||||
create_inference_example(output_dir, config)
|
||||
|
||||
return context_encoder_path, decoder_path, config
|
||||
|
|
@ -0,0 +1,433 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。
|
||||
|
||||
采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。
|
||||
每个分片使用 np.savez_compressed 保存(zlib 压缩),GPU 服务器无需解压到硬盘。
|
||||
|
||||
用法:
|
||||
python -m model.preprocess \
|
||||
--train-data-path "some/hf_dataset" \
|
||||
--eval-data-path "some/hf_dataset" \
|
||||
--output-dir ./preprocessed \
|
||||
--num-train-samples 5000000 \
|
||||
--num-eval-samples 8192
|
||||
|
||||
生成目录结构:
|
||||
output_dir/
|
||||
train/
|
||||
metadata.json
|
||||
shard_000.npz (5M样本, 6个字段, zlib压缩)
|
||||
shard_001.npz
|
||||
...
|
||||
eval/
|
||||
metadata.json
|
||||
shard_000.npz
|
||||
...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
|
||||
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
|
||||
result = {}
|
||||
for f in FIELDS:
|
||||
tensor = batch[f][:take]
|
||||
arr = tensor.numpy().astype(np.int16)
|
||||
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
|
||||
arr = arr.squeeze(-1)
|
||||
result[f] = arr
|
||||
return result
|
||||
|
||||
|
||||
def collect_samples(
|
||||
dataloader: DataLoader,
|
||||
num_samples: int,
|
||||
output_dir: Path,
|
||||
split_name: str,
|
||||
max_seq_length: int = 128,
|
||||
shard_size: int = 5_000_000,
|
||||
) -> int:
|
||||
"""
|
||||
分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。
|
||||
|
||||
内存峰值 = shard_size × 每样本字节数(约578字节 @ shard_size=5M → 约2.9GB)
|
||||
"""
|
||||
split_dir = output_dir / split_name
|
||||
split_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
shard_count = 0
|
||||
shard_idx = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
|
||||
|
||||
for batch in dataloader:
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
remaining = num_samples - total
|
||||
if remaining <= 0:
|
||||
break
|
||||
take = min(batch_size, remaining)
|
||||
|
||||
extracted = _extract_batch(batch, take)
|
||||
for f in FIELDS:
|
||||
shard_buffers[f].append(extracted[f])
|
||||
|
||||
shard_count += take
|
||||
total += take
|
||||
pbar.update(take)
|
||||
|
||||
if shard_count >= shard_size:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||||
|
||||
shard_idx += 1
|
||||
shard_buffers = {f: [] for f in FIELDS}
|
||||
shard_count = 0
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
if total >= num_samples:
|
||||
break
|
||||
|
||||
# 写入最后一个不满的分片
|
||||
if shard_count > 0:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||||
shard_idx += 1
|
||||
|
||||
pbar.close()
|
||||
|
||||
actual_count = min(total, num_samples)
|
||||
num_shards = shard_idx
|
||||
|
||||
shard_sizes = [shard_size] * (num_shards - 1)
|
||||
if num_shards > 0:
|
||||
shard_sizes.append(actual_count - sum(shard_sizes))
|
||||
|
||||
metadata = {
|
||||
"num_samples": actual_count,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": num_shards,
|
||||
"shard_sizes": shard_sizes,
|
||||
}
|
||||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
||||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
||||
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
logger.info(
|
||||
f"{split_name}: {actual_count} samples in {num_shards} shards, "
|
||||
f"{total_size / (1024**3):.2f} GB (compressed)"
|
||||
)
|
||||
|
||||
return actual_count
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
|
||||
parser.add_argument(
|
||||
"--train-data-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="训练数据集路径(HuggingFace格式)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-data-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="评估数据集路径(HuggingFace格式)",
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||||
parser.add_argument(
|
||||
"--num-train-samples", type=int, required=True, help="训练集样本数量"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=2, help="DataLoader worker数量"
|
||||
)
|
||||
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="分片大小(样本数),控制内存峰值(默认500万,约2.9GB/分片未压缩)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--py-style-weight",
|
||||
type=str,
|
||||
default="9,2,1",
|
||||
help="拼音风格权重(逗号分隔)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle-buffer-size",
|
||||
type=int,
|
||||
default=2000000,
|
||||
help="数据集shuffle缓冲区大小",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--length-weights",
|
||||
type=str,
|
||||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||||
help="词长权重",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||||
length_weights = {
|
||||
int(k): int(v)
|
||||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||||
}
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_max_iter = args.num_train_samples * 5
|
||||
eval_max_iter = args.num_eval_samples * 5
|
||||
|
||||
shard_mem_gb = args.shard_size * 578 / (1024**3)
|
||||
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
|
||||
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
|
||||
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"数据类型: int16")
|
||||
console.print(
|
||||
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
|
||||
)
|
||||
console.print()
|
||||
|
||||
num_train_workers = args.num_workers
|
||||
num_eval_workers = max(1, args.num_workers // 2)
|
||||
|
||||
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=args.train_data_path,
|
||||
max_workers=num_train_workers,
|
||||
max_iter_length=train_max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=num_train_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_train_workers > 0 else False,
|
||||
)
|
||||
|
||||
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=args.eval_data_path,
|
||||
max_workers=num_eval_workers,
|
||||
max_iter_length=eval_max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
)
|
||||
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=num_eval_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_eval_workers > 0 else False,
|
||||
)
|
||||
|
||||
logger.info("开始收集训练数据...")
|
||||
train_count = collect_samples(
|
||||
train_dataloader,
|
||||
args.num_train_samples,
|
||||
output_dir,
|
||||
"train",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if train_count < args.num_train_samples:
|
||||
logger.warning(
|
||||
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
|
||||
)
|
||||
|
||||
logger.info("开始收集评估数据...")
|
||||
eval_count = collect_samples(
|
||||
eval_dataloader,
|
||||
args.num_eval_samples,
|
||||
output_dir,
|
||||
"eval",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if eval_count < args.num_eval_samples:
|
||||
logger.warning(
|
||||
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
|
||||
)
|
||||
|
||||
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
|
||||
console.print(f"训练集: {train_count:,} 样本")
|
||||
console.print(f"评估集: {eval_count:,} 样本")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
split_dir = output_dir / split
|
||||
if split_dir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
||||
|
||||
|
||||
def resplit_shards(input_dir: str, output_dir: str, target_size: int = 1_000_000):
|
||||
"""
|
||||
将过大的 .npz 分片拆分到目标大小。
|
||||
|
||||
内存峰值 = 一个分片全部字段 int16 (~21GB for 20M) + 一个 chunk (~1.5GB)。
|
||||
建议在内存充裕的 CPU 机器上运行,拆分后将小分片拷贝至 GPU 服务器。
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
input_path = Path(input_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata_path = input_path / "metadata.json"
|
||||
if not metadata_path.exists():
|
||||
raise FileNotFoundError(f"metadata.json not found in {input_dir}")
|
||||
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
shard_files = sorted(input_path.glob("shard_*.npz"))
|
||||
console = Console()
|
||||
console.print(
|
||||
f"[bold]Resplit {len(shard_files)} shards → {target_size:,} samples/chunk[/bold]"
|
||||
)
|
||||
|
||||
global_shard_idx = 0
|
||||
all_shard_sizes: List[int] = []
|
||||
total_samples = 0
|
||||
|
||||
for sf in tqdm(shard_files, desc="Shards"):
|
||||
t0 = time.time()
|
||||
data = np.load(str(sf))
|
||||
n_samples = data[list(data.keys())[0]].shape[0]
|
||||
|
||||
for start in range(0, n_samples, target_size):
|
||||
end = min(start + target_size, n_samples)
|
||||
chunk_data: Dict[str, np.ndarray] = {}
|
||||
for field in data.keys():
|
||||
chunk_data[field] = data[field][start:end].astype(np.int16).copy()
|
||||
|
||||
out_file = output_path / f"shard_{global_shard_idx:06d}.npz"
|
||||
np.savez_compressed(out_file, **chunk_data)
|
||||
chunk_size = end - start
|
||||
all_shard_sizes.append(chunk_size)
|
||||
total_samples += chunk_size
|
||||
del chunk_data
|
||||
gc.collect()
|
||||
|
||||
global_shard_idx += 1
|
||||
|
||||
data.close()
|
||||
del data
|
||||
gc.collect()
|
||||
|
||||
with open(output_path / "metadata.json", "w", encoding="utf-8") as f:
|
||||
metadata["num_samples"] = total_samples
|
||||
metadata["num_shards"] = global_shard_idx
|
||||
metadata["shard_sizes"] = all_shard_sizes
|
||||
metadata.pop("shard_size", None)
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
console.print(
|
||||
f"[green]Done: {total_samples:,} samples in {global_shard_idx} shards[/green]"
|
||||
)
|
||||
|
||||
|
||||
def _dispatch():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "resplit":
|
||||
import argparse as ap
|
||||
|
||||
p = ap.ArgumentParser(description="拆分过大的 .npz 分片")
|
||||
p.add_argument(
|
||||
"--input-dir", required=True, help="输入目录含 metadata.json 和 shard_*.npz"
|
||||
)
|
||||
p.add_argument("--output-dir", required=True, help="输出目录")
|
||||
p.add_argument(
|
||||
"--target-size", type=int, default=1_000_000, help="目标分片大小(样本数)"
|
||||
)
|
||||
args = p.parse_args(sys.argv[2:])
|
||||
resplit_shards(args.input_dir, args.output_dir, args.target_size)
|
||||
else:
|
||||
main()
|
||||
|
||||
|
||||
app = _dispatch
|
||||
|
||||
if __name__ == "__main__":
|
||||
_dispatch()
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
"""
|
||||
预处理数据集加载器
|
||||
|
||||
支持两种格式:
|
||||
1. 分片压缩格式(.npz):从压缩文件中按需加载分片,LRU 缓存最多持有 max_cache_shards 个分片
|
||||
2. 单体格式(.npy):向后兼容,使用 mmap 零拷贝加载
|
||||
|
||||
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
||||
"""
|
||||
|
||||
import time
|
||||
import gc
|
||||
import json
|
||||
import struct
|
||||
import zipfile
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _read_shard_size(npz_path: Path) -> int:
|
||||
with zipfile.ZipFile(npz_path, "r") as z:
|
||||
first = sorted(z.namelist())[0]
|
||||
with z.open(first) as f:
|
||||
magic = f.read(6)
|
||||
if magic != b"\x93NUMPY":
|
||||
raise ValueError(f"Not a numpy file: {npz_path}")
|
||||
ver = struct.unpack("<BB", f.read(2))
|
||||
if ver[0] >= 2:
|
||||
header_len = struct.unpack("<I", f.read(4))[0]
|
||||
else:
|
||||
header_len = struct.unpack("<H", f.read(2))[0]
|
||||
header = eval(f.read(header_len))
|
||||
return header["shape"][0]
|
||||
|
||||
|
||||
def _build_offsets(sizes: List[int]) -> List[int]:
|
||||
offsets = [0]
|
||||
for s in sizes:
|
||||
offsets.append(offsets[-1] + s)
|
||||
return offsets
|
||||
|
||||
|
||||
def is_preprocessed_data(path: str) -> bool:
|
||||
"""判断路径是否为预处理数据目录"""
|
||||
p = Path(path)
|
||||
return p.is_dir() and (p / "metadata.json").exists()
|
||||
|
||||
|
||||
class _ShardCache:
|
||||
"""LRU 缓存,管理按需加载的 .npz 分片,最多持有 max_size 个分片"""
|
||||
|
||||
def __init__(self, max_size: int = 2):
|
||||
self.max_size = max_size
|
||||
self._cache: OrderedDict[int, Dict[str, np.ndarray]] = OrderedDict()
|
||||
|
||||
def get(self, shard_idx: int, loader_fn) -> Dict[str, np.ndarray]:
|
||||
if shard_idx in self._cache:
|
||||
self._cache.move_to_end(shard_idx)
|
||||
return self._cache[shard_idx]
|
||||
|
||||
data = loader_fn(shard_idx)
|
||||
self._cache[shard_idx] = data
|
||||
self._cache.move_to_end(shard_idx)
|
||||
|
||||
while len(self._cache) > self.max_size:
|
||||
evicted_key, evicted_data = self._cache.popitem(last=False)
|
||||
del evicted_data
|
||||
gc.collect()
|
||||
|
||||
return data
|
||||
|
||||
def clear(self):
|
||||
self._cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class PreProcessedDataset(Dataset):
|
||||
"""
|
||||
预处理数据集加载器,自动检测数据格式:
|
||||
|
||||
- 分片压缩格式(metadata.json 含 shard_size 字段):
|
||||
从 .npz 分片按需加载,LRU 缓存控制内存
|
||||
- 单体格式(向后兼容):
|
||||
mmap 零拷贝加载 .npy 文件
|
||||
|
||||
所有数据以 int16 存储,读取时转为 torch.long (int64)。
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str, max_cache_shards: int = 1):
|
||||
t_start = time.perf_counter()
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
self.max_seq_length = self.metadata["max_seq_length"]
|
||||
|
||||
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||
if shard_files:
|
||||
self._num_shards = len(shard_files)
|
||||
if (
|
||||
"shard_sizes" in self.metadata
|
||||
and len(self.metadata["shard_sizes"]) == self._num_shards
|
||||
):
|
||||
self._shard_sizes = self.metadata["shard_sizes"]
|
||||
logger.info(
|
||||
f"Using shard_sizes from metadata.json ({self._num_shards} shards)"
|
||||
)
|
||||
else:
|
||||
t_scan = time.perf_counter()
|
||||
self._shard_sizes = [_read_shard_size(sf) for sf in shard_files]
|
||||
logger.info(
|
||||
f"Scanned {self._num_shards} shard headers in "
|
||||
f"{time.perf_counter() - t_scan:.2f}s"
|
||||
)
|
||||
self._shard_offsets = _build_offsets(self._shard_sizes)
|
||||
self.num_samples = self._shard_offsets[-1]
|
||||
self._is_sharded = True
|
||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||
logger.info(
|
||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
||||
f"{self._num_shards} shards, "
|
||||
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||
)
|
||||
else:
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
self._is_sharded = False
|
||||
self._load_single_files()
|
||||
logger.info(
|
||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap), "
|
||||
f"init in {time.perf_counter() - t_start:.2f}s"
|
||||
)
|
||||
|
||||
def _load_single_files(self):
|
||||
"""向后兼容:加载单体 .npy 文件(mmap 模式)"""
|
||||
self.input_ids = np.load(self.data_dir / "input_ids.npy", mmap_mode="r")
|
||||
self.token_type_ids = np.load(
|
||||
self.data_dir / "token_type_ids.npy", mmap_mode="r"
|
||||
)
|
||||
self.attention_mask = np.load(
|
||||
self.data_dir / "attention_mask.npy", mmap_mode="r"
|
||||
)
|
||||
self.labels = np.load(self.data_dir / "labels.npy", mmap_mode="r")
|
||||
self.history_slot_ids = np.load(
|
||||
self.data_dir / "history_slot_ids.npy", mmap_mode="r"
|
||||
)
|
||||
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
|
||||
|
||||
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
|
||||
"""加载一个 .npz 分片到内存(保持原始 int16,不转换)"""
|
||||
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
|
||||
return dict(np.load(shard_path))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
if not 0 <= idx < self.num_samples:
|
||||
raise IndexError(
|
||||
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
||||
)
|
||||
|
||||
if not hasattr(self, "_first_access_logged"):
|
||||
self._first_access_logged = True
|
||||
logger.info("First __getitem__ call (initial shard load may be slow)")
|
||||
|
||||
if self._is_sharded:
|
||||
lo, hi = 0, len(self._shard_offsets) - 1
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if self._shard_offsets[mid] <= idx:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
shard_idx = lo - 1
|
||||
local_idx = idx - self._shard_offsets[shard_idx]
|
||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||
result = {
|
||||
"input_ids": torch.from_numpy(
|
||||
shard_data["input_ids"][local_idx].astype(np.int64)
|
||||
),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
shard_data["token_type_ids"][local_idx].astype(np.int64)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
shard_data["attention_mask"][local_idx].astype(np.int64)
|
||||
),
|
||||
"labels": torch.tensor(
|
||||
shard_data["labels"][local_idx], dtype=torch.long
|
||||
),
|
||||
"history_slot_ids": torch.from_numpy(
|
||||
shard_data["history_slot_ids"][local_idx].astype(np.int64)
|
||||
),
|
||||
"pinyin_ids": torch.from_numpy(
|
||||
shard_data["pinyin_ids"][local_idx].astype(np.int64)
|
||||
),
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
self.token_type_ids[idx].astype(np.int64)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
self.attention_mask[idx].astype(np.int64)
|
||||
),
|
||||
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
|
||||
"history_slot_ids": torch.from_numpy(
|
||||
self.history_slot_ids[idx].astype(np.int64)
|
||||
),
|
||||
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def preprocessed_collate_fn(batch):
|
||||
"""
|
||||
预处理数据的 collate 函数。
|
||||
不含 string 字段(prefix/suffix/pinyin),仅处理 tensor 字段。
|
||||
"""
|
||||
return {
|
||||
"input_ids": torch.stack([item["input_ids"] for item in batch]),
|
||||
"token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
|
||||
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
|
||||
"labels": torch.stack([item["labels"] for item in batch]),
|
||||
"history_slot_ids": torch.stack([item["history_slot_ids"] for item in batch]),
|
||||
"pinyin_ids": torch.stack([item["pinyin_ids"] for item in batch]),
|
||||
}
|
||||
|
|
@ -0,0 +1,318 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
打乱并平衡预处理 .npz 分片。
|
||||
|
||||
两阶段处理:
|
||||
Phase 1: 逐分片内部打乱 → 写入临时 .npy(mmap 友好)
|
||||
Phase 2: 从临时 .npy 按比例分配到平衡的输出 .npz 分片
|
||||
|
||||
用法:
|
||||
python -m src.model.shuffle_npz \
|
||||
--input-dir /path/to/subsampled \
|
||||
--output-dir /path/to/shuffled \
|
||||
--shard-size 1000000 \
|
||||
--seed 42
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _phase1_shuffle_to_npy(
|
||||
input_dir: Path,
|
||||
split: str,
|
||||
temp_dir: Path,
|
||||
shard_sizes: List[int],
|
||||
seed: int,
|
||||
):
|
||||
"""
|
||||
Phase 1: 逐分片内部打乱,写入 .npy 文件。
|
||||
|
||||
每个分片的每个字段写入一个独立 .npy,路径格式:
|
||||
temp_dir/<field>/shard_<idx>.npy
|
||||
|
||||
峰值内存 ~10GB(单字段 int16 约 5GB + permuted copy 约 5GB)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
for field in FIELDS:
|
||||
(temp_dir / field).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pbar = tqdm(
|
||||
total=num_shards,
|
||||
desc=f"Phase 1: shuffling {split}",
|
||||
unit="shard",
|
||||
)
|
||||
|
||||
for src_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{src_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
n = shard_sizes[src_idx]
|
||||
perm = rng.permutation(n)
|
||||
|
||||
for field in FIELDS:
|
||||
arr = data[field].copy()
|
||||
shuffled = arr[perm]
|
||||
np.save(temp_dir / field / f"shard_{src_idx:06d}.npy", shuffled)
|
||||
del arr, shuffled
|
||||
gc.collect()
|
||||
|
||||
data.close()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
|
||||
def _phase2_rebalance_to_npz(
|
||||
temp_dir: Path,
|
||||
output_dir: Path,
|
||||
split: str,
|
||||
shard_sizes: List[int],
|
||||
target_shard_size: int,
|
||||
max_seq_length: int,
|
||||
seed: int,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Phase 2: 从 mmap 的 .npy 文件中按比例分配样本到平衡的 .npz 输出分片。
|
||||
|
||||
每个输出分片从所有源分片各取 proportional chunk → concat → shuffle → save。
|
||||
内存峰值 ~3GB(一个输出缓冲 + mmap pages)。
|
||||
"""
|
||||
output_split_dir = output_dir / split
|
||||
output_split_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_samples = sum(shard_sizes)
|
||||
num_output_shards = (total_samples + target_shard_size - 1) // target_shard_size
|
||||
rng = np.random.RandomState(seed + 2)
|
||||
|
||||
logger.info(
|
||||
f"Phase 2: distributing {total_samples:,} samples into "
|
||||
f"{num_output_shards} output shards (~{target_shard_size:,} each)"
|
||||
)
|
||||
|
||||
# 打开所有源 .npy 的 mmap(读模式,零 RAM 开销)
|
||||
src_mmaps: List[Dict[str, np.ndarray]] = []
|
||||
for src_idx in range(len(shard_sizes)):
|
||||
shard_mmap = {}
|
||||
for field in FIELDS:
|
||||
npy_path = temp_dir / field / f"shard_{src_idx:06d}.npy"
|
||||
shard_mmap[field] = np.load(npy_path, mmap_mode="r")
|
||||
src_mmaps.append(shard_mmap)
|
||||
|
||||
output_shard_sizes: List[int] = []
|
||||
|
||||
pbar = tqdm(
|
||||
total=num_output_shards,
|
||||
desc=f"Phase 2: writing {split}",
|
||||
unit="shard",
|
||||
)
|
||||
|
||||
for out_j in range(num_output_shards):
|
||||
buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
|
||||
for src_i in range(len(shard_sizes)):
|
||||
s = shard_sizes[src_i]
|
||||
start = (out_j * s) // num_output_shards
|
||||
end = ((out_j + 1) * s) // num_output_shards
|
||||
if start >= end:
|
||||
continue
|
||||
|
||||
for field in FIELDS:
|
||||
chunk = src_mmaps[src_i][field][start:end].copy()
|
||||
buffers[field].append(chunk)
|
||||
|
||||
output = {}
|
||||
for field in FIELDS:
|
||||
output[field] = (
|
||||
np.concatenate(buffers[field])
|
||||
if len(buffers[field]) > 1
|
||||
else buffers[field][0]
|
||||
)
|
||||
|
||||
out_count = len(output[FIELDS[0]])
|
||||
if out_count > 1:
|
||||
perm = rng.permutation(out_count)
|
||||
for field in FIELDS:
|
||||
output[field] = output[field][perm]
|
||||
|
||||
np.savez_compressed(output_split_dir / f"shard_{out_j:06d}.npz", **output)
|
||||
output_shard_sizes.append(out_count)
|
||||
|
||||
del output, buffers
|
||||
gc.collect()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# 关闭所有 mmap
|
||||
src_mmaps.clear()
|
||||
gc.collect()
|
||||
|
||||
return output_shard_sizes
|
||||
|
||||
|
||||
def _copy_eval(input_dir: Path, output_dir: Path):
|
||||
"""复制 eval 目录(数据量小,无需打乱)。"""
|
||||
src_eval = input_dir / "eval"
|
||||
dst_eval = output_dir / "eval"
|
||||
if not src_eval.exists():
|
||||
return
|
||||
logger.info(f"Copying eval data from {src_eval}")
|
||||
if dst_eval.exists():
|
||||
shutil.rmtree(dst_eval)
|
||||
shutil.copytree(src_eval, dst_eval)
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="打乱并平衡预处理 .npz 分片")
|
||||
parser.add_argument(
|
||||
"--input-dir", type=str, required=True, help="输入目录(含 train/ 和 eval/)"
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=1_000_000,
|
||||
help="输出分片大小(样本数),默认 100 万",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
|
||||
args = parser.parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_meta_path = input_dir / "train" / "metadata.json"
|
||||
if not train_meta_path.exists():
|
||||
console.print(f"[red]错误: {train_meta_path} 不存在[/red]")
|
||||
return
|
||||
|
||||
with open(train_meta_path) as f:
|
||||
train_meta = json.load(f)
|
||||
|
||||
max_seq_length = train_meta["max_seq_length"]
|
||||
shard_sizes: List[int] = train_meta["shard_sizes"]
|
||||
total_samples = sum(shard_sizes)
|
||||
|
||||
console.print("[bold cyan]=== 打乱并平衡预处理数据 ===[/bold cyan]")
|
||||
console.print(f"输入目录: {input_dir}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"总样本数: {total_samples:,}")
|
||||
console.print(f"源分片数: {len(shard_sizes)}")
|
||||
console.print(
|
||||
f"目标分片大小: {args.shard_size:,} "
|
||||
f"(约 {(total_samples + args.shard_size - 1) // args.shard_size} 个分片)"
|
||||
)
|
||||
console.print(f"随机种子: {args.seed}")
|
||||
console.print()
|
||||
|
||||
# 临时目录
|
||||
temp_dir = Path("/home/songsenand/tmp/shuffle_npz_temp")
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_train_dir = temp_dir / "train"
|
||||
temp_train_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
# ── Phase 1 ──
|
||||
console.print("[bold]Phase 1: 逐分片打乱 → 临时 .npy[/bold]")
|
||||
_phase1_shuffle_to_npy(
|
||||
input_dir=input_dir,
|
||||
split="train",
|
||||
temp_dir=temp_train_dir,
|
||||
shard_sizes=shard_sizes,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── Phase 2 ──
|
||||
console.print("[bold]Phase 2: 按比例分配到平衡输出分片[/bold]")
|
||||
output_shard_sizes = _phase2_rebalance_to_npz(
|
||||
temp_dir=temp_train_dir,
|
||||
output_dir=output_dir,
|
||||
split="train",
|
||||
shard_sizes=shard_sizes,
|
||||
target_shard_size=args.shard_size,
|
||||
max_seq_length=max_seq_length,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── 写 metadata ──
|
||||
train_output_meta = {
|
||||
"num_samples": sum(output_shard_sizes),
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": args.shard_size,
|
||||
"num_shards": len(output_shard_sizes),
|
||||
"shard_sizes": output_shard_sizes,
|
||||
"pre_shuffled": True,
|
||||
"seed": args.seed,
|
||||
}
|
||||
out_train_dir = output_dir / "train"
|
||||
with open(out_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(train_output_meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# ── 复制 eval ──
|
||||
_copy_eval(input_dir, output_dir)
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
console.print("[dim]清理临时文件...[/dim]")
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# ── 总结 ──
|
||||
console.print()
|
||||
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||
console.print(
|
||||
f"train/: {sum(output_shard_sizes):,} 样本, "
|
||||
f"{len(output_shard_sizes)} 分片, "
|
||||
f"pre_shuffled=True"
|
||||
)
|
||||
|
||||
for sdir_name in ["train", "eval"]:
|
||||
sdir = output_dir / sdir_name
|
||||
if sdir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
meta_path = sdir / "metadata.json"
|
||||
if meta_path.exists():
|
||||
with open(meta_path) as mf:
|
||||
meta = json.load(mf)
|
||||
else:
|
||||
meta = {}
|
||||
console.print(
|
||||
f" {sdir_name}/: {total_size / (1024**3):.2f} GB, "
|
||||
f"{meta.get('num_shards', '?')} shards"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,464 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
子采样脚本:从预处理的 .npz 分片中抽取子集。
|
||||
|
||||
策略:
|
||||
1. 第1遍扫描:只读 labels 字段,统计每个 label ID 的样本数,记录分片大小
|
||||
2. 中间计算:按每 ID 最多 N 个样本封顶(硬封顶),不足目标总量则从超额池补足
|
||||
3. 第2遍扫描:读取全部字段,按精确保留配额抽取训练样本,同时抽取评估集
|
||||
|
||||
用法:
|
||||
python -m model.subsample \
|
||||
--input-dir ./preprocessed \
|
||||
--output-dir ./subsampled \
|
||||
--cap-per-label 300000 \
|
||||
--target-total 100000000 \
|
||||
--num-eval 2560
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _global_to_shard(
|
||||
positions: np.ndarray, shard_sizes: List[int]
|
||||
) -> Dict[int, List[int]]:
|
||||
"""将全局位置映射为 {shard_idx: [local_indices]}"""
|
||||
cumsum = np.cumsum([0] + shard_sizes)
|
||||
mapping: Dict[int, List[int]] = {}
|
||||
for pos in positions:
|
||||
shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1)
|
||||
local_idx = pos - cumsum[shard_idx]
|
||||
mapping.setdefault(shard_idx, []).append(local_idx)
|
||||
return mapping
|
||||
|
||||
|
||||
def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]:
|
||||
"""
|
||||
第1遍扫描:只读 labels,统计每个 label ID 的总样本数,记录各分片大小。
|
||||
返回 (label_counts, shard_sizes)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
label_counts: Dict[int, int] = {}
|
||||
shard_sizes: List[int] = []
|
||||
|
||||
pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard")
|
||||
|
||||
for shard_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
labels = data["labels"]
|
||||
n = len(labels)
|
||||
shard_sizes.append(n)
|
||||
|
||||
unique, counts = np.unique(labels, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
uid = int(uid)
|
||||
label_counts[uid] = label_counts.get(uid, 0) + int(cnt)
|
||||
|
||||
data.close()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
return label_counts, shard_sizes
|
||||
|
||||
|
||||
def compute_quotas(
|
||||
label_counts: Dict[int, int],
|
||||
cap: int = 300_000,
|
||||
target_total: int = 100_000_000,
|
||||
) -> Dict[int, int]:
|
||||
"""
|
||||
计算每个 label 的精确保留配额(硬封顶)。
|
||||
|
||||
策略:
|
||||
- 每 ID 保留 min(count, cap)
|
||||
- 若封顶后总量 >= target_total,直接用封顶策略
|
||||
- 若封顶后总量 < target_total,从超额池(count > cap 的部分)等比抽取补足
|
||||
"""
|
||||
capped_total = sum(min(cnt, cap) for cnt in label_counts.values())
|
||||
|
||||
if capped_total >= target_total:
|
||||
return {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||
|
||||
# 封顶不足,从超额池补足
|
||||
deficit = target_total - capped_total
|
||||
excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()}
|
||||
excess_total = sum(excess_per_label.values())
|
||||
|
||||
quotas: Dict[int, int] = {}
|
||||
if excess_total > 0:
|
||||
ratio = min(1.0, deficit / excess_total)
|
||||
for label, cnt in label_counts.items():
|
||||
base = min(cnt, cap)
|
||||
extra = int(excess_per_label[label] * ratio)
|
||||
quotas[label] = base + extra
|
||||
else:
|
||||
quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||
|
||||
return quotas
|
||||
|
||||
|
||||
def pass2_subsample(
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
split: str,
|
||||
quotas: Dict[int, int],
|
||||
eval_map: Dict[int, List[int]],
|
||||
shard_sizes: List[int],
|
||||
shard_size: int = 5_000_000,
|
||||
shuffle: bool = True,
|
||||
seed: int = 42,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
||||
|
||||
quotas: {label_id: exact_number_to_keep}
|
||||
返回 (train_kept, eval_kept)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
max_seq_length = metadata["max_seq_length"]
|
||||
|
||||
output_train_dir = output_dir / "train"
|
||||
output_eval_dir = output_dir / "eval"
|
||||
output_train_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 清理旧输出分片(避免残留文件污染)
|
||||
for old_shard in sorted(output_train_dir.glob("shard_*.npz")):
|
||||
old_shard.unlink()
|
||||
logger.debug(f"Removed old shard: {old_shard.name}")
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
shuffle_rng = np.random.RandomState(seed + 1)
|
||||
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
||||
|
||||
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
train_buf_count = 0
|
||||
train_shard_idx = 0
|
||||
train_shard_sizes: List[int] = []
|
||||
total_train_kept = 0
|
||||
|
||||
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
total_eval_kept = 0
|
||||
|
||||
pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard")
|
||||
|
||||
for src_shard_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
n = shard_sizes[src_shard_idx]
|
||||
|
||||
eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64)
|
||||
|
||||
is_eval = np.zeros(n, dtype=bool)
|
||||
if len(eval_local) > 0:
|
||||
is_eval[eval_local] = True
|
||||
|
||||
# ── 第一步:只加载 labels,计算掩码 ──
|
||||
labels = data["labels"]
|
||||
|
||||
# 抽取评估样本的 labels
|
||||
if len(eval_local) > 0:
|
||||
eval_buffers["labels"].append(labels[eval_local].copy())
|
||||
|
||||
# 计算训练保留位置
|
||||
train_candidates = labels[~is_eval]
|
||||
train_n = len(train_candidates)
|
||||
n_kept = 0
|
||||
keep_original_indices = np.array([], dtype=np.int64)
|
||||
|
||||
if train_n > 0:
|
||||
sort_idx = np.argsort(train_candidates, kind="stable")
|
||||
sorted_labels = train_candidates[sort_idx]
|
||||
unique_vals, starts = np.unique(sorted_labels, return_index=True)
|
||||
ends = np.append(starts[1:], train_n)
|
||||
|
||||
train_keep_mask = np.zeros(train_n, dtype=bool)
|
||||
for i in range(len(unique_vals)):
|
||||
label_val = int(unique_vals[i])
|
||||
start, end = starts[i], ends[i]
|
||||
cnt_in_shard = end - start
|
||||
need = remaining.get(label_val, 0)
|
||||
select = min(cnt_in_shard, need)
|
||||
if select >= cnt_in_shard:
|
||||
train_keep_mask[sort_idx[start:end]] = True
|
||||
elif select > 0:
|
||||
chosen = rng.choice(cnt_in_shard, size=select, replace=False)
|
||||
train_keep_mask[sort_idx[start + chosen]] = True
|
||||
remaining[label_val] = max(0, remaining.get(label_val, 0) - select)
|
||||
|
||||
original_train_indices = np.where(~is_eval)[0]
|
||||
keep_original_indices = original_train_indices[train_keep_mask]
|
||||
n_kept = len(keep_original_indices)
|
||||
|
||||
# 释放标签相关的临时数组
|
||||
del train_candidates, sort_idx, sorted_labels
|
||||
del unique_vals, starts, ends, train_keep_mask
|
||||
del original_train_indices
|
||||
gc.collect()
|
||||
|
||||
del labels, is_eval
|
||||
gc.collect()
|
||||
|
||||
# ── 第二步:逐字段加载,抽取评估 + 训练 ──
|
||||
# 评估:跳过 labels(已在第一步抽取)
|
||||
if len(eval_local) > 0:
|
||||
for f in FIELDS:
|
||||
if f == "labels":
|
||||
continue
|
||||
arr = data[f]
|
||||
eval_buffers[f].append(arr[eval_local].copy())
|
||||
del arr
|
||||
gc.collect()
|
||||
total_eval_kept += len(eval_local)
|
||||
|
||||
# 训练:逐字段加载,立即删除
|
||||
if n_kept > 0:
|
||||
for f in FIELDS:
|
||||
arr = data[f]
|
||||
train_buffers[f].append(arr[keep_original_indices].copy())
|
||||
del arr
|
||||
gc.collect()
|
||||
train_buf_count += n_kept
|
||||
total_train_kept += n_kept
|
||||
|
||||
del keep_original_indices
|
||||
data.close()
|
||||
gc.collect()
|
||||
|
||||
if train_buf_count >= shard_size:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
train_shard_sizes.append(train_buf_count)
|
||||
logger.debug(
|
||||
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
||||
)
|
||||
train_shard_idx += 1
|
||||
train_buffers = {f: [] for f in FIELDS}
|
||||
train_buf_count = 0
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# 剩余缓冲
|
||||
if train_buf_count > 0:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
if shuffle and train_buf_count > 1:
|
||||
perm = shuffle_rng.permutation(train_buf_count)
|
||||
for f in FIELDS:
|
||||
merged[f] = merged[f][perm]
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
train_shard_sizes.append(train_buf_count)
|
||||
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
||||
train_shard_idx += 1
|
||||
|
||||
# 评估数据
|
||||
if total_eval_kept > 0:
|
||||
eval_merged = {}
|
||||
for f in FIELDS:
|
||||
eval_merged[f] = np.concatenate(eval_buffers[f], axis=0)
|
||||
np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged)
|
||||
else:
|
||||
logger.warning("No eval samples extracted!")
|
||||
|
||||
# 写 metadata
|
||||
train_metadata = {
|
||||
"num_samples": total_train_kept,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": train_shard_idx,
|
||||
"shard_sizes": train_shard_sizes,
|
||||
"pre_shuffled": shuffle,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
eval_metadata = {
|
||||
"num_samples": total_eval_kept,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": total_eval_kept,
|
||||
"num_shards": 1,
|
||||
}
|
||||
with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(eval_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return total_train_kept, total_eval_kept
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集")
|
||||
parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录")
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录")
|
||||
parser.add_argument(
|
||||
"--cap-per-label",
|
||||
type=int,
|
||||
default=300_000,
|
||||
help="每个 label ID 最大保留样本数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-total",
|
||||
type=int,
|
||||
default=100_000_000,
|
||||
help="目标训练集样本总数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="输出分片大小(样本数)",
|
||||
)
|
||||
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
||||
parser.add_argument(
|
||||
"--no-shuffle",
|
||||
action="store_false",
|
||||
dest="shuffle",
|
||||
help="禁用输出分片内部打乱",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="随机种子(用于标签选择 + 输出打乱)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]")
|
||||
console.print(f"输入目录: {input_dir}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
||||
console.print(f"目标训练集: {args.target_total:,}")
|
||||
console.print(f"评估集: {args.num_eval}")
|
||||
console.print(f"输出分片打乱: {'是' if args.shuffle else '否'}")
|
||||
console.print(f"随机种子: {args.seed}")
|
||||
console.print()
|
||||
|
||||
# ── 第 1 遍:统计 ──
|
||||
console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]")
|
||||
label_counts, shard_sizes = pass1_count(input_dir, "train")
|
||||
|
||||
total_samples = sum(shard_sizes)
|
||||
num_labels = len(label_counts)
|
||||
capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values())
|
||||
|
||||
console.print(f"分片数: {len(shard_sizes)}")
|
||||
console.print(f"总样本数: {total_samples:,}")
|
||||
console.print(f"标签种类: {num_labels}")
|
||||
console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}")
|
||||
console.print()
|
||||
|
||||
# ── 计算保留配额 ──
|
||||
quotas = compute_quotas(
|
||||
label_counts,
|
||||
cap=args.cap_per_label,
|
||||
target_total=args.target_total,
|
||||
)
|
||||
expected_train = sum(quotas.values())
|
||||
console.print(f"期望训练集大小: {expected_train:,}")
|
||||
|
||||
# 标签分布统计
|
||||
stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label)
|
||||
stats_n_below = num_labels - stats_n_capped
|
||||
stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label)
|
||||
console.print(
|
||||
f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, "
|
||||
f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}"
|
||||
)
|
||||
console.print()
|
||||
|
||||
# ── 抽取评估集位置 ──
|
||||
eval_rng = np.random.RandomState(args.seed + 100)
|
||||
eval_positions = eval_rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||
eval_positions.sort()
|
||||
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
||||
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
||||
console.print()
|
||||
|
||||
# ── 第 2 遍:子采样 ──
|
||||
console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]")
|
||||
train_count, eval_count = pass2_subsample(
|
||||
input_dir,
|
||||
output_dir,
|
||||
"train",
|
||||
quotas,
|
||||
eval_map,
|
||||
shard_sizes,
|
||||
args.shard_size,
|
||||
shuffle=args.shuffle,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# ── 输出总结 ──
|
||||
console.print()
|
||||
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||
console.print(f"训练集: {train_count:,} 样本")
|
||||
console.print(f"评估集: {eval_count:,} 样本")
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
sdir = output_dir / split
|
||||
if sdir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
meta_path = sdir / "metadata.json"
|
||||
if meta_path.exists():
|
||||
with open(meta_path) as mf:
|
||||
meta = json.load(mf)
|
||||
else:
|
||||
meta = {}
|
||||
console.print(
|
||||
f" {split}/: {total_size / (1024**3):.2f} GB, "
|
||||
f"{meta.get('num_shards', '?')} shards"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
缺失字符补充工具
|
||||
|
||||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
||||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
||||
步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录
|
||||
|
||||
用法:
|
||||
python -m model.supplement_missing find-missing \
|
||||
--preprocessed-dir ./preprocessed/train \
|
||||
--output missing_chars.json
|
||||
|
||||
python -m model.supplement_missing generate-template \
|
||||
--missing-chars missing_chars.json \
|
||||
--output supplement_texts.jsonl
|
||||
|
||||
python -m model.supplement_missing preprocess-supplement \
|
||||
--missing-chars missing_chars.json \
|
||||
--supplement-texts supplement_texts.jsonl \
|
||||
--output-dir ./preprocessed/supplement \
|
||||
--num-samples 100000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .preprocess import collect_samples
|
||||
from .query import QueryEngine
|
||||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||||
|
||||
|
||||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
||||
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
|
||||
appeared: Set[int] = set()
|
||||
|
||||
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
|
||||
if not shard_files:
|
||||
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
|
||||
return appeared
|
||||
|
||||
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
|
||||
data = np.load(shard_path)
|
||||
labels = data["labels"].astype(np.int64)
|
||||
if labels.ndim > 1 and labels.shape[-1] == 1:
|
||||
labels = labels.squeeze(-1)
|
||||
unique_ids = np.unique(labels)
|
||||
appeared.update(int(uid) for uid in unique_ids)
|
||||
del data
|
||||
|
||||
return appeared
|
||||
|
||||
|
||||
def cmd_find_missing(args):
|
||||
console = Console()
|
||||
preprocessed_dir = Path(args.preprocessed_dir)
|
||||
|
||||
if not preprocessed_dir.exists():
|
||||
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
|
||||
return
|
||||
|
||||
metadata_path = preprocessed_dir / "metadata.json"
|
||||
if not metadata_path.exists():
|
||||
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
console.print(
|
||||
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
|
||||
)
|
||||
|
||||
console.print("[bold cyan]扫描 labels...[/bold cyan]")
|
||||
appeared = scan_labels(preprocessed_dir)
|
||||
|
||||
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
all_ids = set(query_engine._id_to_info.keys())
|
||||
missing_ids = all_ids - appeared
|
||||
|
||||
missing_chars = []
|
||||
for mid in sorted(missing_ids):
|
||||
if mid == 0:
|
||||
continue
|
||||
info = query_engine.query_by_id(mid)
|
||||
if info is not None:
|
||||
missing_chars.append(
|
||||
{
|
||||
"id": info.id,
|
||||
"char": info.char,
|
||||
"pinyin": info.pinyin,
|
||||
"count": info.count,
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
"missing_count": len(missing_chars),
|
||||
"missing_chars": missing_chars,
|
||||
}
|
||||
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
|
||||
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
|
||||
console.print(f"已出现标签: {len(appeared):,}")
|
||||
console.print(
|
||||
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
|
||||
)
|
||||
|
||||
if missing_chars:
|
||||
table = Table(
|
||||
title=f"缺失字符 (共 {len(missing_chars)} 个)",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table.add_column("ID", style="cyan", width=8)
|
||||
table.add_column("字符", style="yellow", width=6)
|
||||
table.add_column("拼音", style="green", width=12)
|
||||
table.add_column("语料频次", style="red", width=12)
|
||||
for entry in missing_chars:
|
||||
table.add_row(
|
||||
str(entry["id"]),
|
||||
entry["char"],
|
||||
entry["pinyin"],
|
||||
f"{entry['count']:,}",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
console.print(f"\n已输出到: {output_path}")
|
||||
|
||||
|
||||
def cmd_generate_template(args):
|
||||
console = Console()
|
||||
|
||||
missing_path = Path(args.missing_chars)
|
||||
if not missing_path.exists():
|
||||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(missing_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
missing_chars = data.get("missing_chars", [])
|
||||
if not missing_chars:
|
||||
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
|
||||
return
|
||||
|
||||
num_entries = args.num_entries
|
||||
total_lines = len(missing_chars) * num_entries
|
||||
|
||||
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
|
||||
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
|
||||
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
|
||||
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for entry in missing_chars:
|
||||
for i in range(num_entries):
|
||||
line = json.dumps(
|
||||
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write(line + "\n")
|
||||
|
||||
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
|
||||
console.print(
|
||||
f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
|
||||
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
|
||||
)
|
||||
|
||||
|
||||
def cmd_preprocess_supplement(args):
|
||||
console = Console()
|
||||
|
||||
# 加载缺失字符
|
||||
missing_path = Path(args.missing_chars)
|
||||
if not missing_path.exists():
|
||||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(missing_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
missing_chars = data.get("missing_chars", [])
|
||||
if not missing_chars:
|
||||
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
|
||||
return
|
||||
|
||||
target_labels = {entry["id"] for entry in missing_chars}
|
||||
target_labels.add(0) # 包含 EOS
|
||||
|
||||
# 解析参数
|
||||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||||
length_weights = {
|
||||
int(k): int(v)
|
||||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||||
}
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
max_iter = args.num_samples * 5
|
||||
num_workers = args.num_workers
|
||||
|
||||
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
|
||||
console.print(f"补充文本: {args.supplement_texts}")
|
||||
console.print(f"缺失字符数: {len(missing_chars)}")
|
||||
console.print(f"目标样本: {args.num_samples:,}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"Worker 数: {num_workers}")
|
||||
console.print()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
|
||||
dataset = PinyinInputDataset(
|
||||
data_path="json",
|
||||
max_workers=num_workers,
|
||||
max_iter_length=max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
data_kwargs={
|
||||
"data_files": args.supplement_texts,
|
||||
"streaming": False,
|
||||
},
|
||||
target_labels=target_labels,
|
||||
)
|
||||
|
||||
dataloader_kwargs = {
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": num_workers,
|
||||
"pin_memory": False,
|
||||
"worker_init_fn": worker_init_fn,
|
||||
"collate_fn": preprocess_collate_fn(args.max_seq_length),
|
||||
}
|
||||
if num_workers > 0:
|
||||
dataloader_kwargs["prefetch_factor"] = 2
|
||||
dataloader_kwargs["persistent_workers"] = True
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||||
|
||||
logger.info("开始收集补充数据...")
|
||||
count = collect_samples(
|
||||
dataloader,
|
||||
args.num_samples,
|
||||
output_dir,
|
||||
"supplement",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if count < args.num_samples:
|
||||
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
|
||||
|
||||
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
|
||||
console.print(f"生成样本: {count:,}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
|
||||
console.print()
|
||||
console.print(
|
||||
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="缺失字符补充工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
子命令:
|
||||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
||||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
||||
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录)
|
||||
|
||||
示例:
|
||||
python -m model.supplement_missing find-missing \\
|
||||
--preprocessed-dir ./preprocessed/train \\
|
||||
--output missing_chars.json
|
||||
|
||||
python -m model.supplement_missing generate-template \\
|
||||
--missing-chars missing_chars.json \\
|
||||
--output supplement_texts.jsonl
|
||||
|
||||
python -m model.supplement_missing preprocess-supplement \\
|
||||
--missing-chars missing_chars.json \\
|
||||
--supplement-texts supplement_texts.jsonl \\
|
||||
--output-dir ./preprocessed/supplement \\
|
||||
--num-samples 100000
|
||||
""",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
||||
# find-missing
|
||||
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
|
||||
p_find.add_argument(
|
||||
"--preprocessed-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="预处理数据目录(包含 shard_*.npz 和 metadata.json)",
|
||||
)
|
||||
p_find.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="missing_chars.json",
|
||||
help="输出 JSON 文件路径(默认: missing_chars.json)",
|
||||
)
|
||||
|
||||
# generate-template
|
||||
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
|
||||
p_gen.add_argument(
|
||||
"--missing-chars",
|
||||
type=str,
|
||||
required=True,
|
||||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||||
)
|
||||
p_gen.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="supplement_texts.jsonl",
|
||||
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)",
|
||||
)
|
||||
p_gen.add_argument(
|
||||
"--num-entries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="每个缺失字符生成的模板条数(默认: 3)",
|
||||
)
|
||||
|
||||
# preprocess-supplement
|
||||
p_pre = subparsers.add_parser(
|
||||
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--missing-chars",
|
||||
type=str,
|
||||
required=True,
|
||||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--supplement-texts",
|
||||
type=str,
|
||||
required=True,
|
||||
help="已填写的补充文本 JSONL 文件路径",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="输出目录(独立目录,不会覆盖已有数据)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
required=True,
|
||||
help="目标样本数量",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="批大小(默认: 128)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
default=128,
|
||||
help="最大序列长度(默认: 128)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="随机种子(默认: 42)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="分片大小(样本数),控制内存峰值(默认: 500万)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--py-style-weight",
|
||||
type=str,
|
||||
default="9,2,1",
|
||||
help="拼音风格权重(逗号分隔,默认: 9,2,1)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--length-weights",
|
||||
type=str,
|
||||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||||
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "find-missing":
|
||||
cmd_find_missing(args)
|
||||
elif args.command == "generate-template":
|
||||
cmd_generate_template(args)
|
||||
elif args.command == "preprocess-supplement":
|
||||
cmd_preprocess_supplement(args)
|
||||
|
||||
|
||||
app = main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -28,6 +28,11 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .preprocessed_dataset import (
|
||||
PreProcessedDataset,
|
||||
is_preprocessed_data,
|
||||
preprocessed_collate_fn,
|
||||
)
|
||||
|
||||
# 导入模型和数据
|
||||
from .model import InputMethodEngine
|
||||
|
|
@ -97,6 +102,11 @@ class Trainer:
|
|||
"""
|
||||
self.model = model
|
||||
self.train_dataloader = train_dataloader
|
||||
if isinstance(eval_dataloader, DataLoader) and not isinstance(
|
||||
eval_dataloader.dataset, torch.utils.data.IterableDataset
|
||||
):
|
||||
self.eval_dataloader = eval_dataloader
|
||||
else:
|
||||
self.eval_dataloader = list([i for i in eval_dataloader])
|
||||
self.output_dir = Path(output_dir)
|
||||
self.num_epochs = num_epochs
|
||||
|
|
@ -620,14 +630,14 @@ class Trainer:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to write training status: {e}")
|
||||
|
||||
def _create_progress_bar(self) -> Progress:
|
||||
"""创建Rich进度条"""
|
||||
def _create_progress(self) -> Progress:
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TextColumn("|"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("|"),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
|
|
@ -694,20 +704,30 @@ class Trainer:
|
|||
accumulated_accuracy = 0.0
|
||||
accumulation_counter = 0
|
||||
|
||||
# 创建进度条
|
||||
with self._create_progress_bar() as progress:
|
||||
# 计算每个 epoch 的步数
|
||||
steps_per_epoch = max(1, self.total_steps // self.num_epochs)
|
||||
|
||||
# 创建双进度条:epoch 级 + epoch 内 batch 级
|
||||
with self._create_progress() as progress:
|
||||
epoch_task = progress.add_task(
|
||||
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
|
||||
total=self.total_steps,
|
||||
total=self.num_epochs,
|
||||
completed=self.current_epoch,
|
||||
)
|
||||
batch_task = progress.add_task(
|
||||
"[green]Batch",
|
||||
total=steps_per_epoch,
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(self.current_epoch, self.num_epochs):
|
||||
self.current_epoch = epoch
|
||||
progress.update(
|
||||
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
|
||||
progress.reset(
|
||||
batch_task,
|
||||
total=steps_per_epoch,
|
||||
description=f"[green]Epoch {epoch + 1} Step 0/{steps_per_epoch}",
|
||||
)
|
||||
|
||||
epoch_step = 0
|
||||
for batch_idx, batch in enumerate(self.train_dataloader):
|
||||
# 更新学习率
|
||||
current_lr = self._update_learning_rate()
|
||||
|
|
@ -733,16 +753,18 @@ class Trainer:
|
|||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 更新进度条
|
||||
# 更新 batch 进度条 (每 step 推进一次)
|
||||
progress.update(
|
||||
epoch_task,
|
||||
batch_task,
|
||||
advance=1,
|
||||
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
|
||||
f"Step {global_step}/{self.total_steps} | "
|
||||
f"Loss: {loss:.4f} | "
|
||||
f"LR: {current_lr:.2e}",
|
||||
description=f"[green]Epoch {epoch + 1} "
|
||||
f"Step {epoch_step + 1}/{steps_per_epoch}"
|
||||
f" | Loss: {loss:.4f}"
|
||||
f" | LR: {current_lr:.2e}",
|
||||
)
|
||||
|
||||
epoch_step += 1
|
||||
|
||||
# 定期评估和记录
|
||||
if (global_step + 1) % self.eval_frequency == 0:
|
||||
# 计算平均指标
|
||||
|
|
@ -770,7 +792,6 @@ class Trainer:
|
|||
# 更新最佳模型
|
||||
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||
# 只保存best_model,不创建额外的checkpoint文件
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 记录到TensorBoard
|
||||
|
|
@ -791,7 +812,7 @@ class Trainer:
|
|||
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||||
)
|
||||
|
||||
progress.console.log(log_text)
|
||||
progress.log(log_text)
|
||||
|
||||
# 重置累积指标
|
||||
accumulated_loss = 0.0
|
||||
|
|
@ -808,12 +829,30 @@ class Trainer:
|
|||
|
||||
# 检查是否达到总步数
|
||||
if global_step >= self.total_steps:
|
||||
progress.update(epoch_task, completed=self.total_steps)
|
||||
break
|
||||
|
||||
# 进度条不重置,显示整体训练进度
|
||||
# epoch 内步数检查
|
||||
if epoch_step >= steps_per_epoch:
|
||||
break
|
||||
|
||||
# 每个 epoch 结束后保存检查点(循环覆盖,只保留最后 3 个)
|
||||
# epoch 完成
|
||||
progress.update(epoch_task, advance=1)
|
||||
|
||||
# 每个 epoch 结束后评估并保存 best model
|
||||
epoch_eval_metrics = self.evaluate()
|
||||
if epoch_eval_metrics:
|
||||
epoch_log_metrics = {
|
||||
"epoch/eval_loss": epoch_eval_metrics["eval_loss"],
|
||||
"epoch/eval_accuracy": epoch_eval_metrics["eval_accuracy"],
|
||||
}
|
||||
if self.writer is not None:
|
||||
for key, value in epoch_log_metrics.items():
|
||||
self.writer.add_scalar(key, value, global_step)
|
||||
if epoch_eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = epoch_eval_metrics["eval_loss"]
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 每个 epoch 结束后保存检查点
|
||||
self.save_epoch_checkpoint(epoch + 1)
|
||||
|
||||
# 检查是否达到总步数
|
||||
|
|
@ -823,6 +862,34 @@ class Trainer:
|
|||
# 训练完成
|
||||
logger.info("Training completed!")
|
||||
|
||||
# 最终评估并保存 best model
|
||||
logger.info("Running final evaluation...")
|
||||
final_eval_metrics = self.evaluate()
|
||||
if final_eval_metrics:
|
||||
final_log_metrics = {
|
||||
"train/loss": accumulated_loss / accumulation_counter
|
||||
if accumulation_counter > 0
|
||||
else 0.0,
|
||||
"train/accuracy": accumulated_accuracy / accumulation_counter
|
||||
if accumulation_counter > 0
|
||||
else 0.0,
|
||||
"train/learning_rate": self._get_current_lr(),
|
||||
"eval/loss": final_eval_metrics["eval_loss"],
|
||||
"eval/accuracy": final_eval_metrics["eval_accuracy"],
|
||||
}
|
||||
self._log_to_tensorboard(final_log_metrics, global_step)
|
||||
logger.info(
|
||||
f"Final eval loss: {final_eval_metrics['eval_loss']:.4f}, "
|
||||
f"Final eval accuracy: {final_eval_metrics['eval_accuracy']:.4f}"
|
||||
)
|
||||
if final_eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||
self.best_eval_loss = final_eval_metrics["eval_loss"]
|
||||
self.save_checkpoint("best_model.pt", is_best=True)
|
||||
|
||||
# 保存最终模型
|
||||
self.save_checkpoint("final_model.pt")
|
||||
logger.info("Final model saved.")
|
||||
|
||||
# 显示保存的 epoch checkpoint 信息
|
||||
if self.epoch_checkpoints:
|
||||
sorted_checkpoints = self.get_epoch_checkpoints()
|
||||
|
|
@ -965,25 +1032,62 @@ def worker_init_fn(worker_id: int) -> None:
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
自定义批处理函数,将多个样本组合成一个batch
|
||||
自定义批处理函数,将多个样本组合成一个batch。
|
||||
支持动态填充:根据batch内最大序列长度进行padding,而非固定max_length。
|
||||
当 max_seq_length > 0 时,pad到指定长度(用于预处理)。
|
||||
|
||||
Args:
|
||||
batch: 样本列表,每个样本是一个字典
|
||||
max_seq_length: 目标序列长度,0表示动态padding
|
||||
|
||||
Returns:
|
||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||
"""
|
||||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||
input_ids_list = [item["input_ids"] for item in batch]
|
||||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||||
attention_mask_list = [item["attention_mask"] for item in batch]
|
||||
|
||||
if max_seq_length > 0:
|
||||
target_len = max_seq_length
|
||||
else:
|
||||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||||
|
||||
padded_input_ids = []
|
||||
padded_token_type_ids = []
|
||||
padded_attention_mask = []
|
||||
for ids, tt_ids, mask in zip(
|
||||
input_ids_list, token_type_ids_list, attention_mask_list
|
||||
):
|
||||
seq_len = ids.shape[0]
|
||||
if seq_len < target_len:
|
||||
pad_len = target_len - seq_len
|
||||
padded_input_ids.append(
|
||||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||||
)
|
||||
padded_token_type_ids.append(
|
||||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||||
)
|
||||
padded_attention_mask.append(
|
||||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||||
)
|
||||
elif seq_len > target_len:
|
||||
padded_input_ids.append(ids[:target_len])
|
||||
padded_token_type_ids.append(tt_ids[:target_len])
|
||||
padded_attention_mask.append(mask[:target_len])
|
||||
else:
|
||||
padded_input_ids.append(ids)
|
||||
padded_token_type_ids.append(tt_ids)
|
||||
padded_attention_mask.append(mask)
|
||||
|
||||
input_ids = torch.stack(padded_input_ids)
|
||||
token_type_ids = torch.stack(padded_token_type_ids)
|
||||
attention_mask = torch.stack(padded_attention_mask)
|
||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||
|
||||
# 字符串字段保持为列表
|
||||
prefixes = [item["prefix"] for item in batch]
|
||||
suffixes = [item["suffix"] for item in batch]
|
||||
pinyins = [item["pinyin"] for item in batch]
|
||||
|
|
@ -1001,9 +1105,18 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def preprocess_collate_fn(max_seq_length: int):
|
||||
"""创建用于预处理的collate_fn,始终pad到max_seq_length"""
|
||||
|
||||
def _collate(batch):
|
||||
return collate_fn(batch, max_seq_length=max_seq_length)
|
||||
|
||||
return _collate
|
||||
|
||||
|
||||
# Typer CLI应用
|
||||
def create_dataloader(
|
||||
dataset: PinyinInputDataset,
|
||||
dataset,
|
||||
batch_size: int,
|
||||
num_workers: int = 2,
|
||||
pin_memory: bool = True,
|
||||
|
|
@ -1011,29 +1124,35 @@ def create_dataloader(
|
|||
max_iter_length: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
创建数据加载器,优先使用DataLoader2,如果不可用则回退到DataLoader。
|
||||
专门针对流式数据集优化。
|
||||
创建数据加载器,自动识别数据集类型。
|
||||
|
||||
Args:
|
||||
dataset: PinyinInputDataset实例
|
||||
batch_size: 批次大小
|
||||
num_workers: worker数量(对于流式数据集建议为2)
|
||||
pin_memory: 是否固定内存
|
||||
shuffle: 是否打乱(流式数据集内部处理打乱)
|
||||
max_iter_length: 最大迭代长度,用于计算总步数
|
||||
|
||||
Returns:
|
||||
数据加载器实例
|
||||
- PinyinInputDataset(IterableDataset):使用流式加载
|
||||
- PreProcessedDataset(map-style):使用标准加载,支持 shuffle
|
||||
"""
|
||||
if isinstance(dataset, PreProcessedDataset):
|
||||
logger.info(f"📊 使用预处理数据集,样本数: {len(dataset)}")
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=preprocessed_collate_fn,
|
||||
drop_last=True,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
)
|
||||
|
||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||
fixed_max_seq_length = getattr(dataset, "max_seq_length", 128)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2, # 减少预取以避免内存问题
|
||||
collate_fn=preprocess_collate_fn(fixed_max_seq_length),
|
||||
drop_last=True,
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
|
@ -1098,6 +1217,11 @@ def train(
|
|||
"--compile/--no-compile",
|
||||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
||||
),
|
||||
moe_mode: str = typer.Option(
|
||||
"all",
|
||||
"--moe-mode",
|
||||
help="MoE 计算策略: all(全量计算), sparse(稀疏计算), sparse_allow_graph(稀疏+allow_in_graph)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
训练输入法模型
|
||||
|
|
@ -1152,6 +1276,7 @@ def train(
|
|||
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
||||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||||
config_table.add_row("模型", "编译优化", str(compile))
|
||||
config_table.add_row("模型", "MoE策略", moe_mode)
|
||||
|
||||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||||
|
|
@ -1164,12 +1289,97 @@ def train(
|
|||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||
|
||||
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||
console.print(config_table)
|
||||
|
||||
# 创建输出目录
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检测数据类型并创建数据加载器
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
is_train_preprocessed = is_preprocessed_data(train_data_path)
|
||||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
||||
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=2)
|
||||
pre_shuffled = train_dataset.metadata.get("pre_shuffled", False)
|
||||
# 预打乱数据不需要 DataLoader 的 RandomSampler(避免跨分片解压抖动)
|
||||
shuffle_train = not pre_shuffled
|
||||
if max_iter_length > 0:
|
||||
max_samples_per_epoch = max_iter_length
|
||||
capped_samples = min(len(train_dataset), max_samples_per_epoch)
|
||||
else:
|
||||
capped_samples = len(train_dataset)
|
||||
total_steps = (capped_samples // batch_size) * num_epochs
|
||||
train_num_workers = min(num_workers, 1)
|
||||
logger.info(
|
||||
f"Preprocessed dataset: {len(train_dataset):,} samples, "
|
||||
f"shuffle={shuffle_train}, pre_shuffled={pre_shuffled}, "
|
||||
f"workers={train_num_workers}, steps={total_steps:,}"
|
||||
)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=train_num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=shuffle_train,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||
config_table.add_row("数据", "预打乱", str(pre_shuffled))
|
||||
else:
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=train_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=max_iter_length,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
total_steps = int(max_iter_length * num_epochs / batch_size)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=max_iter_length,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "流式数据")
|
||||
|
||||
if is_eval_preprocessed:
|
||||
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=0,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
||||
else:
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=batch_size * 64,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=batch_size * 64,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "流式数据")
|
||||
|
||||
config_table.add_row("数据", "总步数", str(total_steps))
|
||||
console.print(config_table)
|
||||
|
||||
# 保存配置
|
||||
config = {
|
||||
"train_data_path": train_data_path,
|
||||
|
|
@ -1202,6 +1412,10 @@ def train(
|
|||
"auto_resume": auto_resume,
|
||||
"max_iter_length": max_iter_length,
|
||||
"compile": compile,
|
||||
"moe_mode": moe_mode,
|
||||
"is_train_preprocessed": is_train_preprocessed,
|
||||
"is_eval_preprocessed": is_eval_preprocessed,
|
||||
"total_steps": total_steps,
|
||||
}
|
||||
|
||||
config_file = output_path / "training_config.json"
|
||||
|
|
@ -1210,52 +1424,6 @@ def train(
|
|||
|
||||
logger.info(f"Configuration saved to {config_file}")
|
||||
|
||||
# 创建数据加载器
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
# 训练数据集
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=train_data_path,
|
||||
max_workers=-1, # 自动选择worker数量
|
||||
max_iter_length=max_iter_length,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
# 训练数据加载器
|
||||
# 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数
|
||||
# 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=max_iter_length,
|
||||
)
|
||||
|
||||
# 评估数据集(使用相同的设置,但可以调整参数)
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=batch_size * 64, # 评估集较小
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2, # 评估使用较少的worker
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=batch_size * 64,
|
||||
)
|
||||
|
||||
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
||||
model = InputMethodEngine(
|
||||
vocab_size=vocab_size,
|
||||
|
|
@ -1267,6 +1435,7 @@ def train(
|
|||
num_experts=num_experts,
|
||||
max_seq_len=max_seq_len,
|
||||
compile=compile,
|
||||
moe_mode=moe_mode,
|
||||
)
|
||||
|
||||
console.print(
|
||||
|
|
@ -1279,7 +1448,7 @@ def train(
|
|||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=int(max_iter_length * num_epochs / batch_size),
|
||||
total_steps=total_steps,
|
||||
output_dir=output_dir,
|
||||
num_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
|
|
@ -1339,21 +1508,46 @@ def evaluate(
|
|||
@app.command()
|
||||
def export(
|
||||
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||||
output_path: str = typer.Option(
|
||||
"./exported_model.onnx", "--output", "-o", help="输出路径"
|
||||
output_dir: str = typer.Option(
|
||||
"./exported_models", "--output", "-o", help="输出目录"
|
||||
),
|
||||
device: str = typer.Option("cpu", "--device", help="导出设备(cpu/cuda)"),
|
||||
skip_verification: bool = typer.Option(
|
||||
False, "--skip-verification", help="跳过ONNX模型验证"
|
||||
),
|
||||
):
|
||||
"""
|
||||
导出模型为ONNX格式
|
||||
导出模型为ONNX格式(生成 context_encoder.onnx 和 decoder.onnx)
|
||||
"""
|
||||
from .onnx_export import check_onnx_available, run_full_export
|
||||
|
||||
console = Console()
|
||||
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
|
||||
console.print("[bold cyan]━━━ ONNX 模型导出 ━━━[/bold cyan]")
|
||||
console.print(f" 检查点: {checkpoint_path}")
|
||||
console.print(f" 输出目录: {output_dir}")
|
||||
console.print(f" 设备: {device}")
|
||||
|
||||
# 这里应该实现导出逻辑
|
||||
# 1. 加载检查点
|
||||
# 2. 导出为ONNX
|
||||
if not check_onnx_available():
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print("[yellow]导出功能待实现[/yellow]")
|
||||
try:
|
||||
context_encoder_path, decoder_path, config = run_full_export(
|
||||
checkpoint_path=checkpoint_path,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
skip_verification=skip_verification,
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]导出失败: {e}[/bold red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print("\n[bold green]✓ 导出完成![/bold green]")
|
||||
console.print(f" context_encoder.onnx -> {context_encoder_path}")
|
||||
console.print(f" decoder.onnx -> {decoder_path}")
|
||||
console.print(
|
||||
f"\n MoE 层使用 'all' 模式(全量计算 {config.get('num_experts', '?')} 个专家),"
|
||||
f"稀疏化优化可后续迭代"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("src")
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
import time
|
||||
import torch
|
||||
|
|
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
|
|||
from torch.utils.data import IterableDataset
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
|
@ -47,8 +48,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
|||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||
|
||||
|
||||
part1 = "杉杉看了柳柳一眼,默默地同情了一下。她这个堂姐长得非常"
|
||||
part2 = "piaoliang"
|
||||
part1 = "从中国回家后,他觉得世界上最好的城市就是"
|
||||
part2 = "shanghai"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
|
|
@ -56,9 +57,9 @@ if len_py < 24:
|
|||
else:
|
||||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||
masked_labels = [1986, 0, 0, 0, 0, 0, 0, 0]
|
||||
masked_labels = [22, 0, 0, 0, 0, 0, 0, 0]
|
||||
part3 = ""
|
||||
part4 = "可行|特别|伤害"
|
||||
part4 = ""
|
||||
|
||||
encoded = tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -83,8 +84,11 @@ sample = {
|
|||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.ptrom", map_location="cpu")
|
||||
checkpoint = torch.load(
|
||||
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
|
||||
)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
token_type_ids = sample["token_type_ids"]
|
||||
|
|
@ -97,8 +101,9 @@ for k, v in sample.items():
|
|||
print(f"{k}: {v}")
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
|
@ -392,7 +392,13 @@ def check_compile_issues():
|
|||
issues = []
|
||||
|
||||
# 检查 components.py 中的潜在问题
|
||||
with open("src/model/components.py", "r") as f:
|
||||
components_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"src",
|
||||
"model",
|
||||
"components.py",
|
||||
)
|
||||
with open(components_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# 检查 float('-inf')
|
||||
|
|
@ -4,9 +4,13 @@
|
|||
解决设备转换和权重加载问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -196,7 +200,7 @@ def test_id_mapping():
|
|||
|
||||
query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.model import InputMethodEngine
|
||||
from model.query import QueryEngine
|
||||
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from modelscope import AutoTokenizer
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials
|
||||
from torch.utils.data import IterableDataset
|
||||
from model.dataset import PinyinInputDataset
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int) -> None:
|
||||
"""
|
||||
初始化每个DataLoader worker的随机种子,确保可复现性
|
||||
|
||||
Args:
|
||||
worker_id: worker的ID
|
||||
"""
|
||||
worker_seed = torch.initial_seed() % (2**32)
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
自定义批处理函数,将多个样本组合成一个batch。
|
||||
支持动态padding:根据batch内最大序列长度进行padding。
|
||||
"""
|
||||
input_ids_list = [item["input_ids"] for item in batch]
|
||||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||||
attention_mask_list = [item["attention_mask"] for item in batch]
|
||||
|
||||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||||
|
||||
padded_input_ids = []
|
||||
padded_token_type_ids = []
|
||||
padded_attention_mask = []
|
||||
for ids, tt_ids, mask in zip(
|
||||
input_ids_list, token_type_ids_list, attention_mask_list
|
||||
):
|
||||
seq_len = ids.shape[0]
|
||||
if seq_len < target_len:
|
||||
pad_len = target_len - seq_len
|
||||
padded_input_ids.append(
|
||||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||||
)
|
||||
padded_token_type_ids.append(
|
||||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||||
)
|
||||
padded_attention_mask.append(
|
||||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||||
)
|
||||
elif seq_len > target_len:
|
||||
padded_input_ids.append(ids[:target_len])
|
||||
padded_token_type_ids.append(tt_ids[:target_len])
|
||||
padded_attention_mask.append(mask[:target_len])
|
||||
else:
|
||||
padded_input_ids.append(ids)
|
||||
padded_token_type_ids.append(tt_ids)
|
||||
padded_attention_mask.append(mask)
|
||||
|
||||
input_ids = torch.stack(padded_input_ids)
|
||||
token_type_ids = torch.stack(padded_token_type_ids)
|
||||
attention_mask = torch.stack(padded_attention_mask)
|
||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||
|
||||
prefixes = [item["prefix"] for item in batch]
|
||||
suffixes = [item["suffix"] for item in batch]
|
||||
pinyins = [item["pinyin"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"history_slot_ids": history_slot_ids,
|
||||
"prefix": prefixes,
|
||||
"suffix": suffixes,
|
||||
"pinyin": pinyins,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
|
||||
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path="/home/songsenand/Data/corpus/CCI-Data/",
|
||||
max_workers=-1, # 自动选择worker数量
|
||||
max_iter_length=1000000,
|
||||
text_field="text",
|
||||
py_style_weight=(90, 2, 1),
|
||||
shuffle_buffer_size=20000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=512,
|
||||
num_workers=16,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2, # 减少预取以避免内存问题
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
|
||||
pass
|
||||
|
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
|||
from rich.console import Console
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.model.model import InputMethodEngine
|
||||
from src.model.trainer import Trainer
|
||||
|
|
@ -0,0 +1,409 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
ONNX模型验证脚本
|
||||
|
||||
验证导出的ONNX模型与原始PyTorch模型输出的一致性
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.model.export_models import create_export_models_from_checkpoint
|
||||
|
||||
|
||||
def compare_outputs(pytorch_output, onnx_output, name="output", rtol=1e-3, atol=1e-5):
|
||||
"""
|
||||
比较PyTorch和ONNX输出
|
||||
|
||||
Args:
|
||||
pytorch_output: PyTorch张量
|
||||
onnx_output: ONNX Runtime输出(numpy数组)
|
||||
name: 输出名称(用于错误信息)
|
||||
rtol: 相对容差
|
||||
atol: 绝对容差
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
# 转换PyTorch输出为numpy
|
||||
if isinstance(pytorch_output, torch.Tensor):
|
||||
pytorch_np = pytorch_output.detach().cpu().numpy()
|
||||
else:
|
||||
pytorch_np = np.array(pytorch_output)
|
||||
|
||||
# 确保形状一致
|
||||
if pytorch_np.shape != onnx_output.shape:
|
||||
print(
|
||||
f"❌ {name} 形状不匹配: PyTorch {pytorch_np.shape} != ONNX {onnx_output.shape}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 计算差异
|
||||
diff = np.abs(pytorch_np - onnx_output)
|
||||
max_diff = np.max(diff)
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
# 检查是否在容差范围内
|
||||
is_close = np.allclose(pytorch_np, onnx_output, rtol=rtol, atol=atol)
|
||||
|
||||
if is_close:
|
||||
print(f"✅ {name} 匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
else:
|
||||
print(f"❌ {name} 不匹配: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
print(
|
||||
f" 范围: PyTorch [{pytorch_np.min():.6f}, {pytorch_np.max():.6f}], "
|
||||
f"ONNX [{onnx_output.min():.6f}, {onnx_output.max():.6f}]"
|
||||
)
|
||||
|
||||
return is_close
|
||||
|
||||
|
||||
def verify_context_encoder(checkpoint_path, onnx_path, device="cpu"):
|
||||
"""
|
||||
验证上下文编码器
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
onnx_path: ONNX模型路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 验证上下文编码器: {onnx_path}")
|
||||
|
||||
# 加载PyTorch模型
|
||||
context_encoder_export, _, config = create_export_models_from_checkpoint(
|
||||
checkpoint_path, device
|
||||
)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 2 # 使用batch_size=2测试动态批处理
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
# 随机屏蔽一些位置
|
||||
attention_mask[:, seq_len // 2 :] = 0
|
||||
|
||||
# PyTorch推理
|
||||
with torch.no_grad():
|
||||
pytorch_outputs = context_encoder_export(input_ids, pinyin_ids, attention_mask)
|
||||
|
||||
# ONNX推理
|
||||
onnx_inputs = {
|
||||
"input_ids": input_ids.numpy(),
|
||||
"pinyin_ids": pinyin_ids.numpy(),
|
||||
"attention_mask": attention_mask.numpy(),
|
||||
}
|
||||
|
||||
onnx_outputs = session.run(None, onnx_inputs)
|
||||
|
||||
# 比较输出
|
||||
output_names = ["context_H", "pinyin_P", "context_mask", "pinyin_mask"]
|
||||
all_match = True
|
||||
|
||||
for i, name in enumerate(output_names):
|
||||
if i < len(pytorch_outputs) and i < len(onnx_outputs):
|
||||
match = compare_outputs(pytorch_outputs[i], onnx_outputs[i], name)
|
||||
all_match = all_match and match
|
||||
|
||||
return all_match
|
||||
|
||||
|
||||
def verify_decoder(checkpoint_path, onnx_path, device="cpu"):
|
||||
"""
|
||||
验证解码器
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
onnx_path: ONNX模型路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 验证解码器: {onnx_path}")
|
||||
|
||||
# 加载PyTorch模型
|
||||
_, decoder_export, config = create_export_models_from_checkpoint(
|
||||
checkpoint_path, device
|
||||
)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 2
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
dim = config.get("dim", 512)
|
||||
num_slots = config.get("num_slots", 8)
|
||||
|
||||
context_H = torch.randn(batch_size, seq_len, dim, dtype=torch.float32)
|
||||
pinyin_P = torch.randn(batch_size, pinyin_len, dim, dtype=torch.float32)
|
||||
context_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int32)
|
||||
pinyin_mask = torch.randint(0, 2, (batch_size, pinyin_len), dtype=torch.int32)
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, num_slots), dtype=torch.long)
|
||||
|
||||
# PyTorch推理
|
||||
with torch.no_grad():
|
||||
pytorch_output = decoder_export(
|
||||
context_H, pinyin_P, history_slot_ids, context_mask, pinyin_mask
|
||||
)
|
||||
|
||||
# ONNX推理
|
||||
onnx_inputs = {
|
||||
"context_H": context_H.numpy(),
|
||||
"pinyin_P": pinyin_P.numpy(),
|
||||
"history_slot_ids": history_slot_ids.numpy(),
|
||||
"context_mask": context_mask.numpy(),
|
||||
"pinyin_mask": pinyin_mask.numpy(),
|
||||
}
|
||||
|
||||
onnx_outputs = session.run(None, onnx_inputs)
|
||||
|
||||
# 比较输出
|
||||
return compare_outputs(pytorch_output, onnx_outputs[0], "logits")
|
||||
|
||||
|
||||
def verify_end_to_end(
|
||||
checkpoint_path, context_encoder_path, decoder_path, device="cpu"
|
||||
):
|
||||
"""
|
||||
端到端验证:比较完整推理流程
|
||||
|
||||
Args:
|
||||
checkpoint_path: PyTorch checkpoint路径
|
||||
context_encoder_path: 上下文编码器ONNX路径
|
||||
decoder_path: 解码器ONNX路径
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
print(f"\n🔍 端到端验证")
|
||||
|
||||
# 加载原始PyTorch模型
|
||||
from src.model.model import InputMethodEngine
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
if "config" in checkpoint:
|
||||
config = checkpoint["config"]
|
||||
else:
|
||||
config = {
|
||||
"vocab_size": 10019,
|
||||
"pinyin_vocab_size": 30,
|
||||
"dim": 512,
|
||||
"num_slots": 8,
|
||||
"n_layers": 4,
|
||||
"n_heads": 4,
|
||||
"num_experts": 10,
|
||||
"max_seq_len": 128,
|
||||
}
|
||||
|
||||
model = InputMethodEngine(
|
||||
vocab_size=config.get("vocab_size", 10019),
|
||||
pinyin_vocab_size=config.get("pinyin_vocab_size", 30),
|
||||
dim=config.get("dim", 512),
|
||||
num_slots=config.get("num_slots", 8),
|
||||
n_layers=config.get("n_layers", 4),
|
||||
n_heads=config.get("n_heads", 4),
|
||||
num_experts=config.get("num_experts", 10),
|
||||
max_seq_len=config.get("max_seq_len", 128),
|
||||
compile=False,
|
||||
)
|
||||
|
||||
if "model_state_dict" in checkpoint:
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建ONNX Runtime会话
|
||||
context_session = ort.InferenceSession(
|
||||
context_encoder_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
decoder_session = ort.InferenceSession(
|
||||
decoder_path,
|
||||
providers=[
|
||||
"CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
|
||||
],
|
||||
)
|
||||
|
||||
# 创建测试输入
|
||||
batch_size = 1
|
||||
seq_len = config.get("max_seq_len", 128)
|
||||
pinyin_len = 24
|
||||
|
||||
input_ids = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.long)
|
||||
pinyin_ids = torch.randint(0, 30, (batch_size, pinyin_len), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
|
||||
history_slot_ids = torch.randint(0, 100, (batch_size, 8), dtype=torch.long)
|
||||
|
||||
# PyTorch完整推理
|
||||
with torch.no_grad():
|
||||
pytorch_logits = model(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=torch.zeros_like(input_ids), # 简化处理
|
||||
attention_mask=attention_mask,
|
||||
pinyin_ids=pinyin_ids,
|
||||
history_slot_ids=history_slot_ids,
|
||||
)
|
||||
|
||||
# ONNX推理流程
|
||||
# 1. 上下文编码器
|
||||
context_inputs = {
|
||||
"input_ids": input_ids.numpy(),
|
||||
"pinyin_ids": pinyin_ids.numpy(),
|
||||
"attention_mask": attention_mask.numpy(),
|
||||
}
|
||||
context_outputs = context_session.run(None, context_inputs)
|
||||
context_H, pinyin_P, context_mask, pinyin_mask = context_outputs
|
||||
|
||||
# 2. 解码器
|
||||
decoder_inputs = {
|
||||
"context_H": context_H,
|
||||
"pinyin_P": pinyin_P,
|
||||
"history_slot_ids": history_slot_ids.numpy(),
|
||||
"context_mask": context_mask,
|
||||
"pinyin_mask": pinyin_mask,
|
||||
}
|
||||
onnx_outputs = decoder_session.run(None, decoder_inputs)
|
||||
|
||||
# 比较输出
|
||||
return compare_outputs(pytorch_logits, onnx_outputs[0], "end_to_end_logits")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ONNX模型验证")
|
||||
parser.add_argument(
|
||||
"--checkpoint", "-c", type=str, required=True, help="PyTorch checkpoint路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-encoder",
|
||||
type=str,
|
||||
help="上下文编码器ONNX路径(默认: ./exported_models/context_encoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
help="解码器ONNX路径(默认: ./exported_models/decoder.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
type=str,
|
||||
default="./exported_models",
|
||||
help="导出目录(如果未指定单个模型路径)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="验证设备(默认: cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-context", action="store_true", help="跳过上下文编码器验证"
|
||||
)
|
||||
parser.add_argument("--skip-decoder", action="store_true", help="跳解码器验证")
|
||||
parser.add_argument("--skip-end-to-end", action="store_true", help="跳过端到端验证")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确定模型路径
|
||||
if args.context_encoder:
|
||||
context_encoder_path = args.context_encoder
|
||||
else:
|
||||
context_encoder_path = os.path.join(args.output_dir, "context_encoder.onnx")
|
||||
|
||||
if args.decoder:
|
||||
decoder_path = args.decoder
|
||||
else:
|
||||
decoder_path = os.path.join(args.output_dir, "decoder.onnx")
|
||||
|
||||
print("🔬 ONNX模型验证")
|
||||
print("=" * 60)
|
||||
print(f"Checkpoint: {args.checkpoint}")
|
||||
print(f"上下文编码器: {context_encoder_path}")
|
||||
print(f"解码器: {decoder_path}")
|
||||
print(f"设备: {args.device}")
|
||||
print()
|
||||
|
||||
all_pass = True
|
||||
|
||||
# 验证上下文编码器
|
||||
if not args.skip_context and os.path.exists(context_encoder_path):
|
||||
if verify_context_encoder(args.checkpoint, context_encoder_path, args.device):
|
||||
print("✅ 上下文编码器验证通过")
|
||||
else:
|
||||
print("❌ 上下文编码器验证失败")
|
||||
all_pass = False
|
||||
elif not args.skip_context:
|
||||
print("⚠️ 上下文编码器文件不存在,跳过验证")
|
||||
|
||||
# 验证解码器
|
||||
if not args.skip_decoder and os.path.exists(decoder_path):
|
||||
if verify_decoder(args.checkpoint, decoder_path, args.device):
|
||||
print("✅ 解码器验证通过")
|
||||
else:
|
||||
print("❌ 解码器验证失败")
|
||||
all_pass = False
|
||||
elif not args.skip_decoder:
|
||||
print("⚠️ 解码器文件不存在,跳过验证")
|
||||
|
||||
# 端到端验证
|
||||
if (
|
||||
not args.skip_end_to_end
|
||||
and os.path.exists(context_encoder_path)
|
||||
and os.path.exists(decoder_path)
|
||||
):
|
||||
if verify_end_to_end(
|
||||
args.checkpoint, context_encoder_path, decoder_path, args.device
|
||||
):
|
||||
print("✅ 端到端验证通过")
|
||||
else:
|
||||
print("❌ 端到端验证失败")
|
||||
all_pass = False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_pass:
|
||||
print("🎉 所有验证通过!ONNX模型与PyTorch模型输出一致")
|
||||
else:
|
||||
print("❌ 部分验证失败,请检查模型导出过程")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue