SUimeModelTraner/src/model/dataset.py

425 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
from .query import QueryEngine
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
CHAR_TO_ID["'"] = 28 # 显式添加引号
CHAR_TO_ID["-"] = 29 # 显式添加短横
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
"""
将拼音字符串转换为 ID 列表。
支持 a-z 和 `。
未知字符映射为 0 (PAD/UNK)。
"""
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
class PinyinInputDataset(IterableDataset):
def __init__(
self,
data_path: str,
max_workers: int = -1,
max_iter_length=1e6,
max_seq_length=128,
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.5,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 30_000_000
self.max_drop_prob = 0.8
self.repeat_end_freq = 10_000
self.max_repeat_expect = 50
self.min_freq = 109
self.tokenizer = AutoTokenizer.from_pretrained(
Path(str(files(__package__))) / "assets" / "tokenizer"
)
self.data_path = data_path
self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
self.retention_ratio = retention_ratio
if not (0 < retention_ratio < 1):
raise ValueError(
f"retention_ratio必须在0和1之间当前值: {retention_ratio}"
)
self.retention_size = int(shuffle_buffer_size * retention_ratio)
if self.retention_size <= 0:
raise ValueError(
f"计算出的retention_size必须大于0当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
)
self.possible_lengths = list(length_weights.keys())
self.weights = list(length_weights.values())
self.query_engine = QueryEngine()
self.query_engine.load()
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字)
if freq >= self.drop_start_freq:
# 线性丢弃概率计算
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
if max_freq == self.drop_start_freq:
drop_prob = 0.0
else:
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (max_freq - self.drop_start_freq)
)
if random.random() < drop_prob:
return 0
else:
return 1
# 2. 填谷处理(低频字)
elif freq <= self.repeat_end_freq:
# 线性重复期望计算
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0
else:
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 使用泊松分布实现随机重复
repeat_count = np.random.poisson(repeat_expect)
return max(1, repeat_count)
# 3. 中间频率字
else:
return 1
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本,转换为拼音列表。
特性:
1. 严格一一对应len(result) == len(text)
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
3. 高性能:预分配内存,无多余对象创建
Args:
text: 输入字符串
Returns:
List[str]: 拼音或非汉字字符的列表
"""
if not text:
return []
text_len = len(text)
# 2. 预分配结果列表,初始化占位符。
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
result: List[str] = [""] * text_len
# 3. 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
# style=Style.NORMAL 获取不带声调的拼音
pinyin_list = lazy_pinyin(hanzi_segment)
# 5. 健壮性兜底:
# 正常情况下pypinyin 返回的拼音数应等于汉字数。
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
# 6. 直接通过索引填充到预分配的位置
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 7. 填充非汉字字符
# 遍历原文,如果 result 对应位置为空,则填入原字符
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
# 生成需要预测汉字对应的拼音,并进行加强
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
break
else:
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
p=self.py_style_weight,
)
if py == "":
py = pinyin_list[i][0]
mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = (
self.max_workers if self.max_workers > 0 else worker_info.num_workers
)
base_seed = torch.initial_seed() if hasattr(torch, "initial_seed") else 42
seed = base_seed + worker_id
random.seed(seed % (2**32))
np.random.seed(seed % (2**32))
# 安全检查如果worker_id >= num_workers则该worker不应该工作
# 这可能发生在self.max_workers小于实际worker数量时
if worker_id >= num_workers:
return # 产生空迭代器
# 使用局部变量存储分片数据集,避免竞争条件
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
# 计算每个worker的配额
# 将 max_iter_length 转换为整数以确保整数除法
total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers
remainder = total_quota % num_workers
# 最后一个worker处理剩余的样本如果有余数
if worker_id == num_workers - 1:
worker_quota = base_quota + remainder
else:
worker_quota = base_quota
else:
# 单worker情况使用全部配额
worker_quota = int(self.max_iter_length)
num_workers = 1
worker_dataset = self.dataset # 不使用分片
# 每个worker有自己的迭代计数器
current_iter_index = 0
batch_samples = []
for sample in worker_dataset:
# 检查是否达到最大迭代次数
if current_iter_index >= worker_quota:
break
text = sample.get(self.text_field, "")
if text:
pinyin_list = self.generate_pinyin(text)
for i in range(len(text)):
# 在开始处理每个字符前检查配额
if current_iter_index >= worker_quota:
break
labels = []
# 如果text[i]不在字符库中,则跳过
# 当i小于48时候则将part1取text[0:i]
# 当i大于48时候则将part1取text[i-48:i]
if not self.query_engine.is_chinese_char(text[i]):
continue
if i < 48:
part1 = text[0:i]
else:
part1 = text[i - 48 : i]
# 首先取随机值pinyin_len1-8pinyin_len取值呈高斯分布最大概率取3
# 获取text[i + pinyin_len]字符如果无法获取所指向的后如果pinyin_len
# part2的长度为x取pinyin_list[i:i+pinyin_len]为part2
# 但是需要注意边界条件
pinyin_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
)
py_end = min(i + pinyin_len, len(text))
pinyin_len, part2 = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
)
split_char = np.random.choice(
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
)
part2 = split_char.join(part2)
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
# x为1-16中的任意整数取值平均分布
part3 = ""
if random.random() > 0.7:
part3 = text[
i + pinyin_len : i
+ pinyin_len
+ np.random.choice(range(1, 17))
]
# part4为文本0.50的概率为空
# 不为空则为1-5个连续字符串
# 连续字符串的取值方法为随机从字符库中取一个字符以及该字符后x个字符
# x为2-6中的任意整数取值平均分布
# 使用|将part4中的字符串连接起来
part4 = ""
if random.random() > 0.5:
# 生成1-5个连续字符串
num_strings = random.randint(1, 5)
string_list = []
for _ in range(num_strings):
# 随机选择起始位置
start_pos = random.randint(0, len(text) - 1)
# 随机选择x的值(2-6)
x = random.randint(2, 6)
# 获取连续字符串
end_pos = min(start_pos + x + 1, len(text))
string_list.append(text[start_pos:end_pos])
# 用|连接所有字符串
part4 = "|".join(string_list)
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(
text[i : i + pinyin_len],
pinyin_list[i : i + pinyin_len],
)
]
except AttributeError as e:
logger.error(
f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}"
)
continue
if random.random() <= 0.1:
labels.append(0)
# 提取历史槽位从预测位置i之前的字符中获取与eval.py一致
history_slot_list = []
for j in range(i - 1, max(-1, i - 100), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
try:
results = self.query_engine.query_by_char(char, limit=1)
if results and results[0][0] > 0:
history_slot_list.append(results[0][0])
except Exception:
pass
if len(history_slot_list) >= 8:
break
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
samples = []
# 修复变量名冲突将内层循环变量i重命名为label_idx
for label_idx, label in enumerate(labels):
repeats = self.adjust_frequency(label)
# 使用从text[0:i]提取的历史槽位与eval.py一致
masked_labels = history_slot_list[:]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))
samples.extend(
[
{
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(
masked_labels, dtype=torch.long
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
]
* repeats
)
# 添加到缓冲区
batch_samples.extend(samples)
# 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size:
# 全量打乱缓冲区
indices = np.random.permutation(len(batch_samples))
# 计算实际保留大小(不超过缓冲区大小)
actual_retention = min(self.retention_size, len(batch_samples))
# 计算输出数量
output_count = len(batch_samples) - actual_retention
# 输出前output_count个样本
for i in range(output_count):
if current_iter_index >= worker_quota:
# 配额用完,清空缓冲区并返回
batch_samples = []
return
yield batch_samples[indices[i]]
current_iter_index += 1
# 保留后actual_retention个样本不清空缓冲区
retained_samples = [
batch_samples[idx] for idx in indices[output_count:]
]
batch_samples = retained_samples
# 处理剩余的样本
if batch_samples:
indices = np.random.permutation(len(batch_samples))
for idx in indices:
if current_iter_index >= worker_quota:
return
yield batch_samples[idx]
current_iter_index += 1