SUimeModelTraner/docs/WORD_BREAK_DESIGN.md

4.4 KiB
Raw Permalink Blame History

破词训练设计文档

背景

输入法用户在实际使用中通常是逐词输入的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:

那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐

但为了增强模型的泛化能力,需要模拟用户从词中间断开的情况。例如用户可能只输入了"漂"就开始选字"亮"。

破词概念

术语定义

术语 说明
整词输入 用户输入完整词的拼音,如"piaoliang"
破词输入 用户只输入词的部分拼音,如"piao"
前缀 光标前的已确认文本
拼音 当前待选字的拼音(可能不完整)
后缀 光标后的原文内容

场景示例

以词"漂亮"为例:

整词模式90%概率):

光标前: 那边的特别
拼音:   piaoliang
预测:   漂 → 亮

破词模式10%概率):

光标前: 那边的特别漂
拼音:   liang
预测:   亮

实现方案

分词策略

使用 jieba 分词器进行词语边界识别:

import jieba
words = list(jieba.cut(text, HMM=False))
# "那边的特别漂亮的女孩是我的表姐。"
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']

两阶段样本生成

每个词生成样本时分为两个阶段:

Phase 1前缀/整词阶段

  • 整词90%prefix_positions = 整个词的所有字符
  • 破词前缀10%prefix_positions = 词的前 break_pos 个字符
if should_break:
    break_pos = random.randint(1, word_len_chars - 1)  # 随机破开位置
else:
    break_pos = word_len_chars  # 整词

Phase 2破词续接阶段仅当破词时

当破词发生时,从断点位置开始继续采样:

if should_break and break_pos < word_len_chars:
    cont_start = char_positions[break_pos]
    # 从断点开始采样后续字符
    target_len = random_choice(1-8)  # 采样长度
    cont_positions = [cont_start, ...]  # 后续字符位置

样本结构

每个字符生成一个训练样本,包含:

字段 说明 示例
part1 (prefix) 光标前文本 "那边的特别漂"
part2 (pinyin) 当前字拼音 "liang"
part3 (suffix) 光标后文本 "亮的女孩是我的表姐"
part4 专有词提示 "漂亮|特别"
label 目标汉字ID 1234
history_slot_ids 历史已确认字 [0, 0, 0, 0, 0, 0, 0, 0]

拼音增强策略

根据 py_style_weight 参数,拼音有以下三种形式:

形式 概率 示例
完整拼音 75% (9/12) "piaoliang"
仅声母 16.7% (2/12) "pl" (通过 to_initials)
仅首字母 8.3% (1/12) "p"

参数配置:py_style_weight=(9, 2, 1)

破词概率控制

word_break_prob 参数

控制每个词从中间断开的概率,默认为 10%

self.word_break_prob = 0.10  # 10%概率从词中间破开

破词位置分布

对于长度为 N 的词,破开位置 break_pos 的分布:

break_pos = random.randint(1, N - 1)
  • 2字词break_pos = 1100%在第1字后破开
  • 3字词break_pos = 1 或 2各50%
  • 4字词break_pos = 1, 2, 或 3各33%

数据分布预期

理想分布

类别 预期比例
单字样本 ~15%
2字词整词 ~30%
3字词整词 ~20%
破词样本 ~10%
其他 ~25%

拼音不完整率

由于 py_style_weight=(9, 2, 1)

  • 声母initials~16.7%
  • 首字母:~8.3%
  • 总计不完整~25%

代码实现位置

主要实现文件:src/model/dataset.py

函数/类 行号 功能
segment_text() ~30 jieba分词
build_word_boundaries() ~35 建立词边界映射
PinyinInputDataset.__iter__() ~280 核心迭代逻辑
get_mask_pinyin() ~215 拼音加强处理
_add_word_samples() ~240 样本构建

注意事项

  1. 破词仅针对多字词:单字词(如"的"、“是”)不会破词
  2. 破词保持语义完整:破词后仍能根据上下文预测正确汉字
  3. 历史槽位模拟逐步确认:同一词内已确认的字会填入 history_slot_ids
  4. 10% EOS标记词尾有10%概率追加ID=0表示句子结束