Compare commits
2 Commits
7abf0edce1
...
c143d793ec
| Author | SHA1 | Date |
|---|---|---|
|
|
c143d793ec | |
|
|
085d90b5d3 |
|
|
@ -4,3 +4,8 @@ version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0.102"
|
||||||
|
lazy_static = "1.5.0"
|
||||||
|
ndarray = "0.17.2"
|
||||||
|
tempfile = "3.26.0"
|
||||||
|
tokenizers = "0.22.2"
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
25
src/main.rs
25
src/main.rs
|
|
@ -1,3 +1,26 @@
|
||||||
|
mod tokenizers;
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
fn main() {
|
fn main() {
|
||||||
println!("Hello, world!");
|
|
||||||
|
|
||||||
|
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
|
tokenizer_json_path.push("assets");
|
||||||
|
tokenizer_json_path.push("tokenizer.json");
|
||||||
|
|
||||||
|
// 示例:使用 HFTokenizer
|
||||||
|
match tokenizers::HFTokenizer::new(tokenizer_json_path) {
|
||||||
|
Ok(mut tokenizer) => {
|
||||||
|
match tokenizer.gen_predict_sample("hello world", "ni hao") {
|
||||||
|
Ok(model_input) => {
|
||||||
|
println!("Model input generated successfully");
|
||||||
|
println!("Input IDs: {:?}", model_input.input_ids);
|
||||||
|
println!("PG value: {}", model_input.pg[[0]]);
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("Error generating model input: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("Error loading tokenizer: {}", e),
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
// tokenizer.rs
|
||||||
|
use anyhow::Result;
|
||||||
|
use ndarray::{Array1, Array2};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use tokenizers::{EncodeInput, Encoding, Tokenizer};
|
||||||
|
|
||||||
|
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
|
||||||
|
const PG_MAP: &[(&str, i64)] = &[
|
||||||
|
("r", 0),
|
||||||
|
("l", 0),
|
||||||
|
("p", 1),
|
||||||
|
("d", 1),
|
||||||
|
("h", 2),
|
||||||
|
("f", 2),
|
||||||
|
("g", 3),
|
||||||
|
("m", 3),
|
||||||
|
("z", 4),
|
||||||
|
("o", 4),
|
||||||
|
("t", 5),
|
||||||
|
("q", 5),
|
||||||
|
("b", 6),
|
||||||
|
("w", 6),
|
||||||
|
("j", 7),
|
||||||
|
("e", 7),
|
||||||
|
("k", 8),
|
||||||
|
("c", 8),
|
||||||
|
("s", 9),
|
||||||
|
("a", 9),
|
||||||
|
("n", 10),
|
||||||
|
("x", 10),
|
||||||
|
("y", 11),
|
||||||
|
];
|
||||||
|
|
||||||
|
// 使用 lazy_static 初始化 HashMap
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref PG: HashMap<String, i64> = {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
for &(k, v) in PG_MAP.iter() {
|
||||||
|
m.insert(k.to_string(), v);
|
||||||
|
}
|
||||||
|
m
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 模型输入结构体,用于组织 hint 和 pg 输入
|
||||||
|
// 现在使用 ndarray::Array 类型
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ModelInput {
|
||||||
|
pub input_ids: Array2<i64>, // (batch_size, sequence_length)
|
||||||
|
pub attention_mask: Array2<i64>, // (batch_size, sequence_length)
|
||||||
|
pub token_type_ids: Array2<i64>, // (batch_size, sequence_length)
|
||||||
|
|
||||||
|
pub pg: Array1<i64>, // 使用 Array1 代表一个 1D 向量
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 封装了 tokenizers crate 的 Rust Tokenizer,并提供与 ONNX Runtime 模型兼容的接口
|
||||||
|
pub struct HFTokenizer {
|
||||||
|
/// 内部使用的 tokenizers crate 的 Tokenizer 实例
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HFTokenizer {
|
||||||
|
/// 创建一个新的 HFTokenizer 实例
|
||||||
|
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
|
||||||
|
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
|
||||||
|
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"Failed to load tokenizer from {:?}: {}",
|
||||||
|
tokenizer_path_or_name.as_ref(),
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(HFTokenizer { tokenizer })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
|
||||||
|
/// `text` - 第一句输入 (通常是文本)
|
||||||
|
/// `py` - 第二句输入 (通常是拼音)
|
||||||
|
fn encode_pair(&mut self, text: &str, py: &str) -> Result<Encoding> {
|
||||||
|
let input = EncodeInput::Dual(text.into(), py.into());
|
||||||
|
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
|
||||||
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(input, true) // true for add_special_tokens
|
||||||
|
.map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?;
|
||||||
|
|
||||||
|
Ok(encoding)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 生成用于预测的样本数据
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `text` - 输入的文本字符串
|
||||||
|
/// * `py` - 输入的拼音字符串
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
|
||||||
|
pub fn gen_predict_sample(&mut self, text: &str, py: &str) -> Result<ModelInput> {
|
||||||
|
let encoding = self.encode_pair(text, py)?;
|
||||||
|
|
||||||
|
// 获取分词结果
|
||||||
|
let input_ids: Vec<i32> = encoding.get_ids().iter().map(|&id| id as i32).collect();
|
||||||
|
let attention_mask: Vec<i32> = encoding
|
||||||
|
.get_attention_mask()
|
||||||
|
.iter()
|
||||||
|
.map(|&mask| mask as i32)
|
||||||
|
.collect();
|
||||||
|
// token_type_ids 是由 post_processor (如 BertProcessing) 自动生成的
|
||||||
|
let token_type_ids: Vec<i32> = encoding
|
||||||
|
.get_type_ids()
|
||||||
|
.iter()
|
||||||
|
.map(|&ty| ty as i32)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 获取序列长度
|
||||||
|
let seq_len = input_ids.len();
|
||||||
|
|
||||||
|
// 转换为 ndarray::Array,并重塑为 (batch_size=1, sequence_length)
|
||||||
|
// ONNX Runtime 通常期望明确的批次维度
|
||||||
|
let input_ids_nd = Array1::from_vec(input_ids)
|
||||||
|
.into_shape_with_order((1, seq_len))
|
||||||
|
.unwrap();
|
||||||
|
let attention_mask_nd = Array1::from_vec(attention_mask)
|
||||||
|
.into_shape_with_order((1, seq_len))
|
||||||
|
.unwrap();
|
||||||
|
let token_type_ids_nd = Array1::from_vec(token_type_ids)
|
||||||
|
.into_shape_with_order((1, seq_len))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// --- PG Logic ---
|
||||||
|
let py_first_char = py
|
||||||
|
.chars()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.to_string())
|
||||||
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
|
let pg_val = *PG.get(&py_first_char).unwrap_or(&12); // Default to 12 if key not found
|
||||||
|
let pg_nd = Array1::from_elem(1, pg_val); // 直接创建一维数组
|
||||||
|
|
||||||
|
Ok(ModelInput {
|
||||||
|
input_ids: input_ids_nd.mapv(|x| x as i64), // Convert i32 to i64 for ONNX
|
||||||
|
attention_mask: attention_mask_nd.mapv(|x| x as i64),
|
||||||
|
token_type_ids: token_type_ids_nd.mapv(|x| x as i64),
|
||||||
|
|
||||||
|
pg: pg_nd.mapv(|x| x as i64),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
/// 预测函数 (此函数通常在模型结构体内实现,此处仅为演示如何调用 gen_predict_sample)
|
||||||
|
/// 这里只是展示如何准备输入,实际的 forward pass 需要在模型的 impl 中完成
|
||||||
|
pub fn prepare_for_prediction(&mut self, text: &str, py: &str) -> Result<ModelInput> {
|
||||||
|
self.gen_predict_sample(text, py)
|
||||||
|
}*/
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue