feat(suinput): 引入拼音分组配置并优化上下文采样逻辑

This commit is contained in:
songsenand 2026-02-22 09:30:39 +08:00
parent 2219f6530d
commit fc71124484
7 changed files with 79 additions and 128 deletions

View File

@ -12,6 +12,32 @@ from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PG = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
class PinyinInputDataset(IterableDataset):
"""
@ -44,6 +70,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section = [0.90, 0.95, 1]
):
"""
初始化数据集
@ -95,32 +122,13 @@ class PinyinInputDataset(IterableDataset):
self.dataset = load_dataset(data_dir, split="train", streaming=True)
# 加载拼音分组
self.pg_groups = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
self.pg_groups = PG
# 上下文采样方式概率区间
self.sample_context_section = sample_context_section
def get_next_chinese_chars(
self,
text: str,
@ -181,13 +189,13 @@ class PinyinInputDataset(IterableDataset):
# 确保有足够长度
context_len = len(context)
# 随机选择采样方式 (各1/3概率)
# 随机选择采样方式
choice = random.random()
if choice < 0.3333:
if choice < self.sample_context_section[0]:
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < 0.6667:
elif choice < self.sample_context_section[1]:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
return context
@ -250,11 +258,11 @@ class PinyinInputDataset(IterableDataset):
if rand_val < 0.1:
# 10%概率截断为空
return ""
elif rand_val < 0.6:
# 50%概率不截断
elif rand_val < 0.9:
# 80%概率不截断
return pinyin
else:
# 40%概率随机截断
# 10%概率随机截断
# 均匀分配剩余概率给各种截断长度
max_len = len(pinyin)
if max_len <= 1:
@ -432,16 +440,23 @@ class PinyinInputDataset(IterableDataset):
return_tensors="pt",
)
prob = random.random()
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
if prob < 0.1:
py = ""
else:
py = processed_pinyin
# 生成样本
sample = {
"hint": hint,
"txt": sampled_context,
"py": processed_pinyin,
"py": py,
"char_id": torch.tensor([char_info["id"]]),
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
[pg]
),
}

View File

@ -9,53 +9,17 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger
from modelscope import AutoModel
from modelscope import AutoModel, AutoTokenizer
from tqdm import tqdm
from .monitor import TrainingMonitor
from suinput.dataset import PG
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
def round_to_power_of_two(x):
if x < 1:
return 0
n = x.bit_length()
n = min(max(7, n), 9)
lower = 1 << (n) # 小于等于x的最大2的幂次
upper = lower << 1 # 大于x的最小2的幂次
if x - lower < upper - x:
return lower
else:
return upper
EXPORT_HIDE_DIM = {
0: 1024,
1: 1024,
2: 1024,
3: 512,
4: 512,
5: 512,
6: 512,
7: 512,
8: 512,
9: 512,
10: 512,
11: 512,
12: 512,
13: 512,
14: 512,
15: 512,
16: 512,
17: 512,
18: 512,
19: 256,
}
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1):
@ -328,23 +292,37 @@ class MoEModel(nn.Module):
accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss
def predict(self, sample, debug=False):
def gen_predict_sample(self, text, py, tokenizer=None):
if tokenizer is None and not hasattr(self, "tokenizer"):
self.tokenizer = AutoTokenizer.from_pretrained(
"iic/nlp_structbert_backbone_tiny_std"
)
else:
self.tokenizer = tokenizer or self.tokenizer
hint = self.tokenizer(
text,
py,
max_length=88,
padding="max_length",
truncation=True,
return_tensors="pt",
)
sample = {}
sample['hint'] = {
"input_ids": hint["input_ids"],
"attention_mask": hint["attention_mask"],
}
sample['pg'] = torch.tensor([PG[py[0]]])
return sample
def predict(self, text, py, tokenizer=None):
"""
基于 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数
sample : dict
必须包含字段
- 'input_ids' : [batch, seq_len] [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量真实标签 debug=True 时必须提供
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool
是否打印预测错误的样本信息若为 True sample 缺少 char_id/txt/char/py抛出 ValueError
是否打印预测错误的样本信息
返回
preds : torch.Tensor
@ -354,6 +332,7 @@ class MoEModel(nn.Module):
# ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度)
sample = self.gen_predict_sample(text, py, tokenizer)
input_ids = sample["hint"]["input_ids"]
attention_mask = sample["hint"]["attention_mask"]
pg = sample["pg"]
@ -375,50 +354,7 @@ class MoEModel(nn.Module):
logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
# ------------------ 4. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim:
return preds.squeeze(0) # 返回标量
return preds