feat(fcitx5-ext): 添加与SUIME服务器通信的Socket客户端及预测请求功能

This commit is contained in:
songsenand 2026-03-09 08:41:04 +08:00
parent 142c52823a
commit acbf38b4e8
16 changed files with 959 additions and 58 deletions

View File

@ -5,10 +5,12 @@ edition = "2024"
[dependencies]
anyhow = "1.0.102"
ctrlc = "3.5.2"
directories = "6.0.0"
lazy_static = "1.5.0"
ndarray = "0.17.2"
ort = "2.0.0-rc.11"
rmp-serde = "1.3.1"
serde = "1.0.228"
serde_json = "1.0.149"
tempfile = "3.26.0"

View File

@ -78,7 +78,9 @@ SUIME/
### 文件说明
- **filter.rs** 将OnnxModel::predict预测的结果转化为汉字、拼音、权重排除和用户输入的拼音不相符的汉字比如用户输入的是shanghai预测的结果为……),特将会被排除,为了简化,直接粗暴的将预测汉字的拼音首字母和用户输入的拼音首字母进行对比,相同视为相符,不相同视为不符。
- **filter.rs** 将OnnxModel::predict预测的结果转化为汉字、对应汉字的完整拼音、权重、剩余拼音排除和用户输入的拼音不相符的汉字。
- 示例用户输入的是shanghai预测的结果为……),特将会被排除。
- 返回的结果里面还应该包含剩余被消耗后的拼音比如shanghai转化后的结果可能为shang0.9hai也可能是sha0.001nghai再比如输入为shhai转化的结果为shang0.6hai
---

35
fcitx5-ext/CMakeLists.txt Normal file
View File

@ -0,0 +1,35 @@
cmake_minimum_required(VERSION 3.10)
project(suime-fcitx5 LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(PkgConfig REQUIRED)
pkg_check_modules(FCITX5 REQUIRED Fcitx5Core Fcitx5Utils)
pkg_check_modules(MSGPACK REQUIRED msgpack-c)
add_library(suime SHARED src/suime.cpp src/socket_client.cpp)
target_include_directories(suime PRIVATE ${FCITX5_INCLUDE_DIRS} ${MSGPACK_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_link_libraries(suime ${FCITX5_LIBRARIES} ${MSGPACK_LIBRARIES})
install(TARGETS suime DESTINATION lib/fcitx5)
# Addon
# .conf.in .conf (如果有变量替换需求)
configure_file(suime-addon.conf.in suime.conf)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/suime.conf
DESTINATION share/fcitx5/addon/)
# InputMethod
# suime.conf inputmethod
configure_file(suime-im.conf.in suime-im.conf)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/suime-im.conf
RENAME suime.conf
DESTINATION share/fcitx5/inputmethod)

View File

@ -0,0 +1,28 @@
#ifndef PROTOCOL_HPP
#define PROTOCOL_HPP
#include <string>
#include <vector>
#include <msgpack.hpp>
struct Request {
std::string pinyin;
std::string context;
MSGPACK_DEFINE(pinyin, context);
};
struct Candidate {
uint32_t id;
std::string text;
float weight;
MSGPACK_DEFINE(id, text, weight);
};
struct Response {
std::vector<Candidate> candidates;
size_t offset;
size_t limit;
MSGPACK_DEFINE(candidates, offset, limit);
};
#endif // PROTOCOL_HPP

View File

@ -0,0 +1,156 @@
#include "socket_client.hpp"
#include "protocol.hpp"
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <cstring>
#include <iostream>
#include <vector>
#include <msgpack.hpp>
#include <arpa/inet.h>
#include <cerrno>
#include <sys/time.h>
SocketClient::SocketClient() : sock_fd_(-1) {}
SocketClient::~SocketClient() {
disconnect();
}
bool SocketClient::connect(const std::string& socket_path) {
sock_fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock_fd_ < 0) {
std::cerr << "SocketClient: Failed to create socket: " << strerror(errno) << std::endl;
return false;
}
// 设置超时
struct timeval tv;
tv.tv_sec = 5; // 5秒超时
tv.tv_usec = 0;
if (setsockopt(sock_fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) {
std::cerr << "SocketClient: Failed to set receive timeout: " << strerror(errno) << std::endl;
close(sock_fd_);
sock_fd_ = -1;
return false;
}
if (setsockopt(sock_fd_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0) {
std::cerr << "SocketClient: Failed to set send timeout: " << strerror(errno) << std::endl;
close(sock_fd_);
sock_fd_ = -1;
return false;
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socket_path.c_str(), sizeof(addr.sun_path) - 1);
addr.sun_path[sizeof(addr.sun_path) - 1] = '\0';
if (::connect(sock_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
std::cerr << "SocketClient: Failed to connect to server at " << socket_path << ": " << strerror(errno) << std::endl;
close(sock_fd_);
sock_fd_ = -1;
return false;
}
std::cout << "SocketClient: Connected to server via UDS at " << socket_path << std::endl;
return true;
}
bool SocketClient::disconnect() {
if (sock_fd_ >= 0) {
close(sock_fd_);
sock_fd_ = -1;
std::cout << "SocketClient: Disconnected" << std::endl;
return true;
}
return false;
}
bool SocketClient::isConnected() const {
return sock_fd_ >= 0;
}
Response SocketClient::sendRequest(const Request& req) {
if (sock_fd_ < 0) {
std::cerr << "SocketClient: Not connected, cannot send request" << std::endl;
return Response{};
}
// Serialize request to MessagePack
msgpack::sbuffer buffer;
msgpack::pack(buffer, req);
// Send TLV header (4-byte big-endian length)
uint32_t len = static_cast<uint32_t>(buffer.size());
uint32_t len_be = htonl(len);
ssize_t total_sent = 0;
while (total_sent < sizeof(len_be)) {
ssize_t sent = write(sock_fd_, reinterpret_cast<const char*>(&len_be) + total_sent, sizeof(len_be) - total_sent);
if (sent < 0) {
if (errno == EINTR) continue; // 被信号中断,重试
std::cerr << "SocketClient: Failed to send header, error: " << strerror(errno) << std::endl;
return Response{};
}
total_sent += sent;
}
// Send payload
total_sent = 0;
const char* data = buffer.data();
size_t data_size = buffer.size();
while (total_sent < static_cast<ssize_t>(data_size)) {
ssize_t sent = write(sock_fd_, data + total_sent, data_size - total_sent);
if (sent < 0) {
if (errno == EINTR) continue;
std::cerr << "SocketClient: Failed to send payload, error: " << strerror(errno) << std::endl;
return Response{};
}
total_sent += sent;
}
// Receive response header
uint32_t resp_len_be;
ssize_t total_received = 0;
while (total_received < sizeof(resp_len_be)) {
ssize_t received = read(sock_fd_, reinterpret_cast<char*>(&resp_len_be) + total_received, sizeof(resp_len_be) - total_received);
if (received < 0) {
if (errno == EINTR) continue;
std::cerr << "SocketClient: Failed to read response header, error: " << strerror(errno) << std::endl;
return Response{};
} else if (received == 0) {
std::cerr << "SocketClient: Connection closed by server while reading header" << std::endl;
return Response{};
}
total_received += received;
}
uint32_t resp_len = ntohl(resp_len_be);
// Receive response payload
std::vector<char> resp_buffer(resp_len);
total_received = 0;
while (total_received < static_cast<ssize_t>(resp_len)) {
ssize_t received = read(sock_fd_, resp_buffer.data() + total_received, resp_len - total_received);
if (received < 0) {
if (errno == EINTR) continue;
std::cerr << "SocketClient: Failed to read response payload, error: " << strerror(errno) << std::endl;
return Response{};
} else if (received == 0) {
std::cerr << "SocketClient: Connection closed by server while reading payload" << std::endl;
return Response{};
}
total_received += received;
}
// Deserialize response
try {
msgpack::object_handle oh = msgpack::unpack(resp_buffer.data(), resp_len);
Response res;
oh.get().convert(res);
return res;
} catch (const std::exception& e) {
std::cerr << "SocketClient: Failed to deserialize response: " << e.what() << std::endl;
return Response{};
}
}

View File

@ -0,0 +1,21 @@
#ifndef SOCKET_CLIENT_HPP
#define SOCKET_CLIENT_HPP
#include <string>
#include "protocol.hpp"
class SocketClient {
public:
SocketClient();
~SocketClient();
bool connect(const std::string& socket_path);
bool disconnect();
bool isConnected() const;
Response sendRequest(const Request& req);
private:
int sock_fd_;
};
#endif // SOCKET_CLIENT_HPP

219
fcitx5-ext/src/suime.cpp Normal file
View File

@ -0,0 +1,219 @@
#include "protocol.hpp"
#include "socket_client.hpp"
#include <fcitx-utils/event.h>
#include <fcitx/addonfactory.h>
#include <fcitx/addonmanager.h>
#include <fcitx/candidatelist.h>
#include <fcitx/inputcontext.h>
#include <fcitx/inputmethodengine.h>
#include <fcitx/inputpanel.h>
#include <fcitx/instance.h>
#include <fcitx/inputmethodentry.h> // 确保包含此头文件
#include <iostream>
#include <string>
#include <vector>
#include <chrono>
#include <memory>
// 自定义候选词,处理选中时的行为
class SuimeEngine; // 前向声明
class SuimeCandidateWord : public fcitx::CandidateWord {
public:
SuimeCandidateWord(const std::string &text, SuimeEngine *engine)
: CandidateWord(fcitx::Text(text)), engine_(engine) {}
void select(fcitx::InputContext *ic) const override;
private:
SuimeEngine *engine_;
};
class SuimeEngine : public fcitx::InputMethodEngineV2 {
public:
SuimeEngine(fcitx::Instance *instance)
: instance_(instance), socket_path_("/tmp/su-ime.sock"),
currentIc_(nullptr), timeoutDuration_(50) {
if (!socket_client_.connect(socket_path_)) {
std::cerr << "Failed to connect to SUIME server" << std::endl;
} else {
std::cout << "Connected to SUIME server via UDS" << std::endl;
}
}
~SuimeEngine() {
// 无需手动调用 removeEvent
socket_client_.disconnect();
}
void keyEvent(const fcitx::InputMethodEntry &entry, fcitx::KeyEvent &keyEvent) override {
auto *ic = keyEvent.inputContext();
if (!ic || keyEvent.isRelease()) return;
currentIc_ = ic; // 存储当前输入上下文
resetTimer(); // 重置定时器
std::string keyStr = keyEvent.key().toString();
std::cerr << "SuIME keyEvent: key=" << keyStr << ", current_pinyin_=" << current_pinyin_ << std::endl;
// 空格键处理:发送预测请求
if (keyEvent.key().check(fcitx::Key("space"))) {
std::cerr << "Space detected, sending prediction request" << std::endl;
sendPredictionRequest(ic);
keyEvent.filter();
return;
}
// 字母处理:只接受可打印的 ASCII 字母,并转换为小写累积
if (keyStr.length() == 1) {
char c = keyStr[0];
if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) {
// 转换为小写并添加
current_pinyin_ += std::tolower(c);
std::cerr << "Added char, current_pinyin_ now: " << current_pinyin_ << std::endl;
sendPredictionRequest(ic); // 新增:发送预测请求
keyEvent.filter();
return; // 确保返回,避免后续处理
}
}
// 退格处理
else if (keyEvent.key().check(fcitx::Key("BackSpace"))) {
if (!current_pinyin_.empty()) {
current_pinyin_.pop_back();
std::cerr << "Backspace, current_pinyin_ now: " << current_pinyin_ << std::endl;
}
sendPredictionRequest(ic); // 新增:发送预测请求
keyEvent.filter();
return; // 确保返回
}
// 其他键不处理,可能触发其他事件
}
void reset(const fcitx::InputMethodEntry &entry,
fcitx::InputContextEvent &event) override {
current_pinyin_.clear();
if (auto *ic = event.inputContext()) {
ic->inputPanel().reset();
ic->updateUserInterface(fcitx::UserInterfaceComponent::InputPanel);
}
currentIc_ = nullptr;
timeoutEvent_.reset(); // 取消定时器
}
std::vector<fcitx::InputMethodEntry> listInputMethods() override {
std::vector<fcitx::InputMethodEntry> entries;
entries.emplace_back("suime", "SUIME", "suime", "zh_CN");
entries.back().setLabel("SuIME");
return entries;
}
void clearPinyin() {
current_pinyin_.clear();
}
private:
void sendPredictionRequest(fcitx::InputContext *ic) {
std::cout << "Sending pinyin to server: " << current_pinyin_ << std::endl; // 打印到控制台
std::cerr << "sendPredictionRequest called" << std::endl;
// 检查连接状态,如果断开则尝试重新连接
if (!socket_client_.isConnected()) {
std::cerr << "Connection lost, attempting to reconnect..." << std::endl;
if (!socket_client_.connect(socket_path_)) {
std::cerr << "Failed to reconnect to SUIME server" << std::endl;
// 保持当前候选词列表不变,避免候选框消失
return;
} else {
std::cout << "Reconnected to SUIME server via UDS" << std::endl;
}
}
if (current_pinyin_.empty()) {
std::cerr << " current_pinyin_ empty, abort" << std::endl;
return;
}
std::cerr << " preparing request for pinyin=" << current_pinyin_ << std::endl;
Request req;
req.pinyin = current_pinyin_;
req.context = getContext(ic);
std::cerr << " calling socket_client_.sendRequest..." << std::endl;
Response res = socket_client_.sendRequest(req);
std::cerr << " sendRequest returned, candidates count=" << res.candidates.size() << std::endl;
// 即使候选词为空,也更新候选词列表以确保候选框状态一致
updateCandidates(ic, res.candidates);
}
void updateCandidates(fcitx::InputContext *ic,
const std::vector<Candidate> &candidates) {
auto candidateList = std::make_unique<fcitx::CommonCandidateList>();
candidateList->setPageSize(5);
candidateList->setSelectionKey(
fcitx::Key::keyListFromString("1234567890"));
candidateList->setLayoutHint(fcitx::CandidateLayoutHint::Vertical);
for (const auto &cand : candidates) {
candidateList->append(
std::make_unique<SuimeCandidateWord>(cand.text, this));
}
ic->inputPanel().setCandidateList(std::move(candidateList));
ic->updateUserInterface(fcitx::UserInterfaceComponent::InputPanel);
}
std::string getContext(fcitx::InputContext *ic) {
// 获取光标前文本作为上下文
auto text = ic->surroundingText();
std::string context = text.text().substr(0, text.cursor());
return context;
}
void resetTimer() {
timeoutEvent_.reset(); // 取消旧定时器
if (currentIc_) {
uint64_t usec = std::chrono::duration_cast<std::chrono::microseconds>(timeoutDuration_).count();
timeoutEvent_ = instance_->eventLoop().addTimeEvent(
CLOCK_MONOTONIC,
usec,
0, // accuracy
[this](fcitx::EventSourceTime*, uint64_t) {
this->handleTimeout();
return false; // 单次触
}
);
}
}
// handleTimeout 保持无参即可
void handleTimeout() {
std::cout << "50ms timeout triggered, sending accumulated pinyin: "
<< current_pinyin_ << std::endl;
if (currentIc_ && !current_pinyin_.empty()) {
sendPredictionRequest(currentIc_);
}
}
fcitx::Instance *instance_;
SocketClient socket_client_;
std::string current_pinyin_;
std::string socket_path_;
fcitx::InputContext *currentIc_;
std::chrono::milliseconds timeoutDuration_;
std::unique_ptr<fcitx::EventSource> timeoutEvent_;
};
void SuimeCandidateWord::select(fcitx::InputContext *ic) const {
ic->commitString(text().toString());
if (engine_) {
engine_->clearPinyin(); // 用户做出选择后清空积累的拼音
}
}
// 插件工厂
class SuimeEngineFactory : public fcitx::AddonFactory {
public:
fcitx::AddonInstance *create(fcitx::AddonManager *manager) override {
return new SuimeEngine(manager->instance());
}
};
FCITX_ADDON_FACTORY(SuimeEngineFactory)

View File

@ -0,0 +1,12 @@
[Addon]
Name=SuIME
Name[zh_CN]=SUIME
Category=InputMethod
Version=0.1.0
Library=libsuime
Type=SharedLibrary
OnDemand=True
Configurable=False
[Addon/Dependencies]
0=punctuation

View File

@ -0,0 +1,8 @@
[InputMethod]
Name=SuIME
Name[zh_CN]=SuIME
Icon=fcitx-keyboard-cn
Label=Su
LangCode=zh_CN
Addon=suime
Configurable=False

View File

@ -16,6 +16,7 @@ pub struct Config {
/// 额外配置,为未来扩展预留
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra: Option<toml::Value>,
pub vocabs_path: String
}
#[derive(Debug, Serialize, Deserialize)]
@ -60,6 +61,9 @@ impl Default for Config {
let tokenizer_path = base_dir.join("assets/tokenizer.json")
.to_string_lossy()
.to_string();
let vocabs_path = base_dir.join("assets/vocabs.json")
.to_string_lossy()
.to_string();
#[cfg(unix)]
let socket = SocketConfig::UnixDomainSocket("/tmp/su-ime.sock".to_string());
@ -71,6 +75,7 @@ impl Default for Config {
model_path,
tokenizer_path,
socket,
vocabs_path,
extra: None, // 默认无额外配置
}
}

81
src/filter.rs Normal file
View File

@ -0,0 +1,81 @@
use crate::vocabs::Dictionary;
/// 筛选后的候选词,包含匹配的拼音部分和剩余拼音。
#[derive(Debug, Clone)]
pub struct FilteredCandidate {
pub id: usize,
pub text: String,
pub weight: f32,
pub matched_pinyin: String,
pub remaining_pinyin: String,
}
/// 对预测结果进行筛选,基于用户输入的拼音。
///
/// # Arguments
/// * `pinyin` - 用户输入的拼音字符串,如 "shanghai"。
/// * `predicted` - 模型预测的 (id, weight) 对列表,通常来自 `OnnxModel::predict_to_sorted_pairs_simple`。
/// * `dict` - 字典查询引擎,用于获取汉字的拼音和文本。
///
/// # Returns
/// 返回一个 `Vec<FilteredCandidate>`,包含匹配的候选词及其信息。
pub fn filter_candidates(pinyin: &str, predicted: Vec<(usize, f32)>, dict: &Dictionary) -> Vec<FilteredCandidate> {
let mut candidates = Vec::new();
for (id, weight) in predicted {
if let Some(char_pinyin) = dict.get_pinyin_by_id(id) {
if let Some((matched, remaining)) = consume_pinyin(pinyin, char_pinyin) {
if let Some(text) = dict.get_char_by_id(id) {
candidates.push(FilteredCandidate {
id,
text: text.to_string(),
weight,
matched_pinyin: matched,
remaining_pinyin: remaining,
});
}
}
}
}
candidates.sort_by(|a, b| {
// 注意:我们要从大到小排序
// f32 的 partial_cmp 返回 Option<Ordering>
// b.weight.partial_cmp(&a.weight) 实现了降序 (b - a)
// 如果 weight 可能为 NaN需要决定如何处理这里假设没有 NaN
b.weight.partial_cmp(&a.weight).unwrap_or(std::cmp::Ordering::Equal)
});
let size = std::cmp::min(100, candidates.len());
candidates[0..size].to_vec()
}
/// 辅助函数:检查 `input` 是否以 `char_pinyin` 的某个前缀开头,返回汉字的完整拼音和剩余部分。
fn consume_pinyin(input: &str, char_pinyin: &str) -> Option<(String, String)> {
let matched_len = common_prefix_length(input, char_pinyin);
if matched_len > 0 {
// 返回汉字的完整拼音作为匹配部分,以及剩余的输入拼音
Some((char_pinyin.to_string(), input[matched_len..].to_string()))
} else {
None
}
}
pub fn common_prefix_length(s1: &str, s2: &str) -> usize {
// 使用 chars() 迭代器以正确处理 UTF-8 多字节字符如中文、Emoji
let mut chars1 = s1.chars();
let mut chars2 = s2.chars();
let mut count = 0;
loop {
match (chars1.next(), chars2.next()) {
// 两个都有字符,且相等
(Some(c1), Some(c2)) if c1 == c2 => {
count += 1;
}
// 其他所有情况:不相等 或 其中一个/两个已结束
_ => break,
}
}
count
}

View File

@ -1,59 +1,193 @@
use std::path::PathBuf;
use std::time::Instant;
mod config;
mod filter;
mod model;
mod protocol;
mod tokenizers;
mod vocabs;
mod config;
use tokenizers::HFTokenizer;
// use crate::vocabs::Dictionary;
use crate::config::{Config, SocketConfig};
use crate::model::OnnxModel;
use crate::config::Config;
use crate::tokenizers::HFTokenizer;
use crate::vocabs::Dictionary;
use crate::filter::filter_candidates;
use crate::protocol::{Request, Response, Candidate};
use std::error::Error;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
let config = Config::load().unwrap() ;
let mut session = OnnxModel::new(config.model_path, 4).unwrap();
let tokenizer = HFTokenizer::new(config.tokenizer_path).unwrap();
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let _ = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
#[cfg(unix)]
use std::os::unix::net::{UnixListener, UnixStream};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt; // 用于设置文件权限
#[cfg(unix)]
use std::fs::set_permissions;
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let logits = session
.predict_to_sorted_pairs_simple(sample, false)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
#[cfg(windows)]
use std::net::{TcpListener, TcpStream};
let start = Instant::now(); // 开始计时
let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap();
let probs = session
.predict_to_sorted_pairs_simple(sample, true)
.unwrap();
let duration = start.elapsed(); // 结束计时
print!("predict Time elapsed: {:?}\n", duration);
print!("logits: {:?}", &logits[0..10].to_vec());
print!("probs: {:?}", &probs[0..10].to_vec());
// 添加 ctrlc 依赖,用于捕获退出信号(仅在 Unix 下需要)
#[cfg(unix)]
use ctrlc;
/*
if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) {
println!("Tokenizer loaded successfully");
if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") {
let logits: ndarray::ArrayBase<
ndarray::OwnedRepr<f32>,
ndarray::Dim<ndarray::IxDynImpl>,
f32,
> = session.predict(sample).unwrap();
let duration = start.elapsed(); // 结束计时
println!("Time elapsed: {:?}", duration);
println!("Model input generated successfully");
println!("Logits: {:?}", logits);
fn handle_client(mut stream: impl Read + Write, model: Arc<Mutex<OnnxModel>>, tokenizer: Arc<Mutex<HFTokenizer>>, dict: Arc<Dictionary>) -> Result<(), Box<dyn Error>> {
loop {
// 读取4字节长度头
let mut len_buf = [0u8; 4];
if let Err(e) = stream.read_exact(&mut len_buf) {
// 如果读取失败,可能是连接关闭,正常退出循环
if e.kind() == std::io::ErrorKind::UnexpectedEof {
println!("客户端关闭连接");
break;
}
return Err(Box::new(e)); // 其他错误上报
}
}*/
let payload_len = u32::from_be_bytes(len_buf) as usize;
// 读取MessagePack载荷
let mut payload = vec![0u8; payload_len];
stream.read_exact(&mut payload)?;
// 反序列化请求
let request: Request = rmp_serde::from_slice(&payload)?;
println!("接收请求: {:?}", request);
// 准备模型输入
let model_input = {
let tokenizer = tokenizer.lock().unwrap();
tokenizer.gen_predict_sample(&request.context, &request.pinyin)?
};
// 进行预测
let predicted_pairs = {
let mut model = model.lock().unwrap();
model.predict_to_sorted_pairs_simple(model_input, true)?
};
// 筛选候选词
let filtered = filter_candidates(&request.pinyin, predicted_pairs, &dict);
// 转换为Response
let candidates: Vec<Candidate> = filtered
.into_iter()
.map(|fc| Candidate {
id: fc.id as u32,
text: fc.text,
weight: fc.weight,
})
.collect();
let response = Response {
candidates,
offset: 0,
limit: 10,
};
// 序列化响应
let response_payload = rmp_serde::to_vec(&response)?;
let response_len = response_payload.len() as u32;
let response_len_bytes = response_len.to_be_bytes();
// 发送响应
stream.write_all(&response_len_bytes)?;
stream.write_all(&response_payload)?;
// 确保数据发送完成flush 在 Write 实现中通常自动处理,但可显式调用)
stream.flush()?;
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
// 加载配置
let config = Config::load()?;
println!("配置加载成功: {:?}", config);
// 初始化组件
let model = OnnxModel::new(&config.model_path, 1)?;
let tokenizer = HFTokenizer::new(&config.tokenizer_path)?;
let dict = Dictionary::from_json_file(&config.vocabs_path)?;
let model = Arc::new(Mutex::new(model));
let tokenizer = Arc::new(Mutex::new(tokenizer));
let dict = Arc::new(dict);
// 创建监听器
match config.socket {
SocketConfig::UnixDomainSocket(path) => {
#[cfg(unix)]
{
// 尝试删除可能残留的socket文件
if let Err(e) = std::fs::remove_file(&path) {
// 如果文件不存在,忽略错误;否则打印警告
if e.kind() != std::io::ErrorKind::NotFound {
eprintln!("警告无法删除socket文件 {}: {}", path, e);
}
}
let listener = UnixListener::bind(&path)?;
println!("监听Unix Domain Socket在: {}", path);
// 设置socket文件权限为600仅所有者可读写减少被其他进程误操作的风险
let mut perms = std::fs::metadata(&path)?.permissions();
perms.set_mode(0o600); // 只允许所有者读写
set_permissions(&path, perms)?;
// 注册信号处理确保程序退出时清理socket文件
let socket_path = path.clone();
ctrlc::set_handler(move || {
println!("收到退出信号清理socket文件: {}", socket_path);
let _ = std::fs::remove_file(&socket_path);
std::process::exit(0);
})?;
for stream in listener.incoming() {
match stream {
Ok(stream) => {
println!("受理连接");
let model = Arc::clone(&model);
let tokenizer = Arc::clone(&tokenizer);
let dict = Arc::clone(&dict);
thread::spawn(move || {
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
eprintln!("处理连接错误: {}", e);
}
});
}
Err(e) => eprintln!("接受连接错误: {}", e),
}
}
}
#[cfg(not(unix))]
{
return Err("Unix Domain Socket仅在Unix系统上支持".into());
}
}
SocketConfig::TcpSocket(addr) => {
#[cfg(windows)]
{
let listener = TcpListener::bind(&addr)?;
println!("监听TCP Socket在: {}", addr);
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let model = Arc::clone(&model);
let tokenizer = Arc::clone(&tokenizer);
let dict = Arc::clone(&dict);
thread::spawn(move || {
if let Err(e) = handle_client(stream, model, tokenizer, dict) {
eprintln!("处理连接错误: {}", e);
}
});
}
Err(e) => eprintln!("接受连接错误: {}", e),
}
}
}
#[cfg(not(windows))]
{
return Err("TCP Socket仅在Windows系统上支持".into());
}
}
}
Ok(())
}

21
src/protocol.rs Normal file
View File

@ -0,0 +1,21 @@
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub pinyin: String, // 用户输入的拼音串
pub context: String, // 光标前文本 (用于上下文预测)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Candidate {
pub id: u32, // 词项 ID
pub text: String, // 候选词文本 (也可由客户端查表)
pub weight: f32, // 置信度/权重
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
pub candidates: Vec<Candidate>,
pub offset: usize,
pub limit: usize,
}

85
src/socket.rs Normal file
View File

@ -0,0 +1,85 @@
use std::io::{Read, Write};
use std::os::unix::net::{UnixListener, UnixStream};
use rmp_serde::{from_slice, to_vec};
use crate::protocol::{Request, Response};
pub fn start_server(socket_path: &str) -> std::io::Result<()> {
// 绑定到Unix Domain Socket路径
let listener = UnixListener::bind(socket_path)?;
println!("服务器监听在: {}", socket_path);
// 持续接受传入连接
for stream in listener.incoming() {
match stream {
Ok(stream) => {
// 在新线程中处理每个客户端连接以避免阻塞
std::thread::spawn(move || {
handle_client(stream);
});
}
Err(e) => {
eprintln!("接受连接失败: {}", e);
}
}
}
Ok(())
}
fn handle_client(mut stream: UnixStream) {
// 读取消息长度头4字节大端
let mut len_buf = [0u8; 4];
if let Err(e) = stream.read_exact(&mut len_buf) {
eprintln!("读取长度失败: {}", e);
return;
}
let len = u32::from_be_bytes(len_buf) as usize;
// 读取序列化的MessagePack数据
let mut data_buf = vec![0u8; len];
if let Err(e) = stream.read_exact(&mut data_buf) {
eprintln!("读取数据失败: {}", e);
return;
}
// 反序列化为Request结构
let request: Request = match from_slice(&data_buf) {
Ok(req) => req,
Err(e) => {
eprintln!("反序列化请求失败: {}", e);
return;
}
};
// 处理请求这里调用其他业务模块如filter.rs
let response = process_request(request);
// 序列化Response为MessagePack
let response_data = match to_vec(&response) {
Ok(data) => data,
Err(e) => {
eprintln!("序列化响应失败: {}", e);
return;
}
};
// 发送响应长度和数据
let response_len = response_data.len() as u32;
let len_bytes = response_len.to_be_bytes();
if let Err(e) = stream.write_all(&len_bytes) {
eprintln!("写入长度失败: {}", e);
return;
}
if let Err(e) = stream.write_all(&response_data) {
eprintln!("写入数据失败: {}", e);
return;
}
}
fn process_request(request: Request) -> Response {
// 示例处理返回空的候选词列表实际应集成filter.rs等模块
Response {
candidates: Vec::new(),
offset: 0,
limit: 0,
}
}

View File

@ -8,7 +8,7 @@ use std::path::Path;
/// 单个字符-拼音对的信息
#[derive(Debug, Deserialize, Clone)]
pub struct CharInfo {
pub id: u32,
pub id: usize,
#[serde(rename = "char")]
pub character: String,
pub pinyin: String,
@ -24,7 +24,7 @@ struct RawStatistics {
/// 字典查询引擎,提供 O(1) 的 ID 到信息的映射
pub struct Dictionary {
id_to_charinfo: HashMap<u32, CharInfo>,
id_to_charinfo: HashMap<usize, CharInfo>,
}
impl Dictionary {
@ -38,7 +38,7 @@ impl Dictionary {
let mut id_to_charinfo = HashMap::with_capacity(raw.pairs.len());
for (id_str, info) in raw.pairs {
let id = id_str
.parse::<u32>()
.parse::<usize>()
.with_context(|| format!("无效的 ID 字符串: {}", id_str))?;
// 可选:验证 id 与 info.id 一致,此处忽略不一致的情况(信任输入数据)
id_to_charinfo.insert(info.id, info);
@ -48,22 +48,22 @@ impl Dictionary {
}
/// 通过 ID 获取汉字(用于填充 Candidate.text
pub fn get_char_by_id(&self, id: u32) -> Option<&str> {
pub fn get_char_by_id(&self, id: usize) -> Option<&str> {
self.id_to_charinfo.get(&id).map(|info| info.character.as_str())
}
/// 通过 ID 获取拼音
pub fn get_pinyin_by_id(&self, id: u32) -> Option<&str> {
pub fn get_pinyin_by_id(&self, id: usize) -> Option<&str> {
self.id_to_charinfo.get(&id).map(|info| info.pinyin.as_str())
}
/// 通过 ID 获取出现次数
pub fn get_count_by_id(&self, id: u32) -> Option<u64> {
pub fn get_count_by_id(&self, id: usize) -> Option<u64> {
self.id_to_charinfo.get(&id).map(|info| info.count)
}
/// 获取完整的 CharInfo 引用
pub fn get_char_info(&self, id: u32) -> Option<&CharInfo> {
pub fn get_char_info(&self, id: usize) -> Option<&CharInfo> {
self.id_to_charinfo.get(&id)
}

92
test.py Normal file
View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
import socket
import struct
import msgpack
import sys
import os
import time
# Define test data: 5 requests with different pinyin and context
test_requests = [
{"pinyin": "shanghai", "context": ""},
{"pinyin": "beijing", "context": ""},
{"pinyin": "nihao", "context": ""},
{"pinyin": "xiexie", "context": ""},
{"pinyin": "f", "context": "返回汉字的完整拼音和剩余部"}
]
# Determine socket address based on operating system
def get_socket_address():
if sys.platform == "win32":
# Windows: TCP socket
return ("127.0.0.1", 23333), "tcp"
else:
# Unix-like systems: Unix Domain Socket
return "/tmp/su-ime.sock", "unix"
# Connect to server
def connect_to_server(address, socket_type):
if socket_type == "unix":
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(address)
else: # tcp
host, port = address
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
return sock
# Send request and receive response
def send_request(sock, request):
# Serialize request to MessagePack
payload = msgpack.packb(request)
payload_len = len(payload)
# Create header: 4-byte big-endian unsigned integer
header = struct.pack(">I", payload_len)
# Send header and payload
sock.sendall(header + payload)
# Receive response
# Read header
len_buf = sock.recv(4)
if len(len_buf) != 4:
raise ValueError("Failed to read response header")
response_len = struct.unpack(">I", len_buf)[0]
# Read payload
response_payload = b""
while len(response_payload) < response_len:
chunk = sock.recv(response_len - len(response_payload))
if not chunk:
break
response_payload += chunk
# Deserialize response
response = msgpack.unpackb(response_payload, raw=False)
return response
# Main function
def main():
address, socket_type = get_socket_address()
try:
sock = connect_to_server(address, socket_type)
print(f"Connected to server via {socket_type} at {address}")
for i, req in enumerate(test_requests, 1):
print(f"\nSending request {i}: {req}")
try:
start = time.time()
response = send_request(sock, req)
print(f'cost time: {time.time() - start}')
print(f"Response {i}: {response[0][0:10]}")
except Exception as e:
print(f"Error sending request {i}: {e}")
sock.close()
print("\nAll requests sent.")
except Exception as e:
print(f"Connection error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()