fix(tokenizer): 移除异常处理,直接加载指定tokenizer

This commit is contained in:
songsenand 2026-04-11 07:32:05 +08:00
parent a0e4d25b2f
commit 1cdef19153
6 changed files with 60 additions and 42 deletions

15
eval.py
View File

@ -78,16 +78,11 @@ class TextEvaluator:
def load_tokenizer(self):
"""加载tokenizer"""
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"⚠️ 无法加载tokenizer: {e}")
print("使用默认的bert-base-chinese tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
def load_query_engine(self):
"""加载查询引擎用于字符-ID转换"""

View File

@ -180,14 +180,15 @@ class SlotMemory(nn.Module):
Args:
history_ids: [batch, total_steps]
Flattened sequence of history tokens.
Empty positions should be filled with a special PAD or handled via mask.
Zero positions use start_emb as learnable query.
Returns:
S: [batch, total_steps, 512] Slot sequence representation [1]
"""
# Embed history tokens
S = self.emb(history_ids) # [B, 24, 512]
S = self.emb(history_ids) # [B, num_slots, dim]
zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1]
S = S * (1 - zero_mask) + self.start_emb * zero_mask
# Add positional embeddings
pos_ids = (
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
)

View File

@ -42,7 +42,8 @@ class PinyinInputDataset(IterableDataset):
max_seq_length=128,
text_field: str = "text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 5000,
shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.5,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
):
# 频率调整参数 (可根据需要调整)
@ -64,6 +65,16 @@ class PinyinInputDataset(IterableDataset):
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
self.retention_ratio = retention_ratio
if not (0 < retention_ratio < 1):
raise ValueError(
f"retention_ratio必须在0和1之间当前值: {retention_ratio}"
)
self.retention_size = int(shuffle_buffer_size * retention_ratio)
if self.retention_size <= 0:
raise ValueError(
f"计算出的retention_size必须大于0当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
)
self.possible_lengths = list(length_weights.keys())
self.weights = list(length_weights.values())
@ -360,17 +371,31 @@ class PinyinInputDataset(IterableDataset):
# 添加到缓冲区
batch_samples.extend(samples)
# 处理shuffle buffer
# 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size:
# 全量打乱缓冲区
indices = np.random.permutation(len(batch_samples))
for idx in indices:
# 计算实际保留大小(不超过缓冲区大小)
actual_retention = min(self.retention_size, len(batch_samples))
# 计算输出数量
output_count = len(batch_samples) - actual_retention
# 输出前output_count个样本
for i in range(output_count):
if current_iter_index >= worker_quota:
# 清空batch_samples并返回
# 配额用完,清空缓冲区并返回
batch_samples = []
return # 使用return而不是break因为我们在生成器函数中
yield batch_samples[idx]
return
yield batch_samples[indices[i]]
current_iter_index += 1
batch_samples = []
# 保留后actual_retention个样本不清空缓冲区
retained_samples = [
batch_samples[idx] for idx in indices[output_count:]
]
batch_samples = retained_samples
# 处理剩余的样本
if batch_samples:

View File

@ -124,20 +124,18 @@ class InputMethodEngine(nn.Module):
moe_out = self.moe(fused)
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile
# 使用显式形状重塑,彻底杜绝广播歧义
batch_size = input_ids.size(0)
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
# 若无有效槽位,使用上下文 H 的掩码均值
if torch.all(denominator == 0):
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
pooled = ctx_sum / ctx_cnt
else:
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
all_zero = denominator == 0 # [batch]
pooled = torch.where(
all_zero.unsqueeze(-1),
all_slot_mean,
numerator / (denominator.unsqueeze(-1) + 1e-8),
)
logits = self.classifier(pooled) # [batch, vocab_size]
return logits

View File

@ -1070,7 +1070,7 @@ def train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1093,7 +1093,7 @@ def train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=500000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1388,7 +1388,7 @@ def expand_and_train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1409,7 +1409,7 @@ def expand_and_train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1711,7 +1711,7 @@ def expand_finetune(
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1730,7 +1730,7 @@ def expand_finetune(
max_seq_length=final_max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
shuffle_buffer_size=100000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1802,4 +1802,3 @@ def expand_finetune(
if __name__ == "__main__":
app()

View File

@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
part1 = "他是一名大学生,在上海读"
part2 = "dayi"
part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都"
part2 = "bxu"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
@ -57,7 +57,7 @@ else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
part3 = ""
part3 = ""
part4 = "可行|特别|伤害"
encoded = tokenizer(
@ -83,7 +83,7 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
input_ids = sample["input_ids"]