649 lines
20 KiB
Python
649 lines
20 KiB
Python
import random
|
||
from typing import Any, Dict, List, Tuple
|
||
import os
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from torch.utils.data import DataLoader, IterableDataset
|
||
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
class PinyinInputDataset(IterableDataset):
|
||
"""
|
||
拼音输入法模拟数据集
|
||
|
||
特性:
|
||
1. 流式读取数据集,内存友好
|
||
2. 实时拼音转换和多音字处理
|
||
3. 前文上下文多种采样方式
|
||
4. 拼音截断模拟不完整输入
|
||
5. 内置削峰填谷算法平衡数据分布
|
||
6. 缓冲区打乱支持多进程
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
data_dir: str,
|
||
query_engine,
|
||
tokenizer_name: str = "iic/nlp_structbert_backbone_tiny_std",
|
||
max_len: int = 88,
|
||
text_field: str = "text",
|
||
batch_query_size: int = 1000,
|
||
# 打乱参数
|
||
shuffle: bool = True,
|
||
shuffle_buffer_size: int = 100,
|
||
# 削峰填谷参数
|
||
max_freq: int = 434748359, # "的"的频率
|
||
min_freq: int = 109, # "蓚"的频率
|
||
drop_start_freq: int = 30000000, # 开始丢弃的阈值
|
||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||
max_repeat_expect: float = 50.0, # 最大重复期望
|
||
py_group_json_file: Path = Path(__file__).parent.parent.parent / "py_group.json",
|
||
):
|
||
"""
|
||
初始化数据集
|
||
|
||
Args:
|
||
data_dir: 数据集目录
|
||
query_engine: QueryEngine实例
|
||
tokenizer_name: tokenizer名称
|
||
max_len: 最大序列长度
|
||
text_field: 文本字段名
|
||
batch_query_size: 批量查询大小
|
||
shuffle: 是否打乱数据
|
||
shuffle_buffer_size: 打乱缓冲区大小
|
||
max_freq: 最大字符频率
|
||
min_freq: 最小字符频率
|
||
drop_start_freq: 开始削峰的频率阈值
|
||
repeat_end_freq: 开始填谷的频率阈值
|
||
max_drop_prob: 最高频率字符的丢弃概率
|
||
max_repeat_expect: 最低频率字符的重复期望
|
||
"""
|
||
self.query_engine = query_engine
|
||
self.max_len = max_len
|
||
self.text_field = text_field
|
||
self.batch_query_size = batch_query_size
|
||
|
||
# 打乱相关参数
|
||
self.shuffle = shuffle
|
||
self.shuffle_buffer_size = shuffle_buffer_size
|
||
|
||
# 削峰填谷参数
|
||
self.max_freq = max_freq
|
||
self.min_freq = min_freq
|
||
self.drop_start_freq = drop_start_freq
|
||
self.repeat_end_freq = repeat_end_freq
|
||
self.max_drop_prob = max_drop_prob
|
||
self.max_repeat_expect = max_repeat_expect
|
||
|
||
# 加载tokenizer
|
||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||
|
||
# 获取总字频用于后续计算
|
||
stats = query_engine.get_statistics()
|
||
self.total_chars = stats.get("valid_input_character_count", 0)
|
||
|
||
# 缓存字典
|
||
self.char_info_cache = {}
|
||
|
||
# 加载数据集
|
||
self.dataset = load_dataset(data_dir, split="train", streaming=True)
|
||
|
||
with open(py_group_json_file, "r") as f:
|
||
self.py_groups = json.load(f)
|
||
|
||
|
||
def get_next_chinese_chars(
|
||
self,
|
||
text: str,
|
||
start_idx: int,
|
||
max_count: int = 3,
|
||
pinyin_list: List[str] = None,
|
||
) -> List[Tuple[str, str]]:
|
||
"""
|
||
获取后续的中文字符及其拼音
|
||
|
||
Args:
|
||
text: 完整文本
|
||
start_idx: 起始位置
|
||
max_count: 最大字符数
|
||
|
||
Returns:
|
||
列表,每个元素为(字符, 拼音)
|
||
"""
|
||
result = []
|
||
count = 0
|
||
|
||
for i in range(start_idx + 1, len(text)):
|
||
if count >= max_count:
|
||
break
|
||
|
||
char = text[i]
|
||
if self.query_engine.is_chinese_char(char):
|
||
# 获取拼音(注意:这里需要确保拼音列表长度与text一致)
|
||
try:
|
||
# 重新计算整个text的拼音可能效率低,但确保准确
|
||
# 实际实现中可以考虑缓存或优化
|
||
if pinyin_list is None:
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||
if i < len(pinyin_list):
|
||
result.append((char, pinyin_list[i]))
|
||
count += 1
|
||
except Exception:
|
||
break
|
||
else:
|
||
# 非汉字,继续查找
|
||
continue
|
||
|
||
return result
|
||
|
||
def sample_context(self, context: str) -> str:
|
||
"""
|
||
三种方式采样前文上下文
|
||
|
||
Args:
|
||
context: 原始前文(最多100字符)
|
||
|
||
Returns:
|
||
采样后的54个字符
|
||
"""
|
||
if not context:
|
||
return ""
|
||
|
||
# 确保有足够长度
|
||
context_len = len(context)
|
||
|
||
# 随机选择采样方式 (各1/3概率)
|
||
choice = random.random()
|
||
|
||
if choice < 0.333:
|
||
# 方式1: 靠近汉字的54个字符
|
||
return context[-54:] if context_len >= 54 else context
|
||
elif choice < 0.667:
|
||
# 方式2: 随机位置取46个连续字符
|
||
if context_len <= 46:
|
||
return context
|
||
start = random.randint(0, context_len - 46)
|
||
return context[start : start + 46]
|
||
else:
|
||
# 方式3: 12+6×7组合
|
||
if context_len < 12:
|
||
return context
|
||
|
||
# 最后12个字符
|
||
last_12 = context[-12:]
|
||
|
||
# 从剩下的前88个字符中随机取6段,每段7个字符
|
||
remaining = context[:-12] if context_len > 12 else ""
|
||
remaining_len = len(remaining)
|
||
|
||
if remaining_len < 7:
|
||
# 如果不够7个字符,直接返回最后12个字符
|
||
return last_12
|
||
|
||
segments = []
|
||
for _ in range(6):
|
||
if remaining_len < 7:
|
||
break
|
||
start = random.randint(0, remaining_len - 7)
|
||
segment = remaining[start : start + 7]
|
||
segments.append(segment)
|
||
|
||
# 拼接
|
||
combined = "".join(segments)
|
||
result = combined + last_12
|
||
|
||
# 确保总长度为54(可能不足)
|
||
if len(result) < 54:
|
||
# 如果不够,从前面补一些字符
|
||
needed = 54 - len(result)
|
||
if context_len >= needed:
|
||
result = context[:needed] + result
|
||
|
||
# 截断到54字符
|
||
return result[:54]
|
||
|
||
def truncate_pinyin(self, pinyin: str) -> str:
|
||
"""
|
||
截断拼音
|
||
|
||
Args:
|
||
pinyin: 原始拼音
|
||
|
||
Returns:
|
||
截断后的拼音,可能为空字符串
|
||
"""
|
||
if not pinyin:
|
||
return ""
|
||
|
||
# 随机决定截断方式
|
||
rand_val = random.random()
|
||
|
||
if rand_val < 0.1:
|
||
# 10%概率截断为空
|
||
return ""
|
||
elif rand_val < 0.6:
|
||
# 50%概率不截断
|
||
return pinyin
|
||
else:
|
||
# 40%概率随机截断
|
||
# 均匀分配剩余概率给各种截断长度
|
||
max_len = len(pinyin)
|
||
if max_len <= 1:
|
||
return pinyin
|
||
|
||
# 随机选择截断长度 (1 到 max_len-1)
|
||
trunc_len = random.randint(1, max_len - 1)
|
||
return pinyin[:trunc_len]
|
||
|
||
def process_pinyin_sequence(self, pinyin_list: List[str]) -> str:
|
||
"""
|
||
处理拼音序列,逐个截断并拼接
|
||
|
||
Args:
|
||
pinyin_list: 拼音列表,长度1-4
|
||
|
||
Returns:
|
||
拼接后的拼音字符串
|
||
"""
|
||
result_parts = []
|
||
|
||
for pinyin in pinyin_list:
|
||
truncated = self.truncate_pinyin(pinyin)
|
||
if not truncated:
|
||
# 如果某个拼音截断为空,则停止
|
||
break
|
||
result_parts.append(truncated)
|
||
|
||
if not result_parts:
|
||
return ""
|
||
|
||
result = "".join(result_parts)
|
||
|
||
# 限制最大长度
|
||
if len(result) > 18:
|
||
result = result[:18]
|
||
|
||
return result
|
||
|
||
def adjust_frequency(self, freq: int) -> int:
|
||
"""
|
||
削峰填谷 - 根据频率调整采样
|
||
|
||
Args:
|
||
freq: 当前字符频率
|
||
|
||
Returns:
|
||
调整后的采样次数,0表示丢弃
|
||
"""
|
||
# 1. 削峰处理(高频字,>= 3000W开始丢弃)
|
||
if freq >= self.drop_start_freq:
|
||
# 线性丢弃概率:3000W时丢弃概率为0,434748359时丢弃概率为0.8
|
||
# 使用线性插值计算丢弃概率
|
||
if self.max_freq == self.drop_start_freq:
|
||
drop_prob = 0.0 # 防止除零
|
||
else:
|
||
drop_prob = (
|
||
self.max_drop_prob
|
||
* (freq - self.drop_start_freq)
|
||
/ (self.max_freq - self.drop_start_freq)
|
||
)
|
||
|
||
# 根据丢弃概率决定是否保留
|
||
if random.random() < drop_prob:
|
||
return 0 # 丢弃该样本
|
||
else:
|
||
return 1 # 保留,但不重复
|
||
|
||
# 2. 填谷处理(低频字,<= 1W开始重复)
|
||
elif freq <= self.repeat_end_freq:
|
||
# 线性重复期望:1W时重复期望为0,109时重复期望为50
|
||
# 使用线性插值计算期望重复次数
|
||
if freq <= self.min_freq:
|
||
repeat_expect = self.max_repeat_expect # 最低频字重复期望为50
|
||
else:
|
||
if self.repeat_end_freq == self.min_freq:
|
||
repeat_expect = 0 # 防止除零
|
||
else:
|
||
# 线性插值公式
|
||
repeat_expect = (
|
||
self.max_repeat_expect
|
||
* (self.repeat_end_freq - freq)
|
||
/ (self.repeat_end_freq - self.min_freq)
|
||
)
|
||
|
||
# 期望重复次数转换为实际重复次数
|
||
# 使用泊松分布实现期望重复,确保有随机性
|
||
repeat_count = np.random.poisson(repeat_expect)
|
||
|
||
# 确保至少返回1次
|
||
return max(1, repeat_count)
|
||
|
||
# 3. 中等频率字(1W < freq < 3000W)
|
||
else:
|
||
return 1 # 保持原样
|
||
|
||
def batch_get_char_info(
|
||
self, char_pinyin_pairs: List[Tuple[str, str]]
|
||
) -> Dict[Tuple[str, str], Any]:
|
||
"""
|
||
批量获取字符信息
|
||
|
||
Args:
|
||
char_pinyin_pairs: [(字符, 拼音), ...]
|
||
|
||
Returns:
|
||
字典,key为(字符, 拼音),value为(id, 频率)或None
|
||
"""
|
||
results = {}
|
||
|
||
# 先检查缓存
|
||
uncached_pairs = []
|
||
for pair in char_pinyin_pairs:
|
||
if pair in self.char_info_cache:
|
||
results[pair] = self.char_info_cache[pair]
|
||
else:
|
||
uncached_pairs.append(pair)
|
||
|
||
# 批量查询未缓存的
|
||
if uncached_pairs:
|
||
# 使用query_engine批量查询
|
||
char_infos = self.query_engine.batch_get_char_pinyin_info(uncached_pairs)
|
||
for pair, char_info in char_infos.items():
|
||
if char_info:
|
||
info = {
|
||
"id": char_info.id,
|
||
"freq": char_info.count,
|
||
"char": char_info.char,
|
||
"pinyin": char_info.pinyin,
|
||
}
|
||
else:
|
||
info = None
|
||
|
||
results[pair] = info
|
||
self.char_info_cache[pair] = info
|
||
|
||
return results
|
||
|
||
def _process_batch(self, char_pinyin_batch, char_positions, text):
|
||
"""处理批量字符"""
|
||
# 批量查询字符信息
|
||
char_info_map = self.batch_get_char_info(char_pinyin_batch)
|
||
|
||
batch_samples = []
|
||
|
||
for pos_info in char_positions:
|
||
char = pos_info["char"]
|
||
pinyin = pos_info["pinyin"]
|
||
next_pinyins = pos_info["next_pinyins"]
|
||
context = pos_info["context"]
|
||
|
||
# 获取字符信息
|
||
char_info = char_info_map.get((char, pinyin))
|
||
if not char_info:
|
||
continue
|
||
|
||
logger.info(f'获取字符信息: {char_info}')
|
||
# 削峰填谷调整
|
||
adjust_factor = self.adjust_frequency(char_info["freq"])
|
||
if adjust_factor <= 0:
|
||
continue
|
||
|
||
# 前文采样
|
||
sampled_context = self.sample_context(context)
|
||
|
||
# 拼音处理
|
||
processed_pinyin = self.process_pinyin_sequence(next_pinyins)
|
||
if not processed_pinyin:
|
||
continue
|
||
|
||
# Tokenize
|
||
hint = self.tokenizer(
|
||
sampled_context,
|
||
processed_pinyin,
|
||
max_length=self.max_len,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
)
|
||
|
||
# 生成样本
|
||
sample = {
|
||
"hint": hint,
|
||
"txt": sampled_context,
|
||
"py": processed_pinyin,
|
||
"char_id": torch.tensor([char_info["id"]]),
|
||
"char": char,
|
||
"freq": char_info["freq"],
|
||
}
|
||
|
||
# 根据调整因子重复样本
|
||
for _ in range(adjust_factor):
|
||
batch_samples.append(sample)
|
||
|
||
return batch_samples
|
||
|
||
def _shuffle_and_yield(self, batch_samples):
|
||
"""优化打乱逻辑"""
|
||
if not self.shuffle:
|
||
yield from batch_samples
|
||
return
|
||
|
||
# 使用numpy批量操作代替random.shuffle
|
||
if batch_samples:
|
||
indices = np.random.permutation(len(batch_samples))
|
||
for idx in indices:
|
||
yield batch_samples[idx]
|
||
|
||
def __iter__(self):
|
||
"""
|
||
迭代器实现,支持多进程
|
||
|
||
返回:
|
||
生成器,每次返回一个样本
|
||
"""
|
||
# 获取worker信息,为每个worker设置不同的随机种子
|
||
worker_info = torch.utils.data.get_worker_info()
|
||
|
||
if worker_info is not None:
|
||
worker_id = worker_info.id
|
||
# 使用base_seed + worker_id确保每个worker有不同但确定的随机序列
|
||
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||
seed = base_seed + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
|
||
for item in self.dataset:
|
||
text = item.get(self.text_field, "")
|
||
if not text:
|
||
continue
|
||
|
||
# 转换为拼音列表
|
||
pinyin_list = lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||
# 批量收集需要查询的字符信息
|
||
chinese_positions = [
|
||
i for i, char in enumerate(text)
|
||
if self.query_engine.is_chinese_char(char)
|
||
]
|
||
|
||
# 批量收集需要查询的字符信息
|
||
char_pinyin_batch = []
|
||
char_positions = []
|
||
|
||
for i in chinese_positions: # 只遍历中文字符位置
|
||
char = text[i]
|
||
py = pinyin_list[i]
|
||
|
||
# 获取后续最多3个中文字符的拼音
|
||
next_chars = self.get_next_chinese_chars(
|
||
text, i, max_count=3, pinyin_list=pinyin_list
|
||
)
|
||
next_pinyins = [py] + [p for _, p in next_chars]
|
||
# 获取前文上下文(最多100字符)
|
||
context = text[max(0, i - 100) : i]
|
||
|
||
# 收集信息用于批量查询
|
||
char_pinyin_batch.append((char, py))
|
||
char_positions.append(
|
||
{
|
||
"index": i,
|
||
"char": char,
|
||
"pinyin": py,
|
||
"next_pinyins": next_pinyins,
|
||
"context": context,
|
||
"next_chars": next_chars,
|
||
}
|
||
)
|
||
|
||
# 达到批量大小时处理
|
||
if len(char_pinyin_batch) >= self.batch_query_size:
|
||
batch_samples = self._process_batch(
|
||
char_pinyin_batch, char_positions, text
|
||
)
|
||
yield from self._shuffle_and_yield(batch_samples)
|
||
char_pinyin_batch = []
|
||
char_positions = []
|
||
# 处理剩余的字符
|
||
if char_pinyin_batch:
|
||
batch_samples = self._process_batch(
|
||
char_pinyin_batch, char_positions, text
|
||
)
|
||
yield from self._shuffle_and_yield(batch_samples)
|
||
|
||
|
||
def __len__(self):
|
||
"""
|
||
由于是流式数据集,无法预先知道长度
|
||
|
||
返回:
|
||
返回一个估计值或-1
|
||
"""
|
||
return -1
|
||
|
||
|
||
# 辅助函数,用于DataLoader
|
||
def worker_init_fn(worker_id):
|
||
"""DataLoader worker初始化函数"""
|
||
# 设置每个worker的随机种子
|
||
seed = torch.initial_seed() + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
torch.manual_seed(seed % (2**32))
|
||
|
||
|
||
def custom_collate_with_txt(batch):
|
||
"""自定义批处理函数"""
|
||
if not batch:
|
||
return {}
|
||
|
||
# 处理hint字段
|
||
hints = [item["hint"] for item in batch]
|
||
|
||
# 合并所有张量字段
|
||
result = {
|
||
"hint": {
|
||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||
},
|
||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||
"char": [item["char"] for item in batch],
|
||
"txt": [item["txt"] for item in batch],
|
||
"py": [item["py"] for item in batch],
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
def custom_collate(batch):
|
||
"""自定义批处理函数"""
|
||
if not batch:
|
||
return {}
|
||
|
||
# 处理hint字段
|
||
hints = [item["hint"] for item in batch]
|
||
|
||
# 合并所有张量字段
|
||
result = {
|
||
"hint": {
|
||
"input_ids": torch.cat([h["input_ids"] for h in hints]),
|
||
"attention_mask": torch.cat([h["attention_mask"] for h in hints]),
|
||
},
|
||
"char_id": torch.cat([item["char_id"] for item in batch]),
|
||
"py": [item["py"] for item in batch],
|
||
# "py_group_id": [item["py"] for item in batch],
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
from query import QueryEngine
|
||
from tqdm import tqdm
|
||
|
||
# 初始化查询引擎
|
||
query_engine = QueryEngine()
|
||
query_engine.load("./pinyin_char_pairs_info.json")
|
||
|
||
# 创建数据集
|
||
dataset = PinyinInputDataset(
|
||
data_dir="/home/songsenand/Data/corpus/CCI-Data/",
|
||
query_engine=query_engine,
|
||
tokenizer_name="iic/nlp_structbert_backbone_tiny_std",
|
||
max_len=88,
|
||
batch_query_size=300,
|
||
shuffle=True,
|
||
shuffle_buffer_size=4000,
|
||
)
|
||
|
||
logger.info("数据集初始化")
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=1024,
|
||
num_workers=15,
|
||
worker_init_fn=worker_init_fn,
|
||
pin_memory=True if torch.cuda.is_available() else False,
|
||
collate_fn=custom_collate_with_txt,
|
||
prefetch_factor=8,
|
||
persistent_workers=True,
|
||
shuffle=False, # 数据集内部已实现打乱
|
||
)
|
||
|
||
"""import cProfile
|
||
|
||
def profile_func(dataloader):
|
||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||
if i >= 3000:
|
||
break
|
||
return
|
||
|
||
|
||
cProfile.run('profile_func(dataloader)')
|
||
|
||
"""
|
||
# 测试数据集
|
||
try:
|
||
logger.info("测试数据集")
|
||
for i, sample in tqdm(enumerate(dataloader), total=3000):
|
||
if i >= 3000:
|
||
break
|
||
"""
|
||
print(f"Sample {i+1}:")
|
||
print(f" Char: {sample['char']}, Id: {sample['char_id'].item()}, Freq: {sample.get('freq', 'N/A')}")
|
||
print(f" Pinyin: {sample['py']}")
|
||
print(f" Context length: {len(sample['txt'])}")
|
||
print(f" Hint shape: {sample['hint']['input_ids'].shape}")
|
||
print()
|
||
"""
|
||
except StopIteration:
|
||
print("数据集为空")
|
||
|