diff --git a/Cargo.toml b/Cargo.toml index a6e00cb..05cf17e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,10 +5,12 @@ edition = "2024" [dependencies] anyhow = "1.0.102" +ctrlc = "3.5.2" directories = "6.0.0" lazy_static = "1.5.0" ndarray = "0.17.2" ort = "2.0.0-rc.11" +rmp-serde = "1.3.1" serde = "1.0.228" serde_json = "1.0.149" tempfile = "3.26.0" diff --git a/README.md b/README.md index c573c42..3de3e1d 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ graph TD end subgraph "Client Frontends" - L[Linux Client
fcitx5-ext] + L[Linux Client
fcitx5-ext] W[Windows Client
Rime/Weasel Mod] Mac[macOS Client
Rime/Squirrel Mod] end @@ -78,7 +78,9 @@ SUIME/ ### 文件说明 -- **filter.rs** 将OnnxModel::predict预测的结果转化为(汉字、拼音、权重),排除和用户输入的拼音不相符的汉字,比如用户输入的是shanghai,预测的结果为(上,商,尚,特,……),特将会被排除,为了简化,直接粗暴的将预测汉字的拼音首字母和用户输入的拼音首字母进行对比,相同视为相符,不相同视为不符。 +- **filter.rs** 将OnnxModel::predict预测的结果转化为(汉字、对应汉字的完整拼音、权重、剩余拼音),排除和用户输入的拼音不相符的汉字。 + - 示例:用户输入的是shanghai,预测的结果为(上,商,尚,特,……),特将会被排除。 + - 返回的结果里面还应该包含剩余被消耗后的拼音,比如shanghai,转化后的结果可能为(上,shang,0.9,hai)也可能是(沙,sha,0.001,nghai),再比如输入为shhai,转化的结果为(上,shang,0.6,hai) --- diff --git a/fcitx5-ext/CMakeLists.txt b/fcitx5-ext/CMakeLists.txt new file mode 100644 index 0000000..d396257 --- /dev/null +++ b/fcitx5-ext/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 3.10) + +project(suime-fcitx5 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(PkgConfig REQUIRED) + +pkg_check_modules(FCITX5 REQUIRED Fcitx5Core Fcitx5Utils) +pkg_check_modules(MSGPACK REQUIRED msgpack-c) + +add_library(suime SHARED src/suime.cpp src/socket_client.cpp) + +target_include_directories(suime PRIVATE ${FCITX5_INCLUDE_DIRS} ${MSGPACK_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/src) + +target_link_libraries(suime ${FCITX5_LIBRARIES} ${MSGPACK_LIBRARIES}) + +install(TARGETS suime DESTINATION lib/fcitx5) + + +# 准备并安装 Addon 配置文件 +# 将 .conf.in 转换为 .conf (如果有变量替换需求),或者直接安装 +configure_file(suime-addon.conf.in suime.conf) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/suime.conf + DESTINATION share/fcitx5/addon/) + + +# 准备并安装 InputMethod 配置文件 +# 注意:这里安装后的文件名也应该是 suime.conf,但在 inputmethod 目录下 +configure_file(suime-im.conf.in suime-im.conf) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/suime-im.conf + RENAME suime.conf + DESTINATION share/fcitx5/inputmethod) diff --git a/fcitx5-ext/src/protocol.hpp b/fcitx5-ext/src/protocol.hpp new file mode 100644 index 0000000..6a303f3 --- /dev/null +++ b/fcitx5-ext/src/protocol.hpp @@ -0,0 +1,28 @@ +#ifndef PROTOCOL_HPP +#define PROTOCOL_HPP + +#include +#include +#include + +struct Request { + std::string pinyin; + std::string context; + MSGPACK_DEFINE(pinyin, context); +}; + +struct Candidate { + uint32_t id; + std::string text; + float weight; + MSGPACK_DEFINE(id, text, weight); +}; + +struct Response { + std::vector candidates; + size_t offset; + size_t limit; + MSGPACK_DEFINE(candidates, offset, limit); +}; + +#endif // PROTOCOL_HPP \ No newline at end of file diff --git a/fcitx5-ext/src/socket_client.cpp b/fcitx5-ext/src/socket_client.cpp new file mode 100644 index 0000000..c3334b5 --- /dev/null +++ b/fcitx5-ext/src/socket_client.cpp @@ -0,0 +1,156 @@ +#include "socket_client.hpp" +#include "protocol.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +SocketClient::SocketClient() : sock_fd_(-1) {} + +SocketClient::~SocketClient() { + disconnect(); +} + +bool SocketClient::connect(const std::string& socket_path) { + sock_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (sock_fd_ < 0) { + std::cerr << "SocketClient: Failed to create socket: " << strerror(errno) << std::endl; + return false; + } + + // 设置超时 + struct timeval tv; + tv.tv_sec = 5; // 5秒超时 + tv.tv_usec = 0; + if (setsockopt(sock_fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { + std::cerr << "SocketClient: Failed to set receive timeout: " << strerror(errno) << std::endl; + close(sock_fd_); + sock_fd_ = -1; + return false; + } + if (setsockopt(sock_fd_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0) { + std::cerr << "SocketClient: Failed to set send timeout: " << strerror(errno) << std::endl; + close(sock_fd_); + sock_fd_ = -1; + return false; + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, socket_path.c_str(), sizeof(addr.sun_path) - 1); + addr.sun_path[sizeof(addr.sun_path) - 1] = '\0'; + + if (::connect(sock_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + std::cerr << "SocketClient: Failed to connect to server at " << socket_path << ": " << strerror(errno) << std::endl; + close(sock_fd_); + sock_fd_ = -1; + return false; + } + + std::cout << "SocketClient: Connected to server via UDS at " << socket_path << std::endl; + return true; +} + +bool SocketClient::disconnect() { + if (sock_fd_ >= 0) { + close(sock_fd_); + sock_fd_ = -1; + std::cout << "SocketClient: Disconnected" << std::endl; + return true; + } + return false; +} + +bool SocketClient::isConnected() const { + return sock_fd_ >= 0; +} + +Response SocketClient::sendRequest(const Request& req) { + if (sock_fd_ < 0) { + std::cerr << "SocketClient: Not connected, cannot send request" << std::endl; + return Response{}; + } + + // Serialize request to MessagePack + msgpack::sbuffer buffer; + msgpack::pack(buffer, req); + + // Send TLV header (4-byte big-endian length) + uint32_t len = static_cast(buffer.size()); + uint32_t len_be = htonl(len); + ssize_t total_sent = 0; + while (total_sent < sizeof(len_be)) { + ssize_t sent = write(sock_fd_, reinterpret_cast(&len_be) + total_sent, sizeof(len_be) - total_sent); + if (sent < 0) { + if (errno == EINTR) continue; // 被信号中断,重试 + std::cerr << "SocketClient: Failed to send header, error: " << strerror(errno) << std::endl; + return Response{}; + } + total_sent += sent; + } + + // Send payload + total_sent = 0; + const char* data = buffer.data(); + size_t data_size = buffer.size(); + while (total_sent < static_cast(data_size)) { + ssize_t sent = write(sock_fd_, data + total_sent, data_size - total_sent); + if (sent < 0) { + if (errno == EINTR) continue; + std::cerr << "SocketClient: Failed to send payload, error: " << strerror(errno) << std::endl; + return Response{}; + } + total_sent += sent; + } + + // Receive response header + uint32_t resp_len_be; + ssize_t total_received = 0; + while (total_received < sizeof(resp_len_be)) { + ssize_t received = read(sock_fd_, reinterpret_cast(&resp_len_be) + total_received, sizeof(resp_len_be) - total_received); + if (received < 0) { + if (errno == EINTR) continue; + std::cerr << "SocketClient: Failed to read response header, error: " << strerror(errno) << std::endl; + return Response{}; + } else if (received == 0) { + std::cerr << "SocketClient: Connection closed by server while reading header" << std::endl; + return Response{}; + } + total_received += received; + } + uint32_t resp_len = ntohl(resp_len_be); + + // Receive response payload + std::vector resp_buffer(resp_len); + total_received = 0; + while (total_received < static_cast(resp_len)) { + ssize_t received = read(sock_fd_, resp_buffer.data() + total_received, resp_len - total_received); + if (received < 0) { + if (errno == EINTR) continue; + std::cerr << "SocketClient: Failed to read response payload, error: " << strerror(errno) << std::endl; + return Response{}; + } else if (received == 0) { + std::cerr << "SocketClient: Connection closed by server while reading payload" << std::endl; + return Response{}; + } + total_received += received; + } + + // Deserialize response + try { + msgpack::object_handle oh = msgpack::unpack(resp_buffer.data(), resp_len); + Response res; + oh.get().convert(res); + return res; + } catch (const std::exception& e) { + std::cerr << "SocketClient: Failed to deserialize response: " << e.what() << std::endl; + return Response{}; + } +} \ No newline at end of file diff --git a/fcitx5-ext/src/socket_client.hpp b/fcitx5-ext/src/socket_client.hpp new file mode 100644 index 0000000..22488f3 --- /dev/null +++ b/fcitx5-ext/src/socket_client.hpp @@ -0,0 +1,21 @@ +#ifndef SOCKET_CLIENT_HPP +#define SOCKET_CLIENT_HPP + +#include +#include "protocol.hpp" + +class SocketClient { +public: + SocketClient(); + ~SocketClient(); + + bool connect(const std::string& socket_path); + bool disconnect(); + bool isConnected() const; + Response sendRequest(const Request& req); + +private: + int sock_fd_; +}; + +#endif // SOCKET_CLIENT_HPP \ No newline at end of file diff --git a/fcitx5-ext/src/suime.cpp b/fcitx5-ext/src/suime.cpp new file mode 100644 index 0000000..3433a94 --- /dev/null +++ b/fcitx5-ext/src/suime.cpp @@ -0,0 +1,219 @@ +#include "protocol.hpp" +#include "socket_client.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include // 确保包含此头文件 +#include +#include +#include +#include +#include + +// 自定义候选词,处理选中时的行为 +class SuimeEngine; // 前向声明 +class SuimeCandidateWord : public fcitx::CandidateWord { +public: + SuimeCandidateWord(const std::string &text, SuimeEngine *engine) + : CandidateWord(fcitx::Text(text)), engine_(engine) {} + + void select(fcitx::InputContext *ic) const override; + +private: + SuimeEngine *engine_; +}; + +class SuimeEngine : public fcitx::InputMethodEngineV2 { +public: +SuimeEngine(fcitx::Instance *instance) + : instance_(instance), socket_path_("/tmp/su-ime.sock"), + currentIc_(nullptr), timeoutDuration_(50) { + if (!socket_client_.connect(socket_path_)) { + std::cerr << "Failed to connect to SUIME server" << std::endl; + } else { + std::cout << "Connected to SUIME server via UDS" << std::endl; + } +} + ~SuimeEngine() { + // 无需手动调用 removeEvent + socket_client_.disconnect(); + } + + + void keyEvent(const fcitx::InputMethodEntry &entry, fcitx::KeyEvent &keyEvent) override { + auto *ic = keyEvent.inputContext(); + if (!ic || keyEvent.isRelease()) return; + + currentIc_ = ic; // 存储当前输入上下文 + resetTimer(); // 重置定时器 + + std::string keyStr = keyEvent.key().toString(); + std::cerr << "SuIME keyEvent: key=" << keyStr << ", current_pinyin_=" << current_pinyin_ << std::endl; + + // 空格键处理:发送预测请求 + if (keyEvent.key().check(fcitx::Key("space"))) { + std::cerr << "Space detected, sending prediction request" << std::endl; + sendPredictionRequest(ic); + keyEvent.filter(); + return; + } + + // 字母处理:只接受可打印的 ASCII 字母,并转换为小写累积 + if (keyStr.length() == 1) { + char c = keyStr[0]; + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) { + // 转换为小写并添加 + current_pinyin_ += std::tolower(c); + std::cerr << "Added char, current_pinyin_ now: " << current_pinyin_ << std::endl; + sendPredictionRequest(ic); // 新增:发送预测请求 + keyEvent.filter(); + return; // 确保返回,避免后续处理 + } + } + // 退格处理 + else if (keyEvent.key().check(fcitx::Key("BackSpace"))) { + if (!current_pinyin_.empty()) { + current_pinyin_.pop_back(); + std::cerr << "Backspace, current_pinyin_ now: " << current_pinyin_ << std::endl; + } + sendPredictionRequest(ic); // 新增:发送预测请求 + keyEvent.filter(); + return; // 确保返回 + } + // 其他键不处理,可能触发其他事件 + } + + void reset(const fcitx::InputMethodEntry &entry, + fcitx::InputContextEvent &event) override { + current_pinyin_.clear(); + if (auto *ic = event.inputContext()) { + ic->inputPanel().reset(); + ic->updateUserInterface(fcitx::UserInterfaceComponent::InputPanel); + } + currentIc_ = nullptr; + timeoutEvent_.reset(); // 取消定时器 + } + + std::vector listInputMethods() override { + std::vector entries; + entries.emplace_back("suime", "SUIME", "suime", "zh_CN"); + entries.back().setLabel("SuIME"); + return entries; + } + + void clearPinyin() { + current_pinyin_.clear(); + } + +private: + void sendPredictionRequest(fcitx::InputContext *ic) { + std::cout << "Sending pinyin to server: " << current_pinyin_ << std::endl; // 打印到控制台 + std::cerr << "sendPredictionRequest called" << std::endl; + // 检查连接状态,如果断开则尝试重新连接 + if (!socket_client_.isConnected()) { + std::cerr << "Connection lost, attempting to reconnect..." << std::endl; + if (!socket_client_.connect(socket_path_)) { + std::cerr << "Failed to reconnect to SUIME server" << std::endl; + // 保持当前候选词列表不变,避免候选框消失 + return; + } else { + std::cout << "Reconnected to SUIME server via UDS" << std::endl; + } + } + if (current_pinyin_.empty()) { + std::cerr << " current_pinyin_ empty, abort" << std::endl; + return; + } + std::cerr << " preparing request for pinyin=" << current_pinyin_ << std::endl; + Request req; + req.pinyin = current_pinyin_; + req.context = getContext(ic); + + std::cerr << " calling socket_client_.sendRequest..." << std::endl; + Response res = socket_client_.sendRequest(req); + std::cerr << " sendRequest returned, candidates count=" << res.candidates.size() << std::endl; + + // 即使候选词为空,也更新候选词列表以确保候选框状态一致 + updateCandidates(ic, res.candidates); + } + + void updateCandidates(fcitx::InputContext *ic, + const std::vector &candidates) { + auto candidateList = std::make_unique(); + candidateList->setPageSize(5); + candidateList->setSelectionKey( + fcitx::Key::keyListFromString("1234567890")); + candidateList->setLayoutHint(fcitx::CandidateLayoutHint::Vertical); + + for (const auto &cand : candidates) { + candidateList->append( + std::make_unique(cand.text, this)); + } + + ic->inputPanel().setCandidateList(std::move(candidateList)); + ic->updateUserInterface(fcitx::UserInterfaceComponent::InputPanel); + } + + std::string getContext(fcitx::InputContext *ic) { + // 获取光标前文本作为上下文 + auto text = ic->surroundingText(); + std::string context = text.text().substr(0, text.cursor()); + return context; + } + + void resetTimer() { + timeoutEvent_.reset(); // 取消旧定时器 + if (currentIc_) { + uint64_t usec = std::chrono::duration_cast(timeoutDuration_).count(); + timeoutEvent_ = instance_->eventLoop().addTimeEvent( + CLOCK_MONOTONIC, + usec, + 0, // accuracy + [this](fcitx::EventSourceTime*, uint64_t) { + this->handleTimeout(); + return false; // 单次触 + } + ); + } + } + + // handleTimeout 保持无参即可 + void handleTimeout() { + std::cout << "50ms timeout triggered, sending accumulated pinyin: " + << current_pinyin_ << std::endl; + if (currentIc_ && !current_pinyin_.empty()) { + sendPredictionRequest(currentIc_); + } + } + + + fcitx::Instance *instance_; + SocketClient socket_client_; + std::string current_pinyin_; + std::string socket_path_; + fcitx::InputContext *currentIc_; + std::chrono::milliseconds timeoutDuration_; + std::unique_ptr timeoutEvent_; +}; + +void SuimeCandidateWord::select(fcitx::InputContext *ic) const { + ic->commitString(text().toString()); + if (engine_) { + engine_->clearPinyin(); // 用户做出选择后清空积累的拼音 + } +} + +// 插件工厂 +class SuimeEngineFactory : public fcitx::AddonFactory { +public: + fcitx::AddonInstance *create(fcitx::AddonManager *manager) override { + return new SuimeEngine(manager->instance()); + } +}; + +FCITX_ADDON_FACTORY(SuimeEngineFactory) diff --git a/fcitx5-ext/suime-addon.conf.in b/fcitx5-ext/suime-addon.conf.in new file mode 100644 index 0000000..de832aa --- /dev/null +++ b/fcitx5-ext/suime-addon.conf.in @@ -0,0 +1,12 @@ +[Addon] +Name=SuIME +Name[zh_CN]=SUIME +Category=InputMethod +Version=0.1.0 +Library=libsuime +Type=SharedLibrary +OnDemand=True +Configurable=False + +[Addon/Dependencies] +0=punctuation diff --git a/fcitx5-ext/suime-im.conf.in b/fcitx5-ext/suime-im.conf.in new file mode 100644 index 0000000..c4048cc --- /dev/null +++ b/fcitx5-ext/suime-im.conf.in @@ -0,0 +1,8 @@ +[InputMethod] +Name=SuIME +Name[zh_CN]=SuIME +Icon=fcitx-keyboard-cn +Label=Su +LangCode=zh_CN +Addon=suime +Configurable=False diff --git a/src/config.rs b/src/config.rs index 92531d3..752a61b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -16,6 +16,7 @@ pub struct Config { /// 额外配置,为未来扩展预留 #[serde(default, skip_serializing_if = "Option::is_none")] pub extra: Option, + pub vocabs_path: String } #[derive(Debug, Serialize, Deserialize)] @@ -60,6 +61,9 @@ impl Default for Config { let tokenizer_path = base_dir.join("assets/tokenizer.json") .to_string_lossy() .to_string(); + let vocabs_path = base_dir.join("assets/vocabs.json") + .to_string_lossy() + .to_string(); #[cfg(unix)] let socket = SocketConfig::UnixDomainSocket("/tmp/su-ime.sock".to_string()); @@ -71,6 +75,7 @@ impl Default for Config { model_path, tokenizer_path, socket, + vocabs_path, extra: None, // 默认无额外配置 } } diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..b3462a8 --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,81 @@ +use crate::vocabs::Dictionary; + +/// 筛选后的候选词,包含匹配的拼音部分和剩余拼音。 +#[derive(Debug, Clone)] +pub struct FilteredCandidate { + pub id: usize, + pub text: String, + pub weight: f32, + pub matched_pinyin: String, + pub remaining_pinyin: String, +} + +/// 对预测结果进行筛选,基于用户输入的拼音。 +/// +/// # Arguments +/// * `pinyin` - 用户输入的拼音字符串,如 "shanghai"。 +/// * `predicted` - 模型预测的 (id, weight) 对列表,通常来自 `OnnxModel::predict_to_sorted_pairs_simple`。 +/// * `dict` - 字典查询引擎,用于获取汉字的拼音和文本。 +/// +/// # Returns +/// 返回一个 `Vec`,包含匹配的候选词及其信息。 +pub fn filter_candidates(pinyin: &str, predicted: Vec<(usize, f32)>, dict: &Dictionary) -> Vec { + let mut candidates = Vec::new(); + for (id, weight) in predicted { + if let Some(char_pinyin) = dict.get_pinyin_by_id(id) { + if let Some((matched, remaining)) = consume_pinyin(pinyin, char_pinyin) { + if let Some(text) = dict.get_char_by_id(id) { + candidates.push(FilteredCandidate { + id, + text: text.to_string(), + weight, + matched_pinyin: matched, + remaining_pinyin: remaining, + }); + } + } + } + } + candidates.sort_by(|a, b| { + // 注意:我们要从大到小排序 + // f32 的 partial_cmp 返回 Option + // b.weight.partial_cmp(&a.weight) 实现了降序 (b - a) + // 如果 weight 可能为 NaN,需要决定如何处理,这里假设没有 NaN + b.weight.partial_cmp(&a.weight).unwrap_or(std::cmp::Ordering::Equal) + }); + let size = std::cmp::min(100, candidates.len()); + candidates[0..size].to_vec() +} + +/// 辅助函数:检查 `input` 是否以 `char_pinyin` 的某个前缀开头,返回汉字的完整拼音和剩余部分。 +fn consume_pinyin(input: &str, char_pinyin: &str) -> Option<(String, String)> { + let matched_len = common_prefix_length(input, char_pinyin); + if matched_len > 0 { + // 返回汉字的完整拼音作为匹配部分,以及剩余的输入拼音 + Some((char_pinyin.to_string(), input[matched_len..].to_string())) + } else { + None + } +} + + +pub fn common_prefix_length(s1: &str, s2: &str) -> usize { + // 使用 chars() 迭代器以正确处理 UTF-8 多字节字符(如中文、Emoji) + let mut chars1 = s1.chars(); + let mut chars2 = s2.chars(); + + let mut count = 0; + + loop { + match (chars1.next(), chars2.next()) { + // 两个都有字符,且相等 + (Some(c1), Some(c2)) if c1 == c2 => { + count += 1; + } + // 其他所有情况:不相等 或 其中一个/两个已结束 + _ => break, + } + } + + count +} diff --git a/src/main.rs b/src/main.rs index 3a154bf..81697cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,59 +1,193 @@ -use std::path::PathBuf; -use std::time::Instant; - +mod config; +mod filter; mod model; +mod protocol; mod tokenizers; mod vocabs; -mod config; -use tokenizers::HFTokenizer; -// use crate::vocabs::Dictionary; +use crate::config::{Config, SocketConfig}; use crate::model::OnnxModel; -use crate::config::Config; +use crate::tokenizers::HFTokenizer; +use crate::vocabs::Dictionary; +use crate::filter::filter_candidates; +use crate::protocol::{Request, Response, Candidate}; +use std::error::Error; +use std::io::{Read, Write}; +use std::sync::{Arc, Mutex}; +use std::thread; -fn main() { - let config = Config::load().unwrap() ; - let mut session = OnnxModel::new(config.model_path, 4).unwrap(); - let tokenizer = HFTokenizer::new(config.tokenizer_path).unwrap(); - let start = Instant::now(); // 开始计时 - let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap(); - let _ = session.predict(sample).unwrap(); - let duration = start.elapsed(); // 结束计时 - print!("predict Time elapsed: {:?}\n", duration); +#[cfg(unix)] +use std::os::unix::net::{UnixListener, UnixStream}; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; // 用于设置文件权限 +#[cfg(unix)] +use std::fs::set_permissions; - let start = Instant::now(); // 开始计时 - let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap(); - let logits = session - .predict_to_sorted_pairs_simple(sample, false) - .unwrap(); - let duration = start.elapsed(); // 结束计时 - print!("predict Time elapsed: {:?}\n", duration); +#[cfg(windows)] +use std::net::{TcpListener, TcpStream}; - let start = Instant::now(); // 开始计时 - let sample = tokenizer.gen_predict_sample("从北京到上", "hai").unwrap(); - let probs = session - .predict_to_sorted_pairs_simple(sample, true) - .unwrap(); - let duration = start.elapsed(); // 结束计时 - print!("predict Time elapsed: {:?}\n", duration); - print!("logits: {:?}", &logits[0..10].to_vec()); - print!("probs: {:?}", &probs[0..10].to_vec()); +// 添加 ctrlc 依赖,用于捕获退出信号(仅在 Unix 下需要) +#[cfg(unix)] +use ctrlc; - /* - if let Ok(tokenizer) = HFTokenizer::new(tokenizer_json_path) { - println!("Tokenizer loaded successfully"); - if let Ok(sample) = tokenizer.gen_predict_sample("从北京到上", "hai") { - let logits: ndarray::ArrayBase< - ndarray::OwnedRepr, - ndarray::Dim, - f32, - > = session.predict(sample).unwrap(); - - let duration = start.elapsed(); // 结束计时 - - println!("Time elapsed: {:?}", duration); - println!("Model input generated successfully"); - println!("Logits: {:?}", logits); +fn handle_client(mut stream: impl Read + Write, model: Arc>, tokenizer: Arc>, dict: Arc) -> Result<(), Box> { + loop { + // 读取4字节长度头 + let mut len_buf = [0u8; 4]; + if let Err(e) = stream.read_exact(&mut len_buf) { + // 如果读取失败,可能是连接关闭,正常退出循环 + if e.kind() == std::io::ErrorKind::UnexpectedEof { + println!("客户端关闭连接"); + break; + } + return Err(Box::new(e)); // 其他错误上报 } - }*/ + let payload_len = u32::from_be_bytes(len_buf) as usize; + + // 读取MessagePack载荷 + let mut payload = vec![0u8; payload_len]; + stream.read_exact(&mut payload)?; + + // 反序列化请求 + let request: Request = rmp_serde::from_slice(&payload)?; + println!("接收请求: {:?}", request); + + // 准备模型输入 + let model_input = { + let tokenizer = tokenizer.lock().unwrap(); + tokenizer.gen_predict_sample(&request.context, &request.pinyin)? + }; + + // 进行预测 + let predicted_pairs = { + let mut model = model.lock().unwrap(); + model.predict_to_sorted_pairs_simple(model_input, true)? + }; + + // 筛选候选词 + let filtered = filter_candidates(&request.pinyin, predicted_pairs, &dict); + + // 转换为Response + let candidates: Vec = filtered + .into_iter() + .map(|fc| Candidate { + id: fc.id as u32, + text: fc.text, + weight: fc.weight, + }) + .collect(); + let response = Response { + candidates, + offset: 0, + limit: 10, + }; + + // 序列化响应 + let response_payload = rmp_serde::to_vec(&response)?; + let response_len = response_payload.len() as u32; + let response_len_bytes = response_len.to_be_bytes(); + + // 发送响应 + stream.write_all(&response_len_bytes)?; + stream.write_all(&response_payload)?; + // 确保数据发送完成(flush 在 Write 实现中通常自动处理,但可显式调用) + stream.flush()?; + } + Ok(()) +} + + +fn main() -> Result<(), Box> { + // 加载配置 + let config = Config::load()?; + println!("配置加载成功: {:?}", config); + + // 初始化组件 + let model = OnnxModel::new(&config.model_path, 1)?; + let tokenizer = HFTokenizer::new(&config.tokenizer_path)?; + let dict = Dictionary::from_json_file(&config.vocabs_path)?; + + let model = Arc::new(Mutex::new(model)); + let tokenizer = Arc::new(Mutex::new(tokenizer)); + let dict = Arc::new(dict); + + // 创建监听器 + match config.socket { + SocketConfig::UnixDomainSocket(path) => { + #[cfg(unix)] + { + // 尝试删除可能残留的socket文件 + if let Err(e) = std::fs::remove_file(&path) { + // 如果文件不存在,忽略错误;否则打印警告 + if e.kind() != std::io::ErrorKind::NotFound { + eprintln!("警告:无法删除socket文件 {}: {}", path, e); + } + } + let listener = UnixListener::bind(&path)?; + println!("监听Unix Domain Socket在: {}", path); + + // 设置socket文件权限为600(仅所有者可读写),减少被其他进程误操作的风险 + let mut perms = std::fs::metadata(&path)?.permissions(); + perms.set_mode(0o600); // 只允许所有者读写 + set_permissions(&path, perms)?; + + // 注册信号处理,确保程序退出时清理socket文件 + let socket_path = path.clone(); + ctrlc::set_handler(move || { + println!("收到退出信号,清理socket文件: {}", socket_path); + let _ = std::fs::remove_file(&socket_path); + std::process::exit(0); + })?; + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + println!("受理连接"); + let model = Arc::clone(&model); + let tokenizer = Arc::clone(&tokenizer); + let dict = Arc::clone(&dict); + thread::spawn(move || { + if let Err(e) = handle_client(stream, model, tokenizer, dict) { + eprintln!("处理连接错误: {}", e); + } + }); + } + Err(e) => eprintln!("接受连接错误: {}", e), + } + } + } + #[cfg(not(unix))] + { + return Err("Unix Domain Socket仅在Unix系统上支持".into()); + } + } + SocketConfig::TcpSocket(addr) => { + #[cfg(windows)] + { + let listener = TcpListener::bind(&addr)?; + println!("监听TCP Socket在: {}", addr); + for stream in listener.incoming() { + match stream { + Ok(stream) => { + let model = Arc::clone(&model); + let tokenizer = Arc::clone(&tokenizer); + let dict = Arc::clone(&dict); + thread::spawn(move || { + if let Err(e) = handle_client(stream, model, tokenizer, dict) { + eprintln!("处理连接错误: {}", e); + } + }); + } + Err(e) => eprintln!("接受连接错误: {}", e), + } + } + } + #[cfg(not(windows))] + { + return Err("TCP Socket仅在Windows系统上支持".into()); + } + } + } + + Ok(()) } diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..de82730 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,21 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub pinyin: String, // 用户输入的拼音串 + pub context: String, // 光标前文本 (用于上下文预测) +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Candidate { + pub id: u32, // 词项 ID + pub text: String, // 候选词文本 (也可由客户端查表) + pub weight: f32, // 置信度/权重 +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub candidates: Vec, + pub offset: usize, + pub limit: usize, +} diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..bc99cf7 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,85 @@ +use std::io::{Read, Write}; +use std::os::unix::net::{UnixListener, UnixStream}; +use rmp_serde::{from_slice, to_vec}; +use crate::protocol::{Request, Response}; + +pub fn start_server(socket_path: &str) -> std::io::Result<()> { + // 绑定到Unix Domain Socket路径 + let listener = UnixListener::bind(socket_path)?; + println!("服务器监听在: {}", socket_path); + + // 持续接受传入连接 + for stream in listener.incoming() { + match stream { + Ok(stream) => { + // 在新线程中处理每个客户端连接以避免阻塞 + std::thread::spawn(move || { + handle_client(stream); + }); + } + Err(e) => { + eprintln!("接受连接失败: {}", e); + } + } + } + Ok(()) +} + +fn handle_client(mut stream: UnixStream) { + // 读取消息长度头(4字节大端) + let mut len_buf = [0u8; 4]; + if let Err(e) = stream.read_exact(&mut len_buf) { + eprintln!("读取长度失败: {}", e); + return; + } + let len = u32::from_be_bytes(len_buf) as usize; + + // 读取序列化的MessagePack数据 + let mut data_buf = vec![0u8; len]; + if let Err(e) = stream.read_exact(&mut data_buf) { + eprintln!("读取数据失败: {}", e); + return; + } + + // 反序列化为Request结构 + let request: Request = match from_slice(&data_buf) { + Ok(req) => req, + Err(e) => { + eprintln!("反序列化请求失败: {}", e); + return; + } + }; + + // 处理请求(这里调用其他业务模块,如filter.rs) + let response = process_request(request); + + // 序列化Response为MessagePack + let response_data = match to_vec(&response) { + Ok(data) => data, + Err(e) => { + eprintln!("序列化响应失败: {}", e); + return; + } + }; + + // 发送响应长度和数据 + let response_len = response_data.len() as u32; + let len_bytes = response_len.to_be_bytes(); + if let Err(e) = stream.write_all(&len_bytes) { + eprintln!("写入长度失败: {}", e); + return; + } + if let Err(e) = stream.write_all(&response_data) { + eprintln!("写入数据失败: {}", e); + return; + } +} + +fn process_request(request: Request) -> Response { + // 示例处理:返回空的候选词列表,实际应集成filter.rs等模块 + Response { + candidates: Vec::new(), + offset: 0, + limit: 0, + } +} diff --git a/src/vocabs.rs b/src/vocabs.rs index d8e35b5..485175a 100644 --- a/src/vocabs.rs +++ b/src/vocabs.rs @@ -8,7 +8,7 @@ use std::path::Path; /// 单个字符-拼音对的信息 #[derive(Debug, Deserialize, Clone)] pub struct CharInfo { - pub id: u32, + pub id: usize, #[serde(rename = "char")] pub character: String, pub pinyin: String, @@ -24,7 +24,7 @@ struct RawStatistics { /// 字典查询引擎,提供 O(1) 的 ID 到信息的映射 pub struct Dictionary { - id_to_charinfo: HashMap, + id_to_charinfo: HashMap, } impl Dictionary { @@ -38,7 +38,7 @@ impl Dictionary { let mut id_to_charinfo = HashMap::with_capacity(raw.pairs.len()); for (id_str, info) in raw.pairs { let id = id_str - .parse::() + .parse::() .with_context(|| format!("无效的 ID 字符串: {}", id_str))?; // 可选:验证 id 与 info.id 一致,此处忽略不一致的情况(信任输入数据) id_to_charinfo.insert(info.id, info); @@ -48,22 +48,22 @@ impl Dictionary { } /// 通过 ID 获取汉字(用于填充 Candidate.text) - pub fn get_char_by_id(&self, id: u32) -> Option<&str> { + pub fn get_char_by_id(&self, id: usize) -> Option<&str> { self.id_to_charinfo.get(&id).map(|info| info.character.as_str()) } /// 通过 ID 获取拼音 - pub fn get_pinyin_by_id(&self, id: u32) -> Option<&str> { + pub fn get_pinyin_by_id(&self, id: usize) -> Option<&str> { self.id_to_charinfo.get(&id).map(|info| info.pinyin.as_str()) } /// 通过 ID 获取出现次数 - pub fn get_count_by_id(&self, id: u32) -> Option { + pub fn get_count_by_id(&self, id: usize) -> Option { self.id_to_charinfo.get(&id).map(|info| info.count) } /// 获取完整的 CharInfo 引用 - pub fn get_char_info(&self, id: u32) -> Option<&CharInfo> { + pub fn get_char_info(&self, id: usize) -> Option<&CharInfo> { self.id_to_charinfo.get(&id) } @@ -71,4 +71,4 @@ impl Dictionary { pub fn len(&self) -> usize { self.id_to_charinfo.len() } -} \ No newline at end of file +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..9188cf4 --- /dev/null +++ b/test.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +import socket +import struct +import msgpack +import sys +import os +import time + +# Define test data: 5 requests with different pinyin and context +test_requests = [ + {"pinyin": "shanghai", "context": ""}, + {"pinyin": "beijing", "context": ""}, + {"pinyin": "nihao", "context": ""}, + {"pinyin": "xiexie", "context": ""}, + {"pinyin": "f", "context": "返回汉字的完整拼音和剩余部"} +] + +# Determine socket address based on operating system +def get_socket_address(): + if sys.platform == "win32": + # Windows: TCP socket + return ("127.0.0.1", 23333), "tcp" + else: + # Unix-like systems: Unix Domain Socket + return "/tmp/su-ime.sock", "unix" + +# Connect to server +def connect_to_server(address, socket_type): + if socket_type == "unix": + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(address) + else: # tcp + host, port = address + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((host, port)) + return sock + +# Send request and receive response +def send_request(sock, request): + # Serialize request to MessagePack + payload = msgpack.packb(request) + payload_len = len(payload) + # Create header: 4-byte big-endian unsigned integer + header = struct.pack(">I", payload_len) + # Send header and payload + sock.sendall(header + payload) + + # Receive response + # Read header + len_buf = sock.recv(4) + if len(len_buf) != 4: + raise ValueError("Failed to read response header") + response_len = struct.unpack(">I", len_buf)[0] + + # Read payload + response_payload = b"" + while len(response_payload) < response_len: + chunk = sock.recv(response_len - len(response_payload)) + if not chunk: + break + response_payload += chunk + + # Deserialize response + response = msgpack.unpackb(response_payload, raw=False) + return response + +# Main function +def main(): + address, socket_type = get_socket_address() + + try: + sock = connect_to_server(address, socket_type) + print(f"Connected to server via {socket_type} at {address}") + + for i, req in enumerate(test_requests, 1): + print(f"\nSending request {i}: {req}") + try: + start = time.time() + response = send_request(sock, req) + print(f'cost time: {time.time() - start}') + print(f"Response {i}: {response[0][0:10]}") + except Exception as e: + print(f"Error sending request {i}: {e}") + + sock.close() + print("\nAll requests sent.") + except Exception as e: + print(f"Connection error: {e}") + sys.exit(1) + +if __name__ == "__main__": + main()