feat(pinyin): 添加拼音LSTM编码器以支持多字符预测的逐步确认评估
This commit is contained in:
parent
68a6fc3533
commit
bb78e0afa0
258
eval.py
258
eval.py
|
|
@ -274,7 +274,7 @@ class TextEvaluator:
|
|||
part1 = text[i - 48 : i]
|
||||
|
||||
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||||
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
|
||||
|
|
@ -332,53 +332,30 @@ class TextEvaluator:
|
|||
print(f"⚠️ 获取标签失败: {e}")
|
||||
labels = [0] * pinyin_len_actual
|
||||
|
||||
# 历史槽位:当前预测字符之前已确认的字符
|
||||
# 从位置i之前的字符中提取历史槽位(最多8个)
|
||||
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
|
||||
# 历史槽位:模拟用户逐步确认过程
|
||||
# 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择
|
||||
# 注意:评估时不知道正确答案,只能基于模型预测来构建历史
|
||||
history_slot_ids = []
|
||||
if self.query_engine:
|
||||
# 收集i之前的汉字字符(最多8个)
|
||||
for j in range(i - 1, max(-1, i - 100), -1):
|
||||
if j < 0:
|
||||
break
|
||||
char = text[j]
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
# 尝试获取该字符的ID(与训练数据一致)
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results:
|
||||
slot_id = results[0][0]
|
||||
# 只添加有效的非零ID(与训练数据一致)
|
||||
if slot_id > 0:
|
||||
history_slot_ids.append(slot_id)
|
||||
if len(history_slot_ids) >= 8:
|
||||
break
|
||||
|
||||
# 填充到8个槽位(使用0填充,与训练数据一致)
|
||||
# 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致)
|
||||
if force_slot_count is not None:
|
||||
force_slot_count = max(0, min(8, force_slot_count))
|
||||
# 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0
|
||||
# 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致)
|
||||
for _ in range(force_slot_count):
|
||||
history_slot_ids.append(random.randint(1, 1000))
|
||||
else:
|
||||
# 正常评估:历史槽位为空(模拟从零开始输入)
|
||||
# 注意:实际使用时,历史槽位应该来自之前的用户选择
|
||||
# 但为了评估公平,我们从空历史开始
|
||||
pass
|
||||
|
||||
# 填充到8个槽位
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
# 强制设置有效槽位数量(用于按槽位数量评估)
|
||||
# 注意:这里不再使用随机ID填充,而是复制现有的有效槽位
|
||||
if force_slot_count is not None:
|
||||
force_slot_count = max(0, min(8, force_slot_count))
|
||||
# 获取所有非零槽位ID
|
||||
valid_ids = [s for s in history_slot_ids if s != 0]
|
||||
|
||||
if len(valid_ids) >= force_slot_count:
|
||||
# 如果有效槽位足够,直接取前force_slot_count个
|
||||
history_slot_ids = valid_ids[:force_slot_count]
|
||||
else:
|
||||
# 如果有效槽位不足,保留现有有效槽位,然后用0填充
|
||||
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
|
||||
history_slot_ids = valid_ids[:]
|
||||
|
||||
# 填充到8个槽位
|
||||
while len(history_slot_ids) < 8:
|
||||
history_slot_ids.append(0)
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
# Tokenize输入
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
|
|
@ -586,6 +563,115 @@ class TextEvaluator:
|
|||
|
||||
return result
|
||||
|
||||
def evaluate_sample_with_sequential_confirmation(
|
||||
self, sample: Dict, print_details: bool = True
|
||||
) -> Dict:
|
||||
"""
|
||||
评估样本并模拟用户逐步确认过程(用于多字符预测)。
|
||||
|
||||
对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
# 获取真实标签和拼音长度
|
||||
true_labels = sample.get("true_labels", [])
|
||||
pinyin_len = len(true_labels) if true_labels else 1
|
||||
|
||||
if pinyin_len <= 1:
|
||||
# 单字符预测,直接使用普通评估
|
||||
return self.evaluate_sample(sample, print_details)
|
||||
|
||||
# 多字符预测:模拟逐步确认
|
||||
confirmed_history = [] # 用户已确认的字符ID
|
||||
all_predictions = [] # 每一步的预测结果
|
||||
all_correct = [] # 每一步是否正确
|
||||
|
||||
# 复制样本用于逐步推理
|
||||
sample_copy = sample.copy()
|
||||
|
||||
for step in range(pinyin_len):
|
||||
# 更新历史槽位:使用已确认的字符
|
||||
history_slot_ids = confirmed_history[:]
|
||||
if len(history_slot_ids) < 8:
|
||||
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||
else:
|
||||
history_slot_ids = history_slot_ids[:8]
|
||||
|
||||
# 更新样本的历史槽位
|
||||
sample_copy["history_slot_ids"] = torch.tensor(
|
||||
history_slot_ids, dtype=torch.long
|
||||
).unsqueeze(0)
|
||||
|
||||
# 执行推理
|
||||
logits, probs = self.inference(sample_copy)
|
||||
analysis = self.analyze_probability_distribution(probs)
|
||||
|
||||
# 获取Top-1预测
|
||||
pred_top_idx = analysis["top_indices"][0]
|
||||
pred_top_prob = analysis["top_probs"][0]
|
||||
|
||||
# 检查是否正确
|
||||
correct = False
|
||||
if step < len(true_labels):
|
||||
correct = pred_top_idx == true_labels[step]
|
||||
|
||||
# 记录结果
|
||||
all_predictions.append(
|
||||
{
|
||||
"predicted_id": pred_top_idx,
|
||||
"probability": pred_top_prob,
|
||||
"correct": correct,
|
||||
}
|
||||
)
|
||||
all_correct.append(correct)
|
||||
|
||||
# 模拟用户选择:将Top-1预测添加到已确认历史
|
||||
# 注意:这里假设用户总是选择Top-1结果
|
||||
if pred_top_idx > 0: # 只添加有效ID
|
||||
confirmed_history.append(pred_top_idx)
|
||||
|
||||
# 计算整体准确率:所有步骤都正确才算正确
|
||||
overall_correct = all(all_correct)
|
||||
|
||||
# 打印结果
|
||||
if print_details:
|
||||
target_char = sample.get("target_char", "")
|
||||
pinyin = sample.get("pinyin", "")
|
||||
valid_slot_count = sample.get("valid_slot_count", 0)
|
||||
|
||||
print(f" {target_char}", end="")
|
||||
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
|
||||
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
|
||||
|
||||
# 显示逐步预测结果
|
||||
step_results = []
|
||||
for i, pred in enumerate(all_predictions):
|
||||
char_info = (
|
||||
self.query_engine.query_by_id(pred["predicted_id"])
|
||||
if self.query_engine
|
||||
else None
|
||||
)
|
||||
char = char_info.char if char_info else f"[{pred['predicted_id']}]"
|
||||
correct_mark = "✓" if pred["correct"] else "✗"
|
||||
step_results.append(f"{char}{correct_mark}")
|
||||
|
||||
print(f" | 逐步:{'→'.join(step_results)}", end="")
|
||||
print(f" | 结果:{'✓' if overall_correct else '✗'}")
|
||||
|
||||
# 返回评估结果
|
||||
result = {
|
||||
"sample": sample,
|
||||
"sequential_predictions": all_predictions,
|
||||
"correct": overall_correct,
|
||||
"step_correct": all_correct,
|
||||
"target_char": sample.get("target_char", ""),
|
||||
"true_labels": true_labels,
|
||||
"pinyin_len": pinyin_len,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
||||
"""
|
||||
评估给定文本,生成多个样本进行评估。
|
||||
|
|
@ -613,17 +699,11 @@ class TextEvaluator:
|
|||
|
||||
try:
|
||||
sample = self.create_sample_from_text(text)
|
||||
result = self.evaluate_sample(sample, print_details=True)
|
||||
results.append(result)
|
||||
|
||||
if result["analysis"]["low_variance"]:
|
||||
low_variance_samples.append(
|
||||
{
|
||||
"sample_idx": sample_idx,
|
||||
"max_prob": result["analysis"]["max_prob"],
|
||||
"entropy": result["analysis"]["entropy"],
|
||||
}
|
||||
# 使用逐步确认评估
|
||||
result = self.evaluate_sample_with_sequential_confirmation(
|
||||
sample, print_details=True
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 样本生成或评估失败: {e}")
|
||||
|
|
@ -670,10 +750,12 @@ class TextEvaluator:
|
|||
sample = self.create_sample_from_text(
|
||||
text, force_slot_count=slot_count
|
||||
)
|
||||
result = self.evaluate_sample(sample, print_details=False)
|
||||
# 使用逐步确认评估(更符合实际使用场景)
|
||||
result = self.evaluate_sample_with_sequential_confirmation(
|
||||
sample, print_details=False
|
||||
)
|
||||
results.append(result)
|
||||
if result["analysis"]["low_variance"]:
|
||||
low_var_count += 1
|
||||
# 注意:逐步确认评估没有low_variance分析
|
||||
except Exception:
|
||||
failed += 1
|
||||
continue
|
||||
|
|
@ -681,19 +763,24 @@ class TextEvaluator:
|
|||
if results:
|
||||
correct_count = sum(1 for r in results if r["correct"])
|
||||
accuracy = correct_count / len(results)
|
||||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||||
mean_max_prob = np.mean(max_probs)
|
||||
entropies = [r["analysis"]["entropy"] for r in results]
|
||||
mean_entropy = np.mean(entropies)
|
||||
|
||||
# 对于逐步确认评估,计算平均步数准确率
|
||||
step_accuracies = []
|
||||
for r in results:
|
||||
if "step_correct" in r:
|
||||
step_correct = r["step_correct"]
|
||||
if step_correct:
|
||||
step_acc = sum(step_correct) / len(step_correct)
|
||||
step_accuracies.append(step_acc)
|
||||
|
||||
mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0
|
||||
|
||||
slot_results[slot_count] = {
|
||||
"results": results,
|
||||
"accuracy": accuracy,
|
||||
"correct": correct_count,
|
||||
"total": len(results),
|
||||
"mean_max_prob": mean_max_prob,
|
||||
"mean_entropy": mean_entropy,
|
||||
"low_variance_count": low_var_count,
|
||||
"mean_step_accuracy": mean_step_accuracy,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
|
|
@ -703,9 +790,7 @@ class TextEvaluator:
|
|||
print(
|
||||
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
|
||||
f"({correct_count:2d}/{len(results):2d}) "
|
||||
f"| 平均最高概率: {mean_max_prob:.4f} "
|
||||
f"| 平均熵: {mean_entropy:.2f} "
|
||||
f"| 低方差: {low_var_count}"
|
||||
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
|
||||
)
|
||||
else:
|
||||
slot_results[slot_count] = {
|
||||
|
|
@ -713,9 +798,7 @@ class TextEvaluator:
|
|||
"accuracy": 0,
|
||||
"correct": 0,
|
||||
"total": 0,
|
||||
"mean_max_prob": 0,
|
||||
"mean_entropy": 0,
|
||||
"low_variance_count": 0,
|
||||
"mean_step_accuracy": 0,
|
||||
"failed": failed,
|
||||
}
|
||||
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
|
||||
|
|
@ -740,44 +823,27 @@ class TextEvaluator:
|
|||
accuracy = correct_count / len(results) if results else 0
|
||||
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
||||
|
||||
# 概率统计
|
||||
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||||
mean_max_prob = np.mean(max_probs) if max_probs else 0
|
||||
print(f"平均最高概率: {mean_max_prob:.4f}")
|
||||
# 计算平均步数准确率(对于多字符预测)
|
||||
step_accuracies = []
|
||||
for r in results:
|
||||
if "step_correct" in r:
|
||||
step_correct = r["step_correct"]
|
||||
if step_correct:
|
||||
step_acc = sum(step_correct) / len(step_correct)
|
||||
step_accuracies.append(step_acc)
|
||||
|
||||
# 预测ID:0的比例
|
||||
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
|
||||
zero_pred_ratio = zero_pred_count / len(results) if results else 0
|
||||
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
|
||||
|
||||
# 平均熵
|
||||
entropies = [r["analysis"]["entropy"] for r in results]
|
||||
mean_entropy = np.mean(entropies) if entropies else 0
|
||||
print(f"平均熵: {mean_entropy:.2f}")
|
||||
if step_accuracies:
|
||||
mean_step_accuracy = np.mean(step_accuracies)
|
||||
print(f"平均步数准确率: {mean_step_accuracy:.1%}")
|
||||
|
||||
# 槽位和拼音长度统计
|
||||
if results and "sample" in results[0]:
|
||||
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
||||
pinyin_lengths = [
|
||||
r["sample"].get("pinyin_input_length", 0) for r in results
|
||||
]
|
||||
pinyin_lengths = [r.get("pinyin_len", 1) for r in results]
|
||||
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
||||
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
||||
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
||||
|
||||
# 低方差样本
|
||||
if low_variance_samples:
|
||||
print(f"低方差样本: {len(low_variance_samples)}个")
|
||||
if len(low_variance_samples) <= 5:
|
||||
for lv in low_variance_samples:
|
||||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||||
else:
|
||||
for lv in low_variance_samples[:3]:
|
||||
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||||
print(f" ... 还有{len(low_variance_samples) - 3}个")
|
||||
else:
|
||||
print(f"低方差样本: 0个")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,62 @@ class AttentionPooling(nn.Module):
|
|||
return pooled
|
||||
|
||||
|
||||
# ---------------------------- 拼音LSTM编码器 ----------------------------
|
||||
class PinyinLSTMEncoder(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
|
||||
# Bidirectional LSTM
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_dim,
|
||||
hidden_size=self.hidden_dim,
|
||||
num_layers=num_layers,
|
||||
bidirectional=True,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0.0,
|
||||
)
|
||||
|
||||
# Project concatenated hidden states to input_dim
|
||||
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: [batch, seq_len, input_dim] pinyin embeddings
|
||||
mask: [batch, seq_len] optional padding mask (0 for padding)
|
||||
Returns:
|
||||
pooled: [batch, input_dim] global pinyin representation
|
||||
"""
|
||||
if mask is not None:
|
||||
# lengths for pack_padded_sequence
|
||||
lengths = mask.sum(dim=1).cpu()
|
||||
# pack sequence
|
||||
packed = nn.utils.rnn.pack_padded_sequence(
|
||||
x, lengths, batch_first=True, enforce_sorted=False
|
||||
)
|
||||
packed_out, (hidden, cell) = self.lstm(packed)
|
||||
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
||||
# Take last layer's forward and backward hidden states
|
||||
forward_hidden = hidden[-2, :, :] # last layer forward
|
||||
backward_hidden = hidden[-1, :, :] # last layer backward
|
||||
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
||||
else:
|
||||
# No mask, assume all sequences same length
|
||||
output, (hidden, cell) = self.lstm(x)
|
||||
# hidden shape: [num_layers * 2, batch, hidden_dim]
|
||||
forward_hidden = hidden[-2, :, :]
|
||||
backward_hidden = hidden[-1, :, :]
|
||||
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
|
||||
|
||||
projected = self.proj(hidden_concat)
|
||||
return self.layer_norm(projected)
|
||||
|
||||
|
||||
# ---------------------------- 残差块 ----------------------------
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim, dropout_prob=0.3):
|
||||
|
|
@ -96,7 +152,7 @@ class ContextEncoder(nn.Module):
|
|||
).embeddings
|
||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||
self.pos_emb = nn.Embedding(max_len, dim)
|
||||
self.pinyin_pooling = AttentionPooling(dim)
|
||||
self.pinyin_pooling = PinyinLSTMEncoder(dim)
|
||||
|
||||
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
|
|
@ -125,9 +181,10 @@ class ContextEncoder(nn.Module):
|
|||
|
||||
# 2. Embed and pool pinyin to global feature
|
||||
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
||||
# 方式1:Attention Pooling(推荐)
|
||||
# LSTM encoder with masking for padding
|
||||
pinyin_mask = pinyin_ids != 0
|
||||
pinyin_global = self.pinyin_pooling(
|
||||
pinyin_emb, mask=None
|
||||
pinyin_emb, mask=pinyin_mask
|
||||
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
||||
|
||||
# Broadcast pinyin to all text positions
|
||||
|
|
@ -279,7 +336,7 @@ class CrossAttentionFusion(nn.Module):
|
|||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||
# ------------------------------------------------------------------
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
|
||||
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
|
|
@ -292,7 +349,7 @@ class MoELayer(nn.Module):
|
|||
Expert(
|
||||
input_dim=dim,
|
||||
d_model=dim,
|
||||
num_resblocks=export_resblocks,
|
||||
num_resblocks=num_resblocks,
|
||||
output_multiplier=1,
|
||||
)
|
||||
for _ in range(num_experts)
|
||||
|
|
|
|||
|
|
@ -260,13 +260,29 @@ class PinyinInputDataset(IterableDataset):
|
|||
part1 = text[0:i]
|
||||
else:
|
||||
part1 = text[i - 48 : i]
|
||||
|
||||
# 方案C:提前检查从位置i开始连续有多少个字符在词库中
|
||||
max_valid_len = 0
|
||||
for j in range(i, min(i + 8, len(text))):
|
||||
if self.query_engine.is_chinese_char(text[j]):
|
||||
max_valid_len += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果没有可用字符,跳过
|
||||
if max_valid_len == 0:
|
||||
continue
|
||||
|
||||
# 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3
|
||||
# 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len
|
||||
# part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2
|
||||
# 但是需要注意边界条件
|
||||
pinyin_len = np.random.choice(
|
||||
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||
target_len = np.random.choice(
|
||||
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
)
|
||||
# 根据实际可用长度调整
|
||||
pinyin_len = min(target_len, max_valid_len)
|
||||
|
||||
py_end = min(i + pinyin_len, len(text))
|
||||
pinyin_len, part2 = self.get_mask_pinyin(
|
||||
text[i:py_end], pinyin_list[i:py_end]
|
||||
|
|
@ -296,13 +312,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
+ np.random.choice(range(1, 17))
|
||||
]
|
||||
|
||||
# part4为文本,0.50的概率为空
|
||||
# part4为文本,0.30的概率为空
|
||||
# 不为空则为1-5个连续字符串
|
||||
# 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符
|
||||
# x为2-6中的任意整数,取值平均分布
|
||||
# 使用|将part4中的字符串连接起来
|
||||
part4 = ""
|
||||
if random.random() > 0.5:
|
||||
if random.random() > 0.7:
|
||||
# 生成1-5个连续字符串
|
||||
num_strings = random.randint(1, 5)
|
||||
string_list = []
|
||||
|
|
@ -332,22 +348,6 @@ class PinyinInputDataset(IterableDataset):
|
|||
if random.random() <= 0.1:
|
||||
labels.append(0)
|
||||
|
||||
# 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致)
|
||||
history_slot_list = []
|
||||
for j in range(i - 1, max(-1, i - 100), -1):
|
||||
if j < 0:
|
||||
break
|
||||
char = text[j]
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
try:
|
||||
results = self.query_engine.query_by_char(char, limit=1)
|
||||
if results and results[0][0] > 0:
|
||||
history_slot_list.append(results[0][0])
|
||||
except Exception:
|
||||
pass
|
||||
if len(history_slot_list) >= 8:
|
||||
break
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
|
|
@ -358,11 +358,23 @@ class PinyinInputDataset(IterableDataset):
|
|||
return_token_type_ids=True,
|
||||
)
|
||||
samples = []
|
||||
# 历史槽位长度权重:增加长历史采样比例
|
||||
# 目标分布: H=0-2占45%, H=3-8占55%
|
||||
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
|
||||
# 修复变量名冲突:将内层循环变量i重命名为label_idx
|
||||
for label_idx, label in enumerate(labels):
|
||||
repeats = self.adjust_frequency(label)
|
||||
# 使用从text[0:i]提取的历史槽位(与eval.py一致)
|
||||
masked_labels = history_slot_list[:]
|
||||
base_repeats = self.adjust_frequency(label)
|
||||
# 根据历史槽位长度调整采样次数
|
||||
weight = (
|
||||
history_weights[label_idx]
|
||||
if label_idx < len(history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
|
||||
masked_labels = labels[:label_idx]
|
||||
len_l = len(masked_labels)
|
||||
masked_labels.extend([0] * (8 - len_l))
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class InputMethodEngine(nn.Module):
|
|||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||
n_layers: int = 4, # Transformer 层数
|
||||
n_heads: int = 4, # 注意力头数
|
||||
num_experts: int = 20, # MoE 专家数量
|
||||
num_experts: int = 10, # MoE 专家数量
|
||||
max_seq_len: int = 128, # 最大上下文长度
|
||||
compile: bool = False, # 是否开启 torch.compile 优化
|
||||
):
|
||||
|
|
@ -72,9 +72,12 @@ class InputMethodEngine(nn.Module):
|
|||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
|
||||
|
||||
# 5. 分类头
|
||||
# 5. 槽位注意力池化
|
||||
self.slot_attention = nn.Linear(dim, 1)
|
||||
|
||||
# 6. 分类头
|
||||
self.classifier = nn.Linear(dim, vocab_size)
|
||||
|
||||
# 开启 torch.compile 优化 (如果请求)
|
||||
|
|
@ -127,19 +130,14 @@ class InputMethodEngine(nn.Module):
|
|||
# 4. MoE 处理 -> [batch, num_slots, dim]
|
||||
moe_out = self.moe(fused)
|
||||
|
||||
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile)
|
||||
# 5. 槽位注意力池化
|
||||
batch_size = input_ids.size(0)
|
||||
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
|
||||
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
|
||||
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
|
||||
|
||||
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
|
||||
all_zero = denominator == 0 # [batch]
|
||||
pooled = torch.where(
|
||||
all_zero.unsqueeze(-1),
|
||||
all_slot_mean,
|
||||
numerator / (denominator.unsqueeze(-1) + 1e-8),
|
||||
)
|
||||
# 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots]
|
||||
slot_scores = self.slot_attention(moe_out).squeeze(-1)
|
||||
# 应用softmax获取注意力权重
|
||||
slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots]
|
||||
# 加权求和得到池化表示
|
||||
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim]
|
||||
|
||||
logits = self.classifier(pooled) # [batch, vocab_size]
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -56,11 +56,11 @@ class Trainer:
|
|||
total_steps: int,
|
||||
output_dir: str = "./output",
|
||||
num_epochs: int = 10,
|
||||
learning_rate: float = 1e-4,
|
||||
learning_rate: float = 2e-4,
|
||||
min_learning_rate: float = 1e-6,
|
||||
weight_decay: float = 0.1,
|
||||
weight_decay: float = 0.05,
|
||||
warmup_ratio: float = 0.1,
|
||||
label_smoothing: float = 0.15,
|
||||
label_smoothing: float = 0.1,
|
||||
loss_weight: Optional[torch.Tensor] = None,
|
||||
grad_accum_steps: int = 1,
|
||||
clip_grad_norm: float = 1.0,
|
||||
|
|
@ -1060,14 +1060,14 @@ def train(
|
|||
# 训练参数
|
||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
||||
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
|
||||
min_learning_rate: float = typer.Option(
|
||||
1e-9, "--min-learning-rate", help="最小学习率"
|
||||
),
|
||||
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
||||
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
|
||||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||||
label_smoothing: float = typer.Option(
|
||||
0.15, "--label-smoothing", help="标签平滑参数"
|
||||
0.1, "--label-smoothing", help="标签平滑参数"
|
||||
),
|
||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||
|
|
@ -1120,7 +1120,7 @@ def train(
|
|||
num_slots = 8
|
||||
n_layers = 4
|
||||
n_heads = 4
|
||||
num_experts = 20
|
||||
num_experts = 10
|
||||
max_seq_len = 128
|
||||
use_pinyin = True # 始终使用拼音
|
||||
|
||||
|
|
@ -1383,14 +1383,14 @@ def expand_and_train(
|
|||
# 训练参数
|
||||
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
||||
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
|
||||
min_learning_rate: float = typer.Option(
|
||||
1e-9, "--min-learning-rate", help="最小学习率"
|
||||
),
|
||||
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
||||
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
|
||||
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||||
label_smoothing: float = typer.Option(
|
||||
0.15, "--label-smoothing", help="标签平滑参数"
|
||||
0.1, "--label-smoothing", help="标签平滑参数"
|
||||
),
|
||||
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||
|
|
@ -1440,7 +1440,7 @@ def expand_and_train(
|
|||
num_slots = 8
|
||||
n_layers = 4
|
||||
n_heads = 4
|
||||
num_experts = 20
|
||||
num_experts = 10
|
||||
max_seq_len = 128
|
||||
use_pinyin = True # 始终使用拼音
|
||||
console = Console()
|
||||
|
|
|
|||
Loading…
Reference in New Issue