624 lines
19 KiB
Python
624 lines
19 KiB
Python
import os
|
||
import random
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from torch.utils.data import DataLoader, IterableDataset
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
PG = {
|
||
"y": 0,
|
||
"k": 0,
|
||
"e": 0,
|
||
"l": 1,
|
||
"w": 1,
|
||
"f": 1,
|
||
"q": 2,
|
||
"a": 2,
|
||
"s": 2,
|
||
"x": 3,
|
||
"b": 3,
|
||
"r": 3,
|
||
"o": 4,
|
||
"m": 4,
|
||
"z": 4,
|
||
"g": 5,
|
||
"n": 5,
|
||
"c": 5,
|
||
"t": 6,
|
||
"p": 6,
|
||
"d": 6,
|
||
"j": 7,
|
||
"h": 7,
|
||
}
|
||
|
||
|
||
class PinyinInputDataset(IterableDataset):
|
||
"""
|
||
拼音输入法模拟数据集
|
||
|
||
特性:
|
||
1. 流式读取数据集,内存友好
|
||
2. 实时拼音转换和多音字处理
|
||
3. 前文上下文多种采样方式
|
||
4. 拼音截断模拟不完整输入
|
||
5. 内置削峰填谷算法平衡数据分布
|
||
6. 缓冲区打乱支持多进程
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
data_dir: str,
|
||
query_engine,
|
||
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
|
||
max_len: int = 88,
|
||
text_field: str = "text",
|
||
batch_query_size: int = 1000,
|
||
# 打乱参数
|
||
shuffle: bool = True,
|
||
shuffle_buffer_size: int = 10000,
|
||
# 削峰填谷参数
|
||
max_freq: int = 434748359, # "的"的频率
|
||
min_freq: int = 109, # "蓚"的频率
|
||
drop_start_freq: int = 30000000, # 开始丢弃的阈值
|
||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||
sample_context_section = [0.90, 0.95, 1]
|
||
):
|
||
"""
|
||
初始化数据集
|
||
|
||
Args:
|
||
data_dir: 数据集目录
|
||
query_engine: QueryEngine实例
|
||
tokenizer_name: tokenizer名称
|
||
max_len: 最大序列长度
|
||
text_field: 文本字段名
|
||
batch_query_size: 批量查询大小
|
||
shuffle: 是否打乱数据
|
||
shuffle_buffer_size: 打乱缓冲区大小
|
||
max_freq: 最大字符频率
|
||
min_freq: 最小字符频率
|
||
drop_start_freq: 开始削峰的频率阈值
|
||
repeat_end_freq: 开始填谷的频率阈值
|
||
max_drop_prob: 最高频率字符的丢弃概率
|
||
max_repeat_expect: 最低频率字符的重复期望
|
||
"""
|
||
self.query_engine = query_engine
|
||
self.max_len = max_len
|
||
self.text_field = text_field
|
||
self.batch_query_size = batch_query_size
|
||
|
||
# 打乱相关参数
|
||
self.shuffle = shuffle
|
||
self.shuffle_buffer_size = shuffle_buffer_size
|
||
|
||
# 削峰填谷参数
|
||
self.max_freq = max_freq
|
||
self.min_freq = min_freq
|
||
self.drop_start_freq = drop_start_freq
|
||
self.repeat_end_freq = repeat_end_freq
|
||
self.max_drop_prob = max_drop_prob
|
||
self.max_repeat_expect = max_repeat_expect
|
||
|
||
# 加载tokenizer
|
||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||
|
||
# 获取总字频用于后续计算
|
||
stats = query_engine.get_statistics()
|
||
self.total_chars = stats.get("valid_input_character_count", 0)
|
||
|
||
# 缓存字典
|
||
self.char_info_cache = {}
|
||
|
||
# 加载数据集
|
||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||
|
||
# 加载拼音分组
|
||
self.pg_groups = PG
|
||
|
||
# 上下文采样方式概率区间
|
||
self.sample_context_section = sample_context_section
|
||
|
||
|
||
|
||
def get_next_chinese_chars(
|
||
self,
|
||
text: str,
|
||
start_idx: int,
|
||
max_count: int = 3,
|
||
pinyin_list: List[str] = None,
|
||
) -> List[Tuple[str, str]]:
|
||
"""
|
||
获取后续的中文字符及其拼音
|
||
|
||
Args:
|
||
text: 完整文本
|
||
start_idx: 起始位置
|
||
max_count: 最大字符数
|
||
|
||
Returns:
|
||
列表,每个元素为(字符, 拼音)
|
||
"""
|
||
result = []
|
||
count = 0
|
||
|
||
for i in range(start_idx + 1, len(text)):
|
||
if count >= max_count:
|
||
break
|
||
|
||
char = text[i]
|
||
if self.query_engine.is_chinese_char(char):
|
||
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
||
try:
|
||
# 重新计算整个text的拼音可能效率低,但确保准确
|
||
# 实际实现中可以考虑缓存或优化
|
||
if pinyin_list is None:
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||
if i < len(pinyin_list):
|
||
result.append((char, pinyin_list[i]))
|
||
count += 1
|
||
except Exception:
|
||
break
|
||
else:
|
||
# 非汉字,继续查找
|
||
continue
|
||
|
||
return result
|
||
|
||
def sample_context(self, context: str) -> str:
|
||
"""
|
||
三种方式采样前文上下文
|
||
|
||
Args:
|
||
context: 原始前文(最多100字符)
|
||
|
||
Returns:
|
||
采样后的54个字符
|
||
"""
|
||
if not context:
|
||
return ""
|
||
|
||
# 确保有足够长度
|
||
context_len = len(context)
|
||
|
||
# 随机选择采样方式
|
||
choice = random.random()
|
||
|
||
if choice < self.sample_context_section[0]:
|
||
# 方式1: 靠近汉字的54个字符
|
||
return context[-54:] if context_len >= 54 else context
|
||
elif choice < self.sample_context_section[1]:
|
||
# 方式2: 随机位置取46个连续字符
|
||
if context_len <= 46:
|
||
return context
|
||
start = random.randint(0, context_len - 46)
|
||
return context[start : start + 46]
|
||
else:
|
||
# 方式3: 12+6×7组合
|
||
if context_len < 12:
|
||
return context
|
||
|
||
# 最后12个字符
|
||
last_12 = context[-12:]
|
||
|
||
# 从剩下的前88个字符中随机取6段,每段7个字符
|
||
remaining = context[:-12] if context_len > 12 else ""
|
||
remaining_len = len(remaining)
|
||
|
||
if remaining_len < 7:
|
||
# 如果不够7个字符,直接返回最后12个字符
|
||
return last_12
|
||
|
||
segments = []
|
||
for _ in range(6):
|
||
if remaining_len < 7:
|
||
break
|
||
start = random.randint(0, remaining_len - 7)
|
||
segment = remaining[start : start + 7]
|
||
segments.append(segment)
|
||
|
||
# 拼接
|
||
combined = "".join(segments)
|
||
result = combined + last_12
|
||
|
||
# 确保总长度为54(可能不足)
|
||
if len(result) < 54:
|
||
# 如果不够,从前面补一些字符
|
||
needed = 54 - len(result)
|
||
if context_len >= needed:
|
||
result = context[:needed] + result
|
||
|
||
# 截断到54字符
|
||
return result[:54]
|
||
|
||
def truncate_pinyin(self, pinyin: str) -> str:
|
||
"""
|
||
截断拼音
|
||
|
||
Args:
|
||
pinyin: 原始拼音
|
||
|
||
Returns:
|
||
截断后的拼音,可能为空字符串
|
||
"""
|
||
if not pinyin:
|
||
return ""
|
||
|
||
# 随机决定截断方式
|
||
rand_val = random.random()
|
||
|
||
if rand_val < 0.1:
|
||
# 10%概率截断为空
|
||
return ""
|
||
elif rand_val < 0.9:
|
||
# 80%概率不截断
|
||
return pinyin
|
||
else:
|
||
# 10%概率随机截断
|
||
# 均匀分配剩余概率给各种截断长度
|
||
max_len = len(pinyin)
|
||
if max_len <= 1:
|
||
return pinyin
|
||
|
||
# 随机选择截断长度 (1 到 max_len-1)
|
||
trunc_len = random.randint(1, max_len - 1)
|
||
return pinyin[:trunc_len]
|
||
|
||
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
||
"""
|
||
处理拼音序列,逐个截断并拼接
|
||
|
||
Args:
|
||
pinyin_list: 拼音列表,长度1-4
|
||
|
||
Returns:
|
||
拼接后的拼音字符串
|
||
"""
|
||
result_parts = []
|
||
|
||
for pinyin in pinyin_list:
|
||
truncated = self.truncate_pinyin(pinyin)
|
||
if not truncated:
|
||
# 如果某个拼音截断为空,则停止
|
||
break
|
||
result_parts.append(truncated)
|
||
|
||
if not result_parts:
|
||
return ""
|
||
|
||
result = "".join(result_parts)
|
||
|
||
# 限制最大长度
|
||
if len(result) > 18:
|
||
result = result[:18]
|
||
|
||
return result
|
||
|
||
def adjust_frequency(self, freq: int) -> int:
|
||
"""
|
||
削峰填谷 - 根据频率调整采样
|
||
|
||
Args:
|
||
freq: 当前字符频率
|
||
|
||
Returns:
|
||
调整后的采样次数,0表示丢弃
|
||
"""
|
||
# 1. 削峰处理(高频字,>= 3000W开始丢弃)
|
||
if freq >= self.drop_start_freq:
|
||
# 线性丢弃概率:3000W时丢弃概率为0,434748359时丢弃概率为0.8
|
||
# 使用线性插值计算丢弃概率
|
||
if self.max_freq == self.drop_start_freq:
|
||
drop_prob = 0.0 # 防止除零
|
||
else:
|
||
drop_prob = (
|
||
self.max_drop_prob
|
||
* (freq - self.drop_start_freq)
|
||
/ (self.max_freq - self.drop_start_freq)
|
||
)
|
||
|
||
# 根据丢弃概率决定是否保留
|
||
if random.random() < drop_prob:
|
||
return 0 # 丢弃该样本
|
||
else:
|
||
return 1 # 保留,但不重复
|
||
|
||
# 2. 填谷处理(低频字,<= 1W开始重复)
|
||
elif freq <= self.repeat_end_freq:
|
||
# 线性重复期望:1W时重复期望为0,109时重复期望为50
|
||
# 使用线性插值计算期望重复次数
|
||
if freq <= self.min_freq:
|
||
repeat_expect = self.max_repeat_expect # 最低频字重复期望为50
|
||
else:
|
||
if self.repeat_end_freq == self.min_freq:
|
||
repeat_expect = 0 # 防止除零
|
||
else:
|
||
# 线性插值公式
|
||
repeat_expect = (
|
||
self.max_repeat_expect
|
||
* (self.repeat_end_freq - freq)
|
||
/ (self.repeat_end_freq - self.min_freq)
|
||
)
|
||
|
||
# 期望重复次数转换为实际重复次数
|
||
# 使用泊松分布实现期望重复,确保有随机性
|
||
repeat_count = np.random.poisson(repeat_expect)
|
||
|
||
# 确保至少返回1次
|
||
return max(1, repeat_count)
|
||
|
||
# 3. 中等频率字(1W < freq < 3000W)
|
||
else:
|
||
return 1 # 保持原样
|
||
|
||
def batch_get_char_info(
|
||
self, char_pinyin_pairs: List[Tuple[str, str]]
|
||
) -> Dict[Tuple[str, str], Any]:
|
||
"""
|
||
批量获取字符信息
|
||
|
||
Args:
|
||
char_pinyin_pairs: [(字符, 拼音), ...]
|
||
|
||
Returns:
|
||
字典,key为(字符, 拼音),value为(id, 频率)或None
|
||
"""
|
||
results = {}
|
||
|
||
# 先检查缓存
|
||
uncached_pairs = []
|
||
for pair in char_pinyin_pairs:
|
||
if pair in self.char_info_cache:
|
||
results[pair] = self.char_info_cache[pair]
|
||
else:
|
||
uncached_pairs.append(pair)
|
||
|
||
# 批量查询未缓存的
|
||
if uncached_pairs:
|
||
# 使用query_engine批量查询
|
||
char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs)
|
||
for pair, char_info in char_infos.items():
|
||
if char_info:
|
||
info = {
|
||
"id": char_info.id,
|
||
"freq": char_info.count,
|
||
"char": char_info.char,
|
||
"pinyin": char_info.pinyin,
|
||
}
|
||
else:
|
||
info = None
|
||
|
||
results[pair] = info
|
||
self.char_info_cache[pair] = info
|
||
|
||
return results
|
||
|
||
def _process_batch(self, char_pinyin_batch, char_positions, text):
|
||
"""处理批量字符"""
|
||
# 批量查询字符信息
|
||
char_info_map = self.batch_get_char_info(char_pinyin_batch)
|
||
|
||
batch_samples = []
|
||
|
||
for pos_info in char_positions:
|
||
char = pos_info["char"]
|
||
pinyin = pos_info["pinyin"]
|
||
next_pinyins = pos_info["next_pinyins"]
|
||
context = pos_info["context"]
|
||
|
||
# 获取字符信息
|
||
char_info = char_info_map.get((char, pinyin))
|
||
if not char_info:
|
||
continue
|
||
|
||
# 削峰填谷调整
|
||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||
if adjust_factor <= 0:
|
||
continue
|
||
|
||
# 前文采样
|
||
sampled_context = self.sample_context(context)
|
||
|
||
# 拼音处理
|
||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||
|
||
# Tokenize
|
||
hint = self.tokenizer(
|
||
sampled_context,
|
||
processed_pinyin,
|
||
max_length=self.max_len,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
)
|
||
|
||
prob = random.random()
|
||
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
||
if prob < 0.1:
|
||
py = ""
|
||
else:
|
||
py = processed_pinyin
|
||
|
||
# 生成样本
|
||
sample = {
|
||
"hint": hint,
|
||
"txt": sampled_context,
|
||
"py": py,
|
||
"char_id": torch.tensor([char_info["id"]]),
|
||
"char": char,
|
||
"freq": char_info["freq"],
|
||
"pg": torch.tensor(
|
||
[pg]
|
||
),
|
||
}
|
||
|
||
# 根据调整因子重复样本
|
||
for _ in range(adjust_factor):
|
||
batch_samples.append(sample)
|
||
|
||
return batch_samples
|
||
|
||
def _shuffle_and_yield(self, batch_samples):
|
||
"""优化打乱逻辑"""
|
||
if not self.shuffle:
|
||
yield from batch_samples
|
||
return
|
||
|
||
# 使用numpy批量操作代替random.shuffle
|
||
if batch_samples:
|
||
indices = np.random.permutation(len(batch_samples))
|
||
for idx in indices:
|
||
yield batch_samples[idx]
|
||
|
||
def __iter__(self):
|
||
"""
|
||
迭代器实现,支持多进程
|
||
|
||
返回:
|
||
生成器,每次返回一个样本
|
||
"""
|
||
# 获取worker信息,为每个worker设置不同的随机种子
|
||
worker_info = torch.utils.data.get_worker_info()
|
||
|
||
if worker_info is not None:
|
||
worker_id = worker_info.id
|
||
# 使用base_seed + worker_id确保每个worker有不同但确定的随机序列
|
||
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||
seed = base_seed + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
|
||
batch_samples = []
|
||
for item in self.dataset:
|
||
text = item.get(self.text_field, "")
|
||
if not text:
|
||
continue
|
||
|
||
# 转换为拼音列表
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||
# 批量收集需要查询的字符信息
|
||
chinese_positions = [
|
||
i
|
||
for i, char in enumerate(text)
|
||
if self.query_engine.is_chinese_char(char)
|
||
]
|
||
|
||
# 批量收集需要查询的字符信息
|
||
char_pinyin_batch = []
|
||
char_positions = []
|
||
|
||
for i in chinese_positions: # 只遍历中文字符位置
|
||
char = text[i]
|
||
py = pinyin_list[i]
|
||
|
||
# 获取后续最多3个中文字符的拼音
|
||
next_chars = self.get_next_chinese_chars(
|
||
text, i, max_count=3, pinyin_list=pinyin_list
|
||
)
|
||
next_pinyins = [py] + [p for _, p in next_chars]
|
||
# 获取前文上下文(最多100字符)
|
||
context = text[max(0, i - 100) : i]
|
||
|
||
# 收集信息用于批量查询
|
||
char_pinyin_batch.append((char, py))
|
||
char_positions.append(
|
||
{
|
||
"index": i,
|
||
"char": char,
|
||
"pinyin": py,
|
||
"next_pinyins": next_pinyins,
|
||
"context": context,
|
||
"next_chars": next_chars,
|
||
}
|
||
)
|
||
|
||
# 达到批量大小时处理
|
||
if len(char_pinyin_batch) >= self.batch_query_size:
|
||
batch_samples += self._process_batch(
|
||
char_pinyin_batch, char_positions, text
|
||
)
|
||
char_pinyin_batch = []
|
||
char_positions = []
|
||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||
# logger.info(f"批量处理完成,开始打乱数据并生成样本, len(batch_samples): {len(batch_samples)}")
|
||
yield from self._shuffle_and_yield(batch_samples)
|
||
batch_samples = []
|
||
# 处理剩余的字符
|
||
if char_pinyin_batch:
|
||
batch_samples += self._process_batch(
|
||
char_pinyin_batch, char_positions, text
|
||
)
|
||
yield from self._shuffle_and_yield(batch_samples)
|
||
|
||
def __len__(self):
|
||
"""
|
||
由于是流式数据集,无法预先知道长度
|
||
|
||
返回:
|
||
返回一个估计值或-1
|
||
"""
|
||
return -1
|
||
|
||
|
||
# 辅助函数,用于DataLoader
|
||
def worker_init_fn(worker_id):
|
||
"""DataLoader worker初始化函数"""
|
||
# 设置每个worker的随机种子
|
||
seed = torch.initial_seed() + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
torch.manual_seed(seed % (2**32))
|
||
|
||
|
||
def custom_collate_with_txt(batch):
|
||
"""自定义批处理函数"""
|
||
if not batch:
|
||
return {}
|
||
|
||
# 处理hint字段
|
||
hints = [item["hint"] for item in batch]
|
||
|
||
# 合并所有张量字段
|
||
result = {
|
||
"hint": {
|
||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||
},
|
||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||
"char": [item["char"] for item in batch],
|
||
"txt": [item["txt"] for item in batch],
|
||
"py": [item["py"] for item in batch],
|
||
"pg": torch.cat([item["pg"] for item in batch]),
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
def custom_collate(batch):
|
||
"""自定义批处理函数"""
|
||
if not batch:
|
||
return {}
|
||
|
||
# 处理hint字段
|
||
hints = [item["hint"] for item in batch]
|
||
|
||
# 合并所有张量字段
|
||
result = {
|
||
"hint": {
|
||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||
},
|
||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||
"pg": torch.cat([item["pg"] for item in batch]),
|
||
}
|
||
|
||
return result
|