Compare commits
No commits in common. "c143d793ec31037912eef79ee8f194a39db70389" and "7abf0edce1553d438f36876e7577e05ed3b45237" have entirely different histories.
c143d793ec
...
7abf0edce1
|
|
@ -4,8 +4,3 @@ version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.102"
|
|
||||||
lazy_static = "1.5.0"
|
|
||||||
ndarray = "0.17.2"
|
|
||||||
tempfile = "3.26.0"
|
|
||||||
tokenizers = "0.22.2"
|
|
||||||
|
|
|
||||||
21278
assets/tokenizer.json
21278
assets/tokenizer.json
File diff suppressed because it is too large
Load Diff
25
src/main.rs
25
src/main.rs
|
|
@ -1,26 +1,3 @@
|
||||||
mod tokenizers;
|
|
||||||
|
|
||||||
use std::path::PathBuf;
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
|
||||||
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
|
||||||
tokenizer_json_path.push("assets");
|
|
||||||
tokenizer_json_path.push("tokenizer.json");
|
|
||||||
|
|
||||||
// 示例:使用 HFTokenizer
|
|
||||||
match tokenizers::HFTokenizer::new(tokenizer_json_path) {
|
|
||||||
Ok(mut tokenizer) => {
|
|
||||||
match tokenizer.gen_predict_sample("hello world", "ni hao") {
|
|
||||||
Ok(model_input) => {
|
|
||||||
println!("Model input generated successfully");
|
|
||||||
println!("Input IDs: {:?}", model_input.input_ids);
|
|
||||||
println!("PG value: {}", model_input.pg[[0]]);
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("Error generating model input: {}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("Error loading tokenizer: {}", e),
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,158 +0,0 @@
|
||||||
// tokenizer.rs
|
|
||||||
use anyhow::Result;
|
|
||||||
use ndarray::{Array1, Array2};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::path::Path;
|
|
||||||
use tokenizers::{EncodeInput, Encoding, Tokenizer};
|
|
||||||
|
|
||||||
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
|
|
||||||
const PG_MAP: &[(&str, i64)] = &[
|
|
||||||
("r", 0),
|
|
||||||
("l", 0),
|
|
||||||
("p", 1),
|
|
||||||
("d", 1),
|
|
||||||
("h", 2),
|
|
||||||
("f", 2),
|
|
||||||
("g", 3),
|
|
||||||
("m", 3),
|
|
||||||
("z", 4),
|
|
||||||
("o", 4),
|
|
||||||
("t", 5),
|
|
||||||
("q", 5),
|
|
||||||
("b", 6),
|
|
||||||
("w", 6),
|
|
||||||
("j", 7),
|
|
||||||
("e", 7),
|
|
||||||
("k", 8),
|
|
||||||
("c", 8),
|
|
||||||
("s", 9),
|
|
||||||
("a", 9),
|
|
||||||
("n", 10),
|
|
||||||
("x", 10),
|
|
||||||
("y", 11),
|
|
||||||
];
|
|
||||||
|
|
||||||
// 使用 lazy_static 初始化 HashMap
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref PG: HashMap<String, i64> = {
|
|
||||||
let mut m = HashMap::new();
|
|
||||||
for &(k, v) in PG_MAP.iter() {
|
|
||||||
m.insert(k.to_string(), v);
|
|
||||||
}
|
|
||||||
m
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模型输入结构体,用于组织 hint 和 pg 输入
|
|
||||||
// 现在使用 ndarray::Array 类型
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ModelInput {
|
|
||||||
pub input_ids: Array2<i64>, // (batch_size, sequence_length)
|
|
||||||
pub attention_mask: Array2<i64>, // (batch_size, sequence_length)
|
|
||||||
pub token_type_ids: Array2<i64>, // (batch_size, sequence_length)
|
|
||||||
|
|
||||||
pub pg: Array1<i64>, // 使用 Array1 代表一个 1D 向量
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 封装了 tokenizers crate 的 Rust Tokenizer,并提供与 ONNX Runtime 模型兼容的接口
|
|
||||||
pub struct HFTokenizer {
|
|
||||||
/// 内部使用的 tokenizers crate 的 Tokenizer 实例
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HFTokenizer {
|
|
||||||
/// 创建一个新的 HFTokenizer 实例
|
|
||||||
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
|
|
||||||
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
|
|
||||||
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
|
|
||||||
anyhow::anyhow!(
|
|
||||||
"Failed to load tokenizer from {:?}: {}",
|
|
||||||
tokenizer_path_or_name.as_ref(),
|
|
||||||
e
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(HFTokenizer { tokenizer })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
|
|
||||||
/// `text` - 第一句输入 (通常是文本)
|
|
||||||
/// `py` - 第二句输入 (通常是拼音)
|
|
||||||
fn encode_pair(&mut self, text: &str, py: &str) -> Result<Encoding> {
|
|
||||||
let input = EncodeInput::Dual(text.into(), py.into());
|
|
||||||
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
|
|
||||||
let encoding = self
|
|
||||||
.tokenizer
|
|
||||||
.encode(input, true) // true for add_special_tokens
|
|
||||||
.map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?;
|
|
||||||
|
|
||||||
Ok(encoding)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 生成用于预测的样本数据
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `text` - 输入的文本字符串
|
|
||||||
/// * `py` - 输入的拼音字符串
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
|
|
||||||
pub fn gen_predict_sample(&mut self, text: &str, py: &str) -> Result<ModelInput> {
|
|
||||||
let encoding = self.encode_pair(text, py)?;
|
|
||||||
|
|
||||||
// 获取分词结果
|
|
||||||
let input_ids: Vec<i32> = encoding.get_ids().iter().map(|&id| id as i32).collect();
|
|
||||||
let attention_mask: Vec<i32> = encoding
|
|
||||||
.get_attention_mask()
|
|
||||||
.iter()
|
|
||||||
.map(|&mask| mask as i32)
|
|
||||||
.collect();
|
|
||||||
// token_type_ids 是由 post_processor (如 BertProcessing) 自动生成的
|
|
||||||
let token_type_ids: Vec<i32> = encoding
|
|
||||||
.get_type_ids()
|
|
||||||
.iter()
|
|
||||||
.map(|&ty| ty as i32)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// 获取序列长度
|
|
||||||
let seq_len = input_ids.len();
|
|
||||||
|
|
||||||
// 转换为 ndarray::Array,并重塑为 (batch_size=1, sequence_length)
|
|
||||||
// ONNX Runtime 通常期望明确的批次维度
|
|
||||||
let input_ids_nd = Array1::from_vec(input_ids)
|
|
||||||
.into_shape_with_order((1, seq_len))
|
|
||||||
.unwrap();
|
|
||||||
let attention_mask_nd = Array1::from_vec(attention_mask)
|
|
||||||
.into_shape_with_order((1, seq_len))
|
|
||||||
.unwrap();
|
|
||||||
let token_type_ids_nd = Array1::from_vec(token_type_ids)
|
|
||||||
.into_shape_with_order((1, seq_len))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// --- PG Logic ---
|
|
||||||
let py_first_char = py
|
|
||||||
.chars()
|
|
||||||
.next()
|
|
||||||
.map(|c| c.to_string())
|
|
||||||
.unwrap_or_else(|| "unknown".to_string());
|
|
||||||
let pg_val = *PG.get(&py_first_char).unwrap_or(&12); // Default to 12 if key not found
|
|
||||||
let pg_nd = Array1::from_elem(1, pg_val); // 直接创建一维数组
|
|
||||||
|
|
||||||
Ok(ModelInput {
|
|
||||||
input_ids: input_ids_nd.mapv(|x| x as i64), // Convert i32 to i64 for ONNX
|
|
||||||
attention_mask: attention_mask_nd.mapv(|x| x as i64),
|
|
||||||
token_type_ids: token_type_ids_nd.mapv(|x| x as i64),
|
|
||||||
|
|
||||||
pg: pg_nd.mapv(|x| x as i64),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
/// 预测函数 (此函数通常在模型结构体内实现,此处仅为演示如何调用 gen_predict_sample)
|
|
||||||
/// 这里只是展示如何准备输入,实际的 forward pass 需要在模型的 impl 中完成
|
|
||||||
pub fn prepare_for_prediction(&mut self, text: &str, py: &str) -> Result<ModelInput> {
|
|
||||||
self.gen_predict_sample(text, py)
|
|
||||||
}*/
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue