feat(dataset): 添加四段式文本编码方法并优化拼音处理逻辑
This commit is contained in:
parent
5061cbe873
commit
fd49058764
|
|
@ -24,10 +24,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size: int = 5000,
|
shuffle_buffer_size: int = 5000,
|
||||||
length_weights = {
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
1: 10, 2: 50, 3: 50, 4: 40,
|
|
||||||
5: 15, 6: 10, 7: 5, 8: 2
|
|
||||||
},
|
|
||||||
):
|
):
|
||||||
# 频率调整参数 (可根据需要调整)
|
# 频率调整参数 (可根据需要调整)
|
||||||
self.drop_start_freq = 30_000_000
|
self.drop_start_freq = 30_000_000
|
||||||
|
|
@ -36,13 +33,16 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.max_repeat_expect = 50
|
self.max_repeat_expect = 50
|
||||||
self.min_freq = 109
|
self.min_freq = 109
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(Path(files(__package__) / "assets" / "tokenizer"))
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
Path(files(__package__) / "assets" / "tokenizer")
|
||||||
|
)
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.text_field = text_field
|
self.text_field = text_field
|
||||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||||
self.max_workers = max_workes
|
self.max_workers = max_workes
|
||||||
self.py_style_weight = py_style_weight
|
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||||
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
self.possible_lengths = list(length_weights.keys())
|
self.possible_lengths = list(length_weights.keys())
|
||||||
self.weights = list(length_weights.values())
|
self.weights = list(length_weights.values())
|
||||||
|
|
||||||
|
|
@ -54,7 +54,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 提取每个样本的目标字符及其频率
|
# 提取每个样本的目标字符及其频率
|
||||||
self.sample_freqs = self.query_engine.get_all_weights()
|
self.sample_freqs = self.query_engine.get_all_weights()
|
||||||
|
|
||||||
|
|
||||||
def adjust_frequency(self, freq: int) -> int:
|
def adjust_frequency(self, freq: int) -> int:
|
||||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||||
# 1. 削峰处理(高频字)
|
# 1. 削峰处理(高频字)
|
||||||
|
|
@ -96,6 +95,70 @@ class PinyinInputDataset(IterableDataset):
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
def tokenize_with_four_seg(self, parts: List[str]) -> Dict[str, Any]:
|
||||||
|
input_ids = []
|
||||||
|
token_type_ids = []
|
||||||
|
|
||||||
|
# 添加 [CLS] (Type 0)
|
||||||
|
cls_id = self.tokenizer.cls_token_id
|
||||||
|
input_ids.append(cls_id)
|
||||||
|
token_type_ids.append(0)
|
||||||
|
|
||||||
|
for seg_idx, part in enumerate(parts):
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Tokenize 单个部分,不加特殊符号
|
||||||
|
# 注意:这里先不截断,最后统一截断,保证优先级高的段落(如part1)完整
|
||||||
|
encoded_part = self.tokenizer(
|
||||||
|
part, add_special_tokens=False, truncation=False
|
||||||
|
)
|
||||||
|
|
||||||
|
part_ids = encoded_part["input_ids"]
|
||||||
|
|
||||||
|
# 如果加上当前部分会超过 MAX_LEN - 1 (留一个位置给最后的SEP或截断),则截断当前部分
|
||||||
|
remaining_space = (
|
||||||
|
self.max_length - len(input_ids) - 1
|
||||||
|
) # -1 for final SEP or safety
|
||||||
|
if len(part_ids) > remaining_space:
|
||||||
|
part_ids = part_ids[:remaining_space]
|
||||||
|
|
||||||
|
if not part_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids.extend(part_ids)
|
||||||
|
# 当前段落的 type_id 即为 seg_idx (0, 1, 2, 3)
|
||||||
|
token_type_ids.extend([seg_idx] * len(part_ids))
|
||||||
|
|
||||||
|
# 添加 [SEP] (Type 跟随当前段落)
|
||||||
|
sep_id = self.tokenizer.sep_token_id
|
||||||
|
input_ids.append(sep_id)
|
||||||
|
token_type_ids.append(seg_idx)
|
||||||
|
|
||||||
|
# 如果已经达到最大长度,提前退出
|
||||||
|
if len(input_ids) >= self.max_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 4. 处理 Padding 或 最终截断
|
||||||
|
if len(input_ids) > self.max_length:
|
||||||
|
input_ids = input_ids[: self.max_length]
|
||||||
|
token_type_ids = token_type_ids[: self.max_length]
|
||||||
|
else:
|
||||||
|
pad_len = self.max_length - len(input_ids)
|
||||||
|
input_ids += [self.tokenizer.pad_token_id] * pad_len
|
||||||
|
token_type_ids += [0] * pad_len # Padding mask type 通常为 0
|
||||||
|
|
||||||
|
# 5. 生成 Attention Mask
|
||||||
|
attention_mask = [
|
||||||
|
1 if i != self.tokenizer.pad_token_id else 0 for i in input_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||||
|
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
||||||
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
# 生成对应文本的拼音
|
# 生成对应文本的拼音
|
||||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
||||||
|
|
@ -105,11 +168,11 @@ class PinyinInputDataset(IterableDataset):
|
||||||
mask_pinyin = []
|
mask_pinyin = []
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
if not self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
return i - 1, mask_pinyin
|
return i, mask_pinyin
|
||||||
else:
|
else:
|
||||||
py = random.choice(
|
py = np.random.choice(
|
||||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||||
weights=self.py_style_weight,
|
p=self.py_style_weight,
|
||||||
)
|
)
|
||||||
if py == "":
|
if py == "":
|
||||||
py = pinyin_list[i][0]
|
py = pinyin_list[i][0]
|
||||||
|
|
@ -136,10 +199,11 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if text:
|
if text:
|
||||||
pinyin_list = self.generate_pinyin(text)
|
pinyin_list = self.generate_pinyin(text)
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
|
labels = []
|
||||||
# 如果text[i]不再字符库中,则跳过
|
# 如果text[i]不再字符库中,则跳过
|
||||||
# 当i小于48时候,则将part1取text[0:i]
|
# 当i小于48时候,则将part1取text[0:i]
|
||||||
# 当i大于48时候,则将part1取text[i-48:i]
|
# 当i大于48时候,则将part1取text[i-48:i]
|
||||||
if self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
continue
|
continue
|
||||||
if i < 48:
|
if i < 48:
|
||||||
part1 = text[0:i]
|
part1 = text[0:i]
|
||||||
|
|
@ -153,7 +217,10 @@ class PinyinInputDataset(IterableDataset):
|
||||||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||||
)
|
)
|
||||||
py_end = min(i + pinyin_len, len(text))
|
py_end = min(i + pinyin_len, len(text))
|
||||||
pinyin_len, part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
|
pinyin_len, part2 = self.get_mask_pinyin(
|
||||||
|
text[i:py_end], pinyin_list[i:py_end]
|
||||||
|
)
|
||||||
|
part2 = "".join(part2)
|
||||||
|
|
||||||
# part3为文本,大概率(0.70)为空
|
# part3为文本,大概率(0.70)为空
|
||||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||||
|
|
@ -188,27 +255,31 @@ class PinyinInputDataset(IterableDataset):
|
||||||
part4 = "|".join(string_list)
|
part4 = "|".join(string_list)
|
||||||
|
|
||||||
labels = [
|
labels = [
|
||||||
self.query_engine.get_char_info_by_char_pinyin(
|
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||||
c, p
|
for c, p in zip(text[i:i + pinyin_len], pinyin_list[i:i + pinyin_len])
|
||||||
).id
|
|
||||||
for c, p in zip(text[i:py_end], pinyin_list[i:py_end])
|
|
||||||
]
|
]
|
||||||
labels.append(0)
|
labels.append(0)
|
||||||
|
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenize_with_four_seg(
|
||||||
[part1, part2, part3, part4],
|
[
|
||||||
truncation=True,
|
part1,
|
||||||
max_length=self.max_seq_length,
|
part2,
|
||||||
padding='max_length',
|
part3,
|
||||||
return_tensors='pt', # 根据训练框架调整
|
part4,
|
||||||
add_special_tokens=True, # 添加[CLS], [SEP]
|
]
|
||||||
|
)
|
||||||
|
repeats = self.adjust_frequency(
|
||||||
|
min([self.sample_freqs[i] for i in labels])
|
||||||
)
|
)
|
||||||
repeats = self.adjust_frequency(min([self.sample_freqs[i] for i in labels]))
|
|
||||||
sample = {
|
sample = {
|
||||||
'input_ids': encoding['input_ids'].squeeze(0),
|
"input_ids": encoded["input_ids"],
|
||||||
'token_type_ids': encoding['token_type_ids'].squeeze(0),
|
"token_type_ids": encoded["token_type_ids"],
|
||||||
'attention_mask': encoding['attention_mask'].squeeze(0),
|
"attention_mask": encoded["attention_mask"],
|
||||||
'labels': labels,
|
"labels": torch.tensor(labels, dtype=torch.long),
|
||||||
|
"part1": part1,
|
||||||
|
"part2": part2,
|
||||||
|
"part3": part3,
|
||||||
|
"part4": part4,
|
||||||
}
|
}
|
||||||
batch_samples.extend([sample] * repeats)
|
batch_samples.extend([sample] * repeats)
|
||||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||||
|
|
@ -216,5 +287,3 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.buffer.extend([batch_samples[i] for i in indices])
|
self.buffer.extend([batch_samples[i] for i in indices])
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
yield from self.buffer
|
yield from self.buffer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
from model.dataset import PinyinInputDataset
|
||||||
|
|
||||||
|
dataset = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||||
|
for i, line in enumerate(dataset):
|
||||||
|
print(line['labels'])
|
||||||
|
if i > 10:
|
||||||
|
break
|
||||||
Loading…
Reference in New Issue