mod config; mod filter; mod model; mod protocol; mod tokenizers; mod vocabs; use crate::config::{Config, SocketConfig}; use crate::model::OnnxModel; use crate::tokenizers::HFTokenizer; use crate::vocabs::Dictionary; use crate::filter::filter_candidates; use crate::protocol::{Request, Response, Candidate}; use std::error::Error; use std::io::{Read, Write}; use std::sync::{Arc, Mutex}; use std::thread; #[cfg(unix)] use std::os::unix::net::{UnixListener, UnixStream}; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; // 用于设置文件权限 #[cfg(unix)] use std::fs::set_permissions; #[cfg(windows)] use std::net::{TcpListener, TcpStream}; // 添加 ctrlc 依赖,用于捕获退出信号(仅在 Unix 下需要) #[cfg(unix)] use ctrlc; fn handle_client(mut stream: impl Read + Write, model: Arc>, tokenizer: Arc>, dict: Arc) -> Result<(), Box> { loop { // 读取4字节长度头 let mut len_buf = [0u8; 4]; if let Err(e) = stream.read_exact(&mut len_buf) { // 如果读取失败,可能是连接关闭,正常退出循环 if e.kind() == std::io::ErrorKind::UnexpectedEof { println!("客户端关闭连接"); break; } return Err(Box::new(e)); // 其他错误上报 } let payload_len = u32::from_be_bytes(len_buf) as usize; // 读取MessagePack载荷 let mut payload = vec![0u8; payload_len]; stream.read_exact(&mut payload)?; // 反序列化请求 let request: Request = rmp_serde::from_slice(&payload)?; println!("接收请求: {:?}", request); // 准备模型输入 let model_input = { let tokenizer = tokenizer.lock().unwrap(); tokenizer.gen_predict_sample(&request.context, &request.pinyin)? }; // 进行预测 let predicted_pairs = { let mut model = model.lock().unwrap(); model.predict_to_sorted_pairs_simple(model_input, true)? }; // 筛选候选词 let filtered = filter_candidates(&request.pinyin, predicted_pairs, &dict); // 转换为Response let candidates: Vec = filtered .into_iter() .map(|fc| Candidate { id: fc.id as u32, text: fc.text, weight: fc.weight, }) .collect(); let response = Response { candidates, offset: 0, limit: 10, }; // 序列化响应 let response_payload = rmp_serde::to_vec(&response)?; let response_len = response_payload.len() as u32; let response_len_bytes = response_len.to_be_bytes(); // 发送响应 stream.write_all(&response_len_bytes)?; stream.write_all(&response_payload)?; // 确保数据发送完成(flush 在 Write 实现中通常自动处理,但可显式调用) stream.flush()?; } Ok(()) } fn main() -> Result<(), Box> { // 加载配置 let config = Config::load()?; println!("配置加载成功: {:?}", config); // 初始化组件 let model = OnnxModel::new(&config.model_path, 1)?; let tokenizer = HFTokenizer::new(&config.tokenizer_path)?; let dict = Dictionary::from_json_file(&config.vocabs_path)?; let model = Arc::new(Mutex::new(model)); let tokenizer = Arc::new(Mutex::new(tokenizer)); let dict = Arc::new(dict); // 创建监听器 match config.socket { SocketConfig::UnixDomainSocket(path) => { #[cfg(unix)] { // 尝试删除可能残留的socket文件 if let Err(e) = std::fs::remove_file(&path) { // 如果文件不存在,忽略错误;否则打印警告 if e.kind() != std::io::ErrorKind::NotFound { eprintln!("警告:无法删除socket文件 {}: {}", path, e); } } let listener = UnixListener::bind(&path)?; println!("监听Unix Domain Socket在: {}", path); // 设置socket文件权限为600(仅所有者可读写),减少被其他进程误操作的风险 let mut perms = std::fs::metadata(&path)?.permissions(); perms.set_mode(0o600); // 只允许所有者读写 set_permissions(&path, perms)?; // 注册信号处理,确保程序退出时清理socket文件 let socket_path = path.clone(); ctrlc::set_handler(move || { println!("收到退出信号,清理socket文件: {}", socket_path); let _ = std::fs::remove_file(&socket_path); std::process::exit(0); })?; for stream in listener.incoming() { match stream { Ok(stream) => { println!("受理连接"); let model = Arc::clone(&model); let tokenizer = Arc::clone(&tokenizer); let dict = Arc::clone(&dict); thread::spawn(move || { if let Err(e) = handle_client(stream, model, tokenizer, dict) { eprintln!("处理连接错误: {}", e); } }); } Err(e) => eprintln!("接受连接错误: {}", e), } } } #[cfg(not(unix))] { return Err("Unix Domain Socket仅在Unix系统上支持".into()); } } SocketConfig::TcpSocket(addr) => { #[cfg(windows)] { let listener = TcpListener::bind(&addr)?; println!("监听TCP Socket在: {}", addr); for stream in listener.incoming() { match stream { Ok(stream) => { let model = Arc::clone(&model); let tokenizer = Arc::clone(&tokenizer); let dict = Arc::clone(&dict); thread::spawn(move || { if let Err(e) = handle_client(stream, model, tokenizer, dict) { eprintln!("处理连接错误: {}", e); } }); } Err(e) => eprintln!("接受连接错误: {}", e), } } } #[cfg(not(windows))] { return Err("TCP Socket仅在Windows系统上支持".into()); } } } Ok(()) }