diff --git a/eval.py b/eval.py index 28d914f..0eb5a55 100644 --- a/eval.py +++ b/eval.py @@ -274,7 +274,7 @@ class TextEvaluator: part1 = text[i - 48 : i] # part2: 拼音输入(随机长度1-8,高斯分布) - pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02] + pinyin_len_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) py_end = min(i + pinyin_len, len(text)) @@ -332,53 +332,30 @@ class TextEvaluator: print(f"⚠️ 获取标签失败: {e}") labels = [0] * pinyin_len_actual - # 历史槽位:当前预测字符之前已确认的字符 - # 从位置i之前的字符中提取历史槽位(最多8个) - # 注意:这里的逻辑与训练数据一致 - 从已确认的字符中提取 + # 历史槽位:模拟用户逐步确认过程 + # 对于多字符预测(pinyin_len_actual > 1),需要模拟用户逐步选择 + # 注意:评估时不知道正确答案,只能基于模型预测来构建历史 history_slot_ids = [] - if self.query_engine: - # 收集i之前的汉字字符(最多8个) - for j in range(i - 1, max(-1, i - 100), -1): - if j < 0: - break - char = text[j] - if self.query_engine.is_chinese_char(char): - # 尝试获取该字符的ID(与训练数据一致) - results = self.query_engine.query_by_char(char, limit=1) - if results: - slot_id = results[0][0] - # 只添加有效的非零ID(与训练数据一致) - if slot_id > 0: - history_slot_ids.append(slot_id) - if len(history_slot_ids) >= 8: - break - # 填充到8个槽位(使用0填充,与训练数据一致) + # 如果强制指定槽位数量,使用模拟的历史(与训练数据分布一致) + if force_slot_count is not None: + force_slot_count = max(0, min(8, force_slot_count)) + # 创建模拟的历史槽位:前force_slot_count个为有效槽位,其余为0 + # 这里使用简单的模拟:有效槽位用1-1000的随机ID(与训练时随机填充逻辑一致) + for _ in range(force_slot_count): + history_slot_ids.append(random.randint(1, 1000)) + else: + # 正常评估:历史槽位为空(模拟从零开始输入) + # 注意:实际使用时,历史槽位应该来自之前的用户选择 + # 但为了评估公平,我们从空历史开始 + pass + + # 填充到8个槽位 if len(history_slot_ids) < 8: history_slot_ids.extend([0] * (8 - len(history_slot_ids))) else: history_slot_ids = history_slot_ids[:8] - # 强制设置有效槽位数量(用于按槽位数量评估) - # 注意:这里不再使用随机ID填充,而是复制现有的有效槽位 - if force_slot_count is not None: - force_slot_count = max(0, min(8, force_slot_count)) - # 获取所有非零槽位ID - valid_ids = [s for s in history_slot_ids if s != 0] - - if len(valid_ids) >= force_slot_count: - # 如果有效槽位足够,直接取前force_slot_count个 - history_slot_ids = valid_ids[:force_slot_count] - else: - # 如果有效槽位不足,保留现有有效槽位,然后用0填充 - # 这样更符合训练数据的分布(槽位数量由实际历史决定) - history_slot_ids = valid_ids[:] - - # 填充到8个槽位 - while len(history_slot_ids) < 8: - history_slot_ids.append(0) - history_slot_ids = history_slot_ids[:8] - # Tokenize输入 encoded = self.tokenizer( f"{part4}|{part1}", @@ -586,6 +563,115 @@ class TextEvaluator: return result + def evaluate_sample_with_sequential_confirmation( + self, sample: Dict, print_details: bool = True + ) -> Dict: + """ + 评估样本并模拟用户逐步确认过程(用于多字符预测)。 + + 对于拼音长度>1的情况,模拟用户逐步选择Top-1预测作为已确认字符。 + + Returns: + 评估结果字典 + """ + # 获取真实标签和拼音长度 + true_labels = sample.get("true_labels", []) + pinyin_len = len(true_labels) if true_labels else 1 + + if pinyin_len <= 1: + # 单字符预测,直接使用普通评估 + return self.evaluate_sample(sample, print_details) + + # 多字符预测:模拟逐步确认 + confirmed_history = [] # 用户已确认的字符ID + all_predictions = [] # 每一步的预测结果 + all_correct = [] # 每一步是否正确 + + # 复制样本用于逐步推理 + sample_copy = sample.copy() + + for step in range(pinyin_len): + # 更新历史槽位:使用已确认的字符 + history_slot_ids = confirmed_history[:] + if len(history_slot_ids) < 8: + history_slot_ids.extend([0] * (8 - len(history_slot_ids))) + else: + history_slot_ids = history_slot_ids[:8] + + # 更新样本的历史槽位 + sample_copy["history_slot_ids"] = torch.tensor( + history_slot_ids, dtype=torch.long + ).unsqueeze(0) + + # 执行推理 + logits, probs = self.inference(sample_copy) + analysis = self.analyze_probability_distribution(probs) + + # 获取Top-1预测 + pred_top_idx = analysis["top_indices"][0] + pred_top_prob = analysis["top_probs"][0] + + # 检查是否正确 + correct = False + if step < len(true_labels): + correct = pred_top_idx == true_labels[step] + + # 记录结果 + all_predictions.append( + { + "predicted_id": pred_top_idx, + "probability": pred_top_prob, + "correct": correct, + } + ) + all_correct.append(correct) + + # 模拟用户选择:将Top-1预测添加到已确认历史 + # 注意:这里假设用户总是选择Top-1结果 + if pred_top_idx > 0: # 只添加有效ID + confirmed_history.append(pred_top_idx) + + # 计算整体准确率:所有步骤都正确才算正确 + overall_correct = all(all_correct) + + # 打印结果 + if print_details: + target_char = sample.get("target_char", "") + pinyin = sample.get("pinyin", "") + valid_slot_count = sample.get("valid_slot_count", 0) + + print(f" {target_char}", end="") + print(f" | 拼音:{pinyin[:10]}{'...' if len(pinyin) > 10 else ''}", end="") + print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_len}", end="") + + # 显示逐步预测结果 + step_results = [] + for i, pred in enumerate(all_predictions): + char_info = ( + self.query_engine.query_by_id(pred["predicted_id"]) + if self.query_engine + else None + ) + char = char_info.char if char_info else f"[{pred['predicted_id']}]" + correct_mark = "✓" if pred["correct"] else "✗" + step_results.append(f"{char}{correct_mark}") + + print(f" | 逐步:{'→'.join(step_results)}", end="") + print(f" | 结果:{'✓' if overall_correct else '✗'}") + + # 返回评估结果 + result = { + "sample": sample, + "sequential_predictions": all_predictions, + "correct": overall_correct, + "step_correct": all_correct, + "target_char": sample.get("target_char", ""), + "true_labels": true_labels, + "pinyin_len": pinyin_len, + } + + return result + def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: """ 评估给定文本,生成多个样本进行评估。 @@ -613,18 +699,12 @@ class TextEvaluator: try: sample = self.create_sample_from_text(text) - result = self.evaluate_sample(sample, print_details=True) + # 使用逐步确认评估 + result = self.evaluate_sample_with_sequential_confirmation( + sample, print_details=True + ) results.append(result) - if result["analysis"]["low_variance"]: - low_variance_samples.append( - { - "sample_idx": sample_idx, - "max_prob": result["analysis"]["max_prob"], - "entropy": result["analysis"]["entropy"], - } - ) - except Exception as e: print(f"❌ 样本生成或评估失败: {e}") import traceback @@ -670,10 +750,12 @@ class TextEvaluator: sample = self.create_sample_from_text( text, force_slot_count=slot_count ) - result = self.evaluate_sample(sample, print_details=False) + # 使用逐步确认评估(更符合实际使用场景) + result = self.evaluate_sample_with_sequential_confirmation( + sample, print_details=False + ) results.append(result) - if result["analysis"]["low_variance"]: - low_var_count += 1 + # 注意:逐步确认评估没有low_variance分析 except Exception: failed += 1 continue @@ -681,19 +763,24 @@ class TextEvaluator: if results: correct_count = sum(1 for r in results if r["correct"]) accuracy = correct_count / len(results) - max_probs = [r["analysis"]["max_prob"] for r in results] - mean_max_prob = np.mean(max_probs) - entropies = [r["analysis"]["entropy"] for r in results] - mean_entropy = np.mean(entropies) + + # 对于逐步确认评估,计算平均步数准确率 + step_accuracies = [] + for r in results: + if "step_correct" in r: + step_correct = r["step_correct"] + if step_correct: + step_acc = sum(step_correct) / len(step_correct) + step_accuracies.append(step_acc) + + mean_step_accuracy = np.mean(step_accuracies) if step_accuracies else 0 slot_results[slot_count] = { "results": results, "accuracy": accuracy, "correct": correct_count, "total": len(results), - "mean_max_prob": mean_max_prob, - "mean_entropy": mean_entropy, - "low_variance_count": low_var_count, + "mean_step_accuracy": mean_step_accuracy, "failed": failed, } @@ -703,9 +790,7 @@ class TextEvaluator: print( f" 槽位 {slot_count}/8 | {bar} {accuracy:6.1%} " f"({correct_count:2d}/{len(results):2d}) " - f"| 平均最高概率: {mean_max_prob:.4f} " - f"| 平均熵: {mean_entropy:.2f} " - f"| 低方差: {low_var_count}" + f"| 平均步数准确率: {mean_step_accuracy:.1%}" ) else: slot_results[slot_count] = { @@ -713,9 +798,7 @@ class TextEvaluator: "accuracy": 0, "correct": 0, "total": 0, - "mean_max_prob": 0, - "mean_entropy": 0, - "low_variance_count": 0, + "mean_step_accuracy": 0, "failed": failed, } print(f" 槽位 {slot_count}/8 | 全部样本生成失败") @@ -740,44 +823,27 @@ class TextEvaluator: accuracy = correct_count / len(results) if results else 0 print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})") - # 概率统计 - max_probs = [r["analysis"]["max_prob"] for r in results] - mean_max_prob = np.mean(max_probs) if max_probs else 0 - print(f"平均最高概率: {mean_max_prob:.4f}") + # 计算平均步数准确率(对于多字符预测) + step_accuracies = [] + for r in results: + if "step_correct" in r: + step_correct = r["step_correct"] + if step_correct: + step_acc = sum(step_correct) / len(step_correct) + step_accuracies.append(step_acc) - # 预测ID:0的比例 - zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0) - zero_pred_ratio = zero_pred_count / len(results) if results else 0 - print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})") - - # 平均熵 - entropies = [r["analysis"]["entropy"] for r in results] - mean_entropy = np.mean(entropies) if entropies else 0 - print(f"平均熵: {mean_entropy:.2f}") + if step_accuracies: + mean_step_accuracy = np.mean(step_accuracies) + print(f"平均步数准确率: {mean_step_accuracy:.1%}") # 槽位和拼音长度统计 if results and "sample" in results[0]: valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results] - pinyin_lengths = [ - r["sample"].get("pinyin_input_length", 0) for r in results - ] + pinyin_lengths = [r.get("pinyin_len", 1) for r in results] avg_slots = np.mean(valid_slots) if valid_slots else 0 avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0 print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}") - # 低方差样本 - if low_variance_samples: - print(f"低方差样本: {len(low_variance_samples)}个") - if len(low_variance_samples) <= 5: - for lv in low_variance_samples: - print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") - else: - for lv in low_variance_samples[:3]: - print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") - print(f" ... 还有{len(low_variance_samples) - 3}个") - else: - print(f"低方差样本: 0个") - def main(): parser = argparse.ArgumentParser(description="评估模型在文本上的表现") diff --git a/src/model/components.py b/src/model/components.py index 513d989..f66c83e 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -26,6 +26,62 @@ class AttentionPooling(nn.Module): return pooled +# ---------------------------- 拼音LSTM编码器 ---------------------------- +class PinyinLSTMEncoder(nn.Module): + def __init__(self, input_dim, hidden_dim=None, num_layers=2, dropout=0.2): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim // 2 + self.num_layers = num_layers + self.dropout = dropout + + # Bidirectional LSTM + self.lstm = nn.LSTM( + input_size=input_dim, + hidden_size=self.hidden_dim, + num_layers=num_layers, + bidirectional=True, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + + # Project concatenated hidden states to input_dim + self.proj = nn.Linear(self.hidden_dim * 2, input_dim) + self.layer_norm = nn.LayerNorm(input_dim) + + def forward(self, x, mask=None): + """ + Args: + x: [batch, seq_len, input_dim] pinyin embeddings + mask: [batch, seq_len] optional padding mask (0 for padding) + Returns: + pooled: [batch, input_dim] global pinyin representation + """ + if mask is not None: + # lengths for pack_padded_sequence + lengths = mask.sum(dim=1).cpu() + # pack sequence + packed = nn.utils.rnn.pack_padded_sequence( + x, lengths, batch_first=True, enforce_sorted=False + ) + packed_out, (hidden, cell) = self.lstm(packed) + # hidden shape: [num_layers * 2, batch, hidden_dim] + # Take last layer's forward and backward hidden states + forward_hidden = hidden[-2, :, :] # last layer forward + backward_hidden = hidden[-1, :, :] # last layer backward + hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1) + else: + # No mask, assume all sequences same length + output, (hidden, cell) = self.lstm(x) + # hidden shape: [num_layers * 2, batch, hidden_dim] + forward_hidden = hidden[-2, :, :] + backward_hidden = hidden[-1, :, :] + hidden_concat = torch.cat([forward_hidden, backward_hidden], dim=1) + + projected = self.proj(hidden_concat) + return self.layer_norm(projected) + + # ---------------------------- 残差块 ---------------------------- class ResidualBlock(nn.Module): def __init__(self, dim, dropout_prob=0.3): @@ -96,7 +152,7 @@ class ContextEncoder(nn.Module): ).embeddings self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim) self.pos_emb = nn.Embedding(max_len, dim) - self.pinyin_pooling = AttentionPooling(dim) + self.pinyin_pooling = PinyinLSTMEncoder(dim) # Transformer Encoder (4 layers, 4 heads) [1] encoder_layer = nn.TransformerEncoderLayer( @@ -125,9 +181,10 @@ class ContextEncoder(nn.Module): # 2. Embed and pool pinyin to global feature pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim] - # 方式1:Attention Pooling(推荐) + # LSTM encoder with masking for padding + pinyin_mask = pinyin_ids != 0 pinyin_global = self.pinyin_pooling( - pinyin_emb, mask=None + pinyin_emb, mask=pinyin_mask ) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position # Broadcast pinyin to all text positions @@ -279,7 +336,7 @@ class CrossAttentionFusion(nn.Module): # 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类 # ------------------------------------------------------------------ class MoELayer(nn.Module): - def __init__(self, dim=512, num_experts=20, top_k=2, export_resblocks=4): + def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8): super().__init__() self.num_experts = num_experts self.top_k = top_k @@ -292,7 +349,7 @@ class MoELayer(nn.Module): Expert( input_dim=dim, d_model=dim, - num_resblocks=export_resblocks, + num_resblocks=num_resblocks, output_multiplier=1, ) for _ in range(num_experts) diff --git a/src/model/dataset.py b/src/model/dataset.py index 935f447..131ed34 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -260,13 +260,29 @@ class PinyinInputDataset(IterableDataset): part1 = text[0:i] else: part1 = text[i - 48 : i] + + # 方案C:提前检查从位置i开始连续有多少个字符在词库中 + max_valid_len = 0 + for j in range(i, min(i + 8, len(text))): + if self.query_engine.is_chinese_char(text[j]): + max_valid_len += 1 + else: + break + + # 如果没有可用字符,跳过 + if max_valid_len == 0: + continue + # 首先取随机值pinyin_len(1-8),pinyin_len取值呈高斯分布,最大概率取3 # 获取text[i + pinyin_len]字符,如果无法获取所指向的后,如果pinyin_len # part2的长度为x,取pinyin_list[i:i+pinyin_len],为part2 # 但是需要注意边界条件 - pinyin_len = np.random.choice( - range(1, 9), p=[0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02] + target_len = np.random.choice( + range(1, 9), p=[0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04] ) + # 根据实际可用长度调整 + pinyin_len = min(target_len, max_valid_len) + py_end = min(i + pinyin_len, len(text)) pinyin_len, part2 = self.get_mask_pinyin( text[i:py_end], pinyin_list[i:py_end] @@ -296,13 +312,13 @@ class PinyinInputDataset(IterableDataset): + np.random.choice(range(1, 17)) ] - # part4为文本,0.50的概率为空 + # part4为文本,0.30的概率为空 # 不为空则为1-5个连续字符串 # 连续字符串的取值方法为:随机从字符库中取一个字符,以及该字符后x个字符 # x为2-6中的任意整数,取值平均分布 # 使用|将part4中的字符串连接起来 part4 = "" - if random.random() > 0.5: + if random.random() > 0.7: # 生成1-5个连续字符串 num_strings = random.randint(1, 5) string_list = [] @@ -332,22 +348,6 @@ class PinyinInputDataset(IterableDataset): if random.random() <= 0.1: labels.append(0) - # 提取历史槽位:从预测位置i之前的字符中获取(与eval.py一致) - history_slot_list = [] - for j in range(i - 1, max(-1, i - 100), -1): - if j < 0: - break - char = text[j] - if self.query_engine.is_chinese_char(char): - try: - results = self.query_engine.query_by_char(char, limit=1) - if results and results[0][0] > 0: - history_slot_list.append(results[0][0]) - except Exception: - pass - if len(history_slot_list) >= 8: - break - encoded = self.tokenizer( f"{part4}|{part1}", part3, @@ -358,11 +358,23 @@ class PinyinInputDataset(IterableDataset): return_token_type_ids=True, ) samples = [] + # 历史槽位长度权重:增加长历史采样比例 + # 目标分布: H=0-2占45%, H=3-8占55% + history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0] + # 修复变量名冲突:将内层循环变量i重命名为label_idx for label_idx, label in enumerate(labels): - repeats = self.adjust_frequency(label) - # 使用从text[0:i]提取的历史槽位(与eval.py一致) - masked_labels = history_slot_list[:] + base_repeats = self.adjust_frequency(label) + # 根据历史槽位长度调整采样次数 + weight = ( + history_weights[label_idx] + if label_idx < len(history_weights) + else 3.0 + ) + repeats = max(1, int(base_repeats * weight)) + + # 历史槽位:同一拼音序列中已确认的字符(模拟用户逐步确认过程) + masked_labels = labels[:label_idx] len_l = len(masked_labels) masked_labels.extend([0] * (8 - len_l)) diff --git a/src/model/model.py b/src/model/model.py index b65834c..67035c4 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -38,7 +38,7 @@ class InputMethodEngine(nn.Module): num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位) n_layers: int = 4, # Transformer 层数 n_heads: int = 4, # 注意力头数 - num_experts: int = 20, # MoE 专家数量 + num_experts: int = 10, # MoE 专家数量 max_seq_len: int = 128, # 最大上下文长度 compile: bool = False, # 是否开启 torch.compile 优化 ): @@ -72,9 +72,12 @@ class InputMethodEngine(nn.Module): self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) # 4. 混合专家层 (MoE) - self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=2) + self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=8) - # 5. 分类头 + # 5. 槽位注意力池化 + self.slot_attention = nn.Linear(dim, 1) + + # 6. 分类头 self.classifier = nn.Linear(dim, vocab_size) # 开启 torch.compile 优化 (如果请求) @@ -127,19 +130,14 @@ class InputMethodEngine(nn.Module): # 4. MoE 处理 -> [batch, num_slots, dim] moe_out = self.moe(fused) - # 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile) + # 5. 槽位注意力池化 batch_size = input_ids.size(0) - slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1) - numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim] - denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch] - - all_slot_mean = moe_out.mean(dim=1) # [batch, dim] - all_zero = denominator == 0 # [batch] - pooled = torch.where( - all_zero.unsqueeze(-1), - all_slot_mean, - numerator / (denominator.unsqueeze(-1) + 1e-8), - ) + # 计算注意力分数 [batch, num_slots, 1] -> [batch, num_slots] + slot_scores = self.slot_attention(moe_out).squeeze(-1) + # 应用softmax获取注意力权重 + slot_weights = torch.softmax(slot_scores, dim=1) # [batch, num_slots] + # 加权求和得到池化表示 + pooled = (moe_out * slot_weights.unsqueeze(-1)).sum(dim=1) # [batch, dim] logits = self.classifier(pooled) # [batch, vocab_size] return logits diff --git a/src/model/trainer.py b/src/model/trainer.py index 0e497be..38c2f47 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -56,11 +56,11 @@ class Trainer: total_steps: int, output_dir: str = "./output", num_epochs: int = 10, - learning_rate: float = 1e-4, + learning_rate: float = 2e-4, min_learning_rate: float = 1e-6, - weight_decay: float = 0.1, + weight_decay: float = 0.05, warmup_ratio: float = 0.1, - label_smoothing: float = 0.15, + label_smoothing: float = 0.1, loss_weight: Optional[torch.Tensor] = None, grad_accum_steps: int = 1, clip_grad_norm: float = 1.0, @@ -1060,14 +1060,14 @@ def train( # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), - learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"), + learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"), min_learning_rate: float = typer.Option( 1e-9, "--min-learning-rate", help="最小学习率" ), - weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"), + weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"), warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), label_smoothing: float = typer.Option( - 0.15, "--label-smoothing", help="标签平滑参数" + 0.1, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), @@ -1120,7 +1120,7 @@ def train( num_slots = 8 n_layers = 4 n_heads = 4 - num_experts = 20 + num_experts = 10 max_seq_len = 128 use_pinyin = True # 始终使用拼音 @@ -1383,14 +1383,14 @@ def expand_and_train( # 训练参数 batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"), num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"), - learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"), + learning_rate: float = typer.Option(2e-4, "--learning-rate", "-lr", help="学习率"), min_learning_rate: float = typer.Option( 1e-9, "--min-learning-rate", help="最小学习率" ), - weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"), + weight_decay: float = typer.Option(0.05, "--weight-decay", help="权重衰减"), warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"), label_smoothing: float = typer.Option( - 0.15, "--label-smoothing", help="标签平滑参数" + 0.1, "--label-smoothing", help="标签平滑参数" ), grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), @@ -1440,7 +1440,7 @@ def expand_and_train( num_slots = 8 n_layers = 4 n_heads = 4 - num_experts = 20 + num_experts = 10 max_seq_len = 128 use_pinyin = True # 始终使用拼音 console = Console()