4.4 KiB
4.4 KiB
破词训练设计文档
背景
输入法用户在实际使用中通常是逐词输入的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
但为了增强模型的泛化能力,需要模拟用户从词中间断开的情况。例如用户可能只输入了"漂"就开始选字"亮"。
破词概念
术语定义
| 术语 | 说明 |
|---|---|
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
| 前缀 | 光标前的已确认文本 |
| 拼音 | 当前待选字的拼音(可能不完整) |
| 后缀 | 光标后的原文内容 |
场景示例
以词"漂亮"为例:
整词模式(90%概率):
光标前: 那边的特别
拼音: piaoliang
预测: 漂 → 亮
破词模式(10%概率):
光标前: 那边的特别漂
拼音: liang
预测: 亮
实现方案
分词策略
使用 jieba 分词器进行词语边界识别:
import jieba
words = list(jieba.cut(text, HMM=False))
# "那边的特别漂亮的女孩是我的表姐。"
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
两阶段样本生成
每个词生成样本时分为两个阶段:
Phase 1:前缀/整词阶段
- 整词(90%):
prefix_positions = 整个词的所有字符 - 破词前缀(10%):
prefix_positions = 词的前 break_pos 个字符
if should_break:
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
else:
break_pos = word_len_chars # 整词
Phase 2:破词续接阶段(仅当破词时)
当破词发生时,从断点位置开始继续采样:
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 从断点开始采样后续字符
target_len = random_choice(1-8) # 采样长度
cont_positions = [cont_start, ...] # 后续字符位置
样本结构
每个字符生成一个训练样本,包含:
| 字段 | 说明 | 示例 |
|---|---|---|
part1 (prefix) |
光标前文本 | "那边的特别漂" |
part2 (pinyin) |
当前字拼音 | "liang" |
part3 (suffix) |
光标后文本 | "亮的女孩是我的表姐" |
part4 |
专有词提示 | "漂亮|特别" |
label |
目标汉字ID | 1234 |
history_slot_ids |
历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
拼音增强策略
根据 py_style_weight 参数,拼音有以下三种形式:
| 形式 | 概率 | 示例 |
|---|---|---|
| 完整拼音 | 75% (9/12) | "piaoliang" |
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
| 仅首字母 | 8.3% (1/12) | "p" |
参数配置:py_style_weight=(9, 2, 1)
破词概率控制
word_break_prob 参数
控制每个词从中间断开的概率,默认为 10%:
self.word_break_prob = 0.10 # 10%概率从词中间破开
破词位置分布
对于长度为 N 的词,破开位置 break_pos 的分布:
break_pos = random.randint(1, N - 1)
- 2字词:break_pos = 1(100%在第1字后破开)
- 3字词:break_pos = 1 或 2(各50%)
- 4字词:break_pos = 1, 2, 或 3(各33%)
数据分布预期
理想分布
| 类别 | 预期比例 |
|---|---|
| 单字样本 | ~15% |
| 2字词整词 | ~30% |
| 3字词整词 | ~20% |
| 破词样本 | ~10% |
| 其他 | ~25% |
拼音不完整率
由于 py_style_weight=(9, 2, 1):
- 声母(initials):~16.7%
- 首字母:~8.3%
- 总计不完整:~25%
代码实现位置
主要实现文件:src/model/dataset.py
| 函数/类 | 行号 | 功能 |
|---|---|---|
segment_text() |
~30 | jieba分词 |
build_word_boundaries() |
~35 | 建立词边界映射 |
PinyinInputDataset.__iter__() |
~280 | 核心迭代逻辑 |
get_mask_pinyin() |
~215 | 拼音加强处理 |
_add_word_samples() |
~240 | 样本构建 |
注意事项
- 破词仅针对多字词:单字词(如"的"、“是”)不会破词
- 破词保持语义完整:破词后仍能根据上下文预测正确汉字
- 历史槽位模拟逐步确认:同一词内已确认的字会填入
history_slot_ids - 10% EOS标记:词尾有10%概率追加ID=0表示句子结束