SUIME/src/main.rs

194 lines
6.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mod config;
mod filter;
mod model;
mod protocol;
mod tokenizers;
mod vocabs;
use crate::config::{Config, SocketConfig};
use crate::model::OnnxModel;
use crate::tokenizers::HFTokenizer;
use crate::vocabs::Dictionary;
use crate::filter::filter_candidates;
use crate::protocol::{Request, Response, Candidate};
use std::error::Error;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use std::thread;
#[cfg(unix)]
use std::os::unix::net::{UnixListener, UnixStream};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt; // 用于设置文件权限
#[cfg(unix)]
use std::fs::set_permissions;
#[cfg(windows)]
use std::net::{TcpListener, TcpStream};
// 添加 ctrlc 依赖,用于捕获退出信号(仅在 Unix 下需要)
#[cfg(unix)]
use ctrlc;
fn handle_client(mut stream: impl Read + Write, model: Arc<Mutex<OnnxModel>>, tokenizer: Arc<Mutex<HFTokenizer>>, dict: Arc<Dictionary>) -> Result<(), Box<dyn Error>> {
loop {
// 读取4字节长度头
let mut len_buf = [0u8; 4];
if let Err(e) = stream.read_exact(&mut len_buf) {
// 如果读取失败,可能是连接关闭,正常退出循环
if e.kind() == std::io::ErrorKind::UnexpectedEof {
println!("客户端关闭连接");
break;
}
return Err(Box::new(e)); // 其他错误上报
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
// 读取MessagePack载荷
let mut payload = vec![0u8; payload_len];
stream.read_exact(&mut payload)?;
// 反序列化请求
let request: Request = rmp_serde::from_slice(&payload)?;
println!("接收请求: {:?}", request);
// 准备模型输入
let model_input = {
let tokenizer = tokenizer.lock().unwrap();
tokenizer.gen_predict_sample(&request.context, &request.pinyin)?
};
// 进行预测
let predicted_pairs = {
let mut model = model.lock().unwrap();
model.predict_to_sorted_pairs_simple(model_input, true)?
};
// 筛选候选词
let filtered = filter_candidates(&request.pinyin, predicted_pairs, &dict);
// 转换为Response
let candidates: Vec<Candidate> = filtered
.into_iter()
.map(|fc| Candidate {
id: fc.id as u32,
text: fc.text,
weight: fc.weight,
})
.collect();
let response = Response {
candidates,
offset: 0,
limit: 10,
};
// 序列化响应
let response_payload = rmp_serde::to_vec(&response)?;
let response_len = response_payload.len() as u32;
let response_len_bytes = response_len.to_be_bytes();
// 发送响应
stream.write_all(&response_len_bytes)?;
stream.write_all(&response_payload)?;
// 确保数据发送完成flush 在 Write 实现中通常自动处理,但可显式调用)
stream.flush()?;
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
// 加载配置
let config = Config::load()?;
println!("配置加载成功: {:?}", config);
// 初始化组件
let model = OnnxModel::new(&config.model_path, 1)?;
let tokenizer = HFTokenizer::new(&config.tokenizer_path)?;
let dict = Dictionary::from_json_file(&config.vocabs_path)?;
let model = Arc::new(Mutex::new(model));
let tokenizer = Arc::new(Mutex::new(tokenizer));
let dict = Arc::new(dict);
// 创建监听器
match config.socket {
SocketConfig::UnixDomainSocket(path) => {
#[cfg(unix)]
{
// 尝试删除可能残留的socket文件
if let Err(e) = std::fs::remove_file(&path) {
// 如果文件不存在,忽略错误;否则打印警告
if e.kind() != std::io::ErrorKind::NotFound {
eprintln!("警告无法删除socket文件 {}: {}", path, e);
}
}
let listener = UnixListener::bind(&path)?;
println!("监听Unix Domain Socket在: {}", path);
// 设置socket文件权限为600仅所有者可读写减少被其他进程误操作的风险
let mut perms = std::fs::metadata(&path)?.permissions();
perms.set_mode(0o600); // 只允许所有者读写
set_permissions(&path, perms)?;
// 注册信号处理确保程序退出时清理socket文件
let socket_path = path.clone();
ctrlc::set_handler(move || {
println!("收到退出信号清理socket文件: {}", socket_path);
let _ = std::fs::remove_file(&socket_path);
std::process::exit(0);
})?;
for stream in listener.incoming() {
match stream {
Ok(stream) => {
println!("受理连接");
let model = Arc::clone(&model);
let tokenizer = Arc::clone(&tokenizer);
let dict = Arc::clone(&dict);
thread::spawn(move || {
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
eprintln!("处理连接错误: {}", e);
}
});
}
Err(e) => eprintln!("接受连接错误: {}", e),
}
}
}
#[cfg(not(unix))]
{
return Err("Unix Domain Socket仅在Unix系统上支持".into());
}
}
SocketConfig::TcpSocket(addr) => {
#[cfg(windows)]
{
let listener = TcpListener::bind(&addr)?;
println!("监听TCP Socket在: {}", addr);
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let model = Arc::clone(&model);
let tokenizer = Arc::clone(&tokenizer);
let dict = Arc::clone(&dict);
thread::spawn(move || {
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
eprintln!("处理连接错误: {}", e);
}
});
}
Err(e) => eprintln!("接受连接错误: {}", e),
}
}
}
#[cfg(not(windows))]
{
return Err("TCP Socket仅在Windows系统上支持".into());
}
}
}
Ok(())
}