SUInput/src/suinput/dataset.py

624 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from torch.utils.data import DataLoader, IterableDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PG = {
"y": 0,
"k": 0,
"e": 0,
"l": 1,
"w": 1,
"f": 1,
"q": 2,
"a": 2,
"s": 2,
"x": 3,
"b": 3,
"r": 3,
"o": 4,
"m": 4,
"z": 4,
"g": 5,
"n": 5,
"c": 5,
"t": 6,
"p": 6,
"d": 6,
"j": 7,
"h": 7,
}
class PinyinInputDataset(IterableDataset):
"""
拼音输入法模拟数据集
特性:
1. 流式读取数据集,内存友好
2. 实时拼音转换和多音字处理
3. 前文上下文多种采样方式
4. 拼音截断模拟不完整输入
5. 内置削峰填谷算法平衡数据分布
6. 缓冲区打乱支持多进程
"""
def __init__(
self,
data_dir: str,
query_engine,
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
max_len: int = 88,
text_field: str = "text",
batch_query_size: int = 1000,
# 打乱参数
shuffle: bool = True,
shuffle_buffer_size: int = 10000,
# 削峰填谷参数
max_freq: int = 434748359, # "的"的频率
min_freq: int = 109, # "蓚"的频率
drop_start_freq: int = 30000000, # 开始丢弃的阈值
repeat_end_freq: int = 10000, # 开始重复的阈值
max_drop_prob: float = 0.8, # 最大丢弃概率
max_repeat_expect: float = 50.0, # 最大重复期望
sample_context_section = [0.90, 0.95, 1]
):
"""
初始化数据集
Args:
data_dir: 数据集目录
query_engine: QueryEngine实例
tokenizer_name: tokenizer名称
max_len: 最大序列长度
text_field: 文本字段名
batch_query_size: 批量查询大小
shuffle: 是否打乱数据
shuffle_buffer_size: 打乱缓冲区大小
max_freq: 最大字符频率
min_freq: 最小字符频率
drop_start_freq: 开始削峰的频率阈值
repeat_end_freq: 开始填谷的频率阈值
max_drop_prob: 最高频率字符的丢弃概率
max_repeat_expect: 最低频率字符的重复期望
"""
self.query_engine = query_engine
self.max_len = max_len
self.text_field = text_field
self.batch_query_size = batch_query_size
# 打乱相关参数
self.shuffle = shuffle
self.shuffle_buffer_size = shuffle_buffer_size
# 削峰填谷参数
self.max_freq = max_freq
self.min_freq = min_freq
self.drop_start_freq = drop_start_freq
self.repeat_end_freq = repeat_end_freq
self.max_drop_prob = max_drop_prob
self.max_repeat_expect = max_repeat_expect
# 加载tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# 获取总字频用于后续计算
stats = query_engine.get_statistics()
self.total_chars = stats.get("valid_input_character_count", 0)
# 缓存字典
self.char_info_cache = {}
# 加载数据集
self.dataset = load_dataset(data_dir, split="train", streaming=True)
# 加载拼音分组
self.pg_groups = PG
# 上下文采样方式概率区间
self.sample_context_section = sample_context_section
def get_next_chinese_chars(
self,
text: str,
start_idx: int,
max_count: int = 3,
pinyin_list: List[str] = None,
) -> List[Tuple[str, str]]:
"""
获取后续的中文字符及其拼音
Args:
text: 完整文本
start_idx: 起始位置
max_count: 最大字符数
Returns:
列表,每个元素为(字符, 拼音)
"""
result = []
count = 0
for i in range(start_idx + 1, len(text)):
if count >= max_count:
break
char = text[i]
if self.query_engine.is_chinese_char(char):
# 获取拼音注意这里需要确保拼音列表长度与text一致
try:
# 重新计算整个text的拼音可能效率低但确保准确
# 实际实现中可以考虑缓存或优化
if pinyin_list is None:
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
if i < len(pinyin_list):
result.append((char, pinyin_list[i]))
count += 1
except Exception:
break
else:
# 非汉字,继续查找
continue
return result
def sample_context(self, context: str) -> str:
"""
三种方式采样前文上下文
Args:
context: 原始前文最多100字符
Returns:
采样后的54个字符
"""
if not context:
return ""
# 确保有足够长度
context_len = len(context)
# 随机选择采样方式
choice = random.random()
if choice < self.sample_context_section[0]:
# 方式1: 靠近汉字的54个字符
return context[-54:] if context_len >= 54 else context
elif choice < self.sample_context_section[1]:
# 方式2: 随机位置取46个连续字符
if context_len <= 46:
return context
start = random.randint(0, context_len - 46)
return context[start : start + 46]
else:
# 方式3: 12+6×7组合
if context_len < 12:
return context
# 最后12个字符
last_12 = context[-12:]
# 从剩下的前88个字符中随机取6段每段7个字符
remaining = context[:-12] if context_len > 12 else ""
remaining_len = len(remaining)
if remaining_len < 7:
# 如果不够7个字符直接返回最后12个字符
return last_12
segments = []
for _ in range(6):
if remaining_len < 7:
break
start = random.randint(0, remaining_len - 7)
segment = remaining[start : start + 7]
segments.append(segment)
# 拼接
combined = "".join(segments)
result = combined + last_12
# 确保总长度为54可能不足
if len(result) < 54:
# 如果不够,从前面补一些字符
needed = 54 - len(result)
if context_len >= needed:
result = context[:needed] + result
# 截断到54字符
return result[:54]
def truncate_pinyin(self, pinyin: str) -> str:
"""
截断拼音
Args:
pinyin: 原始拼音
Returns:
截断后的拼音,可能为空字符串
"""
if not pinyin:
return ""
# 随机决定截断方式
rand_val = random.random()
if rand_val < 0.1:
# 10%概率截断为空
return ""
elif rand_val < 0.9:
# 80%概率不截断
return pinyin
else:
# 10%概率随机截断
# 均匀分配剩余概率给各种截断长度
max_len = len(pinyin)
if max_len <= 1:
return pinyin
# 随机选择截断长度 (1 到 max_len-1)
trunc_len = random.randint(1, max_len - 1)
return pinyin[:trunc_len]
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
"""
处理拼音序列,逐个截断并拼接
Args:
pinyin_list: 拼音列表长度1-4
Returns:
拼接后的拼音字符串
"""
result_parts = []
for pinyin in pinyin_list:
truncated = self.truncate_pinyin(pinyin)
if not truncated:
# 如果某个拼音截断为空,则停止
break
result_parts.append(truncated)
if not result_parts:
return ""
result = "".join(result_parts)
# 限制最大长度
if len(result) > 18:
result = result[:18]
return result
def adjust_frequency(self, freq: int) -> int:
"""
削峰填谷 - 根据频率调整采样
Args:
freq: 当前字符频率
Returns:
调整后的采样次数0表示丢弃
"""
# 1. 削峰处理(高频字,>= 3000W开始丢弃
if freq >= self.drop_start_freq:
# 线性丢弃概率3000W时丢弃概率为0434748359时丢弃概率为0.8
# 使用线性插值计算丢弃概率
if self.max_freq == self.drop_start_freq:
drop_prob = 0.0 # 防止除零
else:
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (self.max_freq - self.drop_start_freq)
)
# 根据丢弃概率决定是否保留
if random.random() < drop_prob:
return 0 # 丢弃该样本
else:
return 1 # 保留,但不重复
# 2. 填谷处理(低频字,<= 1W开始重复
elif freq <= self.repeat_end_freq:
# 线性重复期望1W时重复期望为0109时重复期望为50
# 使用线性插值计算期望重复次数
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect # 最低频字重复期望为50
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0 # 防止除零
else:
# 线性插值公式
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 期望重复次数转换为实际重复次数
# 使用泊松分布实现期望重复,确保有随机性
repeat_count = np.random.poisson(repeat_expect)
# 确保至少返回1次
return max(1, repeat_count)
# 3. 中等频率字1W < freq < 3000W
else:
return 1 # 保持原样
def batch_get_char_info(
self, char_pinyin_pairs: List[Tuple[str, str]]
) -> Dict[Tuple[str, str], Any]:
"""
批量获取字符信息
Args:
char_pinyin_pairs: [(字符, 拼音), ...]
Returns:
字典key为(字符, 拼音)value为(id, 频率)或None
"""
results = {}
# 先检查缓存
uncached_pairs = []
for pair in char_pinyin_pairs:
if pair in self.char_info_cache:
results[pair] = self.char_info_cache[pair]
else:
uncached_pairs.append(pair)
# 批量查询未缓存的
if uncached_pairs:
# 使用query_engine批量查询
char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs)
for pair, char_info in char_infos.items():
if char_info:
info = {
"id": char_info.id,
"freq": char_info.count,
"char": char_info.char,
"pinyin": char_info.pinyin,
}
else:
info = None
results[pair] = info
self.char_info_cache[pair] = info
return results
def _process_batch(self, char_pinyin_batch, char_positions, text):
"""处理批量字符"""
# 批量查询字符信息
char_info_map = self.batch_get_char_info(char_pinyin_batch)
batch_samples = []
for pos_info in char_positions:
char = pos_info["char"]
pinyin = pos_info["pinyin"]
next_pinyins = pos_info["next_pinyins"]
context = pos_info["context"]
# 获取字符信息
char_info = char_info_map.get((char, pinyin))
if not char_info:
continue
# 削峰填谷调整
adjust_factor = self.adjust_frequency(char_info["freq"])
if adjust_factor <= 0:
continue
# 前文采样
sampled_context = self.sample_context(context)
# 拼音处理
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
# Tokenize
hint = self.tokenizer(
sampled_context,
processed_pinyin,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
prob = random.random()
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
if prob < 0.1:
py = ""
else:
py = processed_pinyin
# 生成样本
sample = {
"hint": hint,
"txt": sampled_context,
"py": py,
"char_id": torch.tensor([char_info["id"]]),
"char": char,
"freq": char_info["freq"],
"pg": torch.tensor(
[pg]
),
}
# 根据调整因子重复样本
for _ in range(adjust_factor):
batch_samples.append(sample)
return batch_samples
def _shuffle_and_yield(self, batch_samples):
"""优化打乱逻辑"""
if not self.shuffle:
yield from batch_samples
return
# 使用numpy批量操作代替random.shuffle
if batch_samples:
indices = np.random.permutation(len(batch_samples))
for idx in indices:
yield batch_samples[idx]
def __iter__(self):
"""
迭代器实现,支持多进程
返回:
生成器,每次返回一个样本
"""
# 获取worker信息为每个worker设置不同的随机种子
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
# 使用base_seed + worker_id确保每个worker有不同但确定的随机序列
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
batch_samples = []
for item in self.dataset:
text = item.get(self.text_field, "")
if not text:
continue
# 转换为拼音列表
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
# 批量收集需要查询的字符信息
chinese_positions = [
i
for i, char in enumerate(text)
if self.query_engine.is_chinese_char(char)
]
# 批量收集需要查询的字符信息
char_pinyin_batch = []
char_positions = []
for i in chinese_positions: # 只遍历中文字符位置
char = text[i]
py = pinyin_list[i]
# 获取后续最多3个中文字符的拼音
next_chars = self.get_next_chinese_chars(
text, i, max_count=3, pinyin_list=pinyin_list
)
next_pinyins = [py] + [p for _, p in next_chars]
# 获取前文上下文最多100字符
context = text[max(0, i - 100) : i]
# 收集信息用于批量查询
char_pinyin_batch.append((char, py))
char_positions.append(
{
"index": i,
"char": char,
"pinyin": py,
"next_pinyins": next_pinyins,
"context": context,
"next_chars": next_chars,
}
)
# 达到批量大小时处理
if len(char_pinyin_batch) >= self.batch_query_size:
batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text
)
char_pinyin_batch = []
char_positions = []
if len(batch_samples) >= self.shuffle_buffer_size:
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
yield from self._shuffle_and_yield(batch_samples)
batch_samples = []
# 处理剩余的字符
if char_pinyin_batch:
batch_samples += self._process_batch(
char_pinyin_batch, char_positions, text
)
yield from self._shuffle_and_yield(batch_samples)
def __len__(self):
"""
由于是流式数据集,无法预先知道长度
返回:
返回一个估计值或-1
"""
return -1
# 辅助函数用于DataLoader
def worker_init_fn(worker_id):
"""DataLoader worker初始化函数"""
# 设置每个worker的随机种子
seed = torch.initial_seed() + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
torch.manual_seed(seed % (2**32))
def custom_collate_with_txt(batch):
"""自定义批处理函数"""
if not batch:
return {}
# 处理hint字段
hints = [item["hint"] for item in batch]
# 合并所有张量字段
result = {
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"char": [item["char"] for item in batch],
"txt": [item["txt"] for item in batch],
"py": [item["py"] for item in batch],
"pg": torch.cat([item["pg"] for item in batch]),
}
return result
def custom_collate(batch):
"""自定义批处理函数"""
if not batch:
return {}
# 处理hint字段
hints = [item["hint"] for item in batch]
# 合并所有张量字段
result = {
"hint": {
"input_ids": torch.cat([h["input_ids"] for h in hints]),
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
},
"char_id": torch.cat([item["char_id"] for item in batch]),
"pg": torch.cat([item["pg"] for item in batch]),
}
return result