feat(suinput): 引入拼音分组配置并优化上下文采样逻辑
This commit is contained in:
parent
2219f6530d
commit
fc71124484
|
|
@ -12,6 +12,32 @@ from torch.utils.data import DataLoader, IterableDataset
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
PG = {
|
||||||
|
"y": 0,
|
||||||
|
"k": 0,
|
||||||
|
"e": 0,
|
||||||
|
"l": 1,
|
||||||
|
"w": 1,
|
||||||
|
"f": 1,
|
||||||
|
"q": 2,
|
||||||
|
"a": 2,
|
||||||
|
"s": 2,
|
||||||
|
"x": 3,
|
||||||
|
"b": 3,
|
||||||
|
"r": 3,
|
||||||
|
"o": 4,
|
||||||
|
"m": 4,
|
||||||
|
"z": 4,
|
||||||
|
"g": 5,
|
||||||
|
"n": 5,
|
||||||
|
"c": 5,
|
||||||
|
"t": 6,
|
||||||
|
"p": 6,
|
||||||
|
"d": 6,
|
||||||
|
"j": 7,
|
||||||
|
"h": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -44,6 +70,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
|
sample_context_section = [0.90, 0.95, 1]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -95,31 +122,12 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||||||
|
|
||||||
# 加载拼音分组
|
# 加载拼音分组
|
||||||
self.pg_groups = {
|
self.pg_groups = PG
|
||||||
"y": 0,
|
|
||||||
"k": 0,
|
# 上下文采样方式概率区间
|
||||||
"e": 0,
|
self.sample_context_section = sample_context_section
|
||||||
"l": 1,
|
|
||||||
"w": 1,
|
|
||||||
"f": 1,
|
|
||||||
"q": 2,
|
|
||||||
"a": 2,
|
|
||||||
"s": 2,
|
|
||||||
"x": 3,
|
|
||||||
"b": 3,
|
|
||||||
"r": 3,
|
|
||||||
"o": 4,
|
|
||||||
"m": 4,
|
|
||||||
"z": 4,
|
|
||||||
"g": 5,
|
|
||||||
"n": 5,
|
|
||||||
"c": 5,
|
|
||||||
"t": 6,
|
|
||||||
"p": 6,
|
|
||||||
"d": 6,
|
|
||||||
"j": 7,
|
|
||||||
"h": 7,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
|
|
@ -181,13 +189,13 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 确保有足够长度
|
# 确保有足够长度
|
||||||
context_len = len(context)
|
context_len = len(context)
|
||||||
|
|
||||||
# 随机选择采样方式 (各1/3概率)
|
# 随机选择采样方式
|
||||||
choice = random.random()
|
choice = random.random()
|
||||||
|
|
||||||
if choice < 0.3333:
|
if choice < self.sample_context_section[0]:
|
||||||
# 方式1: 靠近汉字的54个字符
|
# 方式1: 靠近汉字的54个字符
|
||||||
return context[-54:] if context_len >= 54 else context
|
return context[-54:] if context_len >= 54 else context
|
||||||
elif choice < 0.6667:
|
elif choice < self.sample_context_section[1]:
|
||||||
# 方式2: 随机位置取46个连续字符
|
# 方式2: 随机位置取46个连续字符
|
||||||
if context_len <= 46:
|
if context_len <= 46:
|
||||||
return context
|
return context
|
||||||
|
|
@ -250,11 +258,11 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if rand_val < 0.1:
|
if rand_val < 0.1:
|
||||||
# 10%概率截断为空
|
# 10%概率截断为空
|
||||||
return ""
|
return ""
|
||||||
elif rand_val < 0.6:
|
elif rand_val < 0.9:
|
||||||
# 50%概率不截断
|
# 80%概率不截断
|
||||||
return pinyin
|
return pinyin
|
||||||
else:
|
else:
|
||||||
# 40%概率随机截断
|
# 10%概率随机截断
|
||||||
# 均匀分配剩余概率给各种截断长度
|
# 均匀分配剩余概率给各种截断长度
|
||||||
max_len = len(pinyin)
|
max_len = len(pinyin)
|
||||||
if max_len <= 1:
|
if max_len <= 1:
|
||||||
|
|
@ -432,16 +440,23 @@ class PinyinInputDataset(IterableDataset):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prob = random.random()
|
||||||
|
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||||||
|
if prob < 0.1:
|
||||||
|
py = ""
|
||||||
|
else:
|
||||||
|
py = processed_pinyin
|
||||||
|
|
||||||
# 生成样本
|
# 生成样本
|
||||||
sample = {
|
sample = {
|
||||||
"hint": hint,
|
"hint": hint,
|
||||||
"txt": sampled_context,
|
"txt": sampled_context,
|
||||||
"py": processed_pinyin,
|
"py": py,
|
||||||
"char_id": torch.tensor([char_info["id"]]),
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor(
|
||||||
[self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8]
|
[pg]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -9,53 +9,17 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from modelscope import AutoModel
|
from modelscope import AutoModel, AutoTokenizer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .monitor import TrainingMonitor
|
from .monitor import TrainingMonitor
|
||||||
|
from suinput.dataset import PG
|
||||||
|
|
||||||
|
|
||||||
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")):
|
||||||
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")]
|
||||||
|
|
||||||
|
|
||||||
def round_to_power_of_two(x):
|
|
||||||
if x < 1:
|
|
||||||
return 0
|
|
||||||
n = x.bit_length()
|
|
||||||
n = min(max(7, n), 9)
|
|
||||||
lower = 1 << (n) # 小于等于x的最大2的幂次
|
|
||||||
upper = lower << 1 # 大于x的最小2的幂次
|
|
||||||
if x - lower < upper - x:
|
|
||||||
return lower
|
|
||||||
else:
|
|
||||||
return upper
|
|
||||||
|
|
||||||
|
|
||||||
EXPORT_HIDE_DIM = {
|
|
||||||
0: 1024,
|
|
||||||
1: 1024,
|
|
||||||
2: 1024,
|
|
||||||
3: 512,
|
|
||||||
4: 512,
|
|
||||||
5: 512,
|
|
||||||
6: 512,
|
|
||||||
7: 512,
|
|
||||||
8: 512,
|
|
||||||
9: 512,
|
|
||||||
10: 512,
|
|
||||||
11: 512,
|
|
||||||
12: 512,
|
|
||||||
13: 512,
|
|
||||||
14: 512,
|
|
||||||
15: 512,
|
|
||||||
16: 512,
|
|
||||||
17: 512,
|
|
||||||
18: 512,
|
|
||||||
19: 256,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, dim, dropout_prob=0.1):
|
def __init__(self, dim, dropout_prob=0.1):
|
||||||
|
|
@ -328,23 +292,37 @@ class MoEModel(nn.Module):
|
||||||
accuracy = correct / total if total > 0 else 0.0
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
return accuracy, avg_loss
|
return accuracy, avg_loss
|
||||||
|
|
||||||
def predict(self, sample, debug=False):
|
def gen_predict_sample(self, text, py, tokenizer=None):
|
||||||
|
if tokenizer is None and not hasattr(self, "tokenizer"):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"iic/nlp_structbert_backbone_tiny_std"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tokenizer = tokenizer or self.tokenizer
|
||||||
|
hint = self.tokenizer(
|
||||||
|
text,
|
||||||
|
py,
|
||||||
|
max_length=88,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
sample = {}
|
||||||
|
sample['hint'] = {
|
||||||
|
"input_ids": hint["input_ids"],
|
||||||
|
"attention_mask": hint["attention_mask"],
|
||||||
|
}
|
||||||
|
sample['pg'] = torch.tensor([PG[py[0]]])
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def predict(self, text, py, tokenizer=None):
|
||||||
"""
|
"""
|
||||||
基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
sample : dict
|
|
||||||
必须包含字段:
|
|
||||||
- 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本)
|
|
||||||
- 'attention_mask': 同上
|
|
||||||
- 'pg' : [batch] 或标量
|
|
||||||
- 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供)
|
|
||||||
调试时(debug=True)必须包含字段:
|
|
||||||
- 'txt' : 字符串列表(batch)或单个字符串
|
|
||||||
- 'char' : 字符串列表(batch)或单个字符串
|
|
||||||
- 'py' : 字符串列表(batch)或单个字符串
|
|
||||||
debug : bool
|
debug : bool
|
||||||
是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。
|
是否打印预测错误的样本信息。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
preds : torch.Tensor
|
preds : torch.Tensor
|
||||||
|
|
@ -354,6 +332,7 @@ class MoEModel(nn.Module):
|
||||||
|
|
||||||
# ------------------ 1. 提取并规范化输入 ------------------
|
# ------------------ 1. 提取并规范化输入 ------------------
|
||||||
# 判断是否为单样本(input_ids 无 batch 维度)
|
# 判断是否为单样本(input_ids 无 batch 维度)
|
||||||
|
sample = self.gen_predict_sample(text, py, tokenizer)
|
||||||
input_ids = sample["hint"]["input_ids"]
|
input_ids = sample["hint"]["input_ids"]
|
||||||
attention_mask = sample["hint"]["attention_mask"]
|
attention_mask = sample["hint"]["attention_mask"]
|
||||||
pg = sample["pg"]
|
pg = sample["pg"]
|
||||||
|
|
@ -375,50 +354,7 @@ class MoEModel(nn.Module):
|
||||||
logits = self(input_ids, attention_mask, pg)
|
logits = self(input_ids, attention_mask, pg)
|
||||||
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch]
|
||||||
|
|
||||||
# ------------------ 4. 调试打印(错误样本) ------------------
|
# ------------------ 4. 返回结果(保持与输入维度一致) ------------------
|
||||||
if debug:
|
|
||||||
# 检查必需字段
|
|
||||||
required_keys = ["char_id", "txt", "char", "py"]
|
|
||||||
missing = [k for k in required_keys if k not in sample]
|
|
||||||
if missing:
|
|
||||||
raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}")
|
|
||||||
|
|
||||||
# 提取真实标签
|
|
||||||
true_labels = sample["char_id"]
|
|
||||||
if true_labels.dim() == 0:
|
|
||||||
true_labels = true_labels.unsqueeze(0)
|
|
||||||
|
|
||||||
# 移动真实标签到相同设备
|
|
||||||
true_labels = true_labels.to(self.device)
|
|
||||||
|
|
||||||
# 找出预测错误的索引
|
|
||||||
incorrect_mask = preds != true_labels
|
|
||||||
incorrect_indices = torch.where(incorrect_mask)[0]
|
|
||||||
|
|
||||||
if len(incorrect_indices) > 0:
|
|
||||||
print("\n=== 预测错误样本 ===")
|
|
||||||
# 获取调试字段(可能是列表或单个字符串)
|
|
||||||
txts = sample["txt"]
|
|
||||||
chars = sample["char"]
|
|
||||||
pys = sample["py"]
|
|
||||||
|
|
||||||
# 统一转换为列表(如果输入是单个字符串)
|
|
||||||
if isinstance(txts, str):
|
|
||||||
txts = [txts]
|
|
||||||
chars = [chars]
|
|
||||||
pys = [pys]
|
|
||||||
|
|
||||||
for idx in incorrect_indices.cpu().numpy():
|
|
||||||
print(f"样本索引 {idx}:")
|
|
||||||
print(f" Text : {txts[idx]}")
|
|
||||||
print(f" Char : {chars[idx]}")
|
|
||||||
print(f" Pinyin: {pys[idx]}")
|
|
||||||
print(
|
|
||||||
f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}"
|
|
||||||
)
|
|
||||||
print("===================\n")
|
|
||||||
|
|
||||||
# ------------------ 5. 返回结果(保持与输入维度一致) ------------------
|
|
||||||
if not has_batch_dim:
|
if not has_batch_dim:
|
||||||
return preds.squeeze(0) # 返回标量
|
return preds.squeeze(0) # 返回标量
|
||||||
return preds
|
return preds
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue