425 lines
18 KiB
Python
425 lines
18 KiB
Python
import random
|
||
import re
|
||
from importlib.resources import files
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
from datasets import load_dataset
|
||
from loguru import logger
|
||
from modelscope import AutoTokenizer
|
||
from pypinyin import lazy_pinyin
|
||
from pypinyin.contrib.tone_convert import to_initials
|
||
from torch.utils.data import IterableDataset
|
||
|
||
from .query import QueryEngine
|
||
|
||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||
|
||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||
|
||
|
||
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||
"""
|
||
将拼音字符串转换为 ID 列表。
|
||
支持 a-z 和 `。
|
||
未知字符映射为 0 (PAD/UNK)。
|
||
"""
|
||
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||
|
||
|
||
class PinyinInputDataset(IterableDataset):
|
||
def __init__(
|
||
self,
|
||
data_path: str,
|
||
max_workers: int = -1,
|
||
max_iter_length=1e6,
|
||
max_seq_length=128,
|
||
text_field: str = "text",
|
||
py_style_weight=(9, 2, 1),
|
||
shuffle_buffer_size: int = 100000,
|
||
retention_ratio: float = 0.5,
|
||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||
):
|
||
# 频率调整参数 (可根据需要调整)
|
||
self.drop_start_freq = 30_000_000
|
||
self.max_drop_prob = 0.8
|
||
self.repeat_end_freq = 10_000
|
||
self.max_repeat_expect = 50
|
||
self.min_freq = 109
|
||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
Path(str(files(__package__))) / "assets" / "tokenizer"
|
||
)
|
||
self.data_path = data_path
|
||
|
||
self.max_iter_length = max_iter_length
|
||
self.max_seq_length = max_seq_length
|
||
self.text_field = text_field
|
||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||
self.max_workers = max_workers
|
||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||
self.shuffle_buffer_size = shuffle_buffer_size
|
||
self.retention_ratio = retention_ratio
|
||
if not (0 < retention_ratio < 1):
|
||
raise ValueError(
|
||
f"retention_ratio必须在0和1之间,当前值: {retention_ratio}"
|
||
)
|
||
self.retention_size = int(shuffle_buffer_size * retention_ratio)
|
||
if self.retention_size <= 0:
|
||
raise ValueError(
|
||
f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
|
||
)
|
||
self.possible_lengths = list(length_weights.keys())
|
||
self.weights = list(length_weights.values())
|
||
|
||
self.query_engine = QueryEngine()
|
||
self.query_engine.load()
|
||
|
||
# 提取每个样本的目标字符及其频率
|
||
self.sample_freqs = self.query_engine.get_all_weights()
|
||
|
||
def adjust_frequency(self, freq: int) -> int:
|
||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||
# 1. 削峰处理(高频字)
|
||
if freq >= self.drop_start_freq:
|
||
# 线性丢弃概率计算
|
||
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
|
||
if max_freq == self.drop_start_freq:
|
||
drop_prob = 0.0
|
||
else:
|
||
drop_prob = (
|
||
self.max_drop_prob
|
||
* (freq - self.drop_start_freq)
|
||
/ (max_freq - self.drop_start_freq)
|
||
)
|
||
if random.random() < drop_prob:
|
||
return 0
|
||
else:
|
||
return 1
|
||
|
||
# 2. 填谷处理(低频字)
|
||
elif freq <= self.repeat_end_freq:
|
||
# 线性重复期望计算
|
||
if freq <= self.min_freq:
|
||
repeat_expect = self.max_repeat_expect
|
||
else:
|
||
if self.repeat_end_freq == self.min_freq:
|
||
repeat_expect = 0
|
||
else:
|
||
repeat_expect = (
|
||
self.max_repeat_expect
|
||
* (self.repeat_end_freq - freq)
|
||
/ (self.repeat_end_freq - self.min_freq)
|
||
)
|
||
# 使用泊松分布实现随机重复
|
||
repeat_count = np.random.poisson(repeat_expect)
|
||
return max(1, repeat_count)
|
||
|
||
# 3. 中间频率字
|
||
else:
|
||
return 1
|
||
|
||
# 生成对应文本的拼音
|
||
def generate_pinyin(self, text: str) -> List[str]:
|
||
"""
|
||
流式处理单条文本,转换为拼音列表。
|
||
|
||
特性:
|
||
1. 严格一一对应:len(result) == len(text)
|
||
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
|
||
3. 高性能:预分配内存,无多余对象创建
|
||
|
||
Args:
|
||
text: 输入字符串
|
||
|
||
Returns:
|
||
List[str]: 拼音或非汉字字符的列表
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
text_len = len(text)
|
||
# 2. 预分配结果列表,初始化占位符。
|
||
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
|
||
result: List[str] = [""] * text_len
|
||
|
||
# 3. 遍历所有连续汉字片段
|
||
for match in _HANZI_RE.finditer(text):
|
||
start_idx = match.start()
|
||
hanzi_segment = match.group()
|
||
|
||
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
|
||
# style=Style.NORMAL 获取不带声调的拼音
|
||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||
|
||
# 5. 健壮性兜底:
|
||
# 正常情况下,pypinyin 返回的拼音数应等于汉字数。
|
||
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
|
||
if len(pinyin_list) != len(hanzi_segment):
|
||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||
|
||
# 6. 直接通过索引填充到预分配的位置
|
||
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
|
||
for i, py in enumerate(pinyin_list):
|
||
result[start_idx + i] = py
|
||
|
||
# 7. 填充非汉字字符
|
||
# 遍历原文,如果 result 对应位置为空,则填入原字符
|
||
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
|
||
for i, char in enumerate(text):
|
||
if not result[i]:
|
||
result[i] = char
|
||
|
||
return result
|
||
|
||
# 生成需要预测汉字对应的拼音,并进行加强
|
||
def get_mask_pinyin(
|
||
self, text: str, pinyin_list: List[str]
|
||
) -> Tuple[int, List[str]]:
|
||
mask_pinyin = []
|
||
for i in range(len(text)):
|
||
if not self.query_engine.is_chinese_char(text[i]):
|
||
break
|
||
else:
|
||
py = np.random.choice(
|
||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||
p=self.py_style_weight,
|
||
)
|
||
if py == "":
|
||
py = pinyin_list[i][0]
|
||
mask_pinyin.append(py)
|
||
return len(mask_pinyin), mask_pinyin
|
||
|
||
def __iter__(self):
|
||
worker_info = torch.utils.data.get_worker_info()
|
||
if worker_info is not None:
|
||
worker_id = worker_info.id
|
||
num_workers = (
|
||
self.max_workers if self.max_workers > 0 else worker_info.num_workers
|
||
)
|
||
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
|
||
seed = base_seed + worker_id
|
||
random.seed(seed % (2**32))
|
||
np.random.seed(seed % (2**32))
|
||
|
||
# 安全检查:如果worker_id >= num_workers,则该worker不应该工作
|
||
# 这可能发生在self.max_workers小于实际worker数量时
|
||
if worker_id >= num_workers:
|
||
return # 产生空迭代器
|
||
|
||
# 使用局部变量存储分片数据集,避免竞争条件
|
||
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||
|
||
# 计算每个worker的配额
|
||
# 将 max_iter_length 转换为整数以确保整数除法
|
||
total_quota = int(self.max_iter_length)
|
||
base_quota = total_quota // num_workers
|
||
remainder = total_quota % num_workers
|
||
|
||
# 最后一个worker处理剩余的样本(如果有余数)
|
||
if worker_id == num_workers - 1:
|
||
worker_quota = base_quota + remainder
|
||
else:
|
||
worker_quota = base_quota
|
||
else:
|
||
# 单worker情况,使用全部配额
|
||
worker_quota = int(self.max_iter_length)
|
||
num_workers = 1
|
||
worker_dataset = self.dataset # 不使用分片
|
||
|
||
# 每个worker有自己的迭代计数器
|
||
current_iter_index = 0
|
||
|
||
batch_samples = []
|
||
for sample in worker_dataset:
|
||
# 检查是否达到最大迭代次数
|
||
if current_iter_index >= worker_quota:
|
||
break
|
||
|
||
text = sample.get(self.text_field, "")
|
||
if text:
|
||
pinyin_list = self.generate_pinyin(text)
|
||
for i in range(len(text)):
|
||
# 在开始处理每个字符前检查配额
|
||
if current_iter_index >= worker_quota:
|
||
break
|
||
|
||
labels = []
|
||
# 如果text[i]不在字符库中,则跳过
|
||
# 当i小于48时候,则将part1取text[0:i]
|
||
# 当i大于48时候,则将part1取text[i-48:i]
|
||
if not self.query_engine.is_chinese_char(text[i]):
|
||
continue
|
||
if i < 48:
|
||
part1 = text[0:i]
|
||
else:
|
||
part1 = text[i - 48 : i]
|
||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||
# 但是需要注意边界条件
|
||
pinyin_len = np.random.choice(
|
||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||
)
|
||
py_end = min(i + pinyin_len, len(text))
|
||
pinyin_len, part2 = self.get_mask_pinyin(
|
||
text[i:py_end], pinyin_list[i:py_end]
|
||
)
|
||
|
||
split_char = np.random.choice(
|
||
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||
)
|
||
|
||
part2 = split_char.join(part2)
|
||
pinyin_ids = text_to_pinyin_ids(part2)
|
||
len_py = len(pinyin_ids)
|
||
if len_py < 24:
|
||
pinyin_ids.extend([0] * (24 - len_py))
|
||
else:
|
||
pinyin_ids = pinyin_ids[:24]
|
||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||
|
||
# part3为文本,大概率(0.70)为空
|
||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||
# x为1-16中的任意整数,取值平均分布
|
||
part3 = ""
|
||
if random.random() > 0.7:
|
||
part3 = text[
|
||
i + pinyin_len : i
|
||
+ pinyin_len
|
||
+ np.random.choice(range(1, 17))
|
||
]
|
||
|
||
# part4为文本,0.50的概率为空
|
||
# 不为空则为1-5个连续字符串
|
||
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
||
# x为2-6中的任意整数,取值平均分布
|
||
# 使用|将part4中的字符串连接起来
|
||
part4 = ""
|
||
if random.random() > 0.5:
|
||
# 生成1-5个连续字符串
|
||
num_strings = random.randint(1, 5)
|
||
string_list = []
|
||
for _ in range(num_strings):
|
||
# 随机选择起始位置
|
||
start_pos = random.randint(0, len(text) - 1)
|
||
# 随机选择x的值(2-6)
|
||
x = random.randint(2, 6)
|
||
# 获取连续字符串
|
||
end_pos = min(start_pos + x + 1, len(text))
|
||
string_list.append(text[start_pos:end_pos])
|
||
# 用|连接所有字符串
|
||
part4 = "|".join(string_list)
|
||
try:
|
||
labels = [
|
||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||
for c, p in zip(
|
||
text[i : i + pinyin_len],
|
||
pinyin_list[i : i + pinyin_len],
|
||
)
|
||
]
|
||
except AttributeError as e:
|
||
logger.error(
|
||
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}"
|
||
)
|
||
continue
|
||
if random.random() <= 0.1:
|
||
labels.append(0)
|
||
|
||
# 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致)
|
||
history_slot_list = []
|
||
for j in range(i - 1, max(-1, i - 100), -1):
|
||
if j < 0:
|
||
break
|
||
char = text[j]
|
||
if self.query_engine.is_chinese_char(char):
|
||
try:
|
||
results = self.query_engine.query_by_char(char, limit=1)
|
||
if results and results[0][0] > 0:
|
||
history_slot_list.append(results[0][0])
|
||
except Exception:
|
||
pass
|
||
if len(history_slot_list) >= 8:
|
||
break
|
||
|
||
encoded = self.tokenizer(
|
||
f"{part4}|{part1}",
|
||
part3,
|
||
max_length=self.max_seq_length,
|
||
padding="max_length",
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
return_token_type_ids=True,
|
||
)
|
||
samples = []
|
||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||
for label_idx, label in enumerate(labels):
|
||
repeats = self.adjust_frequency(label)
|
||
# 使用从text[0:i]提取的历史槽位(与eval.py一致)
|
||
masked_labels = history_slot_list[:]
|
||
len_l = len(masked_labels)
|
||
masked_labels.extend([0] * (8 - len_l))
|
||
|
||
samples.extend(
|
||
[
|
||
{
|
||
"input_ids": encoded["input_ids"],
|
||
"token_type_ids": encoded["token_type_ids"],
|
||
"attention_mask": encoded["attention_mask"],
|
||
"label": torch.tensor([label], dtype=torch.long),
|
||
"history_slot_ids": torch.tensor(
|
||
masked_labels, dtype=torch.long
|
||
),
|
||
"prefix": f"{part4}^{part1}",
|
||
"suffix": part3,
|
||
"pinyin": part2,
|
||
"pinyin_ids": pinyin_ids,
|
||
}
|
||
]
|
||
* repeats
|
||
)
|
||
|
||
# 添加到缓冲区
|
||
batch_samples.extend(samples)
|
||
|
||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||
# 全量打乱缓冲区
|
||
indices = np.random.permutation(len(batch_samples))
|
||
|
||
# 计算实际保留大小(不超过缓冲区大小)
|
||
actual_retention = min(self.retention_size, len(batch_samples))
|
||
|
||
# 计算输出数量
|
||
output_count = len(batch_samples) - actual_retention
|
||
|
||
# 输出前output_count个样本
|
||
for i in range(output_count):
|
||
if current_iter_index >= worker_quota:
|
||
# 配额用完,清空缓冲区并返回
|
||
batch_samples = []
|
||
return
|
||
yield batch_samples[indices[i]]
|
||
current_iter_index += 1
|
||
|
||
# 保留后actual_retention个样本(不清空缓冲区)
|
||
retained_samples = [
|
||
batch_samples[idx] for idx in indices[output_count:]
|
||
]
|
||
batch_samples = retained_samples
|
||
|
||
# 处理剩余的样本
|
||
if batch_samples:
|
||
indices = np.random.permutation(len(batch_samples))
|
||
for idx in indices:
|
||
if current_iter_index >= worker_quota:
|
||
return
|
||
yield batch_samples[idx]
|
||
current_iter_index += 1
|