618 lines
23 KiB
Python
618 lines
23 KiB
Python
import warnings
|
||
|
||
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
||
|
||
import jieba
|
||
import math
|
||
import random
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
|
||
from .query import QueryEngine
|
||
|
||
|
||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||
|
||
|
||
jieba.setLogLevel(jieba.logging.INFO)
|
||
|
||
|
||
def segment_text(text: str) -> List[str]:
|
||
"""使用 jieba 分词,返回词列表"""
|
||
return list(jieba.cut(text, HMM=False))
|
||
|
||
|
||
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
|
||
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
|
||
result = []
|
||
pos = 0
|
||
for word in words:
|
||
start = pos
|
||
end = pos + len(word)
|
||
result.append((start, end))
|
||
pos = end
|
||
return result
|
||
|
||
|
||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||
"""
|
||
将拼音字符串转换为 ID 列表。
|
||
支持 a-z 和 `。
|
||
未知字符映射为 0 (PAD/UNK)。
|
||
"""
|
||
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||
|
||
|
||
class PinyinInputDataset(IterableDataset):
|
||
def __init__(
|
||
self,
|
||
data_path: str,
|
||
max_workers: int = -1,
|
||
max_iter_length=1e6,
|
||
max_seq_length=128,
|
||
text_field: str = "text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size: int = 100000,
|
||
retention_ratio: float = 0.8,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
merge_short_words_prob: float = 0.5,
|
||
merge_max_short_words: int = 3,
|
||
merge_max_total_chars: int = 6,
|
||
low_freq_repeat: float = 50.0,
|
||
high_freq_repeat: float = 0.1,
|
||
data_kwargs: Optional[Dict] = None,
|
||
target_labels: Optional[Set[int]] = None,
|
||
):
|
||
# 频率调整参数 - 幂律平滑方案
|
||
self.min_freq = 109
|
||
self.low_freq_repeat = low_freq_repeat
|
||
self.high_freq_repeat = high_freq_repeat
|
||
self.word_break_prob = 0.10
|
||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||
self.merge_short_words_prob = merge_short_words_prob
|
||
self.merge_max_short_words = merge_max_short_words
|
||
self.merge_max_total_chars = merge_max_total_chars
|
||
|
||
self.data_kwargs = data_kwargs or {}
|
||
self.target_labels = target_labels
|
||
|
||
jieba.initialize()
|
||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
Path(str(files(__package__))) / "assets" / "tokenizer"
|
||
)
|
||
self.data_path = data_path
|
||
|
||
self.max_iter_length = max_iter_length
|
||
self.max_seq_length = max_seq_length
|
||
self.text_field = text_field
|
||
load_kwargs = {"split": "train", "streaming": True}
|
||
load_kwargs.update(self.data_kwargs)
|
||
self.dataset = load_dataset(data_path, **load_kwargs)
|
||
self.max_workers = max_workers
|
||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||
self.shuffle_buffer_size = shuffle_buffer_size
|
||
self.retention_ratio = retention_ratio
|
||
if not (0 < retention_ratio < 1):
|
||
raise ValueError(
|
||
f"retention_ratio必须在0和1之间,当前值: {retention_ratio}"
|
||
)
|
||
self.retention_size = int(shuffle_buffer_size * retention_ratio)
|
||
if self.retention_size <= 0:
|
||
raise ValueError(
|
||
f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
|
||
)
|
||
self.possible_lengths = list(length_weights.keys())
|
||
self.weights = list(length_weights.values())
|
||
|
||
self.query_engine = QueryEngine()
|
||
self.query_engine.load()
|
||
|
||
# 提取每个样本的目标字符及其频率
|
||
self.sample_freqs = self.query_engine.get_all_weights()
|
||
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
||
|
||
# 计算幂律平滑参数
|
||
if self.max_freq > self.min_freq:
|
||
self.alpha = math.log(
|
||
self.low_freq_repeat / self.high_freq_repeat
|
||
) / math.log(self.max_freq / self.min_freq)
|
||
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
|
||
else:
|
||
self.alpha = 0.0
|
||
self.C = 1.0
|
||
|
||
def adjust_frequency(self, freq: int) -> int:
|
||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃
|
||
使用幂律平滑方案:E(freq) = C × freq^(-α)
|
||
保持频率排序关系,单个连续函数
|
||
"""
|
||
if freq <= 0:
|
||
return 0
|
||
|
||
# 计算期望采样次数
|
||
expected = self.C * (freq ** (-self.alpha))
|
||
|
||
# 采样策略
|
||
if expected >= 1.0:
|
||
# 泊松分布重复
|
||
repeat_count = np.random.poisson(expected)
|
||
return max(1, repeat_count)
|
||
else:
|
||
# 伯努利采样:以概率expected返回1,否则返回0
|
||
return 1 if random.random() < expected else 0
|
||
|
||
# 生成对应文本的拼音
|
||
def generate_pinyin(self, text: str) -> List[str]:
|
||
"""
|
||
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||
利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。
|
||
|
||
特性:
|
||
1. 严格一一对应:len(result) == len(text)
|
||
2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音
|
||
3. 非汉字字符原样占位
|
||
|
||
Args:
|
||
text: 输入字符串
|
||
|
||
Returns:
|
||
List[str]: 拼音或非汉字字符的列表
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
def _fallback(chars):
|
||
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
|
||
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
|
||
result = []
|
||
for char in chars:
|
||
if self.query_engine.is_chinese_char(char):
|
||
ids = self.query_engine.query_by_char(char, limit=1)
|
||
if ids:
|
||
result.append(ids[0][1])
|
||
else:
|
||
result.append(char)
|
||
else:
|
||
result.append(char)
|
||
return result
|
||
|
||
pinyin_list = lazy_pinyin(text, errors=_fallback)
|
||
|
||
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
|
||
if len(pinyin_list) != len(text):
|
||
logger.warning(
|
||
f"pinyin length mismatch: text_len={len(text)}, "
|
||
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
|
||
)
|
||
pinyin_list = []
|
||
for c in text:
|
||
result = lazy_pinyin(c, errors=_fallback)
|
||
pinyin_list.append(result[0] if result else c)
|
||
|
||
return pinyin_list
|
||
|
||
def get_mask_pinyin(
|
||
self, text: str, pinyin_list: List[str]
|
||
) -> Tuple[int, List[str]]:
|
||
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
|
||
style = random.random()
|
||
cumulative = 0.0
|
||
style_idx = 0
|
||
for i, w in enumerate(self.py_style_weight):
|
||
cumulative += w
|
||
if style < cumulative:
|
||
style_idx = i
|
||
break
|
||
|
||
mask_pinyin = []
|
||
for i in range(len(text)):
|
||
if not self.query_engine.is_chinese_char(text[i]):
|
||
break
|
||
full_py = pinyin_list[i]
|
||
if style_idx == 0:
|
||
py = full_py
|
||
elif style_idx == 1:
|
||
py = to_initials(full_py)
|
||
if py == "":
|
||
py = full_py[0]
|
||
else:
|
||
py = full_py[0]
|
||
mask_pinyin.append(py)
|
||
return len(mask_pinyin), mask_pinyin
|
||
|
||
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
|
||
pinyin_ids = text_to_pinyin_ids(pinyin_str)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
return torch.tensor(pinyin_ids, dtype=torch.long)
|
||
|
||
def _build_single_sample(
|
||
self,
|
||
label: int,
|
||
history: list,
|
||
text: str,
|
||
word_start: int,
|
||
word_end: int,
|
||
part2: str,
|
||
pinyin_ids: torch.Tensor,
|
||
words: list,
|
||
) -> dict:
|
||
"""构造单条样本,每次调用都会重新随机采样上下文"""
|
||
|
||
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
|
||
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
|
||
part1 = text[word_start - part1_len : word_start]
|
||
|
||
# part3:每次重新 roll
|
||
part3 = ""
|
||
if random.random() > 0.7:
|
||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||
|
||
# part4:每次重新 roll
|
||
part4 = ""
|
||
if random.random() > 0.7 and words:
|
||
num_words = random.randint(1, 3)
|
||
selected_words = random.sample(words, min(num_words, len(words)))
|
||
part4 = "|".join(selected_words)
|
||
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=self.max_seq_length,
|
||
truncation=True,
|
||
return_token_type_ids=True,
|
||
)
|
||
|
||
# 确保 history 长度为 8
|
||
hist = list(history)
|
||
if len(hist) > 8:
|
||
hist = hist[-8:]
|
||
while len(hist) < 8:
|
||
hist.append(0)
|
||
|
||
return {
|
||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
|
||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||
"label": torch.tensor([label], dtype=torch.long),
|
||
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
|
||
def __iter__(self):
|
||
worker_info = torch.utils.data.get_worker_info()
|
||
if worker_info is not None:
|
||
worker_id = worker_info.id
|
||
num_workers = (
|
||
self.max_workers if self.max_workers > 0 else worker_info.num_workers
|
||
)
|
||
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||
seed = base_seed + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
|
||
if worker_id >= num_workers:
|
||
return
|
||
|
||
try:
|
||
worker_dataset = self.dataset.shard(
|
||
num_shards=num_workers, index=worker_id
|
||
)
|
||
except (IndexError, ValueError):
|
||
worker_dataset = self.dataset
|
||
|
||
total_quota = int(self.max_iter_length)
|
||
base_quota = total_quota // num_workers
|
||
remainder = total_quota % num_workers
|
||
|
||
if worker_id == num_workers - 1:
|
||
worker_quota = base_quota + remainder
|
||
else:
|
||
worker_quota = base_quota
|
||
else:
|
||
worker_quota = int(self.max_iter_length)
|
||
num_workers = 1
|
||
worker_dataset = self.dataset
|
||
|
||
current_iter_index = 0
|
||
|
||
batch_samples = []
|
||
for sample in worker_dataset:
|
||
if current_iter_index >= worker_quota:
|
||
break
|
||
|
||
text = sample.get(self.text_field, "")
|
||
if not text:
|
||
continue
|
||
|
||
words = segment_text(text)
|
||
word_boundaries = build_word_boundaries(words)
|
||
pinyin_list = self.generate_pinyin(text)
|
||
|
||
idx = 0
|
||
while idx < len(word_boundaries):
|
||
word_start, word_end = word_boundaries[idx]
|
||
|
||
char_positions = []
|
||
for i in range(word_start, word_end):
|
||
if self.query_engine.is_chinese_char(text[i]):
|
||
char_positions.append(i)
|
||
|
||
if not char_positions:
|
||
idx += 1
|
||
continue
|
||
|
||
word_len_chars = len(char_positions)
|
||
|
||
merge_end_idx = idx + 1
|
||
if word_len_chars <= 2:
|
||
accumulated_positions = list(char_positions)
|
||
accumulated_count = 1
|
||
next_idx = idx + 1
|
||
|
||
while next_idx < len(word_boundaries):
|
||
ns, ne = word_boundaries[next_idx]
|
||
next_positions = []
|
||
for i in range(ns, ne):
|
||
if self.query_engine.is_chinese_char(text[i]):
|
||
next_positions.append(i)
|
||
next_len = len(next_positions)
|
||
|
||
if next_len == 0 or next_len > 2:
|
||
break
|
||
if (
|
||
len(accumulated_positions) + next_len
|
||
> self.merge_max_total_chars
|
||
):
|
||
break
|
||
if accumulated_count + 1 > self.merge_max_short_words:
|
||
break
|
||
if random.random() > self.merge_short_words_prob:
|
||
break
|
||
|
||
accumulated_positions.extend(next_positions)
|
||
accumulated_count += 1
|
||
next_idx += 1
|
||
|
||
if accumulated_count > 1:
|
||
char_positions = accumulated_positions
|
||
word_len_chars = len(char_positions)
|
||
merge_end_idx = next_idx
|
||
word_start = word_boundaries[idx][0]
|
||
word_end = word_boundaries[next_idx - 1][1]
|
||
|
||
should_break = (
|
||
word_len_chars > 1 and random.random() < self.word_break_prob
|
||
)
|
||
|
||
if should_break:
|
||
break_pos = random.randint(1, word_len_chars - 1)
|
||
else:
|
||
break_pos = word_len_chars
|
||
|
||
# ========== Phase 1: 前缀/整词 ==========
|
||
prefix_positions = char_positions[:break_pos]
|
||
prefix_text = "".join(text[i] for i in prefix_positions)
|
||
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
||
|
||
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
||
r = random.random()
|
||
if r < 0.9:
|
||
split_char = ""
|
||
elif r < 0.94:
|
||
split_char = "`"
|
||
elif r < 0.98:
|
||
split_char = "'"
|
||
else:
|
||
split_char = "-"
|
||
part2 = split_char.join(mask_pinyin)
|
||
pinyin_ids = self._compute_pinyin_ids(part2)
|
||
|
||
try:
|
||
labels = [
|
||
self.query_engine.get_char_info_by_char_pinyin(
|
||
text[i], pinyin_list[i]
|
||
).id
|
||
for i in prefix_positions
|
||
]
|
||
except AttributeError as e:
|
||
logger.error(
|
||
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
||
)
|
||
idx = merge_end_idx
|
||
continue
|
||
|
||
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
||
if not should_break and random.random() <= 0.1:
|
||
labels.append(0)
|
||
|
||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||
processed_history = []
|
||
for label_idx, label in enumerate(labels):
|
||
base_repeats = self.adjust_frequency(
|
||
self.sample_freqs.get(label, 0)
|
||
)
|
||
if base_repeats == 0:
|
||
processed_history.append(label)
|
||
continue
|
||
if (
|
||
self.target_labels is not None
|
||
and label not in self.target_labels
|
||
):
|
||
processed_history.append(label)
|
||
continue
|
||
|
||
weight = (
|
||
self._history_weights[label_idx]
|
||
if label_idx < len(self._history_weights)
|
||
else 3.0
|
||
)
|
||
repeats = max(1, int(base_repeats * weight))
|
||
|
||
for _ in range(repeats):
|
||
sample = self._build_single_sample(
|
||
label=label,
|
||
history=processed_history,
|
||
text=text,
|
||
word_start=word_start,
|
||
word_end=word_end,
|
||
part2=part2,
|
||
pinyin_ids=pinyin_ids,
|
||
words=words,
|
||
)
|
||
batch_samples.append(sample)
|
||
|
||
processed_history.append(label)
|
||
|
||
# ========== Phase 2: 破词续接 ==========
|
||
if should_break and break_pos < word_len_chars:
|
||
cont_start = char_positions[break_pos]
|
||
|
||
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
||
cont_r = random.random()
|
||
cont_probs = self.cont_length_probs
|
||
cont_cumulative = 0.0
|
||
target_len = 4
|
||
for cont_len, cont_p in enumerate(cont_probs):
|
||
cont_cumulative += cont_p
|
||
if cont_r < cont_cumulative:
|
||
target_len = cont_len + 1
|
||
break
|
||
cont_positions = []
|
||
pos = cont_start
|
||
while len(cont_positions) < target_len and pos < len(text):
|
||
if self.query_engine.is_chinese_char(text[pos]):
|
||
cont_positions.append(pos)
|
||
else:
|
||
break
|
||
pos += 1
|
||
|
||
if not cont_positions:
|
||
continue
|
||
|
||
cont_text = "".join(text[i] for i in cont_positions)
|
||
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
||
|
||
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
||
r2 = random.random()
|
||
if r2 < 0.9:
|
||
split_char_cont = ""
|
||
elif r2 < 0.94:
|
||
split_char_cont = "`"
|
||
elif r2 < 0.98:
|
||
split_char_cont = "'"
|
||
else:
|
||
split_char_cont = "-"
|
||
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
||
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
||
|
||
try:
|
||
cont_labels = [
|
||
self.query_engine.get_char_info_by_char_pinyin(
|
||
text[i], pinyin_list[i]
|
||
).id
|
||
for i in cont_positions
|
||
]
|
||
except AttributeError as e:
|
||
logger.error(
|
||
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
||
)
|
||
idx = merge_end_idx
|
||
continue
|
||
|
||
# 续接末尾 10% 概率追加 EOS
|
||
if random.random() <= 0.1:
|
||
cont_labels.append(0)
|
||
|
||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||
cont_processed_history = []
|
||
cont_end = cont_positions[-1] + 1
|
||
for label_idx, label in enumerate(cont_labels):
|
||
base_repeats = self.adjust_frequency(
|
||
self.sample_freqs.get(label, 0)
|
||
)
|
||
if base_repeats == 0:
|
||
cont_processed_history.append(label)
|
||
continue
|
||
if (
|
||
self.target_labels is not None
|
||
and label not in self.target_labels
|
||
):
|
||
cont_processed_history.append(label)
|
||
continue
|
||
|
||
weight = (
|
||
self._history_weights[label_idx]
|
||
if label_idx < len(self._history_weights)
|
||
else 3.0
|
||
)
|
||
repeats = max(1, int(base_repeats * weight))
|
||
|
||
for _ in range(repeats):
|
||
sample = self._build_single_sample(
|
||
label=label,
|
||
history=cont_processed_history,
|
||
text=text,
|
||
word_start=cont_start,
|
||
word_end=cont_end,
|
||
part2=part2_cont,
|
||
pinyin_ids=pinyin_ids_cont,
|
||
words=words,
|
||
)
|
||
batch_samples.append(sample)
|
||
|
||
cont_processed_history.append(label)
|
||
|
||
idx = merge_end_idx
|
||
|
||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||
indices = np.random.permutation(len(batch_samples))
|
||
|
||
actual_retention = min(self.retention_size, len(batch_samples))
|
||
|
||
output_count = len(batch_samples) - actual_retention
|
||
|
||
for i in range(output_count):
|
||
if current_iter_index >= worker_quota:
|
||
batch_samples = []
|
||
return
|
||
yield batch_samples[indices[i]]
|
||
current_iter_index += 1
|
||
|
||
retained_samples = [
|
||
batch_samples[idx] for idx in indices[output_count:]
|
||
]
|
||
batch_samples = retained_samples
|
||
|
||
# 处理剩余的样本
|
||
if batch_samples:
|
||
indices = np.random.permutation(len(batch_samples))
|
||
for idx in indices:
|
||
if current_iter_index >= worker_quota:
|
||
return
|
||
yield batch_samples[idx]
|
||
current_iter_index += 1
|