feat(model): 优化数据集处理逻辑,增加频率调整功能

This commit is contained in:
songsenand 2026-04-02 21:30:03 +08:00
parent 886cb58cbb
commit 74416bfcb1
6 changed files with 2055 additions and 1739 deletions

BIN
.README.md.kate-swp Normal file

Binary file not shown.

View File

@ -54,7 +54,7 @@
- **优点**:保留特征维度,引入非线性,参数量适中。
## 5. 解码策略:束搜索
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**如5)。
- **搜索范围**:在每个槽位步内部执行束搜索,束宽设为 **k**默认为3)。
- **候选维护**:每个候选路径独立维护历史槽位序列(拼接后的嵌入)及累计概率。
- **终止条件**
1. 所有槽位已填满8×3=24步

View File

@ -20,6 +20,7 @@ dependencies = [
"requests>=2.32.5",
"rich>=14.3.1",
"tensorboard>=2.20.0",
"torch>=2.11.0",
"transformers==5.1.0",
"typer>=0.21.1",
]
@ -27,6 +28,7 @@ dependencies = [
[tool.uv]
# 设置当前项目的默认索引源
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
[dependency-groups]
dev = [
"autocommit",

View File

@ -24,7 +24,18 @@ class PinyinInputDataset(IterableDataset):
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 5000,
length_weights = {
1: 10, 2: 50, 3: 50, 4: 40,
5: 15, 6: 10, 7: 5, 8: 2
},
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 30_000_000
self.max_drop_prob = 0.8
self.repeat_end_freq = 10_000
self.max_repeat_expect = 50
self.min_freq = 109
self.tokenizer = AutoTokenizer.from_pretrained(Path(files(__package__) / "assets" / "tokenizer"))
self.data_path = data_path
self.max_length = max_length
@ -32,73 +43,58 @@ class PinyinInputDataset(IterableDataset):
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.max_workers = max_workes
self.py_style_weight = py_style_weight
self.possible_lengths = list(length_weights.keys())
self.weights = list(length_weights.values())
self.query_engine = QueryEngine()
self.query_engine.load()
self.shuffle_buffer_size = shuffle_buffer_size
self.buffer = []
# 提取每个样本的目标字符及其频率
self.sample_freqs = self.query_engine.get_all_weights()
def smart_multi_segment_encode(self, texts):
"""
智能多段落编码
思路使用分词器的基础功能但灵活控制token_type_ids
"""
# 第一步:使用分词器单独编码每个段落
encoded_segments = []
for text in texts:
# 注意:不添加特殊标记,我们后面统一处理
encoded = self.tokenizer.encode(text, add_special_tokens=False)
encoded_segments.append(encoded)
# 第二步:构建完整序列
tokens = []
token_type_ids = []
# 添加[CLS]
tokens.append(self.tokenizer.cls_token_id)
token_type_ids.append(0) # CLS通常为0
# 添加各个段落
for seg_idx, segment in enumerate(encoded_segments):
# 当前段落的类型0-3循环
current_type = seg_idx % 4
# 添加段落内容
tokens.extend(segment)
token_type_ids.extend([current_type] * len(segment))
# 添加[SEP](最后一个段落可以不加)
if seg_idx < len(encoded_segments) - 1:
tokens.append(self.tokenizer.sep_token_id)
token_type_ids.append(current_type)
def adjust_frequency(self, freq: int) -> int:
"""削峰填谷 - 根据频率调整采样次数0表示丢弃"""
# 1. 削峰处理(高频字)
if freq >= self.drop_start_freq:
# 线性丢弃概率计算
max_freq = max(self.sample_freqs) # 或使用预定义的全局最大值
if max_freq == self.drop_start_freq:
drop_prob = 0.0
else:
# 最后一个段落加[SEP]
tokens.append(self.tokenizer.sep_token_id)
token_type_ids.append(current_type)
drop_prob = (
self.max_drop_prob
* (freq - self.drop_start_freq)
/ (max_freq - self.drop_start_freq)
)
if random.random() < drop_prob:
return 0
else:
return 1
# 第三步:截断和填充
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
token_type_ids = token_type_ids[:self.max_length]
# 2. 填谷处理(低频字)
elif freq <= self.repeat_end_freq:
# 线性重复期望计算
if freq <= self.min_freq:
repeat_expect = self.max_repeat_expect
else:
if self.repeat_end_freq == self.min_freq:
repeat_expect = 0
else:
repeat_expect = (
self.max_repeat_expect
* (self.repeat_end_freq - freq)
/ (self.repeat_end_freq - self.min_freq)
)
# 使用泊松分布实现随机重复
repeat_count = np.random.poisson(repeat_expect)
return max(1, repeat_count)
# 3. 中间频率字
else:
# 填充
padding_length = self.max_length - len(tokens)
tokens = tokens + [self.tokenizer.pad_token_id] * padding_length
token_type_ids = token_type_ids + [0] * padding_length # 填充部分用0
# 第四步创建attention mask
attention_mask = [
1 if token != self.tokenizer.pad_token_id else 0 for token in tokens
]
return {
"input_ids": torch.tensor([tokens]),
"token_type_ids": torch.tensor([token_type_ids]),
"attention_mask": torch.tensor([attention_mask]),
}
return 1
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[List[str]]:
@ -108,7 +104,7 @@ class PinyinInputDataset(IterableDataset):
def get_mask_pinyin(self, text: str, pinyin_list: List[str]) -> (int, List[str]):
mask_pinyin = []
for i in range(len(text)):
if self.query_engine.is_chinese_char(text[i]):
if not self.query_engine.is_chinese_char(text[i]):
return i - 1, mask_pinyin
else:
py = random.choice(
@ -157,7 +153,7 @@ class PinyinInputDataset(IterableDataset):
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
)
py_end = min(i + pinyin_len, len(text))
part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
pinyin_len, part2 = self.get_mask_pinyin(text[i:py_end], pinyin_list[i:py_end])
# part3为文本大概率0.70)为空
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
@ -197,10 +193,24 @@ class PinyinInputDataset(IterableDataset):
).id
for c, p in zip(text[i:py_end], pinyin_list[i:py_end])
]
labels.append(0)
encoded = self.smart_multi_segment_encode([part1, part2, part3, part4])
encoded["label"] = labels
batch_samples.append(encoded)
encoded = self.tokenizer(
[part1, part2, part3, part4],
truncation=True,
max_length=self.max_seq_length,
padding='max_length',
return_tensors='pt', # 根据训练框架调整
add_special_tokens=True, # 添加[CLS], [SEP]
)
repeats = self.adjust_frequency(min([self.sample_freqs[i] for i in labels]))
sample = {
'input_ids': encoding['input_ids'].squeeze(0),
'token_type_ids': encoding['token_type_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': labels,
}
batch_samples.extend([sample] * repeats)
if len(batch_samples) >= self.shuffle_buffer_size:
indices = np.random.permutation(len(batch_samples))
self.buffer.extend([batch_samples[i] for i in indices])

View File

@ -314,8 +314,7 @@ class QueryEngine:
def get_all_weights(self):
"""获取所有字符-拼音对出现的次数 - O(n)时间复杂度"""
items_sorted = sorted(self._id_to_info.items(), key=lambda x: x[0])
return [info.count for _, info in items_sorted]
return {k: v.count for k, v in self._id_to_info.items()}
def get_char_info_by_char_pinyin(
self, char: str, pinyin: str

3655
uv.lock

File diff suppressed because it is too large Load Diff