Compare commits

..

2 Commits

4 changed files with 21465 additions and 1 deletions

View File

@ -4,3 +4,8 @@ version = "0.1.0"
edition = "2024"
[dependencies]
anyhow = "1.0.102"
lazy_static = "1.5.0"
ndarray = "0.17.2"
tempfile = "3.26.0"
tokenizers = "0.22.2"

21278
assets/tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,26 @@
mod tokenizers;
use std::path::PathBuf;
fn main() {
println!("Hello, world!");
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
tokenizer_json_path.push("assets");
tokenizer_json_path.push("tokenizer.json");
// 示例:使用 HFTokenizer
match tokenizers::HFTokenizer::new(tokenizer_json_path) {
Ok(mut tokenizer) => {
match tokenizer.gen_predict_sample("hello world", "ni hao") {
Ok(model_input) => {
println!("Model input generated successfully");
println!("Input IDs: {:?}", model_input.input_ids);
println!("PG value: {}", model_input.pg[[0]]);
}
Err(e) => eprintln!("Error generating model input: {}", e),
}
}
Err(e) => eprintln!("Error loading tokenizer: {}", e),
}
}

158
src/tokenizers.rs Normal file
View File

@ -0,0 +1,158 @@
// tokenizer.rs
use anyhow::Result;
use ndarray::{Array1, Array2};
use std::collections::HashMap;
use std::path::Path;
use tokenizers::{EncodeInput, Encoding, Tokenizer};
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
const PG_MAP: &[(&str, i64)] = &[
("r", 0),
("l", 0),
("p", 1),
("d", 1),
("h", 2),
("f", 2),
("g", 3),
("m", 3),
("z", 4),
("o", 4),
("t", 5),
("q", 5),
("b", 6),
("w", 6),
("j", 7),
("e", 7),
("k", 8),
("c", 8),
("s", 9),
("a", 9),
("n", 10),
("x", 10),
("y", 11),
];
// 使用 lazy_static 初始化 HashMap
lazy_static::lazy_static! {
static ref PG: HashMap<String, i64> = {
let mut m = HashMap::new();
for &(k, v) in PG_MAP.iter() {
m.insert(k.to_string(), v);
}
m
};
}
// 模型输入结构体,用于组织 hint 和 pg 输入
// 现在使用 ndarray::Array 类型
#[derive(Debug)]
pub struct ModelInput {
pub input_ids: Array2<i64>, // (batch_size, sequence_length)
pub attention_mask: Array2<i64>, // (batch_size, sequence_length)
pub token_type_ids: Array2<i64>, // (batch_size, sequence_length)
pub pg: Array1<i64>, // 使用 Array1 代表一个 1D 向量
}
/// 封装了 tokenizers crate 的 Rust Tokenizer并提供与 ONNX Runtime 模型兼容的接口
pub struct HFTokenizer {
/// 内部使用的 tokenizers crate 的 Tokenizer 实例
tokenizer: Tokenizer,
}
impl HFTokenizer {
/// 创建一个新的 HFTokenizer 实例
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
let tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
anyhow::anyhow!(
"Failed to load tokenizer from {:?}: {}",
tokenizer_path_or_name.as_ref(),
e
)
})?;
Ok(HFTokenizer { tokenizer })
}
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
/// `text` - 第一句输入 (通常是文本)
/// `py` - 第二句输入 (通常是拼音)
fn encode_pair(&mut self, text: &str, py: &str) -> Result<Encoding> {
let input = EncodeInput::Dual(text.into(), py.into());
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
let encoding = self
.tokenizer
.encode(input, true) // true for add_special_tokens
.map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?;
Ok(encoding)
}
/// 生成用于预测的样本数据
///
/// # Arguments
///
/// * `text` - 输入的文本字符串
/// * `py` - 输入的拼音字符串
///
/// # Returns
///
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
pub fn gen_predict_sample(&mut self, text: &str, py: &str) -> Result<ModelInput> {
let encoding = self.encode_pair(text, py)?;
// 获取分词结果
let input_ids: Vec<i32> = encoding.get_ids().iter().map(|&id| id as i32).collect();
let attention_mask: Vec<i32> = encoding
.get_attention_mask()
.iter()
.map(|&mask| mask as i32)
.collect();
// token_type_ids 是由 post_processor (如 BertProcessing) 自动生成的
let token_type_ids: Vec<i32> = encoding
.get_type_ids()
.iter()
.map(|&ty| ty as i32)
.collect();
// 获取序列长度
let seq_len = input_ids.len();
// 转换为 ndarray::Array并重塑为 (batch_size=1, sequence_length)
// ONNX Runtime 通常期望明确的批次维度
let input_ids_nd = Array1::from_vec(input_ids)
.into_shape_with_order((1, seq_len))
.unwrap();
let attention_mask_nd = Array1::from_vec(attention_mask)
.into_shape_with_order((1, seq_len))
.unwrap();
let token_type_ids_nd = Array1::from_vec(token_type_ids)
.into_shape_with_order((1, seq_len))
.unwrap();
// --- PG Logic ---
let py_first_char = py
.chars()
.next()
.map(|c| c.to_string())
.unwrap_or_else(|| "unknown".to_string());
let pg_val = *PG.get(&py_first_char).unwrap_or(&12); // Default to 12 if key not found
let pg_nd = Array1::from_elem(1, pg_val); // 直接创建一维数组
Ok(ModelInput {
input_ids: input_ids_nd.mapv(|x| x as i64), // Convert i32 to i64 for ONNX
attention_mask: attention_mask_nd.mapv(|x| x as i64),
token_type_ids: token_type_ids_nd.mapv(|x| x as i64),
pg: pg_nd.mapv(|x| x as i64),
})
}
/*
/// 预测函数 (此函数通常在模型结构体内实现,此处仅为演示如何调用 gen_predict_sample)
/// 这里只是展示如何准备输入,实际的 forward pass 需要在模型的 impl 中完成
pub fn prepare_for_prediction(&mut self, text: &str, py: &str) -> Result<ModelInput> {
self.gen_predict_sample(text, py)
}*/
}