feat(dataset): 添加四段式文本编码方法并优化拼音处理逻辑
This commit is contained in:
parent
5061cbe873
commit
fd49058764
|
|
@ -24,10 +24,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
text_field: str = "text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 5000,
|
||||
length_weights = {
|
||||
1: 10, 2: 50, 3: 50, 4: 40,
|
||||
5: 15, 6: 10, 7: 5, 8: 2
|
||||
},
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 30_000_000
|
||||
|
|
@ -36,13 +33,16 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.max_repeat_expect = 50
|
||||
self.min_freq = 109
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(Path(files(__package__) / "assets" / "tokenizer"))
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(files(__package__) / "assets" / "tokenizer")
|
||||
)
|
||||
self.data_path = data_path
|
||||
self.max_length = max_length
|
||||
self.text_field = text_field
|
||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
self.max_workers = max_workes
|
||||
self.py_style_weight = py_style_weight
|
||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
self.possible_lengths = list(length_weights.keys())
|
||||
self.weights = list(length_weights.values())
|
||||
|
||||
|
|
@ -54,7 +54,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 提取每个样本的目标字符及其频率
|
||||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
|
||||
|
||||
def adjust_frequency(self, freq: int) -> int:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||
# 1. 削峰处理(高频字)
|
||||
|
|
@ -96,6 +95,70 @@ class PinyinInputDataset(IterableDataset):
|
|||
else:
|
||||
return 1
|
||||
|
||||
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
|
||||
# 添加 [CLS] (Type 0)
|
||||
cls_id = self.tokenizer.cls_token_id
|
||||
input_ids.append(cls_id)
|
||||
token_type_ids.append(0)
|
||||
|
||||
for seg_idx, part in enumerate(parts):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Tokenize 单个部分,不加特殊符号
|
||||
# 注意:这里先不截断,最后统一截断,保证优先级高的段落(如part1)完整
|
||||
encoded_part = self.tokenizer(
|
||||
part, add_special_tokens=False, truncation=False
|
||||
)
|
||||
|
||||
part_ids = encoded_part["input_ids"]
|
||||
|
||||
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
|
||||
remaining_space = (
|
||||
self.max_length - len(input_ids) - 1
|
||||
) # -1 for final SEP or safety
|
||||
if len(part_ids) > remaining_space:
|
||||
part_ids = part_ids[:remaining_space]
|
||||
|
||||
if not part_ids:
|
||||
continue
|
||||
|
||||
input_ids.extend(part_ids)
|
||||
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
|
||||
token_type_ids.extend([seg_idx] * len(part_ids))
|
||||
|
||||
# 添加 [SEP] (Type 跟随当前段落)
|
||||
sep_id = self.tokenizer.sep_token_id
|
||||
input_ids.append(sep_id)
|
||||
token_type_ids.append(seg_idx)
|
||||
|
||||
# 如果已经达到最大长度,提前退出
|
||||
if len(input_ids) >= self.max_length:
|
||||
break
|
||||
|
||||
# 4. 处理 Padding 或 最终截断
|
||||
if len(input_ids) > self.max_length:
|
||||
input_ids = input_ids[: self.max_length]
|
||||
token_type_ids = token_type_ids[: self.max_length]
|
||||
else:
|
||||
pad_len = self.max_length - len(input_ids)
|
||||
input_ids += [self.tokenizer.pad_token_id] * pad_len
|
||||
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
|
||||
|
||||
# 5. 生成 Attention Mask
|
||||
attention_mask = [
|
||||
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
}
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||
|
|
@ -105,11 +168,11 @@ class PinyinInputDataset(IterableDataset):
|
|||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
return i - 1, mask_pinyin
|
||||
return i, mask_pinyin
|
||||
else:
|
||||
py = random.choice(
|
||||
py = np.random.choice(
|
||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||
weights=self.py_style_weight,
|
||||
p=self.py_style_weight,
|
||||
)
|
||||
if py == "":
|
||||
py = pinyin_list[i][0]
|
||||
|
|
@ -136,10 +199,11 @@ class PinyinInputDataset(IterableDataset):
|
|||
if text:
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
for i in range(len(text)):
|
||||
labels = []
|
||||
# 如果text[i]不再字符库中,则跳过
|
||||
# 当i小于48时候,则将part1取text[0:i]
|
||||
# 当i大于48时候,则将part1取text[i-48:i]
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
continue
|
||||
if i < 48:
|
||||
part1 = text[0:i]
|
||||
|
|
@ -153,7 +217,10 @@ class PinyinInputDataset(IterableDataset):
|
|||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||
)
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
pinyin_len, part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
|
||||
pinyin_len, part2 = self.get_mask_pinyin(
|
||||
text[i:py_end], pinyin_list[i:py_end]
|
||||
)
|
||||
part2 = "".join(part2)
|
||||
|
||||
# part3为文本,大概率(0.70)为空
|
||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||
|
|
@ -188,27 +255,31 @@ class PinyinInputDataset(IterableDataset):
|
|||
part4 = "|".join(string_list)
|
||||
|
||||
labels = [
|
||||
self.query_engine.get_char_info_by_char_pinyin(
|
||||
c, p
|
||||
).id
|
||||
for c, p in zip(text[i:py_end], pinyin_list[i:py_end])
|
||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
|
||||
]
|
||||
labels.append(0)
|
||||
|
||||
encoded = self.tokenizer(
|
||||
[part1, part2, part3, part4],
|
||||
truncation=True,
|
||||
max_length=self.max_seq_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt', # 根据训练框架调整
|
||||
add_special_tokens=True, # 添加[CLS], [SEP]
|
||||
encoded = self.tokenize_with_four_seg(
|
||||
[
|
||||
part1,
|
||||
part2,
|
||||
part3,
|
||||
part4,
|
||||
]
|
||||
)
|
||||
repeats = self.adjust_frequency(
|
||||
min([self.sample_freqs[i] for i in labels])
|
||||
)
|
||||
repeats = self.adjust_frequency(min([self.sample_freqs[i] for i in labels]))
|
||||
sample = {
|
||||
'input_ids': encoding['input_ids'].squeeze(0),
|
||||
'token_type_ids': encoding['token_type_ids'].squeeze(0),
|
||||
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||
'labels': labels,
|
||||
"input_ids": encoded["input_ids"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
"part1": part1,
|
||||
"part2": part2,
|
||||
"part3": part3,
|
||||
"part4": part4,
|
||||
}
|
||||
batch_samples.extend([sample] * repeats)
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
|
|
@ -216,5 +287,3 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.buffer.extend([batch_samples[i] for i in indices])
|
||||
batch_samples = []
|
||||
yield from self.buffer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
from model.dataset import PinyinInputDataset
|
||||
|
||||
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||
for i, line in enumerate(dataset):
|
||||
print(line['labels'])
|
||||
if i > 10:
|
||||
break
|
||||
Loading…
Reference in New Issue