feat(pinyin): 添加拼音LSTM编码器以支持多字符预测的逐步确认评估

This commit is contained in:
songsenand 2026-04-11 22:58:56 +08:00
parent 68a6fc3533
commit bb78e0afa0
5 changed files with 283 additions and 150 deletions

258
eval.py
View File

@ -274,7 +274,7 @@ class TextEvaluator:
part1 = text[i - 48 : i]
# part2: 拼音输入随机长度1-8高斯分布
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
py_end = min(i + pinyin_len, len(text))
@ -332,53 +332,30 @@ class TextEvaluator:
print(f"⚠️ 获取标签失败: {e}")
labels = [0] * pinyin_len_actual
# 历史槽位:当前预测字符之前已确认的字符
# 从位置i之前的字符中提取历史槽位最多8个
# 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取
# 历史槽位:模拟用户逐步确认过程
# 对于多字符预测pinyin_len_actual > 1需要模拟用户逐步选择
# 注意:评估时不知道正确答案,只能基于模型预测来构建历史
history_slot_ids = []
if self.query_engine:
# 收集i之前的汉字字符最多8个
for j in range(i - 1, max(-1, i - 100), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
# 尝试获取该字符的ID与训练数据一致
results = self.query_engine.query_by_char(char, limit=1)
if results:
slot_id = results[0][0]
# 只添加有效的非零ID与训练数据一致
if slot_id > 0:
history_slot_ids.append(slot_id)
if len(history_slot_ids) >= 8:
break
# 填充到8个槽位使用0填充与训练数据一致
# 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致)
if force_slot_count is not None:
force_slot_count = max(0, min(8, force_slot_count))
# 创建模拟的历史槽位前force_slot_count个为有效槽位其余为0
# 这里使用简单的模拟有效槽位用1-1000的随机ID与训练时随机填充逻辑一致
for _ in range(force_slot_count):
history_slot_ids.append(random.randint(1, 1000))
else:
# 正常评估:历史槽位为空(模拟从零开始输入)
# 注意:实际使用时,历史槽位应该来自之前的用户选择
# 但为了评估公平,我们从空历史开始
pass
# 填充到8个槽位
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
# 强制设置有效槽位数量(用于按槽位数量评估)
# 注意这里不再使用随机ID填充而是复制现有的有效槽位
if force_slot_count is not None:
force_slot_count = max(0, min(8, force_slot_count))
# 获取所有非零槽位ID
valid_ids = [s for s in history_slot_ids if s != 0]
if len(valid_ids) >= force_slot_count:
# 如果有效槽位足够直接取前force_slot_count个
history_slot_ids = valid_ids[:force_slot_count]
else:
# 如果有效槽位不足保留现有有效槽位然后用0填充
# 这样更符合训练数据的分布(槽位数量由实际历史决定)
history_slot_ids = valid_ids[:]
# 填充到8个槽位
while len(history_slot_ids) < 8:
history_slot_ids.append(0)
history_slot_ids = history_slot_ids[:8]
# Tokenize输入
encoded = self.tokenizer(
f"{part4}|{part1}",
@ -586,6 +563,115 @@ class TextEvaluator:
return result
def evaluate_sample_with_sequential_confirmation(
self, sample: Dict, print_details: bool = True
) -> Dict:
"""
评估样本并模拟用户逐步确认过程用于多字符预测
对于拼音长度>1的情况模拟用户逐步选择Top-1预测作为已确认字符
Returns:
评估结果字典
"""
# 获取真实标签和拼音长度
true_labels = sample.get("true_labels", [])
pinyin_len = len(true_labels) if true_labels else 1
if pinyin_len <= 1:
# 单字符预测,直接使用普通评估
return self.evaluate_sample(sample, print_details)
# 多字符预测:模拟逐步确认
confirmed_history = [] # 用户已确认的字符ID
all_predictions = [] # 每一步的预测结果
all_correct = [] # 每一步是否正确
# 复制样本用于逐步推理
sample_copy = sample.copy()
for step in range(pinyin_len):
# 更新历史槽位:使用已确认的字符
history_slot_ids = confirmed_history[:]
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
# 更新样本的历史槽位
sample_copy["history_slot_ids"] = torch.tensor(
history_slot_ids, dtype=torch.long
).unsqueeze(0)
# 执行推理
logits, probs = self.inference(sample_copy)
analysis = self.analyze_probability_distribution(probs)
# 获取Top-1预测
pred_top_idx = analysis["top_indices"][0]
pred_top_prob = analysis["top_probs"][0]
# 检查是否正确
correct = False
if step < len(true_labels):
correct = pred_top_idx == true_labels[step]
# 记录结果
all_predictions.append(
{
"predicted_id": pred_top_idx,
"probability": pred_top_prob,
"correct": correct,
}
)
all_correct.append(correct)
# 模拟用户选择将Top-1预测添加到已确认历史
# 注意这里假设用户总是选择Top-1结果
if pred_top_idx > 0: # 只添加有效ID
confirmed_history.append(pred_top_idx)
# 计算整体准确率:所有步骤都正确才算正确
overall_correct = all(all_correct)
# 打印结果
if print_details:
target_char = sample.get("target_char", "")
pinyin = sample.get("pinyin", "")
valid_slot_count = sample.get("valid_slot_count", 0)
print(f" {target_char}", end="")
print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="")
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="")
# 显示逐步预测结果
step_results = []
for i, pred in enumerate(all_predictions):
char_info = (
self.query_engine.query_by_id(pred["predicted_id"])
if self.query_engine
else None
)
char = char_info.char if char_info else f"[{pred['predicted_id']}]"
correct_mark = "" if pred["correct"] else ""
step_results.append(f"{char}{correct_mark}")
print(f" | 逐步:{''.join(step_results)}", end="")
print(f" | 结果:{'' if overall_correct else ''}")
# 返回评估结果
result = {
"sample": sample,
"sequential_predictions": all_predictions,
"correct": overall_correct,
"step_correct": all_correct,
"target_char": sample.get("target_char", ""),
"true_labels": true_labels,
"pinyin_len": pinyin_len,
}
return result
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
"""
评估给定文本生成多个样本进行评估
@ -613,18 +699,12 @@ class TextEvaluator:
try:
sample = self.create_sample_from_text(text)
result = self.evaluate_sample(sample, print_details=True)
# 使用逐步确认评估
result = self.evaluate_sample_with_sequential_confirmation(
sample, print_details=True
)
results.append(result)
if result["analysis"]["low_variance"]:
low_variance_samples.append(
{
"sample_idx": sample_idx,
"max_prob": result["analysis"]["max_prob"],
"entropy": result["analysis"]["entropy"],
}
)
except Exception as e:
print(f"❌ 样本生成或评估失败: {e}")
import traceback
@ -670,10 +750,12 @@ class TextEvaluator:
sample = self.create_sample_from_text(
text, force_slot_count=slot_count
)
result = self.evaluate_sample(sample, print_details=False)
# 使用逐步确认评估(更符合实际使用场景)
result = self.evaluate_sample_with_sequential_confirmation(
sample, print_details=False
)
results.append(result)
if result["analysis"]["low_variance"]:
low_var_count += 1
# 注意逐步确认评估没有low_variance分析
except Exception:
failed += 1
continue
@ -681,19 +763,24 @@ class TextEvaluator:
if results:
correct_count = sum(1 for r in results if r["correct"])
accuracy = correct_count / len(results)
max_probs = [r["analysis"]["max_prob"] for r in results]
mean_max_prob = np.mean(max_probs)
entropies = [r["analysis"]["entropy"] for r in results]
mean_entropy = np.mean(entropies)
# 对于逐步确认评估,计算平均步数准确率
step_accuracies = []
for r in results:
if "step_correct" in r:
step_correct = r["step_correct"]
if step_correct:
step_acc = sum(step_correct) / len(step_correct)
step_accuracies.append(step_acc)
mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0
slot_results[slot_count] = {
"results": results,
"accuracy": accuracy,
"correct": correct_count,
"total": len(results),
"mean_max_prob": mean_max_prob,
"mean_entropy": mean_entropy,
"low_variance_count": low_var_count,
"mean_step_accuracy": mean_step_accuracy,
"failed": failed,
}
@ -703,9 +790,7 @@ class TextEvaluator:
print(
f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} "
f"({correct_count:2d}/{len(results):2d}) "
f"| 平均最高概率: {mean_max_prob:.4f} "
f"| 平均熵: {mean_entropy:.2f} "
f"| 低方差: {low_var_count}"
f"| 平均步数准确率: {mean_step_accuracy:.1%}"
)
else:
slot_results[slot_count] = {
@ -713,9 +798,7 @@ class TextEvaluator:
"accuracy": 0,
"correct": 0,
"total": 0,
"mean_max_prob": 0,
"mean_entropy": 0,
"low_variance_count": 0,
"mean_step_accuracy": 0,
"failed": failed,
}
print(f" 槽位 {slot_count}/8 | 全部样本生成失败")
@ -740,44 +823,27 @@ class TextEvaluator:
accuracy = correct_count / len(results) if results else 0
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
# 概率统计
max_probs = [r["analysis"]["max_prob"] for r in results]
mean_max_prob = np.mean(max_probs) if max_probs else 0
print(f"平均最高概率: {mean_max_prob:.4f}")
# 计算平均步数准确率(对于多字符预测)
step_accuracies = []
for r in results:
if "step_correct" in r:
step_correct = r["step_correct"]
if step_correct:
step_acc = sum(step_correct) / len(step_correct)
step_accuracies.append(step_acc)
# 预测ID:0的比例
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
zero_pred_ratio = zero_pred_count / len(results) if results else 0
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
# 平均熵
entropies = [r["analysis"]["entropy"] for r in results]
mean_entropy = np.mean(entropies) if entropies else 0
print(f"平均熵: {mean_entropy:.2f}")
if step_accuracies:
mean_step_accuracy = np.mean(step_accuracies)
print(f"平均步数准确率: {mean_step_accuracy:.1%}")
# 槽位和拼音长度统计
if results and "sample" in results[0]:
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
pinyin_lengths = [
r["sample"].get("pinyin_input_length", 0) for r in results
]
pinyin_lengths = [r.get("pinyin_len", 1) for r in results]
avg_slots = np.mean(valid_slots) if valid_slots else 0
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
# 低方差样本
if low_variance_samples:
print(f"低方差样本: {len(low_variance_samples)}")
if len(low_variance_samples) <= 5:
for lv in low_variance_samples:
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
else:
for lv in low_variance_samples[:3]:
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
print(f" ... 还有{len(low_variance_samples) - 3}")
else:
print(f"低方差样本: 0个")
def main():
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")

View File

@ -26,6 +26,62 @@ class AttentionPooling(nn.Module):
return pooled
# ---------------------------- 拼音LSTM编码器 ----------------------------
class PinyinLSTMEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2
self.num_layers = num_layers
self.dropout = dropout
# Bidirectional LSTM
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=num_layers,
bidirectional=True,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
# Project concatenated hidden states to input_dim
self.proj = nn.Linear(self.hidden_dim * 2, input_dim)
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, input_dim] pinyin embeddings
mask: [batch, seq_len] optional padding mask (0 for padding)
Returns:
pooled: [batch, input_dim] global pinyin representation
"""
if mask is not None:
# lengths for pack_padded_sequence
lengths = mask.sum(dim=1).cpu()
# pack sequence
packed = nn.utils.rnn.pack_padded_sequence(
x, lengths, batch_first=True, enforce_sorted=False
)
packed_out, (hidden, cell) = self.lstm(packed)
# hidden shape: [num_layers * 2, batch, hidden_dim]
# Take last layer's forward and backward hidden states
forward_hidden = hidden[-2, :, :] # last layer forward
backward_hidden = hidden[-1, :, :] # last layer backward
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
else:
# No mask, assume all sequences same length
output, (hidden, cell) = self.lstm(x)
# hidden shape: [num_layers * 2, batch, hidden_dim]
forward_hidden = hidden[-2, :, :]
backward_hidden = hidden[-1, :, :]
hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1)
projected = self.proj(hidden_concat)
return self.layer_norm(projected)
# ---------------------------- 残差块 ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, dim, dropout_prob=0.3):
@ -96,7 +152,7 @@ class ContextEncoder(nn.Module):
).embeddings
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
self.pos_emb = nn.Embedding(max_len, dim)
self.pinyin_pooling = AttentionPooling(dim)
self.pinyin_pooling = PinyinLSTMEncoder(dim)
# Transformer Encoder (4 layers, 4 heads) [1]
encoder_layer = nn.TransformerEncoderLayer(
@ -125,9 +181,10 @@ class ContextEncoder(nn.Module):
# 2. Embed and pool pinyin to global feature
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
# 方式1Attention Pooling推荐
# LSTM encoder with masking for padding
pinyin_mask = pinyin_ids != 0
pinyin_global = self.pinyin_pooling(
pinyin_emb, mask=None
pinyin_emb, mask=pinyin_mask
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
# Broadcast pinyin to all text positions
@ -279,7 +336,7 @@ class CrossAttentionFusion(nn.Module):
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
# ------------------------------------------------------------------
class MoELayer(nn.Module):
def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4):
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
@ -292,7 +349,7 @@ class MoELayer(nn.Module):
Expert(
input_dim=dim,
d_model=dim,
num_resblocks=export_resblocks,
num_resblocks=num_resblocks,
output_multiplier=1,
)
for _ in range(num_experts)

View File

@ -260,13 +260,29 @@ class PinyinInputDataset(IterableDataset):
part1 = text[0:i]
else:
part1 = text[i - 48 : i]
# 方案C提前检查从位置i开始连续有多少个字符在词库中
max_valid_len = 0
for j in range(i, min(i + 8, len(text))):
if self.query_engine.is_chinese_char(text[j]):
max_valid_len += 1
else:
break
# 如果没有可用字符,跳过
if max_valid_len == 0:
continue
# 首先取随机值pinyin_len1-8pinyin_len取值呈高斯分布最大概率取3
# 获取text[i + pinyin_len]字符如果无法获取所指向的后如果pinyin_len
# part2的长度为x取pinyin_list[i:i+pinyin_len]为part2
# 但是需要注意边界条件
pinyin_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
target_len = np.random.choice(
range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
)
# 根据实际可用长度调整
pinyin_len = min(target_len, max_valid_len)
py_end = min(i + pinyin_len, len(text))
pinyin_len, part2 = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
@ -296,13 +312,13 @@ class PinyinInputDataset(IterableDataset):
+ np.random.choice(range(1, 17))
]
# part4为文本0.50的概率为空
# part4为文本0.30的概率为空
# 不为空则为1-5个连续字符串
# 连续字符串的取值方法为随机从字符库中取一个字符以及该字符后x个字符
# x为2-6中的任意整数取值平均分布
# 使用|将part4中的字符串连接起来
part4 = ""
if random.random() > 0.5:
if random.random() > 0.7:
# 生成1-5个连续字符串
num_strings = random.randint(1, 5)
string_list = []
@ -332,22 +348,6 @@ class PinyinInputDataset(IterableDataset):
if random.random() <= 0.1:
labels.append(0)
# 提取历史槽位从预测位置i之前的字符中获取与eval.py一致
history_slot_list = []
for j in range(i - 1, max(-1, i - 100), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
try:
results = self.query_engine.query_by_char(char, limit=1)
if results and results[0][0] > 0:
history_slot_list.append(results[0][0])
except Exception:
pass
if len(history_slot_list) >= 8:
break
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
@ -358,11 +358,23 @@ class PinyinInputDataset(IterableDataset):
return_token_type_ids=True,
)
samples = []
# 历史槽位长度权重:增加长历史采样比例
# 目标分布: H=0-2占45%, H=3-8占55%
history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
# 修复变量名冲突将内层循环变量i重命名为label_idx
for label_idx, label in enumerate(labels):
repeats = self.adjust_frequency(label)
# 使用从text[0:i]提取的历史槽位与eval.py一致
masked_labels = history_slot_list[:]
base_repeats = self.adjust_frequency(label)
# 根据历史槽位长度调整采样次数
weight = (
history_weights[label_idx]
if label_idx < len(history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
# 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程)
masked_labels = labels[:label_idx]
len_l = len(masked_labels)
masked_labels.extend([0] * (8 - len_l))

View File

@ -38,7 +38,7 @@ class InputMethodEngine(nn.Module):
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
n_layers: int = 4, # Transformer 层数
n_heads: int = 4, # 注意力头数
num_experts: int = 20, # MoE 专家数量
num_experts: int = 10, # MoE 专家数量
max_seq_len: int = 128, # 最大上下文长度
compile: bool = False, # 是否开启 torch.compile 优化
):
@ -72,9 +72,12 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8)
# 5. 分类头
# 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1)
# 6. 分类头
self.classifier = nn.Linear(dim, vocab_size)
# 开启 torch.compile 优化 (如果请求)
@ -127,19 +130,14 @@ class InputMethodEngine(nn.Module):
# 4. MoE 处理 -> [batch, num_slots, dim]
moe_out = self.moe(fused)
# 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile
# 5. 槽位注意力池化
batch_size = input_ids.size(0)
slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1)
numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim]
denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch]
all_slot_mean = moe_out.mean(dim=1) # [batch, dim]
all_zero = denominator == 0 # [batch]
pooled = torch.where(
all_zero.unsqueeze(-1),
all_slot_mean,
numerator / (denominator.unsqueeze(-1) + 1e-8),
)
# 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots]
slot_scores = self.slot_attention(moe_out).squeeze(-1)
# 应用softmax获取注意力权重
slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots]
# 加权求和得到池化表示
pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim]
logits = self.classifier(pooled) # [batch, vocab_size]
return logits

View File

@ -56,11 +56,11 @@ class Trainer:
total_steps: int,
output_dir: str = "./output",
num_epochs: int = 10,
learning_rate: float = 1e-4,
learning_rate: float = 2e-4,
min_learning_rate: float = 1e-6,
weight_decay: float = 0.1,
weight_decay: float = 0.05,
warmup_ratio: float = 0.1,
label_smoothing: float = 0.15,
label_smoothing: float = 0.1,
loss_weight: Optional[torch.Tensor] = None,
grad_accum_steps: int = 1,
clip_grad_norm: float = 1.0,
@ -1060,14 +1060,14 @@ def train(
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
0.1, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
@ -1120,7 +1120,7 @@ def train(
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 20
num_experts = 10
max_seq_len = 128
use_pinyin = True # 始终使用拼音
@ -1383,14 +1383,14 @@ def expand_and_train(
# 训练参数
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"),
min_learning_rate: float = typer.Option(
1e-9, "--min-learning-rate", help="最小学习率"
),
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"),
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
label_smoothing: float = typer.Option(
0.15, "--label-smoothing", help="标签平滑参数"
0.1, "--label-smoothing", help="标签平滑参数"
),
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
@ -1440,7 +1440,7 @@ def expand_and_train(
num_slots = 8
n_layers = 4
n_heads = 4
num_experts = 20
num_experts = 10
max_seq_len = 128
use_pinyin = True # 始终使用拼音
console = Console()