SUIME/src/main.rs

65 lines
2.3 KiB
Rust

use std::path::PathBuf;
use std::time::Instant;
mod model;
mod tokenizers;
mod vocabs;
use tokenizers::HFTokenizer;
// use crate::vocabs::Dictionary;
use crate::model::OnnxModel;
fn main() {
let mut tokenizer_json_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
tokenizer_json_path.push("assets");
tokenizer_json_path.push("tokenizer.json");
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
model_path.push("assets");
model_path.push("20260228suinput_fp32.onnx");
let mut session = OnnxModel::new(model_path, 4).unwrap();
let tokenizer = HFTokenizer::new(tokenizer_json_path).unwrap();
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let _ = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let logits = session
.predict_to_sorted_pairs_simple(sample, false)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let probs = session
.predict_to_sorted_pairs_simple(sample, true)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
print!("logits: {:?}", &logits[0..10].to_vec());
print!("probs: {:?}", &probs[0..10].to_vec());
/*
if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) {
println!("Tokenizer loaded successfully");
if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") {
let logits: ndarray::ArrayBase<
ndarray::OwnedRepr<f32>,
ndarray::Dim<ndarray::IxDynImpl>,
f32,
> = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
println!("Time elapsed: {:?}", duration);
println!("Model input generated successfully");
println!("Logits: {:?}", logits);
}
}*/
}