import warnings warnings.filterwarnings("ignore", message=".*pkg_resources.*") import jieba import math import random from importlib.resources import files from pathlib import Path from typing import Dict, List, Optional, Set, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset from .query import QueryEngine CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID["`"] = 27 # 显式添加反引号 CHAR_TO_ID["'"] = 28 # 显式添加引号 CHAR_TO_ID["-"] = 29 # 显式添加短横 jieba.setLogLevel(jieba.logging.INFO) def segment_text(text: str) -> List[str]: """使用 jieba 分词,返回词列表""" return list(jieba.cut(text, HMM=False)) def build_word_boundaries(words: List[str]) -> List[Tuple[int, int]]: """建立词边界列表 [(start, end), ...],基于顺序位置累加""" result = [] pos = 0 for word in words: start = pos end = pos + len(word) result.append((start, end)) pos = end return result def text_to_pinyin_ids(pinyin_str: str) -> List[int]: """ 将拼音字符串转换为 ID 列表。 支持 a-z 和 `。 未知字符映射为 0 (PAD/UNK)。 """ # 使用 dict.get(key, default) 处理未知字符,默认返回 0 return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] class PinyinInputDataset(IterableDataset): def __init__( self, data_path: str, max_workers: int = -1, max_iter_length=1e6, max_seq_length=128, text_field: str = "text", py_style_weight=(9, 2, 1), shuffle_buffer_size: int = 100000, retention_ratio: float = 0.8, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, merge_short_words_prob: float = 0.5, merge_max_short_words: int = 3, merge_max_total_chars: int = 6, low_freq_repeat: float = 50.0, high_freq_repeat: float = 0.1, data_kwargs: Optional[Dict] = None, target_labels: Optional[Set[int]] = None, ): # 频率调整参数 - 幂律平滑方案 self.min_freq = 109 self.low_freq_repeat = low_freq_repeat self.high_freq_repeat = high_freq_repeat self.word_break_prob = 0.10 self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] self.merge_short_words_prob = merge_short_words_prob self.merge_max_short_words = merge_max_short_words self.merge_max_total_chars = merge_max_total_chars self.data_kwargs = data_kwargs or {} self.target_labels = target_labels jieba.initialize() self.tokenizer = AutoTokenizer.from_pretrained( Path(str(files(__package__))) / "assets" / "tokenizer" ) self.data_path = data_path self.max_iter_length = max_iter_length self.max_seq_length = max_seq_length self.text_field = text_field load_kwargs = {"split": "train", "streaming": True} load_kwargs.update(self.data_kwargs) self.dataset = load_dataset(data_path, **load_kwargs) self.max_workers = max_workers self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight) self.shuffle_buffer_size = shuffle_buffer_size self.retention_ratio = retention_ratio if not (0 < retention_ratio < 1): raise ValueError( f"retention_ratio必须在0和1之间,当前值: {retention_ratio}" ) self.retention_size = int(shuffle_buffer_size * retention_ratio) if self.retention_size <= 0: raise ValueError( f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})" ) self.possible_lengths = list(length_weights.keys()) self.weights = list(length_weights.values()) self.query_engine = QueryEngine() self.query_engine.load() # 提取每个样本的目标字符及其频率 self.sample_freqs = self.query_engine.get_all_weights() self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0 # 计算幂律平滑参数 if self.max_freq > self.min_freq: self.alpha = math.log( self.low_freq_repeat / self.high_freq_repeat ) / math.log(self.max_freq / self.min_freq) self.C = self.low_freq_repeat * (self.min_freq**self.alpha) else: self.alpha = 0.0 self.C = 1.0 def adjust_frequency(self, freq: int) -> int: """削峰填谷 - 根据频率调整采样次数,0表示丢弃 使用幂律平滑方案:E(freq) = C × freq^(-α) 保持频率排序关系,单个连续函数 """ if freq <= 0: return 0 # 计算期望采样次数 expected = self.C * (freq ** (-self.alpha)) # 采样策略 if expected >= 1.0: # 泊松分布重复 repeat_count = np.random.poisson(expected) return max(1, repeat_count) else: # 伯努利采样:以概率expected返回1,否则返回0 return 1 if random.random() < expected else 0 # 生成对应文本的拼音 def generate_pinyin(self, text: str) -> List[str]: """ 将文本转换为拼音列表。对整段文本调用 lazy_pinyin, 利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。 特性: 1. 严格一一对应:len(result) == len(text) 2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音 3. 非汉字字符原样占位 Args: text: 输入字符串 Returns: List[str]: 拼音或非汉字字符的列表 """ if not text: return [] def _fallback(chars): # lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入, # 必须逐字符处理,确保返回列表长度与输入字符数一致。 result = [] for char in chars: if self.query_engine.is_chinese_char(char): ids = self.query_engine.query_by_char(char, limit=1) if ids: result.append(ids[0][1]) else: result.append(char) else: result.append(char) return result pinyin_list = lazy_pinyin(text, errors=_fallback) # 防御性校验:若长度仍不匹配(极罕见),逐字回退 if len(pinyin_list) != len(text): logger.warning( f"pinyin length mismatch: text_len={len(text)}, " f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}" ) pinyin_list = [] for c in text: result = lazy_pinyin(c, errors=_fallback) pinyin_list.append(result[0] if result else c) return pinyin_list def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: # 整词统一拼音风格,避免多字词完整拼音概率指数衰减 style = random.random() cumulative = 0.0 style_idx = 0 for i, w in enumerate(self.py_style_weight): cumulative += w if style < cumulative: style_idx = i break mask_pinyin = [] for i in range(len(text)): if not self.query_engine.is_chinese_char(text[i]): break full_py = pinyin_list[i] if style_idx == 0: py = full_py elif style_idx == 1: py = to_initials(full_py) if py == "": py = full_py[0] else: py = full_py[0] mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def _compute_pinyin_ids(self, pinyin_str: str) -> torch.Tensor: pinyin_ids = text_to_pinyin_ids(pinyin_str) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] return torch.tensor(pinyin_ids, dtype=torch.long) def _build_single_sample( self, label: int, history: list, text: str, word_start: int, word_end: int, part2: str, pinyin_ids: torch.Tensor, words: list, ) -> dict: """构造单条样本,每次调用都会重新随机采样上下文""" # part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)] part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start) part1 = text[word_start - part1_len : word_start] # part3:每次重新 roll part3 = "" if random.random() > 0.7: part3 = text[word_end : word_end + random.randint(1, 16)] # part4:每次重新 roll part4 = "" if random.random() > 0.7 and words: num_words = random.randint(1, 3) selected_words = random.sample(words, min(num_words, len(words))) part4 = "|".join(selected_words) encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=self.max_seq_length, truncation=True, return_token_type_ids=True, ) # 确保 history 长度为 8 hist = list(history) if len(hist) > 8: hist = hist[-8:] while len(hist) < 8: hist.append(0) return { "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), "token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long), "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), "label": torch.tensor([label], dtype=torch.long), "history_slot_ids": torch.tensor(hist, dtype=torch.long), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = ( self.max_workers if self.max_workers > 0 else worker_info.num_workers ) base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42 seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) if worker_id >= num_workers: return try: worker_dataset = self.dataset.shard( num_shards=num_workers, index=worker_id ) except (IndexError, ValueError): worker_dataset = self.dataset total_quota = int(self.max_iter_length) base_quota = total_quota // num_workers remainder = total_quota % num_workers if worker_id == num_workers - 1: worker_quota = base_quota + remainder else: worker_quota = base_quota else: worker_quota = int(self.max_iter_length) num_workers = 1 worker_dataset = self.dataset current_iter_index = 0 batch_samples = [] for sample in worker_dataset: if current_iter_index >= worker_quota: break text = sample.get(self.text_field, "") if not text: continue words = segment_text(text) word_boundaries = build_word_boundaries(words) pinyin_list = self.generate_pinyin(text) idx = 0 while idx < len(word_boundaries): word_start, word_end = word_boundaries[idx] char_positions = [] for i in range(word_start, word_end): if self.query_engine.is_chinese_char(text[i]): char_positions.append(i) if not char_positions: idx += 1 continue word_len_chars = len(char_positions) merge_end_idx = idx + 1 if word_len_chars <= 2: accumulated_positions = list(char_positions) accumulated_count = 1 next_idx = idx + 1 while next_idx < len(word_boundaries): ns, ne = word_boundaries[next_idx] next_positions = [] for i in range(ns, ne): if self.query_engine.is_chinese_char(text[i]): next_positions.append(i) next_len = len(next_positions) if next_len == 0 or next_len > 2: break if ( len(accumulated_positions) + next_len > self.merge_max_total_chars ): break if accumulated_count + 1 > self.merge_max_short_words: break if random.random() > self.merge_short_words_prob: break accumulated_positions.extend(next_positions) accumulated_count += 1 next_idx += 1 if accumulated_count > 1: char_positions = accumulated_positions word_len_chars = len(char_positions) merge_end_idx = next_idx word_start = word_boundaries[idx][0] word_end = word_boundaries[next_idx - 1][1] should_break = ( word_len_chars > 1 and random.random() < self.word_break_prob ) if should_break: break_pos = random.randint(1, word_len_chars - 1) else: break_pos = word_len_chars # ========== Phase 1: 前缀/整词 ========== prefix_positions = char_positions[:break_pos] prefix_text = "".join(text[i] for i in prefix_positions) prefix_pinyin = [pinyin_list[i] for i in prefix_positions] _, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin) r = random.random() if r < 0.9: split_char = "" elif r < 0.94: split_char = "`" elif r < 0.98: split_char = "'" else: split_char = "-" part2 = split_char.join(mask_pinyin) pinyin_ids = self._compute_pinyin_ids(part2) try: labels = [ self.query_engine.get_char_info_by_char_pinyin( text[i], pinyin_list[i] ).id for i in prefix_positions ] except AttributeError as e: logger.error( f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}" ) idx = merge_end_idx continue # 整词末尾 10% 概率追加 EOS(破词前缀不加) if not should_break and random.random() <= 0.1: labels.append(0) # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文 processed_history = [] for label_idx, label in enumerate(labels): base_repeats = self.adjust_frequency( self.sample_freqs.get(label, 0) ) if base_repeats == 0: processed_history.append(label) continue if ( self.target_labels is not None and label not in self.target_labels ): processed_history.append(label) continue weight = ( self._history_weights[label_idx] if label_idx < len(self._history_weights) else 3.0 ) repeats = max(1, int(base_repeats * weight)) for _ in range(repeats): sample = self._build_single_sample( label=label, history=processed_history, text=text, word_start=word_start, word_end=word_end, part2=part2, pinyin_ids=pinyin_ids, words=words, ) batch_samples.append(sample) processed_history.append(label) # ========== Phase 2: 破词续接 ========== if should_break and break_pos < word_len_chars: cont_start = char_positions[break_pos] # 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止 cont_r = random.random() cont_probs = self.cont_length_probs cont_cumulative = 0.0 target_len = 4 for cont_len, cont_p in enumerate(cont_probs): cont_cumulative += cont_p if cont_r < cont_cumulative: target_len = cont_len + 1 break cont_positions = [] pos = cont_start while len(cont_positions) < target_len and pos < len(text): if self.query_engine.is_chinese_char(text[pos]): cont_positions.append(pos) else: break pos += 1 if not cont_positions: continue cont_text = "".join(text[i] for i in cont_positions) cont_pinyin = [pinyin_list[i] for i in cont_positions] _, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin) r2 = random.random() if r2 < 0.9: split_char_cont = "" elif r2 < 0.94: split_char_cont = "`" elif r2 < 0.98: split_char_cont = "'" else: split_char_cont = "-" part2_cont = split_char_cont.join(mask_pinyin_cont) pinyin_ids_cont = self._compute_pinyin_ids(part2_cont) try: cont_labels = [ self.query_engine.get_char_info_by_char_pinyin( text[i], pinyin_list[i] ).id for i in cont_positions ] except AttributeError as e: logger.error( f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}" ) idx = merge_end_idx continue # 续接末尾 10% 概率追加 EOS if random.random() <= 0.1: cont_labels.append(0) # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文 cont_processed_history = [] cont_end = cont_positions[-1] + 1 for label_idx, label in enumerate(cont_labels): base_repeats = self.adjust_frequency( self.sample_freqs.get(label, 0) ) if base_repeats == 0: cont_processed_history.append(label) continue if ( self.target_labels is not None and label not in self.target_labels ): cont_processed_history.append(label) continue weight = ( self._history_weights[label_idx] if label_idx < len(self._history_weights) else 3.0 ) repeats = max(1, int(base_repeats * weight)) for _ in range(repeats): sample = self._build_single_sample( label=label, history=cont_processed_history, text=text, word_start=cont_start, word_end=cont_end, part2=part2_cont, pinyin_ids=pinyin_ids_cont, words=words, ) batch_samples.append(sample) cont_processed_history.append(label) idx = merge_end_idx # 处理shuffle buffer - 单缓冲区半保留方案 if len(batch_samples) >= self.shuffle_buffer_size: indices = np.random.permutation(len(batch_samples)) actual_retention = min(self.retention_size, len(batch_samples)) output_count = len(batch_samples) - actual_retention for i in range(output_count): if current_iter_index >= worker_quota: batch_samples = [] return yield batch_samples[indices[i]] current_iter_index += 1 retained_samples = [ batch_samples[idx] for idx in indices[output_count:] ] batch_samples = retained_samples # 处理剩余的样本 if batch_samples: indices = np.random.permutation(len(batch_samples)) for idx in indices: if current_iter_index >= worker_quota: return yield batch_samples[idx] current_iter_index += 1