import os import random from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from datasets import load_dataset from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from torch.utils.data import DataLoader, IterableDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" PG = { "r": 0, "l": 0, "p": 1, "d": 1, "h": 2, "f": 2, "g": 3, "m": 3, "z": 4, "o": 4, "t": 5, "q": 5, "b": 6, "w": 6, "j": 7, "e": 7, "k": 8, "c": 8, "s": 9, "a": 9, "n": 10, "x": 10, "y": 11, } class PinyinInputDataset(IterableDataset): """ 拼音输入法模拟数据集 特性: 1. 流式读取数据集,内存友好 2. 实时拼音转换和多音字处理 3. 前文上下文多种采样方式 4. 拼音截断模拟不完整输入 5. 内置削峰填谷算法平衡数据分布 6. 缓冲区打乱支持多进程 """ def __init__( self, data_dir: str, query_engine, tokenizer_name: str = "iic/nlp_structbert_backbone_lite_std", max_len: int = 88, text_field: str = "text", batch_query_size: int = 1000, # 打乱参数 shuffle: bool = True, shuffle_buffer_size: int = 10000, # 削峰填谷参数 max_freq: int = 434748359, # "的"的频率 min_freq: int = 109, # "蓚"的频率 drop_start_freq: int = 30000000, # 开始丢弃的阈值 repeat_end_freq: int = 10000, # 开始重复的阈值 max_drop_prob: float = 0.8, # 最大丢弃概率 max_repeat_expect: float = 50.0, # 最大重复期望 sample_context_section=[0.90, 0.95, 1], drop_py_rate: float = 0, ): """ 初始化数据集 Args: data_dir: 数据集目录 query_engine: QueryEngine实例 tokenizer_name: tokenizer名称 max_len: 最大序列长度 text_field: 文本字段名 batch_query_size: 批量查询大小 shuffle: 是否打乱数据 shuffle_buffer_size: 打乱缓冲区大小 max_freq: 最大字符频率 min_freq: 最小字符频率 drop_start_freq: 开始削峰的频率阈值 repeat_end_freq: 开始填谷的频率阈值 max_drop_prob: 最高频率字符的丢弃概率 max_repeat_expect: 最低频率字符的重复期望 """ self.query_engine = query_engine self.max_len = max_len self.text_field = text_field self.batch_query_size = batch_query_size # 打乱相关参数 self.shuffle = shuffle self.shuffle_buffer_size = shuffle_buffer_size # 削峰填谷参数 self.max_freq = max_freq self.min_freq = min_freq self.drop_start_freq = drop_start_freq self.repeat_end_freq = repeat_end_freq self.max_drop_prob = max_drop_prob self.max_repeat_expect = max_repeat_expect # 加载tokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # 获取总字频用于后续计算 stats = query_engine.get_statistics() self.total_chars = stats.get("valid_input_character_count", 0) # 缓存字典 self.char_info_cache = {} # 加载数据集 self.dataset = load_dataset(data_dir, split="train", streaming=True) # 加载拼音分组 self.pg_groups = PG # 上下文采样方式概率区间 self.sample_context_section = sample_context_section self.drop_py_rate = drop_py_rate def get_next_chinese_chars( self, text: str, start_idx: int, max_count: int = 3, pinyin_list: List[str] = None, ) -> List[Tuple[str, str]]: """ 获取从指定起始位置开始的后续中文字符及其对应的拼音。 Args: text (str): 完整的输入文本。 start_idx (int): 开始搜索的索引位置。 max_count (int, optional): 最多返回的中文字符数量,默认为3。 pinyin_list (List[str], optional): 预先计算好的拼音列表,用于提高效率。如果未提供,则会动态计算。 Returns: List[Tuple[str, str]]: 返回一个列表,每个元素是一个元组,包含中文字符及其对应的拼音。 """ result = [] count = 0 # 遍历从起始位置之后的字符 for i in range(start_idx + 1, len(text)): if count >= max_count: break char = text[i] # 判断当前字符是否为中文字符 if self.query_engine.is_chinese_char(char): # 获取拼音信息 try: # 如果没有提供拼音列表,则动态计算整个文本的拼音 if pinyin_list is None: pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) # 确保索引在拼音列表范围内,并将字符和拼音加入结果 if i < len(pinyin_list): result.append((char, pinyin_list[i])) count += 1 except Exception: # 发生异常时终止循环 break else: # 当前字符不是中文,跳过 continue return result def sample_context(self, context: str) -> str: """ 三种方式采样前文上下文 Args: context: 原始前文(最多100字符) Returns: 采样后的54个字符 """ if not context: return "" # 确保有足够长度 context_len = len(context) # 随机选择采样方式 choice = random.random() if choice < self.sample_context_section[0]: # 方式1: 靠近汉字的54个字符 return context[-54:] if context_len >= 54 else context elif choice < self.sample_context_section[1]: # 方式2: 随机位置取46个连续字符 if context_len <= 46: return context start = random.randint(0, context_len - 46) return context[start : start + 46] else: # 方式3: 12+6×7组合 if context_len < 12: return context # 最后12个字符 last_12 = context[-12:] # 从剩下的前88个字符中随机取6段,每段7个字符 remaining = context[:-12] if context_len > 12 else "" remaining_len = len(remaining) if remaining_len < 7: # 如果不够7个字符,直接返回最后12个字符 return last_12 segments = [] for _ in range(6): if remaining_len < 7: break start = random.randint(0, remaining_len - 7) segment = remaining[start : start + 7] segments.append(segment) # 拼接 combined = "".join(segments) result = combined + last_12 # 确保总长度为54(可能不足) if len(result) < 54: # 如果不够,从前面补一些字符 needed = 54 - len(result) if context_len >= needed: result = context[:needed] + result # 截断到54字符 return result[:54] def truncate_pinyin(self, pinyin: str) -> str: """ 截断拼音 Args: pinyin: 原始拼音 Returns: 截断后的拼音,可能为空字符串 """ if not pinyin: return "" # 随机决定截断方式 rand_val = random.random() if rand_val < 0.1: # 10%概率截断为空 return "" elif rand_val < 0.9: # 80%概率不截断 return pinyin else: # 10%概率随机截断 # 均匀分配剩余概率给各种截断长度 max_len = len(pinyin) if max_len <= 1: return pinyin # 随机选择截断长度 (1 到 max_len-1) trunc_len = random.randint(1, max_len - 1) return pinyin[:trunc_len] def process_pinyin_sequence(self, pinyin_list: List[str]) -> str: """ 处理拼音序列,逐个截断并拼接 Args: pinyin_list: 拼音列表,长度1-4 Returns: 拼接后的拼音字符串 """ result_parts = [] for pinyin in pinyin_list: truncated = self.truncate_pinyin(pinyin) if not truncated: # 如果某个拼音截断为空,则停止 break result_parts.append(truncated) if not result_parts: return "" result = "".join(result_parts) # 限制最大长度 if len(result) > 18: result = result[:18] return result def adjust_frequency(self, freq: int) -> int: """ 削峰填谷 - 根据频率调整采样 Args: freq: 当前字符频率 Returns: 调整后的采样次数,0表示丢弃 """ # 1. 削峰处理(高频字,>= 3000W开始丢弃) if freq >= self.drop_start_freq: # 线性丢弃概率:3000W时丢弃概率为0,434748359时丢弃概率为0.8 # 使用线性插值计算丢弃概率 if self.max_freq == self.drop_start_freq: drop_prob = 0.0 # 防止除零 else: drop_prob = ( self.max_drop_prob * (freq - self.drop_start_freq) / (self.max_freq - self.drop_start_freq) ) # 根据丢弃概率决定是否保留 if random.random() < drop_prob: return 0 # 丢弃该样本 else: return 1 # 保留,但不重复 # 2. 填谷处理(低频字,<= 1W开始重复) elif freq <= self.repeat_end_freq: # 线性重复期望:1W时重复期望为0,109时重复期望为50 # 使用线性插值计算期望重复次数 if freq <= self.min_freq: repeat_expect = self.max_repeat_expect # 最低频字重复期望为50 else: if self.repeat_end_freq == self.min_freq: repeat_expect = 0 # 防止除零 else: # 线性插值公式 repeat_expect = ( self.max_repeat_expect * (self.repeat_end_freq - freq) / (self.repeat_end_freq - self.min_freq) ) # 期望重复次数转换为实际重复次数 # 使用泊松分布实现期望重复,确保有随机性 repeat_count = np.random.poisson(repeat_expect) # 确保至少返回1次 return max(1, repeat_count) # 3. 中等频率字(1W < freq < 3000W) else: return 1 # 保持原样 def batch_get_char_info( self, char_pinyin_pairs: List[Tuple[str, str]] ) -> Dict[Tuple[str, str], Any]: """ 批量获取字符信息 Args: char_pinyin_pairs: [(字符, 拼音), ...] Returns: 字典,key为(字符, 拼音),value为(id, 频率)或None """ results = {} # 先检查缓存 uncached_pairs = [] for pair in char_pinyin_pairs: if pair in self.char_info_cache: results[pair] = self.char_info_cache[pair] else: uncached_pairs.append(pair) # 批量查询未缓存的 if uncached_pairs: # 使用query_engine批量查询 char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs) for pair, char_info in char_infos.items(): if char_info: info = { "id": char_info.id, "freq": char_info.count, "char": char_info.char, "pinyin": char_info.pinyin, } else: info = None results[pair] = info self.char_info_cache[pair] = info return results def _process_batch(self, char_pinyin_batch, char_positions, text): """处理批量字符""" # 批量查询字符信息 char_info_map = self.batch_get_char_info(char_pinyin_batch) batch_samples = [] for pos_info in char_positions: char = pos_info["char"] pinyin = pos_info["pinyin"] next_pinyins = pos_info["next_pinyins"] context = pos_info["context"] # 获取字符信息 char_info = char_info_map.get((char, pinyin)) if not char_info: continue # 削峰填谷调整 adjust_factor = self.adjust_frequency(char_info["freq"]) if adjust_factor <= 0: continue # 前文采样 sampled_context = self.sample_context(context) # 拼音处理 processed_pinyin = self.process_pinyin_sequence(next_pinyins) pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12 prob = random.random() if prob < self.drop_py_rate: py = "" else: py = processed_pinyin # Tokenize hint = self.tokenizer( sampled_context, py, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt", return_token_type_ids=True, ) # 生成样本 sample = { "hint": hint, "txt": sampled_context, "py": py, "char_id": torch.tensor([char_info["id"]]), "char": char, "freq": char_info["freq"], "pg": torch.tensor([pg]), } # 根据调整因子重复样本 for _ in range(adjust_factor): batch_samples.append(sample) return batch_samples def _shuffle_and_yield(self, batch_samples): """优化打乱逻辑""" if not self.shuffle: yield from batch_samples return # 使用numpy批量操作代替random.shuffle if batch_samples: indices = np.random.permutation(len(batch_samples)) for idx in indices: yield batch_samples[idx] def __iter__(self): """ 迭代器实现,支持多进程 返回: 生成器,每次返回一个样本 """ # 获取worker信息,为每个worker设置不同的随机种子 worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker_id = worker_info.id # 使用base_seed + worker_id确保每个worker有不同但确定的随机序列 base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42 seed = base_seed + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) batch_samples = [] for item in self.dataset: text = item.get(self.text_field, "") if not text: continue # 转换为拼音列表 pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x]) # 批量收集需要查询的字符信息 chinese_positions = [ i for i, char in enumerate(text) if self.query_engine.is_chinese_char(char) ] # 批量收集需要查询的字符信息 char_pinyin_batch = [] char_positions = [] for i in chinese_positions: # 只遍历中文字符位置 char = text[i] py = pinyin_list[i] # 获取后续最多3个中文字符的拼音 next_chars = self.get_next_chinese_chars( text, i, max_count=3, pinyin_list=pinyin_list ) next_pinyins = [py] + [p for _, p in next_chars] # 获取前文上下文(最多100字符) context = text[max(0, i - 100) : i] # 收集信息用于批量查询 char_pinyin_batch.append((char, py)) char_positions.append( { "index": i, "char": char, "pinyin": py, "next_pinyins": next_pinyins, "context": context, "next_chars": next_chars, } ) # 达到批量大小时处理 if len(char_pinyin_batch) >= self.batch_query_size: batch_samples += self._process_batch( char_pinyin_batch, char_positions, text ) char_pinyin_batch = [] char_positions = [] if len(batch_samples) >= self.shuffle_buffer_size: # logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}") yield from self._shuffle_and_yield(batch_samples) batch_samples = [] # 处理剩余的字符 if char_pinyin_batch: batch_samples += self._process_batch( char_pinyin_batch, char_positions, text ) yield from self._shuffle_and_yield(batch_samples) def __len__(self): """ 由于是流式数据集,无法预先知道长度 返回: 返回一个估计值或-1 """ return -1 # 辅助函数,用于DataLoader def worker_init_fn(worker_id): """DataLoader worker初始化函数""" # 设置每个worker的随机种子 seed = torch.initial_seed() + worker_id random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) torch.manual_seed(seed % (2**32)) def custom_collate_with_txt(batch): """自定义批处理函数""" if not batch: return {} # 处理hint字段 hints = [item["hint"] for item in batch] # 合并所有张量字段 result = { "hint": { "input_ids": torch.cat([h["input_ids"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]), "token_type_ids": torch.cat([h["token_type_ids"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), "char": [item["char"] for item in batch], "txt": [item["txt"] for item in batch], "py": [item["py"] for item in batch], "pg": torch.cat([item["pg"] for item in batch]), } return result def custom_collate(batch): """自定义批处理函数""" if not batch: return {} # 处理hint字段 hints = [item["hint"] for item in batch] # 合并所有张量字段 result = { "hint": { "input_ids": torch.cat([h["input_ids"] for h in hints]), "attention_mask": torch.cat([h["attention_mask"] for h in hints]), }, "char_id": torch.cat([item["char_id"] for item in batch]), "pg": torch.cat([item["pg"] for item in batch]), } return result