feat: 添加常用汉字拼音统计文件
This commit is contained in:
parent
258b33b10b
commit
ef8488bb36
|
|
@ -181,3 +181,6 @@ cython_debug/
|
||||||
# Added by cargo
|
# Added by cargo
|
||||||
|
|
||||||
/target
|
/target
|
||||||
|
|
||||||
|
|
||||||
|
*/*.onnx
|
||||||
|
|
@ -7,5 +7,8 @@ edition = "2024"
|
||||||
anyhow = "1.0.102"
|
anyhow = "1.0.102"
|
||||||
lazy_static = "1.5.0"
|
lazy_static = "1.5.0"
|
||||||
ndarray = "0.17.2"
|
ndarray = "0.17.2"
|
||||||
|
ort = "2.0.0-rc.11"
|
||||||
|
serde = "1.0.228"
|
||||||
|
serde_json = "1.0.149"
|
||||||
tempfile = "3.26.0"
|
tempfile = "3.26.0"
|
||||||
tokenizers = "0.22.2"
|
tokenizers = "0.22.2"
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
72
src/main.rs
72
src/main.rs
|
|
@ -1,26 +1,64 @@
|
||||||
mod tokenizers;
|
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
mod model;
|
||||||
|
mod tokenizers;
|
||||||
|
mod vocabs;
|
||||||
|
|
||||||
|
use tokenizers::HFTokenizer;
|
||||||
|
// use crate::vocabs::Dictionary;
|
||||||
|
use crate::model::OnnxModel;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|
||||||
|
|
||||||
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
tokenizer_json_path.push("assets");
|
tokenizer_json_path.push("assets");
|
||||||
tokenizer_json_path.push("tokenizer.json");
|
tokenizer_json_path.push("tokenizer.json");
|
||||||
|
|
||||||
// 示例:使用 HFTokenizer
|
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
match tokenizers::HFTokenizer::new(tokenizer_json_path) {
|
model_path.push("assets");
|
||||||
Ok(mut tokenizer) => {
|
model_path.push("20260228suinput_fp32.onnx");
|
||||||
match tokenizer.gen_predict_sample("hello world", "ni hao") {
|
|
||||||
Ok(model_input) => {
|
|
||||||
println!("Model input generated successfully");
|
|
||||||
println!("Input IDs: {:?}", model_input.input_ids);
|
|
||||||
println!("PG value: {}", model_input.pg[[0]]);
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("Error generating model input: {}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("Error loading tokenizer: {}", e),
|
|
||||||
|
|
||||||
|
let mut session = OnnxModel::new(model_path, 4).unwrap();
|
||||||
|
let tokenizer = HFTokenizer::new(tokenizer_json_path).unwrap();
|
||||||
|
let start = Instant::now(); // 开始计时
|
||||||
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
||||||
|
let _ = session.predict(sample).unwrap();
|
||||||
|
let duration = start.elapsed(); // 结束计时
|
||||||
|
print!("predict Time elapsed: {:?}\n", duration);
|
||||||
|
|
||||||
|
let start = Instant::now(); // 开始计时
|
||||||
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
||||||
|
let logits = session
|
||||||
|
.predict_to_sorted_pairs_simple(sample, false)
|
||||||
|
.unwrap();
|
||||||
|
let duration = start.elapsed(); // 结束计时
|
||||||
|
print!("predict Time elapsed: {:?}\n", duration);
|
||||||
|
|
||||||
|
let start = Instant::now(); // 开始计时
|
||||||
|
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
|
||||||
|
let probs = session
|
||||||
|
.predict_to_sorted_pairs_simple(sample, true)
|
||||||
|
.unwrap();
|
||||||
|
let duration = start.elapsed(); // 结束计时
|
||||||
|
print!("predict Time elapsed: {:?}\n", duration);
|
||||||
|
print!("logits: {:?}", &logits[0..10].to_vec());
|
||||||
|
print!("probs: {:?}", &probs[0..10].to_vec());
|
||||||
|
|
||||||
|
/*
|
||||||
|
if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) {
|
||||||
|
println!("Tokenizer loaded successfully");
|
||||||
|
if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") {
|
||||||
|
let logits: ndarray::ArrayBase<
|
||||||
|
ndarray::OwnedRepr<f32>,
|
||||||
|
ndarray::Dim<ndarray::IxDynImpl>,
|
||||||
|
f32,
|
||||||
|
> = session.predict(sample).unwrap();
|
||||||
|
|
||||||
|
let duration = start.elapsed(); // 结束计时
|
||||||
|
|
||||||
|
println!("Time elapsed: {:?}", duration);
|
||||||
|
println!("Model input generated successfully");
|
||||||
|
println!("Logits: {:?}", logits);
|
||||||
}
|
}
|
||||||
|
}*/
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
use anyhow::{Result};
|
||||||
|
use std::cmp::Ordering;
|
||||||
|
|
||||||
|
use ndarray::{Array, Axis, IxDyn};
|
||||||
|
use ort::{inputs, session::Session, value::Tensor};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use crate::tokenizers::ModelInput;
|
||||||
|
|
||||||
|
pub struct OnnxModel {
|
||||||
|
session: Arc<Mutex<Session>>, // 推理会话(线程安全)
|
||||||
|
input_names: Vec<String>, // 输入名称
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OnnxModel {
|
||||||
|
/// 加载 ONNX 模型
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_path` - ONNX 模型文件路径
|
||||||
|
/// * `intra_threads` - 配置单个操作中用于并行化的线程数。
|
||||||
|
pub fn new<P: AsRef<Path>>(model_path: P, intra_threads: usize) -> Result<Self> {
|
||||||
|
if intra_threads == 0 {
|
||||||
|
return Err(anyhow::anyhow!("intra_threads must be greater than 0"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 ONNX Runtime 会话构建器
|
||||||
|
let session = Session::builder()?
|
||||||
|
.with_intra_threads(intra_threads)?
|
||||||
|
.commit_from_file(model_path)?;
|
||||||
|
let input_names: Vec<String> = session
|
||||||
|
.inputs()
|
||||||
|
.iter()
|
||||||
|
.map(|input: &ort::value::Outlet| input.name().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 创建 ONNX Runtime 会话
|
||||||
|
Ok(OnnxModel {
|
||||||
|
session: Arc::new(Mutex::new(session)),
|
||||||
|
input_names: input_names
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行推理,返回模型输出的logits
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `input` - 由 tokenizer 生成的模型输入
|
||||||
|
pub fn predict(&mut self, input: ModelInput) -> Result<Array<f32, IxDyn>> {
|
||||||
|
// 将 ndarray 转换为 ort::Tensor
|
||||||
|
let input_ids_tensor = Tensor::from_array(input.input_ids.into_dyn())?;
|
||||||
|
let attention_mask_tensor = Tensor::from_array(input.attention_mask.into_dyn())?;
|
||||||
|
let token_type_ids_tensor = Tensor::from_array(input.token_type_ids.into_dyn())?;
|
||||||
|
let pg_tensor = Tensor::from_array(input.pg.into_dyn())?;
|
||||||
|
|
||||||
|
// 使用 ort::inputs! 宏构建输入(按模型定义的顺序)
|
||||||
|
// 注意:这里假设输入顺序是固定的,实际可能需要根据模型定义调整
|
||||||
|
let inputs = inputs![
|
||||||
|
self.input_names[0].as_str() => input_ids_tensor,
|
||||||
|
self.input_names[1].as_str() => attention_mask_tensor,
|
||||||
|
self.input_names[2].as_str() => token_type_ids_tensor,
|
||||||
|
self.input_names[3].as_str() => pg_tensor,
|
||||||
|
];
|
||||||
|
let mut session = (*self.session).lock().unwrap();
|
||||||
|
// 运行推理
|
||||||
|
let outputs: ort::session::SessionOutputs<'_> = session.run(inputs)?;
|
||||||
|
|
||||||
|
// 获取第一个输出(logits)
|
||||||
|
let output_value = &outputs[0];
|
||||||
|
|
||||||
|
// 提取为 Array<f32>
|
||||||
|
let output_tensor = output_value.try_extract_array::<f32>()?;
|
||||||
|
|
||||||
|
// 假设输出形状为 [batch_size, vocab_size]
|
||||||
|
Ok(output_tensor.index_axis(Axis(0), 0).to_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn softmax(&self, logits: &Array<f32, IxDyn>) -> Array<f32, IxDyn> {
|
||||||
|
// 计算softmax
|
||||||
|
let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||||
|
let exps: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
|
||||||
|
let sum_exps: f32 = exps.iter().sum();
|
||||||
|
|
||||||
|
// 归一化得到概率
|
||||||
|
let normalized: Vec<f32> = exps.iter().map(|&exp| exp / sum_exps).collect();
|
||||||
|
|
||||||
|
// 使用原始数组的形状创建新数组
|
||||||
|
Array::from_shape_vec(logits.shape().to_vec(), normalized).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行推理,返回模型输出的logits,并按logits值从大到小排序
|
||||||
|
pub fn predict_to_sorted_pairs_simple(&mut self, input: ModelInput, with_softmax: bool) -> Result<Vec<(usize, f32)>> {
|
||||||
|
|
||||||
|
let logits_raw = self.predict(input)?;
|
||||||
|
|
||||||
|
let logits = if with_softmax {
|
||||||
|
self.softmax(&logits_raw)
|
||||||
|
}else{
|
||||||
|
logits_raw
|
||||||
|
};
|
||||||
|
|
||||||
|
// 将(索引, logits值)收集到Vec中
|
||||||
|
let mut pairs: Vec<(usize, f32)> = logits
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, &value)| (idx, value))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 按logits值从大到小排序
|
||||||
|
pairs.sort_by(|a, b| {
|
||||||
|
b.1.partial_cmp(&a.1)
|
||||||
|
.unwrap_or(Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(pairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -3,7 +3,9 @@ use anyhow::Result;
|
||||||
use ndarray::{Array1, Array2};
|
use ndarray::{Array1, Array2};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tokenizers::{EncodeInput, Encoding, Tokenizer};
|
use tokenizers::{
|
||||||
|
EncodeInput, Encoding, PaddingParams, Tokenizer, TruncationParams,
|
||||||
|
};
|
||||||
|
|
||||||
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
|
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
|
||||||
const PG_MAP: &[(&str, i64)] = &[
|
const PG_MAP: &[(&str, i64)] = &[
|
||||||
|
|
@ -65,7 +67,7 @@ impl HFTokenizer {
|
||||||
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
|
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
|
||||||
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
|
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
|
||||||
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
|
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
|
let mut tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"Failed to load tokenizer from {:?}: {}",
|
"Failed to load tokenizer from {:?}: {}",
|
||||||
tokenizer_path_or_name.as_ref(),
|
tokenizer_path_or_name.as_ref(),
|
||||||
|
|
@ -73,13 +75,27 @@ impl HFTokenizer {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let padding_params = PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::Fixed(88),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let truncation_params = TruncationParams {
|
||||||
|
max_length: 88,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
tokenizer.with_padding(Some(padding_params));
|
||||||
|
tokenizer
|
||||||
|
.with_truncation(Some(truncation_params))
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to set truncation: {}", e))?;
|
||||||
|
|
||||||
Ok(HFTokenizer { tokenizer })
|
Ok(HFTokenizer { tokenizer })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
|
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
|
||||||
/// `text` - 第一句输入 (通常是文本)
|
/// `text` - 第一句输入 (通常是文本)
|
||||||
/// `py` - 第二句输入 (通常是拼音)
|
/// `py` - 第二句输入 (通常是拼音)
|
||||||
fn encode_pair(&mut self, text: &str, py: &str) -> Result<Encoding> {
|
fn encode_pair(&self, text: &str, py: &str) -> Result<Encoding> {
|
||||||
let input = EncodeInput::Dual(text.into(), py.into());
|
let input = EncodeInput::Dual(text.into(), py.into());
|
||||||
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
|
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
|
||||||
let encoding = self
|
let encoding = self
|
||||||
|
|
@ -100,7 +116,7 @@ impl HFTokenizer {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
|
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
|
||||||
pub fn gen_predict_sample(&mut self, text: &str, py: &str) -> Result<ModelInput> {
|
pub fn gen_predict_sample(&self, text: &str, py: &str) -> Result<ModelInput> {
|
||||||
let encoding = self.encode_pair(text, py)?;
|
let encoding = self.encode_pair(text, py)?;
|
||||||
|
|
||||||
// 获取分词结果
|
// 获取分词结果
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// 单个字符-拼音对的信息
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct CharInfo {
|
||||||
|
pub id: u32,
|
||||||
|
#[serde(rename = "char")]
|
||||||
|
pub character: String,
|
||||||
|
pub pinyin: String,
|
||||||
|
pub count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON 根结构(仅包含需要的字段)
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct RawStatistics {
|
||||||
|
pairs: HashMap<String, CharInfo>, // 键为字符串形式的 ID
|
||||||
|
// 忽略其他元数据字段
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 字典查询引擎,提供 O(1) 的 ID 到信息的映射
|
||||||
|
pub struct Dictionary {
|
||||||
|
id_to_charinfo: HashMap<u32, CharInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dictionary {
|
||||||
|
/// 从 JSON 文件加载字典
|
||||||
|
pub fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||||
|
let file = File::open(path).context("无法打开字典文件")?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let raw: RawStatistics = serde_json::from_reader(reader)
|
||||||
|
.context("无法解析 JSON 字典")?;
|
||||||
|
|
||||||
|
let mut id_to_charinfo = HashMap::with_capacity(raw.pairs.len());
|
||||||
|
for (id_str, info) in raw.pairs {
|
||||||
|
let id = id_str
|
||||||
|
.parse::<u32>()
|
||||||
|
.with_context(|| format!("无效的 ID 字符串: {}", id_str))?;
|
||||||
|
// 可选:验证 id 与 info.id 一致,此处忽略不一致的情况(信任输入数据)
|
||||||
|
id_to_charinfo.insert(info.id, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Dictionary { id_to_charinfo })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 通过 ID 获取汉字(用于填充 Candidate.text)
|
||||||
|
pub fn get_char_by_id(&self, id: u32) -> Option<&str> {
|
||||||
|
self.id_to_charinfo.get(&id).map(|info| info.character.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 通过 ID 获取拼音
|
||||||
|
pub fn get_pinyin_by_id(&self, id: u32) -> Option<&str> {
|
||||||
|
self.id_to_charinfo.get(&id).map(|info| info.pinyin.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 通过 ID 获取出现次数
|
||||||
|
pub fn get_count_by_id(&self, id: u32) -> Option<u64> {
|
||||||
|
self.id_to_charinfo.get(&id).map(|info| info.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取完整的 CharInfo 引用
|
||||||
|
pub fn get_char_info(&self, id: u32) -> Option<&CharInfo> {
|
||||||
|
self.id_to_charinfo.get(&id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 返回字典中存储的条目数量
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.id_to_charinfo.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue