feat(model): 优化数据集处理逻辑,增加频率调整功能
This commit is contained in:
parent
886cb58cbb
commit
74416bfcb1
Binary file not shown.
|
|
@ -54,7 +54,7 @@
|
|||
- **优点**:保留特征维度,引入非线性,参数量适中。
|
||||
|
||||
## 5. 解码策略:束搜索
|
||||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(如5)。
|
||||
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**(默认为3)。
|
||||
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
|
||||
- **终止条件**:
|
||||
1. 所有槽位已填满(8×3=24步);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ dependencies = [
|
|||
"requests>=2.32.5",
|
||||
"rich>=14.3.1",
|
||||
"tensorboard>=2.20.0",
|
||||
"torch>=2.11.0",
|
||||
"transformers==5.1.0",
|
||||
"typer>=0.21.1",
|
||||
]
|
||||
|
|
@ -27,6 +28,7 @@ dependencies = [
|
|||
[tool.uv]
|
||||
# 设置当前项目的默认索引源
|
||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"autocommit",
|
||||
|
|
|
|||
|
|
@ -24,7 +24,18 @@ class PinyinInputDataset(IterableDataset):
|
|||
text_field: str = "text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 5000,
|
||||
length_weights = {
|
||||
1: 10, 2: 50, 3: 50, 4: 40,
|
||||
5: 15, 6: 10, 7: 5, 8: 2
|
||||
},
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 30_000_000
|
||||
self.max_drop_prob = 0.8
|
||||
self.repeat_end_freq = 10_000
|
||||
self.max_repeat_expect = 50
|
||||
self.min_freq = 109
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(Path(files(__package__) / "assets" / "tokenizer"))
|
||||
self.data_path = data_path
|
||||
self.max_length = max_length
|
||||
|
|
@ -32,73 +43,58 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
self.max_workers = max_workes
|
||||
self.py_style_weight = py_style_weight
|
||||
self.possible_lengths = list(length_weights.keys())
|
||||
self.weights = list(length_weights.values())
|
||||
|
||||
self.query_engine = QueryEngine()
|
||||
self.query_engine.load()
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
self.buffer = []
|
||||
|
||||
# 提取每个样本的目标字符及其频率
|
||||
self.sample_freqs = self.query_engine.get_all_weights()
|
||||
|
||||
|
||||
def smart_multi_segment_encode(self, texts):
|
||||
"""
|
||||
智能多段落编码
|
||||
|
||||
思路:使用分词器的基础功能,但灵活控制token_type_ids
|
||||
"""
|
||||
# 第一步:使用分词器单独编码每个段落
|
||||
encoded_segments = []
|
||||
for text in texts:
|
||||
# 注意:不添加特殊标记,我们后面统一处理
|
||||
encoded = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
encoded_segments.append(encoded)
|
||||
|
||||
# 第二步:构建完整序列
|
||||
tokens = []
|
||||
token_type_ids = []
|
||||
|
||||
# 添加[CLS]
|
||||
tokens.append(self.tokenizer.cls_token_id)
|
||||
token_type_ids.append(0) # CLS通常为0
|
||||
|
||||
# 添加各个段落
|
||||
for seg_idx, segment in enumerate(encoded_segments):
|
||||
# 当前段落的类型(0-3循环)
|
||||
current_type = seg_idx % 4
|
||||
|
||||
# 添加段落内容
|
||||
tokens.extend(segment)
|
||||
token_type_ids.extend([current_type] * len(segment))
|
||||
|
||||
# 添加[SEP](最后一个段落可以不加)
|
||||
if seg_idx < len(encoded_segments) - 1:
|
||||
tokens.append(self.tokenizer.sep_token_id)
|
||||
token_type_ids.append(current_type)
|
||||
def adjust_frequency(self, freq: int) -> int:
|
||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||
# 1. 削峰处理(高频字)
|
||||
if freq >= self.drop_start_freq:
|
||||
# 线性丢弃概率计算
|
||||
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
|
||||
if max_freq == self.drop_start_freq:
|
||||
drop_prob = 0.0
|
||||
else:
|
||||
# 最后一个段落加[SEP]
|
||||
tokens.append(self.tokenizer.sep_token_id)
|
||||
token_type_ids.append(current_type)
|
||||
drop_prob = (
|
||||
self.max_drop_prob
|
||||
* (freq - self.drop_start_freq)
|
||||
/ (max_freq - self.drop_start_freq)
|
||||
)
|
||||
if random.random() < drop_prob:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
# 第三步:截断和填充
|
||||
if len(tokens) > self.max_length:
|
||||
tokens = tokens[:self.max_length]
|
||||
token_type_ids = token_type_ids[:self.max_length]
|
||||
# 2. 填谷处理(低频字)
|
||||
elif freq <= self.repeat_end_freq:
|
||||
# 线性重复期望计算
|
||||
if freq <= self.min_freq:
|
||||
repeat_expect = self.max_repeat_expect
|
||||
else:
|
||||
if self.repeat_end_freq == self.min_freq:
|
||||
repeat_expect = 0
|
||||
else:
|
||||
repeat_expect = (
|
||||
self.max_repeat_expect
|
||||
* (self.repeat_end_freq - freq)
|
||||
/ (self.repeat_end_freq - self.min_freq)
|
||||
)
|
||||
# 使用泊松分布实现随机重复
|
||||
repeat_count = np.random.poisson(repeat_expect)
|
||||
return max(1, repeat_count)
|
||||
|
||||
# 3. 中间频率字
|
||||
else:
|
||||
# 填充
|
||||
padding_length = self.max_length - len(tokens)
|
||||
tokens = tokens + [self.tokenizer.pad_token_id] * padding_length
|
||||
token_type_ids = token_type_ids + [0] * padding_length # 填充部分用0
|
||||
|
||||
# 第四步:创建attention mask
|
||||
attention_mask = [
|
||||
1 if token != self.tokenizer.pad_token_id else 0 for token in tokens
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor([tokens]),
|
||||
"token_type_ids": torch.tensor([token_type_ids]),
|
||||
"attention_mask": torch.tensor([attention_mask]),
|
||||
}
|
||||
return 1
|
||||
|
||||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[List[str]]:
|
||||
|
|
@ -108,7 +104,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> (int, List[str]):
|
||||
mask_pinyin = []
|
||||
for i in range(len(text)):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
if not self.query_engine.is_chinese_char(text[i]):
|
||||
return i - 1, mask_pinyin
|
||||
else:
|
||||
py = random.choice(
|
||||
|
|
@ -157,7 +153,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||
)
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
|
||||
pinyin_len, part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
|
||||
|
||||
# part3为文本,大概率(0.70)为空
|
||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||
|
|
@ -197,10 +193,24 @@ class PinyinInputDataset(IterableDataset):
|
|||
).id
|
||||
for c, p in zip(text[i:py_end], pinyin_list[i:py_end])
|
||||
]
|
||||
labels.append(0)
|
||||
|
||||
encoded = self.smart_multi_segment_encode([part1, part2, part3, part4])
|
||||
encoded["label"] = labels
|
||||
batch_samples.append(encoded)
|
||||
encoded = self.tokenizer(
|
||||
[part1, part2, part3, part4],
|
||||
truncation=True,
|
||||
max_length=self.max_seq_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt', # 根据训练框架调整
|
||||
add_special_tokens=True, # 添加[CLS], [SEP]
|
||||
)
|
||||
repeats = self.adjust_frequency(min([self.sample_freqs[i] for i in labels]))
|
||||
sample = {
|
||||
'input_ids': encoding['input_ids'].squeeze(0),
|
||||
'token_type_ids': encoding['token_type_ids'].squeeze(0),
|
||||
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||
'labels': labels,
|
||||
}
|
||||
batch_samples.extend([sample] * repeats)
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
self.buffer.extend([batch_samples[i] for i in indices])
|
||||
|
|
|
|||
|
|
@ -314,8 +314,7 @@ class QueryEngine:
|
|||
|
||||
def get_all_weights(self):
|
||||
"""获取所有字符-拼音对出现的次数 - O(n)时间复杂度"""
|
||||
items_sorted = sorted(self._id_to_info.items(), key=lambda x: x[0])
|
||||
return [info.count for _, info in items_sorted]
|
||||
return {k: v.count for k, v in self._id_to_info.items()}
|
||||
|
||||
def get_char_info_by_char_pinyin(
|
||||
self, char: str, pinyin: str
|
||||
|
|
|
|||
Loading…
Reference in New Issue