feat(suinput): 引入拼音分组配置并优化上下文采样逻辑

This commit is contained in:
songsenand 2026-02-22 09:30:39 +08:00
parent 2219f6530d
commit fc71124484
7 changed files with 79 additions and 128 deletions

View File

@ -12,6 +12,32 @@ from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
PG = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
class PinyinInputDataset(IterableDataset): class PinyinInputDataset(IterableDataset):
""" """
@ -44,6 +70,7 @@ class PinyinInputDataset(IterableDataset):
repeat_end_freq: int = 10000, # 开始重复的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率 max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望 max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section = [0.90, 0.95, 1]
): ):
""" """
初始化数据集 初始化数据集
@ -95,31 +122,12 @@ class PinyinInputDataset(IterableDataset):
self.dataset = load_dataset(data_dir, split="train", streaming=True) self.dataset = load_dataset(data_dir, split="train", streaming=True)
# 加载拼音分组 # 加载拼音分组
self.pg_groups = { self.pg_groups = PG
"y": 0,
"k": 0, # 上下文采样方式概率区间
"e": 0, self.sample_context_section = sample_context_section
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
def get_next_chinese_chars( def get_next_chinese_chars(
self, self,
@ -181,13 +189,13 @@ class PinyinInputDataset(IterableDataset):
# 确保有足够长度 # 确保有足够长度
context_len = len(context) context_len = len(context)
# 随机选择采样方式 (各1/3概率) # 随机选择采样方式
choice = random.random() choice = random.random()
if choice < 0.3333: if choice < self.sample_context_section[0]:
# 方式1: 靠近汉字的54个字符 # 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context return context[-54:] if context_len >= 54 else context
elif choice < 0.6667: elif choice < self.sample_context_section[1]:
# 方式2: 随机位置取46个连续字符 # 方式2: 随机位置取46个连续字符
if context_len <= 46: if context_len <= 46:
return context return context
@ -250,11 +258,11 @@ class PinyinInputDataset(IterableDataset):
if rand_val < 0.1: if rand_val < 0.1:
# 10%概率截断为空 # 10%概率截断为空
return "" return ""
elif rand_val < 0.6: elif rand_val < 0.9:
# 50%概率不截断 # 80%概率不截断
return pinyin return pinyin
else: else:
# 40%概率随机截断 # 10%概率随机截断
# 均匀分配剩余概率给各种截断长度 # 均匀分配剩余概率给各种截断长度
max_len = len(pinyin) max_len = len(pinyin)
if max_len <= 1: if max_len <= 1:
@ -432,16 +440,23 @@ class PinyinInputDataset(IterableDataset):
return_tensors="pt", return_tensors="pt",
) )
prob = random.random()
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
if prob < 0.1:
py = ""
else:
py = processed_pinyin
# 生成样本 # 生成样本
sample = { sample = {
"hint": hint, "hint": hint,
"txt": sampled_context, "txt": sampled_context,
"py": processed_pinyin, "py": py,
"char_id": torch.tensor([char_info["id"]]), "char_id": torch.tensor([char_info["id"]]),
"char": char, "char": char,
"freq": char_info["freq"], "freq": char_info["freq"],
"pg": torch.tensor( "pg": torch.tensor(
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] [pg]
), ),
} }

View File

@ -9,53 +9,17 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from loguru import logger from loguru import logger
from modelscope import AutoModel from modelscope import AutoModel, AutoTokenizer
from tqdm import tqdm from tqdm import tqdm
from .monitor import TrainingMonitor from .monitor import TrainingMonitor
from suinput.dataset import PG
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
def round_to_power_of_two(x):
if x < 1:
return 0
n = x.bit_length()
n = min(max(7, n), 9)
lower = 1 << (n) # 小于等于x的最大2的幂次
upper = lower << 1 # 大于x的最小2的幂次
if x - lower < upper - x:
return lower
else:
return upper
EXPORT_HIDE_DIM = {
0: 1024,
1: 1024,
2: 1024,
3: 512,
4: 512,
5: 512,
6: 512,
7: 512,
8: 512,
9: 512,
10: 512,
11: 512,
12: 512,
13: 512,
14: 512,
15: 512,
16: 512,
17: 512,
18: 512,
19: 256,
}
# ---------------------------- 残差块 ---------------------------- # ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.1): def __init__(self, dim, dropout_prob=0.1):
@ -328,23 +292,37 @@ class MoEModel(nn.Module):
accuracy = correct / total if total > 0 else 0.0 accuracy = correct / total if total > 0 else 0.0
return accuracy, avg_loss return accuracy, avg_loss
def predict(self, sample, debug=False): def gen_predict_sample(self, text, py, tokenizer=None):
if tokenizer is None and not hasattr(self, "tokenizer"):
self.tokenizer = AutoTokenizer.from_pretrained(
"iic/nlp_structbert_backbone_tiny_std"
)
else:
self.tokenizer = tokenizer or self.tokenizer
hint = self.tokenizer(
text,
py,
max_length=88,
padding="max_length",
truncation=True,
return_tensors="pt",
)
sample = {}
sample['hint'] = {
"input_ids": hint["input_ids"],
"attention_mask": hint["attention_mask"],
}
sample['pg'] = torch.tensor([PG[py[0]]])
return sample
def predict(self, text, py, tokenizer=None):
""" """
基于 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息 基于输入的文本和拼音生成 sample 字典进行预测支持批量/单样本可选调试打印错误样本信息
参数 参数
sample : dict
必须包含字段
- 'input_ids' : [batch, seq_len] [seq_len] (单样本)
- 'attention_mask': 同上
- 'pg' : [batch] 或标量
- 'char_id' : [batch] 或标量真实标签 debug=True 时必须提供
调试时debug=True必须包含字段
- 'txt' : 字符串列表batch或单个字符串
- 'char' : 字符串列表batch或单个字符串
- 'py' : 字符串列表batch或单个字符串
debug : bool debug : bool
是否打印预测错误的样本信息若为 True sample 缺少 char_id/txt/char/py抛出 ValueError 是否打印预测错误的样本信息
返回 返回
preds : torch.Tensor preds : torch.Tensor
@ -354,6 +332,7 @@ class MoEModel(nn.Module):
# ------------------ 1. 提取并规范化输入 ------------------ # ------------------ 1. 提取并规范化输入 ------------------
# 判断是否为单样本input_ids 无 batch 维度) # 判断是否为单样本input_ids 无 batch 维度)
sample = self.gen_predict_sample(text, py, tokenizer)
input_ids = sample["hint"]["input_ids"] input_ids = sample["hint"]["input_ids"]
attention_mask = sample["hint"]["attention_mask"] attention_mask = sample["hint"]["attention_mask"]
pg = sample["pg"] pg = sample["pg"]
@ -375,50 +354,7 @@ class MoEModel(nn.Module):
logits = self(input_ids, attention_mask, pg) logits = self(input_ids, attention_mask, pg)
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
# ------------------ 4. 调试打印(错误样本) ------------------ # ------------------ 4. 返回结果(保持与输入维度一致) ------------------
if debug:
# 检查必需字段
required_keys = ["char_id", "txt", "char", "py"]
missing = [k for k in required_keys if k not in sample]
if missing:
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
# 提取真实标签
true_labels = sample["char_id"]
if true_labels.dim() == 0:
true_labels = true_labels.unsqueeze(0)
# 移动真实标签到相同设备
true_labels = true_labels.to(self.device)
# 找出预测错误的索引
incorrect_mask = preds != true_labels
incorrect_indices = torch.where(incorrect_mask)[0]
if len(incorrect_indices) > 0:
print("\n=== 预测错误样本 ===")
# 获取调试字段(可能是列表或单个字符串)
txts = sample["txt"]
chars = sample["char"]
pys = sample["py"]
# 统一转换为列表(如果输入是单个字符串)
if isinstance(txts, str):
txts = [txts]
chars = [chars]
pys = [pys]
for idx in incorrect_indices.cpu().numpy():
print(f"样本索引 {idx}:")
print(f" Text : {txts[idx]}")
print(f" Char : {chars[idx]}")
print(f" Pinyin: {pys[idx]}")
print(
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
)
print("===================\n")
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
if not has_batch_dim: if not has_batch_dim:
return preds.squeeze(0) # 返回标量 return preds.squeeze(0) # 返回标量
return preds return preds