feat(dataset): 引入幂律平滑方案优化频率调整逻辑

This commit is contained in:
songsenand 2026-04-30 08:10:34 +08:00
parent 4ded2d656f
commit 8b41bcdc6f
6 changed files with 405 additions and 90 deletions

0
samples.csv Normal file
View File

View File

@ -1,4 +1,9 @@
import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba
import math
import random
import re
from importlib.resources import files
@ -69,13 +74,13 @@ class PinyinInputDataset(IterableDataset):
merge_short_words_prob: float = 0.5,
merge_max_short_words: int = 3,
merge_max_total_chars: int = 6,
low_freq_repeat: float = 50.0,
high_freq_repeat: float = 0.1,
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 10_000_000
self.max_drop_prob = 0.9
self.repeat_end_freq = 10_000
self.max_repeat_expect = 50
# 频率调整参数 - 幂律平滑方案
self.min_freq = 109
self.low_freq_repeat = low_freq_repeat
self.high_freq_repeat = high_freq_repeat
self.word_break_prob = 0.10
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
@ -117,50 +122,35 @@ class PinyinInputDataset(IterableDataset):
self.sample_freqs = self.query_engine.get_all_weights()
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字)
if freq >= self.drop_start_freq:
# 线性丢弃概率计算
max_freq = self.max_freq # 使用预计算的最大频率值
if max_freq <= self.drop_start_freq:
drop_prob = 0.0
else:
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (max_freq - self.drop_start_freq)
)
if random.random() < drop_prob:
return 0
else:
return 1
# 2. 填谷处理(低频字)
elif freq <= self.repeat_end_freq:
# 线性重复期望计算
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0
else:
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 使用泊松分布实现随机重复
repeat_count = np.random.poisson(repeat_expect)
if repeat_expect < 1.0:
# 小期望值时,以概率 repeat_expect 采样 1 次
return 1 if random.random() < repeat_expect else 0
else:
return max(1, repeat_count) # 原逻辑
# 3. 中间频率字
# 计算幂律平滑参数
if self.max_freq > self.min_freq:
self.alpha = math.log(
self.low_freq_repeat / self.high_freq_repeat
) / math.log(self.max_freq / self.min_freq)
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
else:
return 1
self.alpha = 0.0
self.C = 1.0
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃
使用幂律平滑方案E(freq) = C × freq^(-α)
保持频率排序关系单个连续函数
"""
if freq <= 0:
return 0
# 计算期望采样次数
expected = self.C * (freq ** (-self.alpha))
# 采样策略
if expected >= 1.0:
# 泊松分布重复
repeat_count = np.random.poisson(expected)
return max(1, repeat_count)
else:
# 伯努利采样以概率expected返回1否则返回0
return 1 if random.random() < expected else 0
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
@ -215,22 +205,33 @@ class PinyinInputDataset(IterableDataset):
return result
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
style = random.random()
cumulative = 0.0
style_idx = 0
for i, w in enumerate(self.py_style_weight):
cumulative += w
if style < cumulative:
style_idx = i
break
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
break
else:
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
p=self.py_style_weight,
)
full_py = pinyin_list[i]
if style_idx == 0:
py = full_py
elif style_idx == 1:
py = to_initials(full_py)
if py == "":
py = pinyin_list[i][0]
mask_pinyin.append(py)
py = full_py[0]
else:
py = full_py[0]
mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
@ -271,9 +272,13 @@ class PinyinInputDataset(IterableDataset):
history.extend([0] * (8 - len(history)))
sample_dict = {
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"token_type_ids": torch.tensor(
encoded["token_type_ids"], dtype=torch.long
),
"attention_mask": torch.tensor(
encoded["attention_mask"], dtype=torch.long
),
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(history, dtype=torch.long),
"prefix": f"{part4}^{part1}",
@ -401,9 +406,15 @@ class PinyinInputDataset(IterableDataset):
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
r = random.random()
if r < 0.9:
split_char = ""
elif r < 0.94:
split_char = "`"
elif r < 0.98:
split_char = "'"
else:
split_char = "-"
part2 = split_char.join(mask_pinyin)
pinyin_ids = self._compute_pinyin_ids(part2)
@ -431,7 +442,7 @@ class PinyinInputDataset(IterableDataset):
# part3: 词后文本
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + np.random.choice(range(1, 17))]
part3 = text[word_end : word_end + random.randint(1, 16)]
# part4: 词提示
part4 = ""
@ -447,9 +458,7 @@ class PinyinInputDataset(IterableDataset):
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
@ -469,7 +478,15 @@ class PinyinInputDataset(IterableDataset):
cont_start = char_positions[break_pos]
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
target_len = np.random.choice(range(1, 9), p=self.cont_length_probs)
cont_r = random.random()
cont_probs = self.cont_length_probs
cont_cumulative = 0.0
target_len = 4
for cont_len, cont_p in enumerate(cont_probs):
cont_cumulative += cont_p
if cont_r < cont_cumulative:
target_len = cont_len + 1
break
cont_positions = []
pos = cont_start
while len(cont_positions) < target_len and pos < len(text):
@ -486,9 +503,15 @@ class PinyinInputDataset(IterableDataset):
cont_pinyin = [pinyin_list[i] for i in cont_positions]
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
split_char_cont = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
r2 = random.random()
if r2 < 0.9:
split_char_cont = ""
elif r2 < 0.94:
split_char_cont = "`"
elif r2 < 0.98:
split_char_cont = "'"
else:
split_char_cont = "-"
part2_cont = split_char_cont.join(mask_pinyin_cont)
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
@ -517,17 +540,13 @@ class PinyinInputDataset(IterableDataset):
cont_end = cont_positions[-1] + 1
part3_cont = ""
if random.random() > 0.7:
part3_cont = text[
cont_end : cont_end + np.random.choice(range(1, 17))
]
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
encoded_cont = self.tokenizer(
f"{part4}|{part1_cont}",
part3_cont,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)

View File

@ -231,21 +231,21 @@ def main():
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
# 获取词表总大小
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
# 获取词表总大小_id_to_info 已包含 EOS id=0
vocab_size = len(query_engine._id_to_info)
appeared_ids = set(counter.keys())
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
all_ids = set(query_engine._id_to_info.keys())
missing_ids = all_ids - appeared_ids
console.print(f"\n总样本数: {total:,}")
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
console.print(f"词表大小: {vocab_size:,} (含 EOS)")
console.print(f"唯一标签数: {len(counter):,}")
console.print(
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
)
console.print(
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]"
)
most_common = counter.most_common(args.top_k)
@ -328,7 +328,7 @@ def main():
table_dist.add_row(
"未出现",
f"{len(missing_ids):,}",
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
f"{len(missing_ids) / vocab_size * 100:.1f}%",
)
console.print(table_dist)

View File

@ -40,7 +40,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .trainer import collate_fn, worker_init_fn
from .trainer import preprocess_collate_fn, worker_init_fn
FIELDS = [
"input_ids",
@ -257,7 +257,7 @@ def main():
num_workers=num_train_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_train_workers > 0 else False,
)
@ -280,7 +280,7 @@ def main():
num_workers=num_eval_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
collate_fn=preprocess_collate_fn(args.max_seq_length),
prefetch_factor=2,
persistent_workers=True if num_eval_workers > 0 else False,
)

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
缺失字符补充工具
步骤 1: find-missing 扫描已预处理数据找出从未出现的 label ID输出 JSON
步骤 2: generate-template 根据 JSON 生成 JSONL 占位文件供用户手动填入包含缺失字的真实文本
用法
python -m model.supplement_missing find-missing \
--preprocessed-dir ./preprocessed/train \
--output missing_chars.json
python -m model.supplement_missing generate-template \
--missing-chars missing_chars.json \
--output supplement_texts.jsonl
"""
import argparse
import json
from pathlib import Path
from typing import Set
import numpy as np
from loguru import logger
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from .query import QueryEngine
def scan_labels(preprocessed_dir: Path) -> Set[int]:
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
appeared: Set[int] = set()
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
if not shard_files:
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
return appeared
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
data = np.load(shard_path)
labels = data["labels"].astype(np.int64)
if labels.ndim > 1 and labels.shape[-1] == 1:
labels = labels.squeeze(-1)
unique_ids = np.unique(labels)
appeared.update(int(uid) for uid in unique_ids)
del data
return appeared
def cmd_find_missing(args):
console = Console()
preprocessed_dir = Path(args.preprocessed_dir)
if not preprocessed_dir.exists():
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
return
metadata_path = preprocessed_dir / "metadata.json"
if not metadata_path.exists():
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
return
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
console.print(
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
)
console.print("[bold cyan]扫描 labels...[/bold cyan]")
appeared = scan_labels(preprocessed_dir)
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
query_engine = QueryEngine()
query_engine.load()
all_ids = set(query_engine._id_to_info.keys())
missing_ids = all_ids - appeared
missing_chars = []
for mid in sorted(missing_ids):
if mid == 0:
continue
info = query_engine.query_by_id(mid)
if info is not None:
missing_chars.append(
{
"id": info.id,
"char": info.char,
"pinyin": info.pinyin,
"count": info.count,
}
)
result = {
"missing_count": len(missing_chars),
"missing_chars": missing_chars,
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
console.print(f"已出现标签: {len(appeared):,}")
console.print(
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
)
if missing_chars:
table = Table(
title=f"缺失字符 (共 {len(missing_chars)} 个)",
show_header=True,
header_style="bold magenta",
)
table.add_column("ID", style="cyan", width=8)
table.add_column("字符", style="yellow", width=6)
table.add_column("拼音", style="green", width=12)
table.add_column("语料频次", style="red", width=12)
for entry in missing_chars:
table.add_row(
str(entry["id"]),
entry["char"],
entry["pinyin"],
f"{entry['count']:,}",
)
console.print(table)
console.print(f"\n已输出到: {output_path}")
def cmd_generate_template(args):
console = Console()
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
return
num_entries = args.num_entries
total_lines = len(missing_chars) * num_entries
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for entry in missing_chars:
for i in range(num_entries):
line = json.dumps(
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
ensure_ascii=False,
)
f.write(line + "\n")
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
console.print(
f"{total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
)
def main():
parser = argparse.ArgumentParser(
description="缺失字符补充工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
子命令
find-missing 扫描已预处理数据找出从未出现的 label ID
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
示例
python -m model.supplement_missing find-missing \\
--preprocessed-dir ./preprocessed/train \\
--output missing_chars.json
python -m model.supplement_missing generate-template \\
--missing-chars missing_chars.json \\
--output supplement_texts.jsonl
""",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
# find-missing
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
p_find.add_argument(
"--preprocessed-dir",
type=str,
required=True,
help="预处理数据目录(包含 shard_*.npz 和 metadata.json",
)
p_find.add_argument(
"--output",
type=str,
default="missing_chars.json",
help="输出 JSON 文件路径(默认: missing_chars.json",
)
# generate-template
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
p_gen.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_gen.add_argument(
"--output",
type=str,
default="supplement_texts.jsonl",
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl",
)
p_gen.add_argument(
"--num-entries",
type=int,
default=3,
help="每个缺失字符生成的模板条数(默认: 3",
)
args = parser.parse_args()
if args.command is None:
parser.print_help()
return
if args.command == "find-missing":
cmd_find_missing(args)
elif args.command == "generate-template":
cmd_generate_template(args)
app = main
if __name__ == "__main__":
main()

View File

@ -975,25 +975,62 @@ def worker_init_fn(worker_id: int) -> None:
random.seed(worker_seed)
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
自定义批处理函数将多个样本组合成一个batch
支持动态填充根据batch内最大序列长度进行padding而非固定max_length
max_seq_length > 0 pad到指定长度用于预处理
Args:
batch: 样本列表每个样本是一个字典
max_seq_length: 目标序列长度0表示动态padding
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
input_ids_list = [item["input_ids"] for item in batch]
token_type_ids_list = [item["token_type_ids"] for item in batch]
attention_mask_list = [item["attention_mask"] for item in batch]
if max_seq_length > 0:
target_len = max_seq_length
else:
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
@ -1011,6 +1048,15 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
}
def preprocess_collate_fn(max_seq_length: int):
"""创建用于预处理的collate_fn始终pad到max_seq_length"""
def _collate(batch):
return collate_fn(batch, max_seq_length=max_seq_length)
return _collate
# Typer CLI应用
def create_dataloader(
dataset,