fix(tokenizer): 移除异常处理,直接加载指定tokenizer
This commit is contained in:
parent
a0e4d25b2f
commit
1cdef19153
5
eval.py
5
eval.py
|
|
@ -78,16 +78,11 @@ class TextEvaluator:
|
|||
|
||||
def load_tokenizer(self):
|
||||
"""加载tokenizer"""
|
||||
try:
|
||||
tokenizer_path = (
|
||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 无法加载tokenizer: {e}")
|
||||
print("使用默认的bert-base-chinese tokenizer")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||
|
||||
def load_query_engine(self):
|
||||
"""加载查询引擎用于字符-ID转换"""
|
||||
|
|
|
|||
|
|
@ -180,14 +180,15 @@ class SlotMemory(nn.Module):
|
|||
Args:
|
||||
history_ids: [batch, total_steps]
|
||||
Flattened sequence of history tokens.
|
||||
Empty positions should be filled with a special PAD or handled via mask.
|
||||
Zero positions use start_emb as learnable query.
|
||||
Returns:
|
||||
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||||
"""
|
||||
# Embed history tokens
|
||||
S = self.emb(history_ids) # [B, 24, 512]
|
||||
S = self.emb(history_ids) # [B, num_slots, dim]
|
||||
|
||||
zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1]
|
||||
S = S * (1 - zero_mask) + self.start_emb * zero_mask
|
||||
|
||||
# Add positional embeddings
|
||||
pos_ids = (
|
||||
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
max_seq_length=128,
|
||||
text_field: str = "text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 5000,
|
||||
shuffle_buffer_size: int = 100000,
|
||||
retention_ratio: float = 0.5,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
|
|
@ -64,6 +65,16 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.max_workers = max_workers
|
||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
self.retention_ratio = retention_ratio
|
||||
if not (0 < retention_ratio < 1):
|
||||
raise ValueError(
|
||||
f"retention_ratio必须在0和1之间,当前值: {retention_ratio}"
|
||||
)
|
||||
self.retention_size = int(shuffle_buffer_size * retention_ratio)
|
||||
if self.retention_size <= 0:
|
||||
raise ValueError(
|
||||
f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
|
||||
)
|
||||
self.possible_lengths = list(length_weights.keys())
|
||||
self.weights = list(length_weights.values())
|
||||
|
||||
|
|
@ -360,17 +371,31 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 添加到缓冲区
|
||||
batch_samples.extend(samples)
|
||||
|
||||
# 处理shuffle buffer
|
||||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
# 全量打乱缓冲区
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
for idx in indices:
|
||||
|
||||
# 计算实际保留大小(不超过缓冲区大小)
|
||||
actual_retention = min(self.retention_size, len(batch_samples))
|
||||
|
||||
# 计算输出数量
|
||||
output_count = len(batch_samples) - actual_retention
|
||||
|
||||
# 输出前output_count个样本
|
||||
for i in range(output_count):
|
||||
if current_iter_index >= worker_quota:
|
||||
# 清空batch_samples并返回
|
||||
# 配额用完,清空缓冲区并返回
|
||||
batch_samples = []
|
||||
return # 使用return而不是break,因为我们在生成器函数中
|
||||
yield batch_samples[idx]
|
||||
return
|
||||
yield batch_samples[indices[i]]
|
||||
current_iter_index += 1
|
||||
batch_samples = []
|
||||
|
||||
# 保留后actual_retention个样本(不清空缓冲区)
|
||||
retained_samples = [
|
||||
batch_samples[idx] for idx in indices[output_count:]
|
||||
]
|
||||
batch_samples = retained_samples
|
||||
|
||||
# 处理剩余的样本
|
||||
if batch_samples:
|
||||
|
|
|
|||
|
|
@ -124,20 +124,18 @@ class InputMethodEngine(nn.Module):
|
|||
moe_out = self.moe(fused)
|
||||
|
||||
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
||||
# 使用显式形状重塑,彻底杜绝广播歧义
|
||||
batch_size = input_ids.size(0)
|
||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
||||
# 若无有效槽位,使用上下文 H 的掩码均值
|
||||
if torch.all(denominator == 0):
|
||||
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
|
||||
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
|
||||
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
|
||||
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
|
||||
pooled = ctx_sum / ctx_cnt
|
||||
else:
|
||||
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
|
||||
|
||||
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
|
||||
all_zero = denominator == 0 # [batch]
|
||||
pooled = torch.where(
|
||||
all_zero.unsqueeze(-1),
|
||||
all_slot_mean,
|
||||
numerator / (denominator.unsqueeze(-1) + 1e-8),
|
||||
)
|
||||
|
||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -1070,7 +1070,7 @@ def train(
|
|||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=50000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1093,7 +1093,7 @@ def train(
|
|||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=500000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1388,7 +1388,7 @@ def expand_and_train(
|
|||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=5000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1409,7 +1409,7 @@ def expand_and_train(
|
|||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=50000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1711,7 +1711,7 @@ def expand_finetune(
|
|||
max_seq_length=final_max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=5000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1730,7 +1730,7 @@ def expand_finetune(
|
|||
max_seq_length=final_max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=50000,
|
||||
shuffle_buffer_size=100000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
|
|
@ -1802,4 +1802,3 @@ def expand_finetune(
|
|||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
|
|
|
|||
8
test.py
8
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
|||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||
|
||||
|
||||
part1 = "他是一名大学生,在上海读"
|
||||
part2 = "dayi"
|
||||
part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都"
|
||||
part2 = "bxu"
|
||||
pinyin_ids = text_to_pinyin_ids(part2)
|
||||
len_py = len(pinyin_ids)
|
||||
if len_py < 24:
|
||||
|
|
@ -57,7 +57,7 @@ else:
|
|||
pinyin_ids = pinyin_ids[:24]
|
||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
part3 = "。"
|
||||
part3 = ""
|
||||
part4 = "可行|特别|伤害"
|
||||
|
||||
encoded = tokenizer(
|
||||
|
|
@ -83,7 +83,7 @@ sample = {
|
|||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
||||
checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue