feat(pinyin): 添加拼音LSTM编码器以支持多字符预测的逐步确认评估
This commit is contained in:
parent
68a6fc3533
commit
bb78e0afa0
258
eval.py
258
eval.py
|
|
@ -274,7 +274,7 @@ class TextEvaluator:
|
||||||
part1 = text[i - 48 : i]
|
part1 = text[i - 48 : i]
|
||||||
|
|
||||||
# part2: 拼音输入(随机长度1-8,高斯分布)
|
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||||||
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||||
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||||||
py_end = min(i + pinyin_len, len(text))
|
py_end = min(i + pinyin_len, len(text))
|
||||||
|
|
||||||
|
|
@ -332,53 +332,30 @@ class TextEvaluator:
|
||||||
print(f"⚠️ 获取标签失败: {e}")
|
print(f"⚠️ 获取标签失败: {e}")
|
||||||
labels = [0] * pinyin_len_actual
|
labels = [0] * pinyin_len_actual
|
||||||
|
|
||||||
# 历史槽位:当前预测字符之前已确认的字符
|
# 历史槽位:模拟用户逐步确认过程
|
||||||
# 从位置i之前的字符中提取历史槽位(最多8个)
|
# 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择
|
||||||
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
|
# 注意:评估时不知道正确答案,只能基于模型预测来构建历史
|
||||||
history_slot_ids = []
|
history_slot_ids = []
|
||||||
if self.query_engine:
|
|
||||||
# 收集i之前的汉字字符(最多8个)
|
|
||||||
for j in range(i - 1, max(-1, i - 100), -1):
|
|
||||||
if j < 0:
|
|
||||||
break
|
|
||||||
char = text[j]
|
|
||||||
if self.query_engine.is_chinese_char(char):
|
|
||||||
# 尝试获取该字符的ID(与训练数据一致)
|
|
||||||
results = self.query_engine.query_by_char(char, limit=1)
|
|
||||||
if results:
|
|
||||||
slot_id = results[0][0]
|
|
||||||
# 只添加有效的非零ID(与训练数据一致)
|
|
||||||
if slot_id > 0:
|
|
||||||
history_slot_ids.append(slot_id)
|
|
||||||
if len(history_slot_ids) >= 8:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 填充到8个槽位(使用0填充,与训练数据一致)
|
# 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致)
|
||||||
|
if force_slot_count is not None:
|
||||||
|
force_slot_count = max(0, min(8, force_slot_count))
|
||||||
|
# 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0
|
||||||
|
# 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致)
|
||||||
|
for _ in range(force_slot_count):
|
||||||
|
history_slot_ids.append(random.randint(1, 1000))
|
||||||
|
else:
|
||||||
|
# 正常评估:历史槽位为空(模拟从零开始输入)
|
||||||
|
# 注意:实际使用时,历史槽位应该来自之前的用户选择
|
||||||
|
# 但为了评估公平,我们从空历史开始
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 填充到8个槽位
|
||||||
if len(history_slot_ids) < 8:
|
if len(history_slot_ids) < 8:
|
||||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||||
else:
|
else:
|
||||||
history_slot_ids = history_slot_ids[:8]
|
history_slot_ids = history_slot_ids[:8]
|
||||||
|
|
||||||
# 强制设置有效槽位数量(用于按槽位数量评估)
|
|
||||||
# 注意:这里不再使用随机ID填充,而是复制现有的有效槽位
|
|
||||||
if force_slot_count is not None:
|
|
||||||
force_slot_count = max(0, min(8, force_slot_count))
|
|
||||||
# 获取所有非零槽位ID
|
|
||||||
valid_ids = [s for s in history_slot_ids if s != 0]
|
|
||||||
|
|
||||||
if len(valid_ids) >= force_slot_count:
|
|
||||||
# 如果有效槽位足够,直接取前force_slot_count个
|
|
||||||
history_slot_ids = valid_ids[:force_slot_count]
|
|
||||||
else:
|
|
||||||
# 如果有效槽位不足,保留现有有效槽位,然后用0填充
|
|
||||||
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
|
|
||||||
history_slot_ids = valid_ids[:]
|
|
||||||
|
|
||||||
# 填充到8个槽位
|
|
||||||
while len(history_slot_ids) < 8:
|
|
||||||
history_slot_ids.append(0)
|
|
||||||
history_slot_ids = history_slot_ids[:8]
|
|
||||||
|
|
||||||
# Tokenize输入
|
# Tokenize输入
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
|
|
@ -586,6 +563,115 @@ class TextEvaluator:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def evaluate_sample_with_sequential_confirmation(
|
||||||
|
self, sample: Dict, print_details: bool = True
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
评估样本并模拟用户逐步确认过程(用于多字符预测)。
|
||||||
|
|
||||||
|
对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估结果字典
|
||||||
|
"""
|
||||||
|
# 获取真实标签和拼音长度
|
||||||
|
true_labels = sample.get("true_labels", [])
|
||||||
|
pinyin_len = len(true_labels) if true_labels else 1
|
||||||
|
|
||||||
|
if pinyin_len <= 1:
|
||||||
|
# 单字符预测,直接使用普通评估
|
||||||
|
return self.evaluate_sample(sample, print_details)
|
||||||
|
|
||||||
|
# 多字符预测:模拟逐步确认
|
||||||
|
confirmed_history = [] # 用户已确认的字符ID
|
||||||
|
all_predictions = [] # 每一步的预测结果
|
||||||
|
all_correct = [] # 每一步是否正确
|
||||||
|
|
||||||
|
# 复制样本用于逐步推理
|
||||||
|
sample_copy = sample.copy()
|
||||||
|
|
||||||
|
for step in range(pinyin_len):
|
||||||
|
# 更新历史槽位:使用已确认的字符
|
||||||
|
history_slot_ids = confirmed_history[:]
|
||||||
|
if len(history_slot_ids) < 8:
|
||||||
|
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||||
|
else:
|
||||||
|
history_slot_ids = history_slot_ids[:8]
|
||||||
|
|
||||||
|
# 更新样本的历史槽位
|
||||||
|
sample_copy["history_slot_ids"] = torch.tensor(
|
||||||
|
history_slot_ids, dtype=torch.long
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
logits, probs = self.inference(sample_copy)
|
||||||
|
analysis = self.analyze_probability_distribution(probs)
|
||||||
|
|
||||||
|
# 获取Top-1预测
|
||||||
|
pred_top_idx = analysis["top_indices"][0]
|
||||||
|
pred_top_prob = analysis["top_probs"][0]
|
||||||
|
|
||||||
|
# 检查是否正确
|
||||||
|
correct = False
|
||||||
|
if step < len(true_labels):
|
||||||
|
correct = pred_top_idx == true_labels[step]
|
||||||
|
|
||||||
|
# 记录结果
|
||||||
|
all_predictions.append(
|
||||||
|
{
|
||||||
|
"predicted_id": pred_top_idx,
|
||||||
|
"probability": pred_top_prob,
|
||||||
|
"correct": correct,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
all_correct.append(correct)
|
||||||
|
|
||||||
|
# 模拟用户选择:将Top-1预测添加到已确认历史
|
||||||
|
# 注意:这里假设用户总是选择Top-1结果
|
||||||
|
if pred_top_idx > 0: # 只添加有效ID
|
||||||
|
confirmed_history.append(pred_top_idx)
|
||||||
|
|
||||||
|
# 计算整体准确率:所有步骤都正确才算正确
|
||||||
|
overall_correct = all(all_correct)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
if print_details:
|
||||||
|
target_char = sample.get("target_char", "")
|
||||||
|
pinyin = sample.get("pinyin", "")
|
||||||
|
valid_slot_count = sample.get("valid_slot_count", 0)
|
||||||
|
|
||||||
|
print(f" {target_char}", end="")
|
||||||
|
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
|
||||||
|
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
|
||||||
|
|
||||||
|
# 显示逐步预测结果
|
||||||
|
step_results = []
|
||||||
|
for i, pred in enumerate(all_predictions):
|
||||||
|
char_info = (
|
||||||
|
self.query_engine.query_by_id(pred["predicted_id"])
|
||||||
|
if self.query_engine
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
char = char_info.char if char_info else f"[{pred['predicted_id']}]"
|
||||||
|
correct_mark = "✓" if pred["correct"] else "✗"
|
||||||
|
step_results.append(f"{char}{correct_mark}")
|
||||||
|
|
||||||
|
print(f" | 逐步:{'→'.join(step_results)}", end="")
|
||||||
|
print(f" | 结果:{'✓' if overall_correct else '✗'}")
|
||||||
|
|
||||||
|
# 返回评估结果
|
||||||
|
result = {
|
||||||
|
"sample": sample,
|
||||||
|
"sequential_predictions": all_predictions,
|
||||||
|
"correct": overall_correct,
|
||||||
|
"step_correct": all_correct,
|
||||||
|
"target_char": sample.get("target_char", ""),
|
||||||
|
"true_labels": true_labels,
|
||||||
|
"pinyin_len": pinyin_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
评估给定文本,生成多个样本进行评估。
|
评估给定文本,生成多个样本进行评估。
|
||||||
|
|
@ -613,17 +699,11 @@ class TextEvaluator:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sample = self.create_sample_from_text(text)
|
sample = self.create_sample_from_text(text)
|
||||||
result = self.evaluate_sample(sample, print_details=True)
|
# 使用逐步确认评估
|
||||||
results.append(result)
|
result = self.evaluate_sample_with_sequential_confirmation(
|
||||||
|
sample, print_details=True
|
||||||
if result["analysis"]["low_variance"]:
|
|
||||||
low_variance_samples.append(
|
|
||||||
{
|
|
||||||
"sample_idx": sample_idx,
|
|
||||||
"max_prob": result["analysis"]["max_prob"],
|
|
||||||
"entropy": result["analysis"]["entropy"],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 样本生成或评估失败: {e}")
|
print(f"❌ 样本生成或评估失败: {e}")
|
||||||
|
|
@ -670,10 +750,12 @@ class TextEvaluator:
|
||||||
sample = self.create_sample_from_text(
|
sample = self.create_sample_from_text(
|
||||||
text, force_slot_count=slot_count
|
text, force_slot_count=slot_count
|
||||||
)
|
)
|
||||||
result = self.evaluate_sample(sample, print_details=False)
|
# 使用逐步确认评估(更符合实际使用场景)
|
||||||
|
result = self.evaluate_sample_with_sequential_confirmation(
|
||||||
|
sample, print_details=False
|
||||||
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
if result["analysis"]["low_variance"]:
|
# 注意:逐步确认评估没有low_variance分析
|
||||||
low_var_count += 1
|
|
||||||
except Exception:
|
except Exception:
|
||||||
failed += 1
|
failed += 1
|
||||||
continue
|
continue
|
||||||
|
|
@ -681,19 +763,24 @@ class TextEvaluator:
|
||||||
if results:
|
if results:
|
||||||
correct_count = sum(1 for r in results if r["correct"])
|
correct_count = sum(1 for r in results if r["correct"])
|
||||||
accuracy = correct_count / len(results)
|
accuracy = correct_count / len(results)
|
||||||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
|
||||||
mean_max_prob = np.mean(max_probs)
|
# 对于逐步确认评估,计算平均步数准确率
|
||||||
entropies = [r["analysis"]["entropy"] for r in results]
|
step_accuracies = []
|
||||||
mean_entropy = np.mean(entropies)
|
for r in results:
|
||||||
|
if "step_correct" in r:
|
||||||
|
step_correct = r["step_correct"]
|
||||||
|
if step_correct:
|
||||||
|
step_acc = sum(step_correct) / len(step_correct)
|
||||||
|
step_accuracies.append(step_acc)
|
||||||
|
|
||||||
|
mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0
|
||||||
|
|
||||||
slot_results[slot_count] = {
|
slot_results[slot_count] = {
|
||||||
"results": results,
|
"results": results,
|
||||||
"accuracy": accuracy,
|
"accuracy": accuracy,
|
||||||
"correct": correct_count,
|
"correct": correct_count,
|
||||||
"total": len(results),
|
"total": len(results),
|
||||||
"mean_max_prob": mean_max_prob,
|
"mean_step_accuracy": mean_step_accuracy,
|
||||||
"mean_entropy": mean_entropy,
|
|
||||||
"low_variance_count": low_var_count,
|
|
||||||
"failed": failed,
|
"failed": failed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -703,9 +790,7 @@ class TextEvaluator:
|
||||||
print(
|
print(
|
||||||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||||
f"({correct_count:2d}/{len(results):2d}) "
|
f"({correct_count:2d}/{len(results):2d}) "
|
||||||
f"| 平均最高概率: {mean_max_prob:.4f} "
|
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
|
||||||
f"| 平均熵: {mean_entropy:.2f} "
|
|
||||||
f"| 低方差: {low_var_count}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
slot_results[slot_count] = {
|
slot_results[slot_count] = {
|
||||||
|
|
@ -713,9 +798,7 @@ class TextEvaluator:
|
||||||
"accuracy": 0,
|
"accuracy": 0,
|
||||||
"correct": 0,
|
"correct": 0,
|
||||||
"total": 0,
|
"total": 0,
|
||||||
"mean_max_prob": 0,
|
"mean_step_accuracy": 0,
|
||||||
"mean_entropy": 0,
|
|
||||||
"low_variance_count": 0,
|
|
||||||
"failed": failed,
|
"failed": failed,
|
||||||
}
|
}
|
||||||
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||||||
|
|
@ -740,44 +823,27 @@ class TextEvaluator:
|
||||||
accuracy = correct_count / len(results) if results else 0
|
accuracy = correct_count / len(results) if results else 0
|
||||||
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
||||||
|
|
||||||
# 概率统计
|
# 计算平均步数准确率(对于多字符预测)
|
||||||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
step_accuracies = []
|
||||||
mean_max_prob = np.mean(max_probs) if max_probs else 0
|
for r in results:
|
||||||
print(f"平均最高概率: {mean_max_prob:.4f}")
|
if "step_correct" in r:
|
||||||
|
step_correct = r["step_correct"]
|
||||||
|
if step_correct:
|
||||||
|
step_acc = sum(step_correct) / len(step_correct)
|
||||||
|
step_accuracies.append(step_acc)
|
||||||
|
|
||||||
# 预测ID:0的比例
|
if step_accuracies:
|
||||||
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
|
mean_step_accuracy = np.mean(step_accuracies)
|
||||||
zero_pred_ratio = zero_pred_count / len(results) if results else 0
|
print(f"平均步数准确率: {mean_step_accuracy:.1%}")
|
||||||
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
|
|
||||||
|
|
||||||
# 平均熵
|
|
||||||
entropies = [r["analysis"]["entropy"] for r in results]
|
|
||||||
mean_entropy = np.mean(entropies) if entropies else 0
|
|
||||||
print(f"平均熵: {mean_entropy:.2f}")
|
|
||||||
|
|
||||||
# 槽位和拼音长度统计
|
# 槽位和拼音长度统计
|
||||||
if results and "sample" in results[0]:
|
if results and "sample" in results[0]:
|
||||||
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
||||||
pinyin_lengths = [
|
pinyin_lengths = [r.get("pinyin_len", 1) for r in results]
|
||||||
r["sample"].get("pinyin_input_length", 0) for r in results
|
|
||||||
]
|
|
||||||
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
||||||
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
||||||
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
||||||
|
|
||||||
# 低方差样本
|
|
||||||
if low_variance_samples:
|
|
||||||
print(f"低方差样本: {len(low_variance_samples)}个")
|
|
||||||
if len(low_variance_samples) <= 5:
|
|
||||||
for lv in low_variance_samples:
|
|
||||||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
|
||||||
else:
|
|
||||||
for lv in low_variance_samples[:3]:
|
|
||||||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
|
||||||
print(f" ... 还有{len(low_variance_samples) - 3}个")
|
|
||||||
else:
|
|
||||||
print(f"低方差样本: 0个")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,62 @@ class AttentionPooling(nn.Module):
|
||||||
return pooled
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 拼音LSTM编码器 ----------------------------
|
||||||
|
class PinyinLSTMEncoder(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
# Bidirectional LSTM
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=input_dim,
|
||||||
|
hidden_size=self.hidden_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bidirectional=True,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout if num_layers > 1 else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Project concatenated hidden states to input_dim
|
||||||
|
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||||||
|
self.layer_norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [batch, seq_len, input_dim] pinyin embeddings
|
||||||
|
mask: [batch, seq_len] optional padding mask (0 for padding)
|
||||||
|
Returns:
|
||||||
|
pooled: [batch, input_dim] global pinyin representation
|
||||||
|
"""
|
||||||
|
if mask is not None:
|
||||||
|
# lengths for pack_padded_sequence
|
||||||
|
lengths = mask.sum(dim=1).cpu()
|
||||||
|
# pack sequence
|
||||||
|
packed = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, lengths, batch_first=True, enforce_sorted=False
|
||||||
|
)
|
||||||
|
packed_out, (hidden, cell) = self.lstm(packed)
|
||||||
|
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
||||||
|
# Take last layer's forward and backward hidden states
|
||||||
|
forward_hidden = hidden[-2, :, :] # last layer forward
|
||||||
|
backward_hidden = hidden[-1, :, :] # last layer backward
|
||||||
|
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
||||||
|
else:
|
||||||
|
# No mask, assume all sequences same length
|
||||||
|
output, (hidden, cell) = self.lstm(x)
|
||||||
|
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
||||||
|
forward_hidden = hidden[-2, :, :]
|
||||||
|
backward_hidden = hidden[-1, :, :]
|
||||||
|
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
||||||
|
|
||||||
|
projected = self.proj(hidden_concat)
|
||||||
|
return self.layer_norm(projected)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
# ---------------------------- 残差块 ----------------------------
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, dim, dropout_prob=0.3):
|
def __init__(self, dim, dropout_prob=0.3):
|
||||||
|
|
@ -96,7 +152,7 @@ class ContextEncoder(nn.Module):
|
||||||
).embeddings
|
).embeddings
|
||||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||||
self.pos_emb = nn.Embedding(max_len, dim)
|
self.pos_emb = nn.Embedding(max_len, dim)
|
||||||
self.pinyin_pooling = AttentionPooling(dim)
|
self.pinyin_pooling = PinyinLSTMEncoder(dim)
|
||||||
|
|
||||||
# Transformer Encoder (4 layers, 4 heads) [1]
|
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
|
@ -125,9 +181,10 @@ class ContextEncoder(nn.Module):
|
||||||
|
|
||||||
# 2. Embed and pool pinyin to global feature
|
# 2. Embed and pool pinyin to global feature
|
||||||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
||||||
# 方式1:Attention Pooling(推荐)
|
# LSTM encoder with masking for padding
|
||||||
|
pinyin_mask = pinyin_ids != 0
|
||||||
pinyin_global = self.pinyin_pooling(
|
pinyin_global = self.pinyin_pooling(
|
||||||
pinyin_emb, mask=None
|
pinyin_emb, mask=pinyin_mask
|
||||||
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
||||||
|
|
||||||
# Broadcast pinyin to all text positions
|
# Broadcast pinyin to all text positions
|
||||||
|
|
@ -279,7 +336,7 @@ class CrossAttentionFusion(nn.Module):
|
||||||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
class MoELayer(nn.Module):
|
class MoELayer(nn.Module):
|
||||||
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
|
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -292,7 +349,7 @@ class MoELayer(nn.Module):
|
||||||
Expert(
|
Expert(
|
||||||
input_dim=dim,
|
input_dim=dim,
|
||||||
d_model=dim,
|
d_model=dim,
|
||||||
num_resblocks=export_resblocks,
|
num_resblocks=num_resblocks,
|
||||||
output_multiplier=1,
|
output_multiplier=1,
|
||||||
)
|
)
|
||||||
for _ in range(num_experts)
|
for _ in range(num_experts)
|
||||||
|
|
|
||||||
|
|
@ -260,13 +260,29 @@ class PinyinInputDataset(IterableDataset):
|
||||||
part1 = text[0:i]
|
part1 = text[0:i]
|
||||||
else:
|
else:
|
||||||
part1 = text[i - 48 : i]
|
part1 = text[i - 48 : i]
|
||||||
|
|
||||||
|
# 方案C:提前检查从位置i开始连续有多少个字符在词库中
|
||||||
|
max_valid_len = 0
|
||||||
|
for j in range(i, min(i + 8, len(text))):
|
||||||
|
if self.query_engine.is_chinese_char(text[j]):
|
||||||
|
max_valid_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果没有可用字符,跳过
|
||||||
|
if max_valid_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||||
# 但是需要注意边界条件
|
# 但是需要注意边界条件
|
||||||
pinyin_len = np.random.choice(
|
target_len = np.random.choice(
|
||||||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||||
)
|
)
|
||||||
|
# 根据实际可用长度调整
|
||||||
|
pinyin_len = min(target_len, max_valid_len)
|
||||||
|
|
||||||
py_end = min(i + pinyin_len, len(text))
|
py_end = min(i + pinyin_len, len(text))
|
||||||
pinyin_len, part2 = self.get_mask_pinyin(
|
pinyin_len, part2 = self.get_mask_pinyin(
|
||||||
text[i:py_end], pinyin_list[i:py_end]
|
text[i:py_end], pinyin_list[i:py_end]
|
||||||
|
|
@ -296,13 +312,13 @@ class PinyinInputDataset(IterableDataset):
|
||||||
+ np.random.choice(range(1, 17))
|
+ np.random.choice(range(1, 17))
|
||||||
]
|
]
|
||||||
|
|
||||||
# part4为文本,0.50的概率为空
|
# part4为文本,0.30的概率为空
|
||||||
# 不为空则为1-5个连续字符串
|
# 不为空则为1-5个连续字符串
|
||||||
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
||||||
# x为2-6中的任意整数,取值平均分布
|
# x为2-6中的任意整数,取值平均分布
|
||||||
# 使用|将part4中的字符串连接起来
|
# 使用|将part4中的字符串连接起来
|
||||||
part4 = ""
|
part4 = ""
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.7:
|
||||||
# 生成1-5个连续字符串
|
# 生成1-5个连续字符串
|
||||||
num_strings = random.randint(1, 5)
|
num_strings = random.randint(1, 5)
|
||||||
string_list = []
|
string_list = []
|
||||||
|
|
@ -332,22 +348,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if random.random() <= 0.1:
|
if random.random() <= 0.1:
|
||||||
labels.append(0)
|
labels.append(0)
|
||||||
|
|
||||||
# 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致)
|
|
||||||
history_slot_list = []
|
|
||||||
for j in range(i - 1, max(-1, i - 100), -1):
|
|
||||||
if j < 0:
|
|
||||||
break
|
|
||||||
char = text[j]
|
|
||||||
if self.query_engine.is_chinese_char(char):
|
|
||||||
try:
|
|
||||||
results = self.query_engine.query_by_char(char, limit=1)
|
|
||||||
if results and results[0][0] > 0:
|
|
||||||
history_slot_list.append(results[0][0])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if len(history_slot_list) >= 8:
|
|
||||||
break
|
|
||||||
|
|
||||||
encoded = self.tokenizer(
|
encoded = self.tokenizer(
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
part3,
|
part3,
|
||||||
|
|
@ -358,11 +358,23 @@ class PinyinInputDataset(IterableDataset):
|
||||||
return_token_type_ids=True,
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
samples = []
|
samples = []
|
||||||
|
# 历史槽位长度权重:增加长历史采样比例
|
||||||
|
# 目标分布: H=0-2占45%, H=3-8占55%
|
||||||
|
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||||
|
|
||||||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||||
for label_idx, label in enumerate(labels):
|
for label_idx, label in enumerate(labels):
|
||||||
repeats = self.adjust_frequency(label)
|
base_repeats = self.adjust_frequency(label)
|
||||||
# 使用从text[0:i]提取的历史槽位(与eval.py一致)
|
# 根据历史槽位长度调整采样次数
|
||||||
masked_labels = history_slot_list[:]
|
weight = (
|
||||||
|
history_weights[label_idx]
|
||||||
|
if label_idx < len(history_weights)
|
||||||
|
else 3.0
|
||||||
|
)
|
||||||
|
repeats = max(1, int(base_repeats * weight))
|
||||||
|
|
||||||
|
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
|
||||||
|
masked_labels = labels[:label_idx]
|
||||||
len_l = len(masked_labels)
|
len_l = len(masked_labels)
|
||||||
masked_labels.extend([0] * (8 - len_l))
|
masked_labels.extend([0] * (8 - len_l))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class InputMethodEngine(nn.Module):
|
||||||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||||
n_layers: int = 4, # Transformer 层数
|
n_layers: int = 4, # Transformer 层数
|
||||||
n_heads: int = 4, # 注意力头数
|
n_heads: int = 4, # 注意力头数
|
||||||
num_experts: int = 20, # MoE 专家数量
|
num_experts: int = 10, # MoE 专家数量
|
||||||
max_seq_len: int = 128, # 最大上下文长度
|
max_seq_len: int = 128, # 最大上下文长度
|
||||||
compile: bool = False, # 是否开启 torch.compile 优化
|
compile: bool = False, # 是否开启 torch.compile 优化
|
||||||
):
|
):
|
||||||
|
|
@ -72,9 +72,12 @@ class InputMethodEngine(nn.Module):
|
||||||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||||
|
|
||||||
# 4. 混合专家层 (MoE)
|
# 4. 混合专家层 (MoE)
|
||||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
|
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
||||||
|
|
||||||
# 5. 分类头
|
# 5. 槽位注意力池化
|
||||||
|
self.slot_attention = nn.Linear(dim, 1)
|
||||||
|
|
||||||
|
# 6. 分类头
|
||||||
self.classifier = nn.Linear(dim, vocab_size)
|
self.classifier = nn.Linear(dim, vocab_size)
|
||||||
|
|
||||||
# 开启 torch.compile 优化 (如果请求)
|
# 开启 torch.compile 优化 (如果请求)
|
||||||
|
|
@ -127,19 +130,14 @@ class InputMethodEngine(nn.Module):
|
||||||
# 4. MoE 处理 -> [batch, num_slots, dim]
|
# 4. MoE 处理 -> [batch, num_slots, dim]
|
||||||
moe_out = self.moe(fused)
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
# 5. 槽位注意力池化
|
||||||
batch_size = input_ids.size(0)
|
batch_size = input_ids.size(0)
|
||||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
# 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots]
|
||||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
# 应用softmax获取注意力权重
|
||||||
|
slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots]
|
||||||
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
|
# 加权求和得到池化表示
|
||||||
all_zero = denominator == 0 # [batch]
|
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim]
|
||||||
pooled = torch.where(
|
|
||||||
all_zero.unsqueeze(-1),
|
|
||||||
all_slot_mean,
|
|
||||||
numerator / (denominator.unsqueeze(-1) + 1e-8),
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
|
|
@ -56,11 +56,11 @@ class Trainer:
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
output_dir: str = "./output",
|
output_dir: str = "./output",
|
||||||
num_epochs: int = 10,
|
num_epochs: int = 10,
|
||||||
learning_rate: float = 1e-4,
|
learning_rate: float = 2e-4,
|
||||||
min_learning_rate: float = 1e-6,
|
min_learning_rate: float = 1e-6,
|
||||||
weight_decay: float = 0.1,
|
weight_decay: float = 0.05,
|
||||||
warmup_ratio: float = 0.1,
|
warmup_ratio: float = 0.1,
|
||||||
label_smoothing: float = 0.15,
|
label_smoothing: float = 0.1,
|
||||||
loss_weight: Optional[torch.Tensor] = None,
|
loss_weight: Optional[torch.Tensor] = None,
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
clip_grad_norm: float = 1.0,
|
clip_grad_norm: float = 1.0,
|
||||||
|
|
@ -1060,14 +1060,14 @@ def train(
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||||
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
|
||||||
min_learning_rate: float = typer.Option(
|
min_learning_rate: float = typer.Option(
|
||||||
1e-9, "--min-learning-rate", help="最小学习率"
|
1e-9, "--min-learning-rate", help="最小学习率"
|
||||||
),
|
),
|
||||||
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
|
||||||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||||||
label_smoothing: float = typer.Option(
|
label_smoothing: float = typer.Option(
|
||||||
0.15, "--label-smoothing", help="标签平滑参数"
|
0.1, "--label-smoothing", help="标签平滑参数"
|
||||||
),
|
),
|
||||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||||
|
|
@ -1120,7 +1120,7 @@ def train(
|
||||||
num_slots = 8
|
num_slots = 8
|
||||||
n_layers = 4
|
n_layers = 4
|
||||||
n_heads = 4
|
n_heads = 4
|
||||||
num_experts = 20
|
num_experts = 10
|
||||||
max_seq_len = 128
|
max_seq_len = 128
|
||||||
use_pinyin = True # 始终使用拼音
|
use_pinyin = True # 始终使用拼音
|
||||||
|
|
||||||
|
|
@ -1383,14 +1383,14 @@ def expand_and_train(
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||||
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
|
||||||
min_learning_rate: float = typer.Option(
|
min_learning_rate: float = typer.Option(
|
||||||
1e-9, "--min-learning-rate", help="最小学习率"
|
1e-9, "--min-learning-rate", help="最小学习率"
|
||||||
),
|
),
|
||||||
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
|
||||||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||||||
label_smoothing: float = typer.Option(
|
label_smoothing: float = typer.Option(
|
||||||
0.15, "--label-smoothing", help="标签平滑参数"
|
0.1, "--label-smoothing", help="标签平滑参数"
|
||||||
),
|
),
|
||||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||||
|
|
@ -1440,7 +1440,7 @@ def expand_and_train(
|
||||||
num_slots = 8
|
num_slots = 8
|
||||||
n_layers = 4
|
n_layers = 4
|
||||||
n_heads = 4
|
n_heads = 4
|
||||||
num_experts = 20
|
num_experts = 10
|
||||||
max_seq_len = 128
|
max_seq_len = 128
|
||||||
use_pinyin = True # 始终使用拼音
|
use_pinyin = True # 始终使用拼音
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue