From 8b41bcdc6f60562429c22ea583765ffb016e1e80 Mon Sep 17 00:00:00 2001 From: songsenand Date: Thu, 30 Apr 2026 08:10:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(dataset):=20=E5=BC=95=E5=85=A5=E5=B9=82?= =?UTF-8?q?=E5=BE=8B=E5=B9=B3=E6=BB=91=E6=96=B9=E6=A1=88=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=A2=91=E7=8E=87=E8=B0=83=E6=95=B4=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- samples.csv | 0 src/model/dataset.py | 167 +++++++++++--------- src/model/inspect_preprocessed.py | 12 +- src/model/preprocess.py | 6 +- src/model/supplement_missing.py | 250 ++++++++++++++++++++++++++++++ src/model/trainer.py | 60 ++++++- 6 files changed, 405 insertions(+), 90 deletions(-) create mode 100644 samples.csv create mode 100644 src/model/supplement_missing.py diff --git a/samples.csv b/samples.csv new file mode 100644 index 0000000..e69de29 diff --git a/src/model/dataset.py b/src/model/dataset.py index 60125bc..9571291 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -1,4 +1,9 @@ +import warnings + +warnings.filterwarnings("ignore", message=".*pkg_resources.*") + import jieba +import math import random import re from importlib.resources import files @@ -69,13 +74,13 @@ class PinyinInputDataset(IterableDataset): merge_short_words_prob: float = 0.5, merge_max_short_words: int = 3, merge_max_total_chars: int = 6, + low_freq_repeat: float = 50.0, + high_freq_repeat: float = 0.1, ): - # 频率调整参数 (可根据需要调整) - self.drop_start_freq = 10_000_000 - self.max_drop_prob = 0.9 - self.repeat_end_freq = 10_000 - self.max_repeat_expect = 50 + # 频率调整参数 - 幂律平滑方案 self.min_freq = 109 + self.low_freq_repeat = low_freq_repeat + self.high_freq_repeat = high_freq_repeat self.word_break_prob = 0.10 self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] @@ -117,50 +122,35 @@ class PinyinInputDataset(IterableDataset): self.sample_freqs = self.query_engine.get_all_weights() self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0 - def adjust_frequency(self, freq: int) -> int: - """削峰填谷 - 根据频率调整采样次数,0表示丢弃""" - # 1. 削峰处理(高频字) - if freq >= self.drop_start_freq: - # 线性丢弃概率计算 - max_freq = self.max_freq # 使用预计算的最大频率值 - if max_freq <= self.drop_start_freq: - drop_prob = 0.0 - else: - drop_prob = ( - self.max_drop_prob - * (freq - self.drop_start_freq) - / (max_freq - self.drop_start_freq) - ) - if random.random() < drop_prob: - return 0 - else: - return 1 - - # 2. 填谷处理(低频字) - elif freq <= self.repeat_end_freq: - # 线性重复期望计算 - if freq <= self.min_freq: - repeat_expect = self.max_repeat_expect - else: - if self.repeat_end_freq == self.min_freq: - repeat_expect = 0 - else: - repeat_expect = ( - self.max_repeat_expect - * (self.repeat_end_freq - freq) - / (self.repeat_end_freq - self.min_freq) - ) - # 使用泊松分布实现随机重复 - repeat_count = np.random.poisson(repeat_expect) - if repeat_expect < 1.0: - # 小期望值时,以概率 repeat_expect 采样 1 次 - return 1 if random.random() < repeat_expect else 0 - else: - return max(1, repeat_count) # 原逻辑 - - # 3. 中间频率字 + # 计算幂律平滑参数 + if self.max_freq > self.min_freq: + self.alpha = math.log( + self.low_freq_repeat / self.high_freq_repeat + ) / math.log(self.max_freq / self.min_freq) + self.C = self.low_freq_repeat * (self.min_freq**self.alpha) else: - return 1 + self.alpha = 0.0 + self.C = 1.0 + + def adjust_frequency(self, freq: int) -> int: + """削峰填谷 - 根据频率调整采样次数,0表示丢弃 + 使用幂律平滑方案:E(freq) = C × freq^(-α) + 保持频率排序关系,单个连续函数 + """ + if freq <= 0: + return 0 + + # 计算期望采样次数 + expected = self.C * (freq ** (-self.alpha)) + + # 采样策略 + if expected >= 1.0: + # 泊松分布重复 + repeat_count = np.random.poisson(expected) + return max(1, repeat_count) + else: + # 伯努利采样:以概率expected返回1,否则返回0 + return 1 if random.random() < expected else 0 # 生成对应文本的拼音 def generate_pinyin(self, text: str) -> List[str]: @@ -215,22 +205,33 @@ class PinyinInputDataset(IterableDataset): return result - # 生成需要预测汉字对应的拼音,并进行加强 def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: + # 整词统一拼音风格,避免多字词完整拼音概率指数衰减 + style = random.random() + cumulative = 0.0 + style_idx = 0 + for i, w in enumerate(self.py_style_weight): + cumulative += w + if style < cumulative: + style_idx = i + break + mask_pinyin = [] for i in range(len(text)): if not self.query_engine.is_chinese_char(text[i]): break - else: - py = np.random.choice( - (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), - p=self.py_style_weight, - ) + full_py = pinyin_list[i] + if style_idx == 0: + py = full_py + elif style_idx == 1: + py = to_initials(full_py) if py == "": - py = pinyin_list[i][0] - mask_pinyin.append(py) + py = full_py[0] + else: + py = full_py[0] + mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor: @@ -271,9 +272,13 @@ class PinyinInputDataset(IterableDataset): history.extend([0] * (8 - len(history))) sample_dict = { - "input_ids": encoded["input_ids"], - "token_type_ids": encoded["token_type_ids"], - "attention_mask": encoded["attention_mask"], + "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), + "token_type_ids": torch.tensor( + encoded["token_type_ids"], dtype=torch.long + ), + "attention_mask": torch.tensor( + encoded["attention_mask"], dtype=torch.long + ), "label": torch.tensor([label], dtype=torch.long), "history_slot_ids": torch.tensor(history, dtype=torch.long), "prefix": f"{part4}^{part1}", @@ -401,9 +406,15 @@ class PinyinInputDataset(IterableDataset): prefix_pinyin = [pinyin_list[i] for i in prefix_positions] _, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin) - split_char = np.random.choice( - ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] - ) + r = random.random() + if r < 0.9: + split_char = "" + elif r < 0.94: + split_char = "`" + elif r < 0.98: + split_char = "'" + else: + split_char = "-" part2 = split_char.join(mask_pinyin) pinyin_ids = self._compute_pinyin_ids(part2) @@ -431,7 +442,7 @@ class PinyinInputDataset(IterableDataset): # part3: 词后文本 part3 = "" if random.random() > 0.7: - part3 = text[word_end : word_end + np.random.choice(range(1, 17))] + part3 = text[word_end : word_end + random.randint(1, 16)] # part4: 词提示 part4 = "" @@ -447,9 +458,7 @@ class PinyinInputDataset(IterableDataset): f"{part4}|{part1}", part3, max_length=self.max_seq_length, - padding="max_length", truncation=True, - return_tensors="pt", return_token_type_ids=True, ) @@ -469,7 +478,15 @@ class PinyinInputDataset(IterableDataset): cont_start = char_positions[break_pos] # 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止 - target_len = np.random.choice(range(1, 9), p=self.cont_length_probs) + cont_r = random.random() + cont_probs = self.cont_length_probs + cont_cumulative = 0.0 + target_len = 4 + for cont_len, cont_p in enumerate(cont_probs): + cont_cumulative += cont_p + if cont_r < cont_cumulative: + target_len = cont_len + 1 + break cont_positions = [] pos = cont_start while len(cont_positions) < target_len and pos < len(text): @@ -486,9 +503,15 @@ class PinyinInputDataset(IterableDataset): cont_pinyin = [pinyin_list[i] for i in cont_positions] _, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin) - split_char_cont = np.random.choice( - ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] - ) + r2 = random.random() + if r2 < 0.9: + split_char_cont = "" + elif r2 < 0.94: + split_char_cont = "`" + elif r2 < 0.98: + split_char_cont = "'" + else: + split_char_cont = "-" part2_cont = split_char_cont.join(mask_pinyin_cont) pinyin_ids_cont = self._compute_pinyin_ids(part2_cont) @@ -517,17 +540,13 @@ class PinyinInputDataset(IterableDataset): cont_end = cont_positions[-1] + 1 part3_cont = "" if random.random() > 0.7: - part3_cont = text[ - cont_end : cont_end + np.random.choice(range(1, 17)) - ] + part3_cont = text[cont_end : cont_end + random.randint(1, 16)] encoded_cont = self.tokenizer( f"{part4}|{part1_cont}", part3_cont, max_length=self.max_seq_length, - padding="max_length", truncation=True, - return_tensors="pt", return_token_type_ids=True, ) diff --git a/src/model/inspect_preprocessed.py b/src/model/inspect_preprocessed.py index 80a7746..312cfe7 100644 --- a/src/model/inspect_preprocessed.py +++ b/src/model/inspect_preprocessed.py @@ -231,21 +231,21 @@ def main(): console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]") counter, total = analyze_labels(dataset, max_shards=args.max_shards) - # 获取词表总大小 - vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0) + # 获取词表总大小(_id_to_info 已包含 EOS id=0) + vocab_size = len(query_engine._id_to_info) appeared_ids = set(counter.keys()) - all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS) + all_ids = set(query_engine._id_to_info.keys()) missing_ids = all_ids - appeared_ids console.print(f"\n总样本数: {total:,}") - console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)") + console.print(f"词表大小: {vocab_size:,} (含 EOS)") console.print(f"唯一标签数: {len(counter):,}") console.print( f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)" ) console.print( - f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]" + f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]" ) most_common = counter.most_common(args.top_k) @@ -328,7 +328,7 @@ def main(): table_dist.add_row( "未出现", f"{len(missing_ids):,}", - f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%", + f"{len(missing_ids) / vocab_size * 100:.1f}%", ) console.print(table_dist) diff --git a/src/model/preprocess.py b/src/model/preprocess.py index f5d8a6e..63eac20 100644 --- a/src/model/preprocess.py +++ b/src/model/preprocess.py @@ -40,7 +40,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from .dataset import PinyinInputDataset -from .trainer import collate_fn, worker_init_fn +from .trainer import preprocess_collate_fn, worker_init_fn FIELDS = [ "input_ids", @@ -257,7 +257,7 @@ def main(): num_workers=num_train_workers, pin_memory=False, worker_init_fn=worker_init_fn, - collate_fn=collate_fn, + collate_fn=preprocess_collate_fn(args.max_seq_length), prefetch_factor=2, persistent_workers=True if num_train_workers > 0 else False, ) @@ -280,7 +280,7 @@ def main(): num_workers=num_eval_workers, pin_memory=False, worker_init_fn=worker_init_fn, - collate_fn=collate_fn, + collate_fn=preprocess_collate_fn(args.max_seq_length), prefetch_factor=2, persistent_workers=True if num_eval_workers > 0 else False, ) diff --git a/src/model/supplement_missing.py b/src/model/supplement_missing.py new file mode 100644 index 0000000..ba7dab5 --- /dev/null +++ b/src/model/supplement_missing.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +缺失字符补充工具 + +步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON +步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本 + +用法: + python -m model.supplement_missing find-missing \ + --preprocessed-dir ./preprocessed/train \ + --output missing_chars.json + + python -m model.supplement_missing generate-template \ + --missing-chars missing_chars.json \ + --output supplement_texts.jsonl +""" + +import argparse +import json +from pathlib import Path +from typing import Set + +import numpy as np +from loguru import logger +from rich.console import Console +from rich.table import Table +from tqdm import tqdm + +from .query import QueryEngine + + +def scan_labels(preprocessed_dir: Path) -> Set[int]: + """扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID""" + appeared: Set[int] = set() + + shard_files = sorted(preprocessed_dir.glob("shard_*.npz")) + if not shard_files: + logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}") + return appeared + + for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"): + data = np.load(shard_path) + labels = data["labels"].astype(np.int64) + if labels.ndim > 1 and labels.shape[-1] == 1: + labels = labels.squeeze(-1) + unique_ids = np.unique(labels) + appeared.update(int(uid) for uid in unique_ids) + del data + + return appeared + + +def cmd_find_missing(args): + console = Console() + preprocessed_dir = Path(args.preprocessed_dir) + + if not preprocessed_dir.exists(): + console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]") + return + + metadata_path = preprocessed_dir / "metadata.json" + if not metadata_path.exists(): + console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]") + return + + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + console.print( + f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]" + ) + + console.print("[bold cyan]扫描 labels...[/bold cyan]") + appeared = scan_labels(preprocessed_dir) + + console.print("[bold cyan]加载 QueryEngine...[/bold cyan]") + query_engine = QueryEngine() + query_engine.load() + + all_ids = set(query_engine._id_to_info.keys()) + missing_ids = all_ids - appeared + + missing_chars = [] + for mid in sorted(missing_ids): + if mid == 0: + continue + info = query_engine.query_by_id(mid) + if info is not None: + missing_chars.append( + { + "id": info.id, + "char": info.char, + "pinyin": info.pinyin, + "count": info.count, + } + ) + + result = { + "missing_count": len(missing_chars), + "missing_chars": missing_chars, + } + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]") + console.print(f"词表大小: {len(all_ids):,} (含 EOS)") + console.print(f"已出现标签: {len(appeared):,}") + console.print( + f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})" + ) + + if missing_chars: + table = Table( + title=f"缺失字符 (共 {len(missing_chars)} 个)", + show_header=True, + header_style="bold magenta", + ) + table.add_column("ID", style="cyan", width=8) + table.add_column("字符", style="yellow", width=6) + table.add_column("拼音", style="green", width=12) + table.add_column("语料频次", style="red", width=12) + for entry in missing_chars: + table.add_row( + str(entry["id"]), + entry["char"], + entry["pinyin"], + f"{entry['count']:,}", + ) + console.print(table) + + console.print(f"\n已输出到: {output_path}") + + +def cmd_generate_template(args): + console = Console() + + missing_path = Path(args.missing_chars) + if not missing_path.exists(): + console.print(f"[bold red]文件不存在: {missing_path}[/bold red]") + return + + with open(missing_path, "r", encoding="utf-8") as f: + data = json.load(f) + + missing_chars = data.get("missing_chars", []) + if not missing_chars: + console.print("[bold green]没有缺失字符,无需生成模板[/bold green]") + return + + num_entries = args.num_entries + total_lines = len(missing_chars) * num_entries + + console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]") + console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]") + console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + for entry in missing_chars: + for i in range(num_entries): + line = json.dumps( + {"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"}, + ensure_ascii=False, + ) + f.write(line + "\n") + + console.print(f"[bold green]模板已生成: {output_path}[/bold green]") + console.print( + f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符)," + f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="缺失字符补充工具", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +子命令: + find-missing 扫描已预处理数据,找出从未出现的 label ID + generate-template 根据缺失字符 JSON 生成 JSONL 占位文件 + +示例: + python -m model.supplement_missing find-missing \\ + --preprocessed-dir ./preprocessed/train \\ + --output missing_chars.json + + python -m model.supplement_missing generate-template \\ + --missing-chars missing_chars.json \\ + --output supplement_texts.jsonl +""", + ) + subparsers = parser.add_subparsers(dest="command", help="子命令") + + # find-missing + p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签") + p_find.add_argument( + "--preprocessed-dir", + type=str, + required=True, + help="预处理数据目录(包含 shard_*.npz 和 metadata.json)", + ) + p_find.add_argument( + "--output", + type=str, + default="missing_chars.json", + help="输出 JSON 文件路径(默认: missing_chars.json)", + ) + + # generate-template + p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板") + p_gen.add_argument( + "--missing-chars", + type=str, + required=True, + help="缺失字符 JSON 文件路径(由 find-missing 生成)", + ) + p_gen.add_argument( + "--output", + type=str, + default="supplement_texts.jsonl", + help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)", + ) + p_gen.add_argument( + "--num-entries", + type=int, + default=3, + help="每个缺失字符生成的模板条数(默认: 3)", + ) + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + return + + if args.command == "find-missing": + cmd_find_missing(args) + elif args.command == "generate-template": + cmd_generate_template(args) + + +app = main + +if __name__ == "__main__": + main() diff --git a/src/model/trainer.py b/src/model/trainer.py index 1c27cfd..165c996 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -975,25 +975,62 @@ def worker_init_fn(worker_id: int) -> None: random.seed(worker_seed) -def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: +def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]: """ - 自定义批处理函数,将多个样本组合成一个batch + 自定义批处理函数,将多个样本组合成一个batch。 + 支持动态填充:根据batch内最大序列长度进行padding,而非固定max_length。 + 当 max_seq_length > 0 时,pad到指定长度(用于预处理)。 Args: batch: 样本列表,每个样本是一个字典 + max_seq_length: 目标序列长度,0表示动态padding Returns: 批处理后的字典,tensor字段已stack,字符串字段保持为列表 """ - # 处理tensor字段 - 使用squeeze去除多余的batch维度 - input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) - token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) - attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) + input_ids_list = [item["input_ids"] for item in batch] + token_type_ids_list = [item["token_type_ids"] for item in batch] + attention_mask_list = [item["attention_mask"] for item in batch] + + if max_seq_length > 0: + target_len = max_seq_length + else: + target_len = max(ids.shape[0] for ids in input_ids_list) + + padded_input_ids = [] + padded_token_type_ids = [] + padded_attention_mask = [] + for ids, tt_ids, mask in zip( + input_ids_list, token_type_ids_list, attention_mask_list + ): + seq_len = ids.shape[0] + if seq_len < target_len: + pad_len = target_len - seq_len + padded_input_ids.append( + torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)]) + ) + padded_token_type_ids.append( + torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)]) + ) + padded_attention_mask.append( + torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)]) + ) + elif seq_len > target_len: + padded_input_ids.append(ids[:target_len]) + padded_token_type_ids.append(tt_ids[:target_len]) + padded_attention_mask.append(mask[:target_len]) + else: + padded_input_ids.append(ids) + padded_token_type_ids.append(tt_ids) + padded_attention_mask.append(mask) + + input_ids = torch.stack(padded_input_ids) + token_type_ids = torch.stack(padded_token_type_ids) + attention_mask = torch.stack(padded_attention_mask) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) - # 字符串字段保持为列表 prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] @@ -1011,6 +1048,15 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: } +def preprocess_collate_fn(max_seq_length: int): + """创建用于预处理的collate_fn,始终pad到max_seq_length""" + + def _collate(batch): + return collate_fn(batch, max_seq_length=max_seq_length) + + return _collate + + # Typer CLI应用 def create_dataloader( dataset,