194 lines
6.8 KiB
Rust
194 lines
6.8 KiB
Rust
mod config;
|
||
mod filter;
|
||
mod model;
|
||
mod protocol;
|
||
mod tokenizers;
|
||
mod vocabs;
|
||
|
||
use crate::config::{Config, SocketConfig};
|
||
use crate::model::OnnxModel;
|
||
use crate::tokenizers::HFTokenizer;
|
||
use crate::vocabs::Dictionary;
|
||
use crate::filter::filter_candidates;
|
||
use crate::protocol::{Request, Response, Candidate};
|
||
use std::error::Error;
|
||
use std::io::{Read, Write};
|
||
use std::sync::{Arc, Mutex};
|
||
use std::thread;
|
||
|
||
#[cfg(unix)]
|
||
use std::os::unix::net::{UnixListener, UnixStream};
|
||
#[cfg(unix)]
|
||
use std::os::unix::fs::PermissionsExt; // 用于设置文件权限
|
||
#[cfg(unix)]
|
||
use std::fs::set_permissions;
|
||
|
||
#[cfg(windows)]
|
||
use std::net::{TcpListener, TcpStream};
|
||
|
||
// 添加 ctrlc 依赖,用于捕获退出信号(仅在 Unix 下需要)
|
||
#[cfg(unix)]
|
||
use ctrlc;
|
||
|
||
fn handle_client(mut stream: impl Read + Write, model: Arc<Mutex<OnnxModel>>, tokenizer: Arc<Mutex<HFTokenizer>>, dict: Arc<Dictionary>) -> Result<(), Box<dyn Error>> {
|
||
loop {
|
||
// 读取4字节长度头
|
||
let mut len_buf = [0u8; 4];
|
||
if let Err(e) = stream.read_exact(&mut len_buf) {
|
||
// 如果读取失败,可能是连接关闭,正常退出循环
|
||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||
println!("客户端关闭连接");
|
||
break;
|
||
}
|
||
return Err(Box::new(e)); // 其他错误上报
|
||
}
|
||
let payload_len = u32::from_be_bytes(len_buf) as usize;
|
||
|
||
// 读取MessagePack载荷
|
||
let mut payload = vec![0u8; payload_len];
|
||
stream.read_exact(&mut payload)?;
|
||
|
||
// 反序列化请求
|
||
let request: Request = rmp_serde::from_slice(&payload)?;
|
||
println!("接收请求: {:?}", request);
|
||
|
||
// 准备模型输入
|
||
let model_input = {
|
||
let tokenizer = tokenizer.lock().unwrap();
|
||
tokenizer.gen_predict_sample(&request.context, &request.pinyin)?
|
||
};
|
||
|
||
// 进行预测
|
||
let predicted_pairs = {
|
||
let mut model = model.lock().unwrap();
|
||
model.predict_to_sorted_pairs_simple(model_input, true)?
|
||
};
|
||
|
||
// 筛选候选词
|
||
let filtered = filter_candidates(&request.pinyin, predicted_pairs, &dict);
|
||
|
||
// 转换为Response
|
||
let candidates: Vec<Candidate> = filtered
|
||
.into_iter()
|
||
.map(|fc| Candidate {
|
||
id: fc.id as u32,
|
||
text: fc.text,
|
||
weight: fc.weight,
|
||
})
|
||
.collect();
|
||
let response = Response {
|
||
candidates,
|
||
offset: 0,
|
||
limit: 10,
|
||
};
|
||
|
||
// 序列化响应
|
||
let response_payload = rmp_serde::to_vec(&response)?;
|
||
let response_len = response_payload.len() as u32;
|
||
let response_len_bytes = response_len.to_be_bytes();
|
||
|
||
// 发送响应
|
||
stream.write_all(&response_len_bytes)?;
|
||
stream.write_all(&response_payload)?;
|
||
// 确保数据发送完成(flush 在 Write 实现中通常自动处理,但可显式调用)
|
||
stream.flush()?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
|
||
fn main() -> Result<(), Box<dyn Error>> {
|
||
// 加载配置
|
||
let config = Config::load()?;
|
||
println!("配置加载成功: {:?}", config);
|
||
|
||
// 初始化组件
|
||
let model = OnnxModel::new(&config.model_path, 1)?;
|
||
let tokenizer = HFTokenizer::new(&config.tokenizer_path)?;
|
||
let dict = Dictionary::from_json_file(&config.vocabs_path)?;
|
||
|
||
let model = Arc::new(Mutex::new(model));
|
||
let tokenizer = Arc::new(Mutex::new(tokenizer));
|
||
let dict = Arc::new(dict);
|
||
|
||
// 创建监听器
|
||
match config.socket {
|
||
SocketConfig::UnixDomainSocket(path) => {
|
||
#[cfg(unix)]
|
||
{
|
||
// 尝试删除可能残留的socket文件
|
||
if let Err(e) = std::fs::remove_file(&path) {
|
||
// 如果文件不存在,忽略错误;否则打印警告
|
||
if e.kind() != std::io::ErrorKind::NotFound {
|
||
eprintln!("警告:无法删除socket文件 {}: {}", path, e);
|
||
}
|
||
}
|
||
let listener = UnixListener::bind(&path)?;
|
||
println!("监听Unix Domain Socket在: {}", path);
|
||
|
||
// 设置socket文件权限为600(仅所有者可读写),减少被其他进程误操作的风险
|
||
let mut perms = std::fs::metadata(&path)?.permissions();
|
||
perms.set_mode(0o600); // 只允许所有者读写
|
||
set_permissions(&path, perms)?;
|
||
|
||
// 注册信号处理,确保程序退出时清理socket文件
|
||
let socket_path = path.clone();
|
||
ctrlc::set_handler(move || {
|
||
println!("收到退出信号,清理socket文件: {}", socket_path);
|
||
let _ = std::fs::remove_file(&socket_path);
|
||
std::process::exit(0);
|
||
})?;
|
||
|
||
for stream in listener.incoming() {
|
||
match stream {
|
||
Ok(stream) => {
|
||
println!("受理连接");
|
||
let model = Arc::clone(&model);
|
||
let tokenizer = Arc::clone(&tokenizer);
|
||
let dict = Arc::clone(&dict);
|
||
thread::spawn(move || {
|
||
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
|
||
eprintln!("处理连接错误: {}", e);
|
||
}
|
||
});
|
||
}
|
||
Err(e) => eprintln!("接受连接错误: {}", e),
|
||
}
|
||
}
|
||
}
|
||
#[cfg(not(unix))]
|
||
{
|
||
return Err("Unix Domain Socket仅在Unix系统上支持".into());
|
||
}
|
||
}
|
||
SocketConfig::TcpSocket(addr) => {
|
||
#[cfg(windows)]
|
||
{
|
||
let listener = TcpListener::bind(&addr)?;
|
||
println!("监听TCP Socket在: {}", addr);
|
||
for stream in listener.incoming() {
|
||
match stream {
|
||
Ok(stream) => {
|
||
let model = Arc::clone(&model);
|
||
let tokenizer = Arc::clone(&tokenizer);
|
||
let dict = Arc::clone(&dict);
|
||
thread::spawn(move || {
|
||
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
|
||
eprintln!("处理连接错误: {}", e);
|
||
}
|
||
});
|
||
}
|
||
Err(e) => eprintln!("接受连接错误: {}", e),
|
||
}
|
||
}
|
||
}
|
||
#[cfg(not(windows))]
|
||
{
|
||
return Err("TCP Socket仅在Windows系统上支持".into());
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|