feat(suinput): 引入拼音分组配置并优化上下文采样逻辑
This commit is contained in:
parent
2219f6530d
commit
fc71124484
|
|
@ -12,6 +12,32 @@ from torch.utils.data import DataLoader, IterableDataset
|
|||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
PG = {
|
||||
"y": 0,
|
||||
"k": 0,
|
||||
"e": 0,
|
||||
"l": 1,
|
||||
"w": 1,
|
||||
"f": 1,
|
||||
"q": 2,
|
||||
"a": 2,
|
||||
"s": 2,
|
||||
"x": 3,
|
||||
"b": 3,
|
||||
"r": 3,
|
||||
"o": 4,
|
||||
"m": 4,
|
||||
"z": 4,
|
||||
"g": 5,
|
||||
"n": 5,
|
||||
"c": 5,
|
||||
"t": 6,
|
||||
"p": 6,
|
||||
"d": 6,
|
||||
"j": 7,
|
||||
"h": 7,
|
||||
}
|
||||
|
||||
|
||||
class PinyinInputDataset(IterableDataset):
|
||||
"""
|
||||
|
|
@ -44,6 +70,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||
sample_context_section = [0.90, 0.95, 1]
|
||||
):
|
||||
"""
|
||||
初始化数据集
|
||||
|
|
@ -95,31 +122,12 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||
|
||||
# 加载拼音分组
|
||||
self.pg_groups = {
|
||||
"y": 0,
|
||||
"k": 0,
|
||||
"e": 0,
|
||||
"l": 1,
|
||||
"w": 1,
|
||||
"f": 1,
|
||||
"q": 2,
|
||||
"a": 2,
|
||||
"s": 2,
|
||||
"x": 3,
|
||||
"b": 3,
|
||||
"r": 3,
|
||||
"o": 4,
|
||||
"m": 4,
|
||||
"z": 4,
|
||||
"g": 5,
|
||||
"n": 5,
|
||||
"c": 5,
|
||||
"t": 6,
|
||||
"p": 6,
|
||||
"d": 6,
|
||||
"j": 7,
|
||||
"h": 7,
|
||||
}
|
||||
self.pg_groups = PG
|
||||
|
||||
# 上下文采样方式概率区间
|
||||
self.sample_context_section = sample_context_section
|
||||
|
||||
|
||||
|
||||
def get_next_chinese_chars(
|
||||
self,
|
||||
|
|
@ -181,13 +189,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 确保有足够长度
|
||||
context_len = len(context)
|
||||
|
||||
# 随机选择采样方式 (各1/3概率)
|
||||
# 随机选择采样方式
|
||||
choice = random.random()
|
||||
|
||||
if choice < 0.3333:
|
||||
if choice < self.sample_context_section[0]:
|
||||
# 方式1: 靠近汉字的54个字符
|
||||
return context[-54:] if context_len >= 54 else context
|
||||
elif choice < 0.6667:
|
||||
elif choice < self.sample_context_section[1]:
|
||||
# 方式2: 随机位置取46个连续字符
|
||||
if context_len <= 46:
|
||||
return context
|
||||
|
|
@ -250,11 +258,11 @@ class PinyinInputDataset(IterableDataset):
|
|||
if rand_val < 0.1:
|
||||
# 10%概率截断为空
|
||||
return ""
|
||||
elif rand_val < 0.6:
|
||||
# 50%概率不截断
|
||||
elif rand_val < 0.9:
|
||||
# 80%概率不截断
|
||||
return pinyin
|
||||
else:
|
||||
# 40%概率随机截断
|
||||
# 10%概率随机截断
|
||||
# 均匀分配剩余概率给各种截断长度
|
||||
max_len = len(pinyin)
|
||||
if max_len <= 1:
|
||||
|
|
@ -432,16 +440,23 @@ class PinyinInputDataset(IterableDataset):
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
prob = random.random()
|
||||
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||
if prob < 0.1:
|
||||
py = ""
|
||||
else:
|
||||
py = processed_pinyin
|
||||
|
||||
# 生成样本
|
||||
sample = {
|
||||
"hint": hint,
|
||||
"txt": sampled_context,
|
||||
"py": processed_pinyin,
|
||||
"py": py,
|
||||
"char_id": torch.tensor([char_info["id"]]),
|
||||
"char": char,
|
||||
"freq": char_info["freq"],
|
||||
"pg": torch.tensor(
|
||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
||||
[pg]
|
||||
),
|
||||
}
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -9,53 +9,17 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
from modelscope import AutoModel
|
||||
from modelscope import AutoModel, AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from .monitor import TrainingMonitor
|
||||
from suinput.dataset import PG
|
||||
|
||||
|
||||
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||
|
||||
|
||||
def round_to_power_of_two(x):
|
||||
if x < 1:
|
||||
return 0
|
||||
n = x.bit_length()
|
||||
n = min(max(7, n), 9)
|
||||
lower = 1 << (n) # 小于等于x的最大2的幂次
|
||||
upper = lower << 1 # 大于x的最小2的幂次
|
||||
if x - lower < upper - x:
|
||||
return lower
|
||||
else:
|
||||
return upper
|
||||
|
||||
|
||||
EXPORT_HIDE_DIM = {
|
||||
0: 1024,
|
||||
1: 1024,
|
||||
2: 1024,
|
||||
3: 512,
|
||||
4: 512,
|
||||
5: 512,
|
||||
6: 512,
|
||||
7: 512,
|
||||
8: 512,
|
||||
9: 512,
|
||||
10: 512,
|
||||
11: 512,
|
||||
12: 512,
|
||||
13: 512,
|
||||
14: 512,
|
||||
15: 512,
|
||||
16: 512,
|
||||
17: 512,
|
||||
18: 512,
|
||||
19: 256,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim, dropout_prob=0.1):
|
||||
|
|
@ -328,23 +292,37 @@ class MoEModel(nn.Module):
|
|||
accuracy = correct / total if total > 0 else 0.0
|
||||
return accuracy, avg_loss
|
||||
|
||||
def predict(self, sample, debug=False):
|
||||
def gen_predict_sample(self, text, py, tokenizer=None):
|
||||
if tokenizer is None and not hasattr(self, "tokenizer"):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
"iic/nlp_structbert_backbone_tiny_std"
|
||||
)
|
||||
else:
|
||||
self.tokenizer = tokenizer or self.tokenizer
|
||||
hint = self.tokenizer(
|
||||
text,
|
||||
py,
|
||||
max_length=88,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sample = {}
|
||||
sample['hint'] = {
|
||||
"input_ids": hint["input_ids"],
|
||||
"attention_mask": hint["attention_mask"],
|
||||
}
|
||||
sample['pg'] = torch.tensor([PG[py[0]]])
|
||||
return sample
|
||||
|
||||
def predict(self, text, py, tokenizer=None):
|
||||
"""
|
||||
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||
基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||
|
||||
参数:
|
||||
sample : dict
|
||||
必须包含字段:
|
||||
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
|
||||
- 'attention_mask': 同上
|
||||
- 'pg' : [batch] 或标量
|
||||
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
|
||||
调试时(debug=True)必须包含字段:
|
||||
- 'txt' : 字符串列表(batch)或单个字符串
|
||||
- 'char' : 字符串列表(batch)或单个字符串
|
||||
- 'py' : 字符串列表(batch)或单个字符串
|
||||
|
||||
debug : bool
|
||||
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。
|
||||
是否打印预测错误的样本信息。
|
||||
|
||||
返回:
|
||||
preds : torch.Tensor
|
||||
|
|
@ -354,6 +332,7 @@ class MoEModel(nn.Module):
|
|||
|
||||
# ------------------ 1. 提取并规范化输入 ------------------
|
||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||
sample = self.gen_predict_sample(text, py, tokenizer)
|
||||
input_ids = sample["hint"]["input_ids"]
|
||||
attention_mask = sample["hint"]["attention_mask"]
|
||||
pg = sample["pg"]
|
||||
|
|
@ -375,50 +354,7 @@ class MoEModel(nn.Module):
|
|||
logits = self(input_ids, attention_mask, pg)
|
||||
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
||||
|
||||
# ------------------ 4. 调试打印(错误样本) ------------------
|
||||
if debug:
|
||||
# 检查必需字段
|
||||
required_keys = ["char_id", "txt", "char", "py"]
|
||||
missing = [k for k in required_keys if k not in sample]
|
||||
if missing:
|
||||
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
|
||||
|
||||
# 提取真实标签
|
||||
true_labels = sample["char_id"]
|
||||
if true_labels.dim() == 0:
|
||||
true_labels = true_labels.unsqueeze(0)
|
||||
|
||||
# 移动真实标签到相同设备
|
||||
true_labels = true_labels.to(self.device)
|
||||
|
||||
# 找出预测错误的索引
|
||||
incorrect_mask = preds != true_labels
|
||||
incorrect_indices = torch.where(incorrect_mask)[0]
|
||||
|
||||
if len(incorrect_indices) > 0:
|
||||
print("\n=== 预测错误样本 ===")
|
||||
# 获取调试字段(可能是列表或单个字符串)
|
||||
txts = sample["txt"]
|
||||
chars = sample["char"]
|
||||
pys = sample["py"]
|
||||
|
||||
# 统一转换为列表(如果输入是单个字符串)
|
||||
if isinstance(txts, str):
|
||||
txts = [txts]
|
||||
chars = [chars]
|
||||
pys = [pys]
|
||||
|
||||
for idx in incorrect_indices.cpu().numpy():
|
||||
print(f"样本索引 {idx}:")
|
||||
print(f" Text : {txts[idx]}")
|
||||
print(f" Char : {chars[idx]}")
|
||||
print(f" Pinyin: {pys[idx]}")
|
||||
print(
|
||||
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
|
||||
)
|
||||
print("===================\n")
|
||||
|
||||
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
|
||||
# ------------------ 4. 返回结果(保持与输入维度一致) ------------------
|
||||
if not has_batch_dim:
|
||||
return preds.squeeze(0) # 返回标量
|
||||
return preds
|
||||
|
|
|
|||
Loading…
Reference in New Issue