feat: 添加常用汉字拼音统计文件

This commit is contained in:
songsenand 2026-03-02 01:31:22 +08:00
parent 258b33b10b
commit ef8488bb36
8 changed files with 142821 additions and 22 deletions

3
.gitignore vendored
View File

@ -181,3 +181,6 @@ cython_debug/
# Added by cargo
/target
*/*.onnx

View File

@ -7,5 +7,8 @@ edition = "2024"
anyhow = "1.0.102"
lazy_static = "1.5.0"
ndarray = "0.17.2"
ort = "2.0.0-rc.11"
serde = "1.0.228"
serde_json = "1.0.149"
tempfile = "3.26.0"
tokenizers = "0.22.2"

142546
assets/vocabs.json Normal file

File diff suppressed because it is too large Load Diff

0
src/config.rs Normal file
View File

View File

@ -1,26 +1,64 @@
mod tokenizers;
use std::path::PathBuf;
use std::time::Instant;
mod model;
mod tokenizers;
mod vocabs;
use tokenizers::HFTokenizer;
// use crate::vocabs::Dictionary;
use crate::model::OnnxModel;
fn main() {
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
tokenizer_json_path.push("assets");
tokenizer_json_path.push("tokenizer.json");
// 示例:使用 HFTokenizer
match tokenizers::HFTokenizer::new(tokenizer_json_path) {
Ok(mut tokenizer) => {
match tokenizer.gen_predict_sample("hello world", "ni hao") {
Ok(model_input) => {
println!("Model input generated successfully");
println!("Input IDs: {:?}", model_input.input_ids);
println!("PG value: {}", model_input.pg[[0]]);
}
Err(e) => eprintln!("Error generating model input: {}", e),
}
}
Err(e) => eprintln!("Error loading tokenizer: {}", e),
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
model_path.push("assets");
model_path.push("20260228suinput_fp32.onnx");
let mut session = OnnxModel::new(model_path, 4).unwrap();
let tokenizer = HFTokenizer::new(tokenizer_json_path).unwrap();
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let _ = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let logits = session
.predict_to_sorted_pairs_simple(sample, false)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let probs = session
.predict_to_sorted_pairs_simple(sample, true)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
print!("logits: {:?}", &logits[0..10].to_vec());
print!("probs: {:?}", &probs[0..10].to_vec());
/*
if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) {
println!("Tokenizer loaded successfully");
if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") {
let logits: ndarray::ArrayBase<
ndarray::OwnedRepr<f32>,
ndarray::Dim<ndarray::IxDynImpl>,
f32,
> = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
println!("Time elapsed: {:?}", duration);
println!("Model input generated successfully");
println!("Logits: {:?}", logits);
}
}*/
}

119
src/model.rs Normal file
View File

@ -0,0 +1,119 @@
use anyhow::{Result};
use std::cmp::Ordering;
use ndarray::{Array, Axis, IxDyn};
use ort::{inputs, session::Session, value::Tensor};
use std::path::Path;
use std::sync::{Arc, Mutex};
use crate::tokenizers::ModelInput;
pub struct OnnxModel {
session: Arc<Mutex<Session>>, // 推理会话(线程安全)
input_names: Vec<String>, // 输入名称
}
impl OnnxModel {
/// 加载 ONNX 模型
///
/// # Arguments
/// * `model_path` - ONNX 模型文件路径
/// * `intra_threads` - 配置单个操作中用于并行化的线程数。
pub fn new<P: AsRef<Path>>(model_path: P, intra_threads: usize) -> Result<Self> {
if intra_threads == 0 {
return Err(anyhow::anyhow!("intra_threads must be greater than 0"));
}
// 创建 ONNX Runtime 会话构建器
let session = Session::builder()?
.with_intra_threads(intra_threads)?
.commit_from_file(model_path)?;
let input_names: Vec<String> = session
.inputs()
.iter()
.map(|input: &ort::value::Outlet| input.name().to_string())
.collect();
// 创建 ONNX Runtime 会话
Ok(OnnxModel {
session: Arc::new(Mutex::new(session)),
input_names: input_names
})
}
/// 执行推理返回模型输出的logits
///
/// # Arguments
/// * `input` - 由 tokenizer 生成的模型输入
pub fn predict(&mut self, input: ModelInput) -> Result<Array<f32, IxDyn>> {
// 将 ndarray 转换为 ort::Tensor
let input_ids_tensor = Tensor::from_array(input.input_ids.into_dyn())?;
let attention_mask_tensor = Tensor::from_array(input.attention_mask.into_dyn())?;
let token_type_ids_tensor = Tensor::from_array(input.token_type_ids.into_dyn())?;
let pg_tensor = Tensor::from_array(input.pg.into_dyn())?;
// 使用 ort::inputs! 宏构建输入(按模型定义的顺序)
// 注意:这里假设输入顺序是固定的,实际可能需要根据模型定义调整
let inputs = inputs![
self.input_names[0].as_str() => input_ids_tensor,
self.input_names[1].as_str() => attention_mask_tensor,
self.input_names[2].as_str() => token_type_ids_tensor,
self.input_names[3].as_str() => pg_tensor,
];
let mut session = (*self.session).lock().unwrap();
// 运行推理
let outputs: ort::session::SessionOutputs<'_> = session.run(inputs)?;
// 获取第一个输出logits
let output_value = &outputs[0];
// 提取为 Array<f32>
let output_tensor = output_value.try_extract_array::<f32>()?;
// 假设输出形状为 [batch_size, vocab_size]
Ok(output_tensor.index_axis(Axis(0), 0).to_owned())
}
fn softmax(&self, logits: &Array<f32, IxDyn>) -> Array<f32, IxDyn> {
// 计算softmax
let max_val = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exps: f32 = exps.iter().sum();
// 归一化得到概率
let normalized: Vec<f32> = exps.iter().map(|&exp| exp / sum_exps).collect();
// 使用原始数组的形状创建新数组
Array::from_shape_vec(logits.shape().to_vec(), normalized).unwrap()
}
/// 执行推理返回模型输出的logits并按logits值从大到小排序
pub fn predict_to_sorted_pairs_simple(&mut self, input: ModelInput, with_softmax: bool) -> Result<Vec<(usize, f32)>> {
let logits_raw = self.predict(input)?;
let logits = if with_softmax {
self.softmax(&logits_raw)
}else{
logits_raw
};
// 将(索引, logits值)收集到Vec中
let mut pairs: Vec<(usize, f32)> = logits
.iter()
.enumerate()
.map(|(idx, &value)| (idx, value))
.collect();
// 按logits值从大到小排序
pairs.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(Ordering::Equal)
});
Ok(pairs)
}
}

View File

@ -3,7 +3,9 @@ use anyhow::Result;
use ndarray::{Array1, Array2};
use std::collections::HashMap;
use std::path::Path;
use tokenizers::{EncodeInput, Encoding, Tokenizer};
use tokenizers::{
EncodeInput, Encoding, PaddingParams, Tokenizer, TruncationParams,
};
// 拼音组映射表 (PG map) - 使用 const array 初始化 HashMap
const PG_MAP: &[(&str, i64)] = &[
@ -65,7 +67,7 @@ impl HFTokenizer {
/// `tokenizer_path_or_name`: 指向预训练 tokenizer 配置文件的路径 (例如 tokenizer.json)
pub fn new<P: AsRef<Path>>(tokenizer_path_or_name: P) -> Result<Self> {
// 加载 tokenizer 配置文件 (通常名为 tokenizer.json)
let tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
let mut tokenizer = Tokenizer::from_file(tokenizer_path_or_name.as_ref()).map_err(|e| {
anyhow::anyhow!(
"Failed to load tokenizer from {:?}: {}",
tokenizer_path_or_name.as_ref(),
@ -73,13 +75,27 @@ impl HFTokenizer {
)
})?;
let padding_params = PaddingParams {
strategy: tokenizers::PaddingStrategy::Fixed(88),
..Default::default()
};
let truncation_params = TruncationParams {
max_length: 88,
..Default::default()
};
tokenizer.with_padding(Some(padding_params));
tokenizer
.with_truncation(Some(truncation_params))
.map_err(|e| anyhow::anyhow!("Failed to set truncation: {}", e))?;
Ok(HFTokenizer { tokenizer })
}
/// 使用 `tokenizers` crate 的 EncodeInput 构造函数来处理文本对
/// `text` - 第一句输入 (通常是文本)
/// `py` - 第二句输入 (通常是拼音)
fn encode_pair(&mut self, text: &str, py: &str) -> Result<Encoding> {
fn encode_pair(&self, text: &str, py: &str) -> Result<Encoding> {
let input = EncodeInput::Dual(text.into(), py.into());
// encode 方法会应用 pre_tokenizer, normalizer, post_processor, truncation, padding
let encoding = self
@ -100,7 +116,7 @@ impl HFTokenizer {
/// # Returns
///
/// * `Result<ModelInput>` - 包含模型所需输入的结构体,包含 hint (input_ids, attention_mask, token_type_ids) 和 pg array
pub fn gen_predict_sample(&mut self, text: &str, py: &str) -> Result<ModelInput> {
pub fn gen_predict_sample(&self, text: &str, py: &str) -> Result<ModelInput> {
let encoding = self.encode_pair(text, py)?;
// 获取分词结果

74
src/vocabs.rs Normal file
View File

@ -0,0 +1,74 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
/// 单个字符-拼音对的信息
#[derive(Debug, Deserialize, Clone)]
pub struct CharInfo {
pub id: u32,
#[serde(rename = "char")]
pub character: String,
pub pinyin: String,
pub count: u64,
}
/// JSON 根结构(仅包含需要的字段)
#[derive(Debug, Deserialize)]
struct RawStatistics {
pairs: HashMap<String, CharInfo>, // 键为字符串形式的 ID
// 忽略其他元数据字段
}
/// 字典查询引擎,提供 O(1) 的 ID 到信息的映射
pub struct Dictionary {
id_to_charinfo: HashMap<u32, CharInfo>,
}
impl Dictionary {
/// 从 JSON 文件加载字典
pub fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path).context("无法打开字典文件")?;
let reader = BufReader::new(file);
let raw: RawStatistics = serde_json::from_reader(reader)
.context("无法解析 JSON 字典")?;
let mut id_to_charinfo = HashMap::with_capacity(raw.pairs.len());
for (id_str, info) in raw.pairs {
let id = id_str
.parse::<u32>()
.with_context(|| format!("无效的 ID 字符串: {}", id_str))?;
// 可选:验证 id 与 info.id 一致,此处忽略不一致的情况(信任输入数据)
id_to_charinfo.insert(info.id, info);
}
Ok(Dictionary { id_to_charinfo })
}
/// 通过 ID 获取汉字(用于填充 Candidate.text
pub fn get_char_by_id(&self, id: u32) -> Option<&str> {
self.id_to_charinfo.get(&id).map(|info| info.character.as_str())
}
/// 通过 ID 获取拼音
pub fn get_pinyin_by_id(&self, id: u32) -> Option<&str> {
self.id_to_charinfo.get(&id).map(|info| info.pinyin.as_str())
}
/// 通过 ID 获取出现次数
pub fn get_count_by_id(&self, id: u32) -> Option<u64> {
self.id_to_charinfo.get(&id).map(|info| info.count)
}
/// 获取完整的 CharInfo 引用
pub fn get_char_info(&self, id: u32) -> Option<&CharInfo> {
self.id_to_charinfo.get(&id)
}
/// 返回字典中存储的条目数量
pub fn len(&self) -> usize {
self.id_to_charinfo.len()
}
}