import random import re from importlib.resources import files from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset from .query import QueryEngine _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID["`"] = 27 # 显式添加反引号 CHAR_TO_ID["'"] = 28 # 显式添加引号 CHAR_TO_ID["-"] = 29 # 显式添加短横 def text_to_pinyin_ids(pinyin_str: str) -> List[int]: """ 将拼音字符串转换为 ID 列表。 支持 a-z 和 `。 未知字符映射为 0 (PAD/UNK)。 """ # 使用 dict.get(key, default) 处理未知字符,默认返回 0 return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] class PinyinInputDataset(IterableDataset): def __init__( self, data_path: str, max_workers: int = -1, max_iter_length=1e6, max_seq_length=128, text_field: str = "text", py_style_weight=(9, 2, 1), shuffle_buffer_size: int = 100000, retention_ratio: float = 0.5, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ): # 频率调整参数 (可根据需要调整) self.drop_start_freq = 30_000_000 self.max_drop_prob = 0.8 self.repeat_end_freq = 10_000 self.max_repeat_expect = 50 self.min_freq = 109 self.tokenizer = AutoTokenizer.from_pretrained( Path(str(files(__package__))) / "assets" / "tokenizer" ) self.data_path = data_path self.max_iter_length = max_iter_length self.max_seq_length = max_seq_length self.text_field = text_field self.dataset = load_dataset(data_path, split="train", streaming=True) self.max_workers = max_workers self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight) self.shuffle_buffer_size = shuffle_buffer_size self.retention_ratio = retention_ratio if not (0 < retention_ratio < 1): raise ValueError( f"retention_ratio必须在0和1之间,当前值: {retention_ratio}" ) self.retention_size = int(shuffle_buffer_size * retention_ratio) if self.retention_size <= 0: raise ValueError( f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})" ) self.possible_lengths = list(length_weights.keys()) self.weights = list(length_weights.values()) self.query_engine = QueryEngine() self.query_engine.load() # 提取每个样本的目标字符及其频率 self.sample_freqs = self.query_engine.get_all_weights() def adjust_frequency(self, freq: int) -> int: """削峰填谷 - 根据频率调整采样次数,0表示丢弃""" # 1. 削峰处理(高频字) if freq >= self.drop_start_freq: # 线性丢弃概率计算 max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值 if max_freq == self.drop_start_freq: drop_prob = 0.0 else: drop_prob = ( self.max_drop_prob * (freq - self.drop_start_freq) / (max_freq - self.drop_start_freq) ) if random.random() < drop_prob: return 0 else: return 1 # 2. 填谷处理(低频字) elif freq <= self.repeat_end_freq: # 线性重复期望计算 if freq <= self.min_freq: repeat_expect = self.max_repeat_expect else: if self.repeat_end_freq == self.min_freq: repeat_expect = 0 else: repeat_expect = ( self.max_repeat_expect * (self.repeat_end_freq - freq) / (self.repeat_end_freq - self.min_freq) ) # 使用泊松分布实现随机重复 repeat_count = np.random.poisson(repeat_expect) return max(1, repeat_count) # 3. 中间频率字 else: return 1 # 生成对应文本的拼音 def generate_pinyin(self, text: str) -> List[str]: """ 流式处理单条文本,转换为拼音列表。 特性: 1. 严格一一对应:len(result) == len(text) 2. 高多音字准确率:利用 pypinyin 内部的词语分词能力 3. 高性能:预分配内存,无多余对象创建 Args: text: 输入字符串 Returns: List[str]: 拼音或非汉字字符的列表 """ if not text: return [] text_len = len(text) # 2. 预分配结果列表,初始化占位符。 # 使用 None 或空字符串均可,这里用空字符串方便后续判断 result: List[str] = [""] * text_len # 3. 遍历所有连续汉字片段 for match in _HANZI_RE.finditer(text): start_idx = match.start() hanzi_segment = match.group() # 4. 核心转换:利用 pypinyin 的分词能力处理该片段 # style=Style.NORMAL 获取不带声调的拼音 pinyin_list = lazy_pinyin(hanzi_segment) # 5. 健壮性兜底: # 正常情况下,pypinyin 返回的拼音数应等于汉字数。 # 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换 if len(pinyin_list) != len(hanzi_segment): pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] # 6. 直接通过索引填充到预分配的位置 # 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观 for i, py in enumerate(pinyin_list): result[start_idx + i] = py # 7. 填充非汉字字符 # 遍历原文,如果 result 对应位置为空,则填入原字符 # 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的 for i, char in enumerate(text): if not result[i]: result[i] = char return result # 生成需要预测汉字对应的拼音,并进行加强 def get_mask_pinyin( self, text: str, pinyin_list: List[str] ) -> Tuple[int, List[str]]: mask_pinyin = [] for i in range(len(text)): if not self.query_engine.is_chinese_char(text[i]): break else: py = np.random.choice( (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), p=self.py_style_weight, ) if py == "": py = pinyin_list[i][0] mask_pinyin.append(py) return len(mask_pinyin), mask_pinyin def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id num_workers = ( self.max_workers if self.max_workers > 0 else worker_info.num_workers ) base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42 seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) # 安全检查:如果worker_id >= num_workers,则该worker不应该工作 # 这可能发生在self.max_workers小于实际worker数量时 if worker_id >= num_workers: return # 产生空迭代器 # 使用局部变量存储分片数据集,避免竞争条件 worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id) # 计算每个worker的配额 # 将 max_iter_length 转换为整数以确保整数除法 total_quota = int(self.max_iter_length) base_quota = total_quota // num_workers remainder = total_quota % num_workers # 最后一个worker处理剩余的样本(如果有余数) if worker_id == num_workers - 1: worker_quota = base_quota + remainder else: worker_quota = base_quota else: # 单worker情况,使用全部配额 worker_quota = int(self.max_iter_length) num_workers = 1 worker_dataset = self.dataset # 不使用分片 # 每个worker有自己的迭代计数器 current_iter_index = 0 batch_samples = [] for sample in worker_dataset: # 检查是否达到最大迭代次数 if current_iter_index >= worker_quota: break text = sample.get(self.text_field, "") if text: pinyin_list = self.generate_pinyin(text) for i in range(len(text)): # 在开始处理每个字符前检查配额 if current_iter_index >= worker_quota: break labels = [] # 如果text[i]不在字符库中,则跳过 # 当i小于48时候,则将part1取text[0:i] # 当i大于48时候,则将part1取text[i-48:i] if not self.query_engine.is_chinese_char(text[i]): continue if i < 48: part1 = text[0:i] else: part1 = text[i - 48 : i] # 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3 # 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len # part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2 # 但是需要注意边界条件 pinyin_len = np.random.choice( range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02] ) py_end = min(i + pinyin_len, len(text)) pinyin_len, part2 = self.get_mask_pinyin( text[i:py_end], pinyin_list[i:py_end] ) split_char = np.random.choice( ["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02] ) part2 = split_char.join(part2) pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: pinyin_ids.extend([0] * (24 - len_py)) else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long) # part3为文本,大概率(0.70)为空 # 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符 # x为1-16中的任意整数,取值平均分布 part3 = "" if random.random() > 0.7: part3 = text[ i + pinyin_len : i + pinyin_len + np.random.choice(range(1, 17)) ] # part4为文本,0.50的概率为空 # 不为空则为1-5个连续字符串 # 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符 # x为2-6中的任意整数,取值平均分布 # 使用|将part4中的字符串连接起来 part4 = "" if random.random() > 0.5: # 生成1-5个连续字符串 num_strings = random.randint(1, 5) string_list = [] for _ in range(num_strings): # 随机选择起始位置 start_pos = random.randint(0, len(text) - 1) # 随机选择x的值(2-6) x = random.randint(2, 6) # 获取连续字符串 end_pos = min(start_pos + x + 1, len(text)) string_list.append(text[start_pos:end_pos]) # 用|连接所有字符串 part4 = "|".join(string_list) try: labels = [ self.query_engine.get_char_info_by_char_pinyin(c, p).id for c, p in zip( text[i : i + pinyin_len], pinyin_list[i : i + pinyin_len], ) ] except AttributeError as e: logger.error( f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}" ) continue if random.random() <= 0.1: labels.append(0) # 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致) history_slot_list = [] for j in range(i - 1, max(-1, i - 100), -1): if j < 0: break char = text[j] if self.query_engine.is_chinese_char(char): try: results = self.query_engine.query_by_char(char, limit=1) if results and results[0][0] > 0: history_slot_list.append(results[0][0]) except Exception: pass if len(history_slot_list) >= 8: break encoded = self.tokenizer( f"{part4}|{part1}", part3, max_length=self.max_seq_length, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) samples = [] # 修复变量名冲突:将内层循环变量i重命名为label_idx for label_idx, label in enumerate(labels): repeats = self.adjust_frequency(label) # 使用从text[0:i]提取的历史槽位(与eval.py一致) masked_labels = history_slot_list[:] len_l = len(masked_labels) masked_labels.extend([0] * (8 - len_l)) samples.extend( [ { "input_ids": encoded["input_ids"], "token_type_ids": encoded["token_type_ids"], "attention_mask": encoded["attention_mask"], "label": torch.tensor([label], dtype=torch.long), "history_slot_ids": torch.tensor( masked_labels, dtype=torch.long ), "prefix": f"{part4}^{part1}", "suffix": part3, "pinyin": part2, "pinyin_ids": pinyin_ids, } ] * repeats ) # 添加到缓冲区 batch_samples.extend(samples) # 处理shuffle buffer - 单缓冲区半保留方案 if len(batch_samples) >= self.shuffle_buffer_size: # 全量打乱缓冲区 indices = np.random.permutation(len(batch_samples)) # 计算实际保留大小(不超过缓冲区大小) actual_retention = min(self.retention_size, len(batch_samples)) # 计算输出数量 output_count = len(batch_samples) - actual_retention # 输出前output_count个样本 for i in range(output_count): if current_iter_index >= worker_quota: # 配额用完,清空缓冲区并返回 batch_samples = [] return yield batch_samples[indices[i]] current_iter_index += 1 # 保留后actual_retention个样本(不清空缓冲区) retained_samples = [ batch_samples[idx] for idx in indices[output_count:] ] batch_samples = retained_samples # 处理剩余的样本 if batch_samples: indices = np.random.permutation(len(batch_samples)) for idx in indices: if current_iter_index >= worker_quota: return yield batch_samples[idx] current_iter_index += 1