feat(dataset): 添加四段式文本编码方法并优化拼音处理逻辑

This commit is contained in:
songsenand 2026-04-03 08:19:43 +08:00
parent 5061cbe873
commit fd49058764
3 changed files with 1902 additions and 1838 deletions

View File

@ -24,10 +24,7 @@ class PinyinInputDataset(IterableDataset):
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 5000,
length_weights = {
1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2
},
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 30_000_000
@ -36,13 +33,16 @@ class PinyinInputDataset(IterableDataset):
self.max_repeat_expect = 50
self.min_freq = 109
self.tokenizer = AutoTokenizer.from_pretrained(Path(files(__package__) / "assets" / "tokenizer"))
self.tokenizer = AutoTokenizer.from_pretrained(
Path(files(__package__) / "assets" / "tokenizer")
)
self.data_path = data_path
self.max_length = max_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workes
self.py_style_weight = py_style_weight
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
self.possible_lengths = list(length_weights.keys())
self.weights = list(length_weights.values())
@ -54,7 +54,6 @@ class PinyinInputDataset(IterableDataset):
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字)
@ -96,6 +95,70 @@ class PinyinInputDataset(IterableDataset):
else:
return 1
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
input_ids = []
token_type_ids = []
# 添加 [CLS] (Type 0)
cls_id = self.tokenizer.cls_token_id
input_ids.append(cls_id)
token_type_ids.append(0)
for seg_idx, part in enumerate(parts):
if not part:
continue
# Tokenize 单个部分,不加特殊符号
# 注意这里先不截断最后统一截断保证优先级高的段落如part1完整
encoded_part = self.tokenizer(
part, add_special_tokens=False, truncation=False
)
part_ids = encoded_part["input_ids"]
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
remaining_space = (
self.max_length - len(input_ids) - 1
) # -1 for final SEP or safety
if len(part_ids) > remaining_space:
part_ids = part_ids[:remaining_space]
if not part_ids:
continue
input_ids.extend(part_ids)
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
token_type_ids.extend([seg_idx] * len(part_ids))
# 添加 [SEP] (Type 跟随当前段落)
sep_id = self.tokenizer.sep_token_id
input_ids.append(sep_id)
token_type_ids.append(seg_idx)
# 如果已经达到最大长度,提前退出
if len(input_ids) >= self.max_length:
break
# 4. 处理 Padding 或 最终截断
if len(input_ids) > self.max_length:
input_ids = input_ids[: self.max_length]
token_type_ids = token_type_ids[: self.max_length]
else:
pad_len = self.max_length - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * pad_len
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
# 5. 生成 Attention Mask
attention_mask = [
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
}
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[List[str]]:
return lazy_pinyin(text, errors=lambda x: [c for c in x])
@ -105,11 +168,11 @@ class PinyinInputDataset(IterableDataset):
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine.is_chinese_char(text[i]):
return i - 1, mask_pinyin
return i, mask_pinyin
else:
py = random.choice(
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
weights=self.py_style_weight,
p=self.py_style_weight,
)
if py == "":
py = pinyin_list[i][0]
@ -136,10 +199,11 @@ class PinyinInputDataset(IterableDataset):
if text:
pinyin_list = self.generate_pinyin(text)
for i in range(len(text)):
labels = []
# 如果text[i]不再字符库中,则跳过
# 当i小于48时候则将part1取text[0:i]
# 当i大于48时候则将part1取text[i-48:i]
if self.query_engine.is_chinese_char(text[i]):
if not self.query_engine.is_chinese_char(text[i]):
continue
if i < 48:
part1 = text[0:i]
@ -153,7 +217,10 @@ class PinyinInputDataset(IterableDataset):
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
)
py_end = min(i + pinyin_len, len(text))
pinyin_len, part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
pinyin_len, part2 = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
)
part2 = "".join(part2)
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
@ -188,27 +255,31 @@ class PinyinInputDataset(IterableDataset):
part4 = "|".join(string_list)
labels = [
self.query_engine.get_char_info_by_char_pinyin(
c, p
).id
for c, p in zip(text[i:py_end], pinyin_list[i:py_end])
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
]
labels.append(0)
encoded = self.tokenizer(
[part1, part2, part3, part4],
truncation=True,
max_length=self.max_seq_length,
padding='max_length',
return_tensors='pt', # 根据训练框架调整
add_special_tokens=True, # 添加[CLS], [SEP]
encoded = self.tokenize_with_four_seg(
[
part1,
part2,
part3,
part4,
]
)
repeats = self.adjust_frequency(
min([self.sample_freqs[i] for i in labels])
)
repeats = self.adjust_frequency(min([self.sample_freqs[i] for i in labels]))
sample = {
'input_ids': encoding['input_ids'].squeeze(0),
'token_type_ids': encoding['token_type_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': labels,
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"labels": torch.tensor(labels, dtype=torch.long),
"part1": part1,
"part2": part2,
"part3": part3,
"part4": part4,
}
batch_samples.extend([sample] * repeats)
if len(batch_samples) >= self.shuffle_buffer_size:
@ -216,5 +287,3 @@ class PinyinInputDataset(IterableDataset):
self.buffer.extend([batch_samples[i] for i in indices])
batch_samples = []
yield from self.buffer

7
test.py Normal file
View File

@ -0,0 +1,7 @@
from model.dataset import PinyinInputDataset
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
for i, line in enumerate(dataset):
print(line['labels'])
if i > 10:
break

3604
uv.lock

File diff suppressed because it is too large Load Diff