fix(tokenizer): 移除异常处理,直接加载指定tokenizer
This commit is contained in:
parent
a0e4d25b2f
commit
1cdef19153
15
eval.py
15
eval.py
|
|
@ -78,16 +78,11 @@ class TextEvaluator:
|
||||||
|
|
||||||
def load_tokenizer(self):
|
def load_tokenizer(self):
|
||||||
"""加载tokenizer"""
|
"""加载tokenizer"""
|
||||||
try:
|
tokenizer_path = (
|
||||||
tokenizer_path = (
|
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||||
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
)
|
||||||
)
|
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||||||
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ 无法加载tokenizer: {e}")
|
|
||||||
print("使用默认的bert-base-chinese tokenizer")
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
|
||||||
|
|
||||||
def load_query_engine(self):
|
def load_query_engine(self):
|
||||||
"""加载查询引擎用于字符-ID转换"""
|
"""加载查询引擎用于字符-ID转换"""
|
||||||
|
|
|
||||||
|
|
@ -180,14 +180,15 @@ class SlotMemory(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
history_ids: [batch, total_steps]
|
history_ids: [batch, total_steps]
|
||||||
Flattened sequence of history tokens.
|
Flattened sequence of history tokens.
|
||||||
Empty positions should be filled with a special PAD or handled via mask.
|
Zero positions use start_emb as learnable query.
|
||||||
Returns:
|
Returns:
|
||||||
S: [batch, total_steps, 512] Slot sequence representation [1]
|
S: [batch, total_steps, 512] Slot sequence representation [1]
|
||||||
"""
|
"""
|
||||||
# Embed history tokens
|
S = self.emb(history_ids) # [B, num_slots, dim]
|
||||||
S = self.emb(history_ids) # [B, 24, 512]
|
|
||||||
|
zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1]
|
||||||
|
S = S * (1 - zero_mask) + self.start_emb * zero_mask
|
||||||
|
|
||||||
# Add positional embeddings
|
|
||||||
pos_ids = (
|
pos_ids = (
|
||||||
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
max_seq_length=128,
|
max_seq_length=128,
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size: int = 5000,
|
shuffle_buffer_size: int = 100000,
|
||||||
|
retention_ratio: float = 0.5,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
):
|
):
|
||||||
# 频率调整参数 (可根据需要调整)
|
# 频率调整参数 (可根据需要调整)
|
||||||
|
|
@ -64,6 +65,16 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.max_workers = max_workers
|
self.max_workers = max_workers
|
||||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||||
self.shuffle_buffer_size = shuffle_buffer_size
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
|
self.retention_ratio = retention_ratio
|
||||||
|
if not (0 < retention_ratio < 1):
|
||||||
|
raise ValueError(
|
||||||
|
f"retention_ratio必须在0和1之间,当前值: {retention_ratio}"
|
||||||
|
)
|
||||||
|
self.retention_size = int(shuffle_buffer_size * retention_ratio)
|
||||||
|
if self.retention_size <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})"
|
||||||
|
)
|
||||||
self.possible_lengths = list(length_weights.keys())
|
self.possible_lengths = list(length_weights.keys())
|
||||||
self.weights = list(length_weights.values())
|
self.weights = list(length_weights.values())
|
||||||
|
|
||||||
|
|
@ -360,17 +371,31 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 添加到缓冲区
|
# 添加到缓冲区
|
||||||
batch_samples.extend(samples)
|
batch_samples.extend(samples)
|
||||||
|
|
||||||
# 处理shuffle buffer
|
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||||
|
# 全量打乱缓冲区
|
||||||
indices = np.random.permutation(len(batch_samples))
|
indices = np.random.permutation(len(batch_samples))
|
||||||
for idx in indices:
|
|
||||||
|
# 计算实际保留大小(不超过缓冲区大小)
|
||||||
|
actual_retention = min(self.retention_size, len(batch_samples))
|
||||||
|
|
||||||
|
# 计算输出数量
|
||||||
|
output_count = len(batch_samples) - actual_retention
|
||||||
|
|
||||||
|
# 输出前output_count个样本
|
||||||
|
for i in range(output_count):
|
||||||
if current_iter_index >= worker_quota:
|
if current_iter_index >= worker_quota:
|
||||||
# 清空batch_samples并返回
|
# 配额用完,清空缓冲区并返回
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
return # 使用return而不是break,因为我们在生成器函数中
|
return
|
||||||
yield batch_samples[idx]
|
yield batch_samples[indices[i]]
|
||||||
current_iter_index += 1
|
current_iter_index += 1
|
||||||
batch_samples = []
|
|
||||||
|
# 保留后actual_retention个样本(不清空缓冲区)
|
||||||
|
retained_samples = [
|
||||||
|
batch_samples[idx] for idx in indices[output_count:]
|
||||||
|
]
|
||||||
|
batch_samples = retained_samples
|
||||||
|
|
||||||
# 处理剩余的样本
|
# 处理剩余的样本
|
||||||
if batch_samples:
|
if batch_samples:
|
||||||
|
|
|
||||||
|
|
@ -124,20 +124,18 @@ class InputMethodEngine(nn.Module):
|
||||||
moe_out = self.moe(fused)
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
||||||
# 使用显式形状重塑,彻底杜绝广播歧义
|
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
||||||
# 若无有效槽位,使用上下文 H 的掩码均值
|
|
||||||
if torch.all(denominator == 0):
|
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
|
||||||
# H: [batch, seq_len, dim], attention_mask: [batch, seq_len]
|
all_zero = denominator == 0 # [batch]
|
||||||
ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1]
|
pooled = torch.where(
|
||||||
ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim]
|
all_zero.unsqueeze(-1),
|
||||||
ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1]
|
all_slot_mean,
|
||||||
pooled = ctx_sum / ctx_cnt
|
numerator / (denominator.unsqueeze(-1) + 1e-8),
|
||||||
else:
|
)
|
||||||
pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim]
|
|
||||||
|
|
||||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
|
|
@ -1070,7 +1070,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=50000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1093,7 +1093,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=500000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1388,7 +1388,7 @@ def expand_and_train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=5000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1409,7 +1409,7 @@ def expand_and_train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=50000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1711,7 +1711,7 @@ def expand_finetune(
|
||||||
max_seq_length=final_max_seq_len,
|
max_seq_length=final_max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=5000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1730,7 +1730,7 @@ def expand_finetune(
|
||||||
max_seq_length=final_max_seq_len,
|
max_seq_length=final_max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=50000,
|
shuffle_buffer_size=100000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1802,4 +1802,3 @@ def expand_finetune(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
|
||||||
8
test.py
8
test.py
|
|
@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
part1 = "他是一名大学生,在上海读"
|
part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都"
|
||||||
part2 = "dayi"
|
part2 = "bxu"
|
||||||
pinyin_ids = text_to_pinyin_ids(part2)
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
len_py = len(pinyin_ids)
|
len_py = len(pinyin_ids)
|
||||||
if len_py < 24:
|
if len_py < 24:
|
||||||
|
|
@ -57,7 +57,7 @@ else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = "。"
|
part3 = ""
|
||||||
part4 = "可行|特别|伤害"
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
encoded = tokenizer(
|
encoded = tokenizer(
|
||||||
|
|
@ -83,7 +83,7 @@ sample = {
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu")
|
checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue