diff --git a/src/suinput/dataset.py b/src/suinput/dataset.py index 611ad24..f1921b5 100644 --- a/src/suinput/dataset.py +++ b/src/suinput/dataset.py @@ -12,6 +12,32 @@ from torch.utils.data import DataLoader, IterableDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" +PG = { + "y": 0, + "k": 0, + "e": 0, + "l": 1, + "w": 1, + "f": 1, + "q": 2, + "a": 2, + "s": 2, + "x": 3, + "b": 3, + "r": 3, + "o": 4, + "m": 4, + "z": 4, + "g": 5, + "n": 5, + "c": 5, + "t": 6, + "p": 6, + "d": 6, + "j": 7, + "h": 7, + } + class PinyinInputDataset(IterableDataset): """ @@ -44,6 +70,7 @@ class PinyinInputDataset(IterableDataset): repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 + sample_context_section = [0.90, 0.95, 1] ): """ 初始化数据集 @@ -95,32 +122,13 @@ class PinyinInputDataset(IterableDataset): self.dataset = load_dataset(data_dir, split="train", streaming=True) # 加载拼音分组 - self.pg_groups = { - "y": 0, - "k": 0, - "e": 0, - "l": 1, - "w": 1, - "f": 1, - "q": 2, - "a": 2, - "s": 2, - "x": 3, - "b": 3, - "r": 3, - "o": 4, - "m": 4, - "z": 4, - "g": 5, - "n": 5, - "c": 5, - "t": 6, - "p": 6, - "d": 6, - "j": 7, - "h": 7, - } + self.pg_groups = PG + # 上下文采样方式概率区间 + self.sample_context_section = sample_context_section + + + def get_next_chinese_chars( self, text: str, @@ -181,13 +189,13 @@ class PinyinInputDataset(IterableDataset): # 确保有足够长度 context_len = len(context) - # 随机选择采样方式 (各1/3概率) + # 随机选择采样方式 choice = random.random() - if choice < 0.3333: + if choice < self.sample_context_section[0]: # 方式1: 靠近汉字的54个字符 return context[-54:] if context_len >= 54 else context - elif choice < 0.6667: + elif choice < self.sample_context_section[1]: # 方式2: 随机位置取46个连续字符 if context_len <= 46: return context @@ -250,11 +258,11 @@ class PinyinInputDataset(IterableDataset): if rand_val < 0.1: # 10%概率截断为空 return "" - elif rand_val < 0.6: - # 50%概率不截断 + elif rand_val < 0.9: + # 80%概率不截断 return pinyin else: - # 40%概率随机截断 + # 10%概率随机截断 # 均匀分配剩余概率给各种截断长度 max_len = len(pinyin) if max_len <= 1: @@ -432,16 +440,23 @@ class PinyinInputDataset(IterableDataset): return_tensors="pt", ) + prob = random.random() + pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8 + if prob < 0.1: + py = "" + else: + py = processed_pinyin + # 生成样本 sample = { "hint": hint, "txt": sampled_context, - "py": processed_pinyin, + "py": py, "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], "pg": torch.tensor( - [self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8] + [pg] ), } diff --git a/src/trainer/eval_dataset/sample_0.pkl b/src/trainer/eval_dataset/sample_0.pkl index 332a16e..036ff3d 100644 Binary files a/src/trainer/eval_dataset/sample_0.pkl and b/src/trainer/eval_dataset/sample_0.pkl differ diff --git a/src/trainer/eval_dataset/sample_1.pkl b/src/trainer/eval_dataset/sample_1.pkl index b743a6e..cf2c7e3 100644 Binary files a/src/trainer/eval_dataset/sample_1.pkl and b/src/trainer/eval_dataset/sample_1.pkl differ diff --git a/src/trainer/eval_dataset/sample_2.pkl b/src/trainer/eval_dataset/sample_2.pkl index 53f3dbd..a9da721 100644 Binary files a/src/trainer/eval_dataset/sample_2.pkl and b/src/trainer/eval_dataset/sample_2.pkl differ diff --git a/src/trainer/eval_dataset/sample_3.pkl b/src/trainer/eval_dataset/sample_3.pkl index 31147ce..9964f8e 100644 Binary files a/src/trainer/eval_dataset/sample_3.pkl and b/src/trainer/eval_dataset/sample_3.pkl differ diff --git a/src/trainer/eval_dataset/sample_4.pkl b/src/trainer/eval_dataset/sample_4.pkl index 750108c..bda53a5 100644 Binary files a/src/trainer/eval_dataset/sample_4.pkl and b/src/trainer/eval_dataset/sample_4.pkl differ diff --git a/src/trainer/model.py b/src/trainer/model.py index bc86280..eb04e65 100644 --- a/src/trainer/model.py +++ b/src/trainer/model.py @@ -9,53 +9,17 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from loguru import logger -from modelscope import AutoModel +from modelscope import AutoModel, AutoTokenizer from tqdm import tqdm from .monitor import TrainingMonitor +from suinput.dataset import PG def eval_dataloader(path: Union[str, Path] = (files(__package__) / "eval_dataset")): return [pickle.load(file.open("rb")) for file in Path(path).glob("*.pkl")] -def round_to_power_of_two(x): - if x < 1: - return 0 - n = x.bit_length() - n = min(max(7, n), 9) - lower = 1 << (n) # 小于等于x的最大2的幂次 - upper = lower << 1 # 大于x的最小2的幂次 - if x - lower < upper - x: - return lower - else: - return upper - - -EXPORT_HIDE_DIM = { - 0: 1024, - 1: 1024, - 2: 1024, - 3: 512, - 4: 512, - 5: 512, - 6: 512, - 7: 512, - 8: 512, - 9: 512, - 10: 512, - 11: 512, - 12: 512, - 13: 512, - 14: 512, - 15: 512, - 16: 512, - 17: 512, - 18: 512, - 19: 256, -} - - # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.1): @@ -328,23 +292,37 @@ class MoEModel(nn.Module): accuracy = correct / total if total > 0 else 0.0 return accuracy, avg_loss - def predict(self, sample, debug=False): + def gen_predict_sample(self, text, py, tokenizer=None): + if tokenizer is None and not hasattr(self, "tokenizer"): + self.tokenizer = AutoTokenizer.from_pretrained( + "iic/nlp_structbert_backbone_tiny_std" + ) + else: + self.tokenizer = tokenizer or self.tokenizer + hint = self.tokenizer( + text, + py, + max_length=88, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + sample = {} + sample['hint'] = { + "input_ids": hint["input_ids"], + "attention_mask": hint["attention_mask"], + } + sample['pg'] = torch.tensor([PG[py[0]]]) + return sample + + def predict(self, text, py, tokenizer=None): """ - 基于 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 + 基于输入的文本和拼音,生成 sample 字典进行预测,支持批量/单样本,可选调试打印错误样本信息。 参数: - sample : dict - 必须包含字段: - - 'input_ids' : [batch, seq_len] 或 [seq_len] (单样本) - - 'attention_mask': 同上 - - 'pg' : [batch] 或标量 - - 'char_id' : [batch] 或标量,真实标签(当 debug=True 时必须提供) - 调试时(debug=True)必须包含字段: - - 'txt' : 字符串列表(batch)或单个字符串 - - 'char' : 字符串列表(batch)或单个字符串 - - 'py' : 字符串列表(batch)或单个字符串 + debug : bool - 是否打印预测错误的样本信息。若为 True 但 sample 缺少 char_id/txt/char/py,抛出 ValueError。 + 是否打印预测错误的样本信息。 返回: preds : torch.Tensor @@ -354,6 +332,7 @@ class MoEModel(nn.Module): # ------------------ 1. 提取并规范化输入 ------------------ # 判断是否为单样本(input_ids 无 batch 维度) + sample = self.gen_predict_sample(text, py, tokenizer) input_ids = sample["hint"]["input_ids"] attention_mask = sample["hint"]["attention_mask"] pg = sample["pg"] @@ -375,50 +354,7 @@ class MoEModel(nn.Module): logits = self(input_ids, attention_mask, pg) preds = torch.softmax(logits, dim=-1).argmax(dim=-1) # [batch] - # ------------------ 4. 调试打印(错误样本) ------------------ - if debug: - # 检查必需字段 - required_keys = ["char_id", "txt", "char", "py"] - missing = [k for k in required_keys if k not in sample] - if missing: - raise ValueError(f"debug=True 时 sample 必须包含字段: {missing}") - - # 提取真实标签 - true_labels = sample["char_id"] - if true_labels.dim() == 0: - true_labels = true_labels.unsqueeze(0) - - # 移动真实标签到相同设备 - true_labels = true_labels.to(self.device) - - # 找出预测错误的索引 - incorrect_mask = preds != true_labels - incorrect_indices = torch.where(incorrect_mask)[0] - - if len(incorrect_indices) > 0: - print("\n=== 预测错误样本 ===") - # 获取调试字段(可能是列表或单个字符串) - txts = sample["txt"] - chars = sample["char"] - pys = sample["py"] - - # 统一转换为列表(如果输入是单个字符串) - if isinstance(txts, str): - txts = [txts] - chars = [chars] - pys = [pys] - - for idx in incorrect_indices.cpu().numpy(): - print(f"样本索引 {idx}:") - print(f" Text : {txts[idx]}") - print(f" Char : {chars[idx]}") - print(f" Pinyin: {pys[idx]}") - print( - f" 预测标签: {preds[idx].item()}, 真实标签: {true_labels[idx].item()}" - ) - print("===================\n") - - # ------------------ 5. 返回结果(保持与输入维度一致) ------------------ + # ------------------ 4. 返回结果(保持与输入维度一致) ------------------ if not has_batch_dim: return preds.squeeze(0) # 返回标量 return preds