feat(dataset): 引入幂律平滑方案优化频率调整逻辑
This commit is contained in:
parent
4ded2d656f
commit
8b41bcdc6f
|
|
|
|
@ -1,4 +1,9 @@
|
|||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
||||
|
||||
import jieba
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
|
|
@ -69,13 +74,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
merge_short_words_prob: float = 0.5,
|
||||
merge_max_short_words: int = 3,
|
||||
merge_max_total_chars: int = 6,
|
||||
low_freq_repeat: float = 50.0,
|
||||
high_freq_repeat: float = 0.1,
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 10_000_000
|
||||
self.max_drop_prob = 0.9
|
||||
self.repeat_end_freq = 10_000
|
||||
self.max_repeat_expect = 50
|
||||
# 频率调整参数 - 幂律平滑方案
|
||||
self.min_freq = 109
|
||||
self.low_freq_repeat = low_freq_repeat
|
||||
self.high_freq_repeat = high_freq_repeat
|
||||
self.word_break_prob = 0.10
|
||||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
|
|
@ -117,50 +122,35 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
||||
|
||||
# 计算幂律平滑参数
|
||||
if self.max_freq > self.min_freq:
|
||||
self.alpha = math.log(
|
||||
self.low_freq_repeat / self.high_freq_repeat
|
||||
) / math.log(self.max_freq / self.min_freq)
|
||||
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
|
||||
else:
|
||||
self.alpha = 0.0
|
||||
self.C = 1.0
|
||||
|
||||
def adjust_frequency(self, freq: int) -> int:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||
# 1. 削峰处理(高频字)
|
||||
if freq >= self.drop_start_freq:
|
||||
# 线性丢弃概率计算
|
||||
max_freq = self.max_freq # 使用预计算的最大频率值
|
||||
if max_freq <= self.drop_start_freq:
|
||||
drop_prob = 0.0
|
||||
else:
|
||||
drop_prob = (
|
||||
self.max_drop_prob
|
||||
* (freq - self.drop_start_freq)
|
||||
/ (max_freq - self.drop_start_freq)
|
||||
)
|
||||
if random.random() < drop_prob:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃
|
||||
使用幂律平滑方案:E(freq) = C × freq^(-α)
|
||||
保持频率排序关系,单个连续函数
|
||||
"""
|
||||
if freq <= 0:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
# 2. 填谷处理(低频字)
|
||||
elif freq <= self.repeat_end_freq:
|
||||
# 线性重复期望计算
|
||||
if freq <= self.min_freq:
|
||||
repeat_expect = self.max_repeat_expect
|
||||
else:
|
||||
if self.repeat_end_freq == self.min_freq:
|
||||
repeat_expect = 0
|
||||
else:
|
||||
repeat_expect = (
|
||||
self.max_repeat_expect
|
||||
* (self.repeat_end_freq - freq)
|
||||
/ (self.repeat_end_freq - self.min_freq)
|
||||
)
|
||||
# 使用泊松分布实现随机重复
|
||||
repeat_count = np.random.poisson(repeat_expect)
|
||||
if repeat_expect < 1.0:
|
||||
# 小期望值时,以概率 repeat_expect 采样 1 次
|
||||
return 1 if random.random() < repeat_expect else 0
|
||||
else:
|
||||
return max(1, repeat_count) # 原逻辑
|
||||
# 计算期望采样次数
|
||||
expected = self.C * (freq ** (-self.alpha))
|
||||
|
||||
# 3. 中间频率字
|
||||
# 采样策略
|
||||
if expected >= 1.0:
|
||||
# 泊松分布重复
|
||||
repeat_count = np.random.poisson(expected)
|
||||
return max(1, repeat_count)
|
||||
else:
|
||||
return 1
|
||||
# 伯努利采样:以概率expected返回1,否则返回0
|
||||
return 1 if random.random() < expected else 0
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
|
|
@ -215,21 +205,32 @@ class PinyinInputDataset(IterableDataset):
|
|||
|
||||
return result
|
||||
|
||||
# 生成需要预测汉字对应的拼音,并进行加强
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
) -> Tuple[int, List[str]]:
|
||||
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
|
||||
style = random.random()
|
||||
cumulative = 0.0
|
||||
style_idx = 0
|
||||
for i, w in enumerate(self.py_style_weight):
|
||||
cumulative += w
|
||||
if style < cumulative:
|
||||
style_idx = i
|
||||
break
|
||||
|
||||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
break
|
||||
else:
|
||||
py = np.random.choice(
|
||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||
p=self.py_style_weight,
|
||||
)
|
||||
full_py = pinyin_list[i]
|
||||
if style_idx == 0:
|
||||
py = full_py
|
||||
elif style_idx == 1:
|
||||
py = to_initials(full_py)
|
||||
if py == "":
|
||||
py = pinyin_list[i][0]
|
||||
py = full_py[0]
|
||||
else:
|
||||
py = full_py[0]
|
||||
mask_pinyin.append(py)
|
||||
return len(mask_pinyin), mask_pinyin
|
||||
|
||||
|
|
@ -271,9 +272,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
history.extend([0] * (8 - len(history)))
|
||||
|
||||
sample_dict = {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(
|
||||
encoded["token_type_ids"], dtype=torch.long
|
||||
),
|
||||
"attention_mask": torch.tensor(
|
||||
encoded["attention_mask"], dtype=torch.long
|
||||
),
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
|
|
@ -401,9 +406,15 @@ class PinyinInputDataset(IterableDataset):
|
|||
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
||||
|
||||
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
||||
split_char = np.random.choice(
|
||||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||
)
|
||||
r = random.random()
|
||||
if r < 0.9:
|
||||
split_char = ""
|
||||
elif r < 0.94:
|
||||
split_char = "`"
|
||||
elif r < 0.98:
|
||||
split_char = "'"
|
||||
else:
|
||||
split_char = "-"
|
||||
part2 = split_char.join(mask_pinyin)
|
||||
pinyin_ids = self._compute_pinyin_ids(part2)
|
||||
|
||||
|
|
@ -431,7 +442,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
# part3: 词后文本
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[word_end : word_end + np.random.choice(range(1, 17))]
|
||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||||
|
||||
# part4: 词提示
|
||||
part4 = ""
|
||||
|
|
@ -447,9 +458,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
|
|
@ -469,7 +478,15 @@ class PinyinInputDataset(IterableDataset):
|
|||
cont_start = char_positions[break_pos]
|
||||
|
||||
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
||||
target_len = np.random.choice(range(1, 9), p=self.cont_length_probs)
|
||||
cont_r = random.random()
|
||||
cont_probs = self.cont_length_probs
|
||||
cont_cumulative = 0.0
|
||||
target_len = 4
|
||||
for cont_len, cont_p in enumerate(cont_probs):
|
||||
cont_cumulative += cont_p
|
||||
if cont_r < cont_cumulative:
|
||||
target_len = cont_len + 1
|
||||
break
|
||||
cont_positions = []
|
||||
pos = cont_start
|
||||
while len(cont_positions) < target_len and pos < len(text):
|
||||
|
|
@ -486,9 +503,15 @@ class PinyinInputDataset(IterableDataset):
|
|||
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
||||
|
||||
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
||||
split_char_cont = np.random.choice(
|
||||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||
)
|
||||
r2 = random.random()
|
||||
if r2 < 0.9:
|
||||
split_char_cont = ""
|
||||
elif r2 < 0.94:
|
||||
split_char_cont = "`"
|
||||
elif r2 < 0.98:
|
||||
split_char_cont = "'"
|
||||
else:
|
||||
split_char_cont = "-"
|
||||
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
||||
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
||||
|
||||
|
|
@ -517,17 +540,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
cont_end = cont_positions[-1] + 1
|
||||
part3_cont = ""
|
||||
if random.random() > 0.7:
|
||||
part3_cont = text[
|
||||
cont_end : cont_end + np.random.choice(range(1, 17))
|
||||
]
|
||||
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
|
||||
|
||||
encoded_cont = self.tokenizer(
|
||||
f"{part4}|{part1_cont}",
|
||||
part3_cont,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -231,21 +231,21 @@ def main():
|
|||
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
|
||||
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
|
||||
|
||||
# 获取词表总大小
|
||||
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
|
||||
# 获取词表总大小(_id_to_info 已包含 EOS id=0)
|
||||
vocab_size = len(query_engine._id_to_info)
|
||||
|
||||
appeared_ids = set(counter.keys())
|
||||
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
|
||||
all_ids = set(query_engine._id_to_info.keys())
|
||||
missing_ids = all_ids - appeared_ids
|
||||
|
||||
console.print(f"\n总样本数: {total:,}")
|
||||
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
|
||||
console.print(f"词表大小: {vocab_size:,} (含 EOS)")
|
||||
console.print(f"唯一标签数: {len(counter):,}")
|
||||
console.print(
|
||||
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
|
||||
)
|
||||
console.print(
|
||||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
|
||||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]"
|
||||
)
|
||||
|
||||
most_common = counter.most_common(args.top_k)
|
||||
|
|
@ -328,7 +328,7 @@ def main():
|
|||
table_dist.add_row(
|
||||
"未出现",
|
||||
f"{len(missing_ids):,}",
|
||||
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
|
||||
f"{len(missing_ids) / vocab_size * 100:.1f}%",
|
||||
)
|
||||
console.print(table_dist)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .trainer import collate_fn, worker_init_fn
|
||||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
|
|
@ -257,7 +257,7 @@ def main():
|
|||
num_workers=num_train_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_train_workers > 0 else False,
|
||||
)
|
||||
|
|
@ -280,7 +280,7 @@ def main():
|
|||
num_workers=num_eval_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_eval_workers > 0 else False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,250 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
缺失字符补充工具
|
||||
|
||||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
||||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
||||
|
||||
用法:
|
||||
python -m model.supplement_missing find-missing \
|
||||
--preprocessed-dir ./preprocessed/train \
|
||||
--output missing_chars.json
|
||||
|
||||
python -m model.supplement_missing generate-template \
|
||||
--missing-chars missing_chars.json \
|
||||
--output supplement_texts.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from tqdm import tqdm
|
||||
|
||||
from .query import QueryEngine
|
||||
|
||||
|
||||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
||||
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
|
||||
appeared: Set[int] = set()
|
||||
|
||||
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
|
||||
if not shard_files:
|
||||
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
|
||||
return appeared
|
||||
|
||||
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
|
||||
data = np.load(shard_path)
|
||||
labels = data["labels"].astype(np.int64)
|
||||
if labels.ndim > 1 and labels.shape[-1] == 1:
|
||||
labels = labels.squeeze(-1)
|
||||
unique_ids = np.unique(labels)
|
||||
appeared.update(int(uid) for uid in unique_ids)
|
||||
del data
|
||||
|
||||
return appeared
|
||||
|
||||
|
||||
def cmd_find_missing(args):
|
||||
console = Console()
|
||||
preprocessed_dir = Path(args.preprocessed_dir)
|
||||
|
||||
if not preprocessed_dir.exists():
|
||||
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
|
||||
return
|
||||
|
||||
metadata_path = preprocessed_dir / "metadata.json"
|
||||
if not metadata_path.exists():
|
||||
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
console.print(
|
||||
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
|
||||
)
|
||||
|
||||
console.print("[bold cyan]扫描 labels...[/bold cyan]")
|
||||
appeared = scan_labels(preprocessed_dir)
|
||||
|
||||
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
all_ids = set(query_engine._id_to_info.keys())
|
||||
missing_ids = all_ids - appeared
|
||||
|
||||
missing_chars = []
|
||||
for mid in sorted(missing_ids):
|
||||
if mid == 0:
|
||||
continue
|
||||
info = query_engine.query_by_id(mid)
|
||||
if info is not None:
|
||||
missing_chars.append(
|
||||
{
|
||||
"id": info.id,
|
||||
"char": info.char,
|
||||
"pinyin": info.pinyin,
|
||||
"count": info.count,
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
"missing_count": len(missing_chars),
|
||||
"missing_chars": missing_chars,
|
||||
}
|
||||
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
|
||||
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
|
||||
console.print(f"已出现标签: {len(appeared):,}")
|
||||
console.print(
|
||||
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
|
||||
)
|
||||
|
||||
if missing_chars:
|
||||
table = Table(
|
||||
title=f"缺失字符 (共 {len(missing_chars)} 个)",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table.add_column("ID", style="cyan", width=8)
|
||||
table.add_column("字符", style="yellow", width=6)
|
||||
table.add_column("拼音", style="green", width=12)
|
||||
table.add_column("语料频次", style="red", width=12)
|
||||
for entry in missing_chars:
|
||||
table.add_row(
|
||||
str(entry["id"]),
|
||||
entry["char"],
|
||||
entry["pinyin"],
|
||||
f"{entry['count']:,}",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
console.print(f"\n已输出到: {output_path}")
|
||||
|
||||
|
||||
def cmd_generate_template(args):
|
||||
console = Console()
|
||||
|
||||
missing_path = Path(args.missing_chars)
|
||||
if not missing_path.exists():
|
||||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(missing_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
missing_chars = data.get("missing_chars", [])
|
||||
if not missing_chars:
|
||||
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
|
||||
return
|
||||
|
||||
num_entries = args.num_entries
|
||||
total_lines = len(missing_chars) * num_entries
|
||||
|
||||
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
|
||||
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
|
||||
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
|
||||
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for entry in missing_chars:
|
||||
for i in range(num_entries):
|
||||
line = json.dumps(
|
||||
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write(line + "\n")
|
||||
|
||||
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
|
||||
console.print(
|
||||
f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
|
||||
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="缺失字符补充工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
子命令:
|
||||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
||||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
||||
|
||||
示例:
|
||||
python -m model.supplement_missing find-missing \\
|
||||
--preprocessed-dir ./preprocessed/train \\
|
||||
--output missing_chars.json
|
||||
|
||||
python -m model.supplement_missing generate-template \\
|
||||
--missing-chars missing_chars.json \\
|
||||
--output supplement_texts.jsonl
|
||||
""",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
||||
# find-missing
|
||||
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
|
||||
p_find.add_argument(
|
||||
"--preprocessed-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="预处理数据目录(包含 shard_*.npz 和 metadata.json)",
|
||||
)
|
||||
p_find.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="missing_chars.json",
|
||||
help="输出 JSON 文件路径(默认: missing_chars.json)",
|
||||
)
|
||||
|
||||
# generate-template
|
||||
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
|
||||
p_gen.add_argument(
|
||||
"--missing-chars",
|
||||
type=str,
|
||||
required=True,
|
||||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||||
)
|
||||
p_gen.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="supplement_texts.jsonl",
|
||||
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)",
|
||||
)
|
||||
p_gen.add_argument(
|
||||
"--num-entries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="每个缺失字符生成的模板条数(默认: 3)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "find-missing":
|
||||
cmd_find_missing(args)
|
||||
elif args.command == "generate-template":
|
||||
cmd_generate_template(args)
|
||||
|
||||
|
||||
app = main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -975,25 +975,62 @@ def worker_init_fn(worker_id: int) -> None:
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
自定义批处理函数,将多个样本组合成一个batch
|
||||
自定义批处理函数,将多个样本组合成一个batch。
|
||||
支持动态填充:根据batch内最大序列长度进行padding,而非固定max_length。
|
||||
当 max_seq_length > 0 时,pad到指定长度(用于预处理)。
|
||||
|
||||
Args:
|
||||
batch: 样本列表,每个样本是一个字典
|
||||
max_seq_length: 目标序列长度,0表示动态padding
|
||||
|
||||
Returns:
|
||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||
"""
|
||||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||
input_ids_list = [item["input_ids"] for item in batch]
|
||||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||||
attention_mask_list = [item["attention_mask"] for item in batch]
|
||||
|
||||
if max_seq_length > 0:
|
||||
target_len = max_seq_length
|
||||
else:
|
||||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||||
|
||||
padded_input_ids = []
|
||||
padded_token_type_ids = []
|
||||
padded_attention_mask = []
|
||||
for ids, tt_ids, mask in zip(
|
||||
input_ids_list, token_type_ids_list, attention_mask_list
|
||||
):
|
||||
seq_len = ids.shape[0]
|
||||
if seq_len < target_len:
|
||||
pad_len = target_len - seq_len
|
||||
padded_input_ids.append(
|
||||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||||
)
|
||||
padded_token_type_ids.append(
|
||||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||||
)
|
||||
padded_attention_mask.append(
|
||||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||||
)
|
||||
elif seq_len > target_len:
|
||||
padded_input_ids.append(ids[:target_len])
|
||||
padded_token_type_ids.append(tt_ids[:target_len])
|
||||
padded_attention_mask.append(mask[:target_len])
|
||||
else:
|
||||
padded_input_ids.append(ids)
|
||||
padded_token_type_ids.append(tt_ids)
|
||||
padded_attention_mask.append(mask)
|
||||
|
||||
input_ids = torch.stack(padded_input_ids)
|
||||
token_type_ids = torch.stack(padded_token_type_ids)
|
||||
attention_mask = torch.stack(padded_attention_mask)
|
||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||
|
||||
# 字符串字段保持为列表
|
||||
prefixes = [item["prefix"] for item in batch]
|
||||
suffixes = [item["suffix"] for item in batch]
|
||||
pinyins = [item["pinyin"] for item in batch]
|
||||
|
|
@ -1011,6 +1048,15 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def preprocess_collate_fn(max_seq_length: int):
|
||||
"""创建用于预处理的collate_fn,始终pad到max_seq_length"""
|
||||
|
||||
def _collate(batch):
|
||||
return collate_fn(batch, max_seq_length=max_seq_length)
|
||||
|
||||
return _collate
|
||||
|
||||
|
||||
# Typer CLI应用
|
||||
def create_dataloader(
|
||||
dataset,
|
||||
|
|
|
|||
Loading…
Reference in New Issue