SUimeModelTraner/src/model/dataset.py

618 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba
import math
import random
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
from .query import QueryEngine
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
jieba.setLogLevel(jieba.logging.INFO)
def segment_text(text: str) -> List[str]:
"""使用 jieba 分词,返回词列表"""
return list(jieba.cut(text, HMM=False))
def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]:
"""建立词边界列表 [(start, end), ...],基于顺序位置累加"""
result = []
pos = 0
for word in words:
start = pos
end = pos + len(word)
result.append((start, end))
pos = end
return result
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表。
支持 a-z 和 `。
未知字符映射为 0 (PAD/UNK)。
"""
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
class PinyinInputDataset(IterableDataset):
def __init__(
self,
data_path: str,
max_workers: int = -1,
max_iter_length=1e6,
max_seq_length=128,
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.8,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
merge_short_words_prob: float = 0.5,
merge_max_short_words: int = 3,
merge_max_total_chars: int = 6,
low_freq_repeat: float = 50.0,
high_freq_repeat: float = 0.1,
data_kwargs: Optional[Dict] = None,
target_labels: Optional[Set[int]] = None,
):
# 频率调整参数 - 幂律平滑方案
self.min_freq = 109
self.low_freq_repeat = low_freq_repeat
self.high_freq_repeat = high_freq_repeat
self.word_break_prob = 0.10
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
self.merge_short_words_prob = merge_short_words_prob
self.merge_max_short_words = merge_max_short_words
self.merge_max_total_chars = merge_max_total_chars
self.data_kwargs = data_kwargs or {}
self.target_labels = target_labels
jieba.initialize()
self.tokenizer = AutoTokenizer.from_pretrained(
Path(str(files(__package__))) / "assets" / "tokenizer"
)
self.data_path = data_path
self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length
self.text_field = text_field
load_kwargs = {"split": "train", "streaming": True}
load_kwargs.update(self.data_kwargs)
self.dataset = load_dataset(data_path, **load_kwargs)
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
self.retention_ratio = retention_ratio
if not (0 < retention_ratio < 1):
raise ValueError(
f"retention_ratio必须在0和1之间当前值: {retention_ratio}"
)
self.retention_size = int(shuffle_buffer_size * retention_ratio)
if self.retention_size <= 0:
raise ValueError(
f"计算出的retention_size必须大于0当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
)
self.possible_lengths = list(length_weights.keys())
self.weights = list(length_weights.values())
self.query_engine = QueryEngine()
self.query_engine.load()
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
# 计算幂律平滑参数
if self.max_freq > self.min_freq:
self.alpha = math.log(
self.low_freq_repeat / self.high_freq_repeat
) / math.log(self.max_freq / self.min_freq)
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
else:
self.alpha = 0.0
self.C = 1.0
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃
使用幂律平滑方案E(freq) = C × freq^(-α)
保持频率排序关系,单个连续函数
"""
if freq <= 0:
return 0
# 计算期望采样次数
expected = self.C * (freq ** (-self.alpha))
# 采样策略
if expected >= 1.0:
# 泊松分布重复
repeat_count = np.random.poisson(expected)
return max(1, repeat_count)
else:
# 伯努利采样以概率expected返回1否则返回0
return 1 if random.random() < expected else 0
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
"""
将文本转换为拼音列表。对整段文本调用 lazy_pinyin
利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。
特性:
1. 严格一一对应len(result) == len(text)
2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音
3. 非汉字字符原样占位
Args:
text: 输入字符串
Returns:
List[str]: 拼音或非汉字字符的列表
"""
if not text:
return []
def _fallback(chars):
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
result = []
for char in chars:
if self.query_engine.is_chinese_char(char):
ids = self.query_engine.query_by_char(char, limit=1)
if ids:
result.append(ids[0][1])
else:
result.append(char)
else:
result.append(char)
return result
pinyin_list = lazy_pinyin(text, errors=_fallback)
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
if len(pinyin_list) != len(text):
logger.warning(
f"pinyin length mismatch: text_len={len(text)}, "
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
)
pinyin_list = []
for c in text:
result = lazy_pinyin(c, errors=_fallback)
pinyin_list.append(result[0] if result else c)
return pinyin_list
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
style = random.random()
cumulative = 0.0
style_idx = 0
for i, w in enumerate(self.py_style_weight):
cumulative += w
if style < cumulative:
style_idx = i
break
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
break
full_py = pinyin_list[i]
if style_idx == 0:
py = full_py
elif style_idx == 1:
py = to_initials(full_py)
if py == "":
py = full_py[0]
else:
py = full_py[0]
mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin
def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor:
pinyin_ids = text_to_pinyin_ids(pinyin_str)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
return torch.tensor(pinyin_ids, dtype=torch.long)
def _build_single_sample(
self,
label: int,
history: list,
text: str,
word_start: int,
word_end: int,
part2: str,
pinyin_ids: torch.Tensor,
words: list,
) -> dict:
"""构造单条样本,每次调用都会重新随机采样上下文"""
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
part1 = text[word_start - part1_len : word_start]
# part3每次重新 roll
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + random.randint(1, 16)]
# part4每次重新 roll
part4 = ""
if random.random() > 0.7 and words:
num_words = random.randint(1, 3)
selected_words = random.sample(words, min(num_words, len(words)))
part4 = "|".join(selected_words)
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
)
# 确保 history 长度为 8
hist = list(history)
if len(hist) > 8:
hist = hist[-8:]
while len(hist) < 8:
hist.append(0)
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = (
self.max_workers if self.max_workers > 0 else worker_info.num_workers
)
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
if worker_id >= num_workers:
return
try:
worker_dataset = self.dataset.shard(
num_shards=num_workers, index=worker_id
)
except (IndexError, ValueError):
worker_dataset = self.dataset
total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers
remainder = total_quota % num_workers
if worker_id == num_workers - 1:
worker_quota = base_quota + remainder
else:
worker_quota = base_quota
else:
worker_quota = int(self.max_iter_length)
num_workers = 1
worker_dataset = self.dataset
current_iter_index = 0
batch_samples = []
for sample in worker_dataset:
if current_iter_index >= worker_quota:
break
text = sample.get(self.text_field, "")
if not text:
continue
words = segment_text(text)
word_boundaries = build_word_boundaries(words)
pinyin_list = self.generate_pinyin(text)
idx = 0
while idx < len(word_boundaries):
word_start, word_end = word_boundaries[idx]
char_positions = []
for i in range(word_start, word_end):
if self.query_engine.is_chinese_char(text[i]):
char_positions.append(i)
if not char_positions:
idx += 1
continue
word_len_chars = len(char_positions)
merge_end_idx = idx + 1
if word_len_chars <= 2:
accumulated_positions = list(char_positions)
accumulated_count = 1
next_idx = idx + 1
while next_idx < len(word_boundaries):
ns, ne = word_boundaries[next_idx]
next_positions = []
for i in range(ns, ne):
if self.query_engine.is_chinese_char(text[i]):
next_positions.append(i)
next_len = len(next_positions)
if next_len == 0 or next_len > 2:
break
if (
len(accumulated_positions) + next_len
> self.merge_max_total_chars
):
break
if accumulated_count + 1 > self.merge_max_short_words:
break
if random.random() > self.merge_short_words_prob:
break
accumulated_positions.extend(next_positions)
accumulated_count += 1
next_idx += 1
if accumulated_count > 1:
char_positions = accumulated_positions
word_len_chars = len(char_positions)
merge_end_idx = next_idx
word_start = word_boundaries[idx][0]
word_end = word_boundaries[next_idx - 1][1]
should_break = (
word_len_chars > 1 and random.random() < self.word_break_prob
)
if should_break:
break_pos = random.randint(1, word_len_chars - 1)
else:
break_pos = word_len_chars
# ========== Phase 1: 前缀/整词 ==========
prefix_positions = char_positions[:break_pos]
prefix_text = "".join(text[i] for i in prefix_positions)
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
r = random.random()
if r < 0.9:
split_char = ""
elif r < 0.94:
split_char = "`"
elif r < 0.98:
split_char = "'"
else:
split_char = "-"
part2 = split_char.join(mask_pinyin)
pinyin_ids = self._compute_pinyin_ids(part2)
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(
text[i], pinyin_list[i]
).id
for i in prefix_positions
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
)
idx = merge_end_idx
continue
# 整词末尾 10% 概率追加 EOS破词前缀不加
if not should_break and random.random() <= 0.1:
labels.append(0)
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
processed_history = []
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
processed_history.append(label)
continue
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=processed_history,
text=text,
word_start=word_start,
word_end=word_end,
part2=part2,
pinyin_ids=pinyin_ids,
words=words,
)
batch_samples.append(sample)
processed_history.append(label)
# ========== Phase 2: 破词续接 ==========
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
cont_r = random.random()
cont_probs = self.cont_length_probs
cont_cumulative = 0.0
target_len = 4
for cont_len, cont_p in enumerate(cont_probs):
cont_cumulative += cont_p
if cont_r < cont_cumulative:
target_len = cont_len + 1
break
cont_positions = []
pos = cont_start
while len(cont_positions) < target_len and pos < len(text):
if self.query_engine.is_chinese_char(text[pos]):
cont_positions.append(pos)
else:
break
pos += 1
if not cont_positions:
continue
cont_text = "".join(text[i] for i in cont_positions)
cont_pinyin = [pinyin_list[i] for i in cont_positions]
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
r2 = random.random()
if r2 < 0.9:
split_char_cont = ""
elif r2 < 0.94:
split_char_cont = "`"
elif r2 < 0.98:
split_char_cont = "'"
else:
split_char_cont = "-"
part2_cont = split_char_cont.join(mask_pinyin_cont)
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
try:
cont_labels = [
self.query_engine.get_char_info_by_char_pinyin(
text[i], pinyin_list[i]
).id
for i in cont_positions
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
)
idx = merge_end_idx
continue
# 续接末尾 10% 概率追加 EOS
if random.random() <= 0.1:
cont_labels.append(0)
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
cont_processed_history = []
cont_end = cont_positions[-1] + 1
for label_idx, label in enumerate(cont_labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
cont_processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
cont_processed_history.append(label)
continue
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=cont_processed_history,
text=text,
word_start=cont_start,
word_end=cont_end,
part2=part2_cont,
pinyin_ids=pinyin_ids_cont,
words=words,
)
batch_samples.append(sample)
cont_processed_history.append(label)
idx = merge_end_idx
# 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size:
indices = np.random.permutation(len(batch_samples))
actual_retention = min(self.retention_size, len(batch_samples))
output_count = len(batch_samples) - actual_retention
for i in range(output_count):
if current_iter_index >= worker_quota:
batch_samples = []
return
yield batch_samples[indices[i]]
current_iter_index += 1
retained_samples = [
batch_samples[idx] for idx in indices[output_count:]
]
batch_samples = retained_samples
# 处理剩余的样本
if batch_samples:
indices = np.random.permutation(len(batch_samples))
for idx in indices:
if current_iter_index >= worker_quota:
return
yield batch_samples[idx]
current_iter_index += 1