feat(MoELayer): 添加 moe_mode 支持稀疏和图内计算策略

This commit is contained in:
songsenand 2026-05-10 10:26:44 +08:00
parent e8eab1f260
commit 432132a108
6 changed files with 674 additions and 61 deletions

View File

@ -341,15 +341,37 @@ class CrossAttentionFusion(nn.Module):
# 4. 专家混合层 (MoE Layer)
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
# ------------------------------------------------------------------
@torch.compiler.allow_in_graph
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
output = torch.zeros_like(x_flat)
for e in range(num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
return output
class MoELayer(nn.Module):
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
"""
moe_mode 支持三种策略:
- "all": 计算全部专家torch.compile 不断裂 (当前默认)
- "sparse": 只计算被路由到的专家 (产生 graph break)
- "sparse_allow_graph": 稀疏 MoE通过 allow_in_graph 避免 graph break
"""
def __init__(
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.dim = dim
self.moe_mode = moe_mode
# Import Expert from your existing components
# Assuming Expert class is defined as in components.py [2]
self.experts = nn.ModuleList(
[
Expert(
@ -362,48 +384,41 @@ class MoELayer(nn.Module):
]
)
# Gating Network [2]
self.gate = nn.Linear(dim, num_experts)
def forward(self, x):
"""
并行化 MoE 前向传播完全兼容 torch.compile AMP
Args:
x: [batch, seq_len, dim]
Returns:
out: [batch, seq_len, dim]
"""
B, L, D = x.shape
num_tokens = B * L
x_flat = x.view(num_tokens, D)
# 展平输入以便处理
x_flat = x.view(num_tokens, D) # [B*L, D]
gates = self.gate(x_flat)
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
topk_weights = F.softmax(topk_weights, dim=-1)
# 1. 计算门控分数
gates = self.gate(x_flat) # [B*L, num_experts]
if self.moe_mode == "all":
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
)
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
out_flat = weighted_outputs.sum(dim=1)
# 2. 选择 Top-K 专家
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
elif self.moe_mode == "sparse":
out_flat = torch.zeros_like(x_flat)
for e in range(self.num_experts):
mask = topk_indices == e
idx, k_idx = mask.nonzero(as_tuple=True)
if idx.numel() > 0:
w = topk_weights[idx, k_idx].unsqueeze(-1)
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
# 归一化权重
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
elif self.moe_mode == "sparse_allow_graph":
out_flat = _sparse_moe_dispatch(
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
)
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
expert_outputs = torch.stack(
[expert(x_flat) for expert in self.experts], dim=1
) # [B*L, num_experts, D]
else:
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
# 4. 使用 gather 选择对应专家的输出
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
selected_outputs = torch.gather(
expert_outputs, 1, indices_expanded
) # [B*L, K, D]
# 5. 加权求和
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
# 恢复原始形状
return out_flat.view(B, L, D)

View File

@ -4,7 +4,8 @@
功能
1. 统计 labels 的分布出现次数比例最大/最小未出现标签数
2. 随机抽样还原为人类可读文本导出为 CSV 文件
2. 统计 history_slot_ids 的有效长度分布
3. 抽样还原为人类可读文本导出为 CSV 文件
用法
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
@ -14,7 +15,6 @@
import argparse
import csv
import json
import random
from collections import Counter
from pathlib import Path
@ -86,6 +86,36 @@ def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
return counter, total
def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0):
"""统计 history_slot_ids 的有效长度分布(非零元素个数)"""
logger.info("正在统计 history_slot_ids 长度分布...")
counter = Counter()
total = 0
num_shards = dataset._num_shards if dataset._is_sharded else 1
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard")
for shard_idx in pbar:
if dataset._is_sharded:
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
history_slots = shard_data["history_slot_ids"].astype(np.int64)
else:
history_slots = dataset.history_slot_ids[:].astype(np.int64)
lengths = np.count_nonzero(history_slots, axis=1)
unique, counts = np.unique(lengths, return_counts=True)
for uid, cnt in zip(unique, counts):
counter[int(uid)] += cnt
total += len(history_slots)
if dataset._is_sharded:
del shard_data
return counter, total
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
"""将一个样本还原为人类可读格式"""
input_ids = (
@ -139,6 +169,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
pinyin_str = decode_pinyin_ids(pinyin_ids)
label_info = query_engine.query_by_id(labels)
history_str = decode_history(history_ids, query_engine)
history_slot_length = sum(1 for h in history_ids if h != 0)
return {
"context": context_text,
@ -150,6 +181,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
else f"<ID:{labels}>",
"label_count": label_info.count if label_info else 0,
"history": history_str,
"history_slot_length": history_slot_length,
"full_tokens": token_text,
}
@ -168,7 +200,7 @@ def main():
"--num-samples",
type=int,
default=50,
help="随机抽样的样本数量默认50",
help="抽样的样本数量默认50,取前 N 个样本",
)
parser.add_argument(
"--output",
@ -182,24 +214,14 @@ def main():
default=0,
help="统计 labels 时最多读取的分片数0=全部)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="随机种子",
)
parser.add_argument(
"--top-k",
type=int,
default=30,
help="显示出现次数最多和最少的标签数量",
)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
if args.output is None:
args.output = str(Path(args.data_dir) / "samples.csv")
@ -332,13 +354,56 @@ def main():
)
console.print(table_dist)
# ====== 2. 随机抽样还原 → CSV ======
num_samples = min(args.num_samples, len(dataset))
console.print(
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
# ====== 2. 历史槽位长度分析 ======
console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]")
history_counter, history_total = analyze_history_slots(
dataset, max_shards=args.max_shards
)
indices = random.sample(range(len(dataset)), num_samples)
if not history_counter:
console.print("[yellow] 无历史槽位数据[/yellow]")
else:
sorted_items = sorted(history_counter.items())
lengths_arr = [l for l, _ in sorted_items]
counts_arr = [c for _, c in sorted_items]
weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr))
avg_length = weighted_sum / history_total if history_total > 0 else 0
cumsum = 0
median_length = 0
for length, count in sorted_items:
cumsum += count
if cumsum >= history_total / 2:
median_length = length
break
console.print(f"\n总样本数: {history_total:,}")
console.print(f"最大历史槽位数: {max(lengths_arr)}")
console.print(f"最小历史槽位数: {min(lengths_arr)}")
console.print(f"平均历史槽位数: {avg_length:.2f}")
console.print(f"中位数历史槽位数: {median_length}")
hist_table = Table(
title="历史槽位有效长度分布",
show_header=True,
header_style="bold magenta",
)
hist_table.add_column("槽位数", style="cyan", width=10)
hist_table.add_column("样本数", style="green", width=12)
hist_table.add_column("占比", style="yellow", width=10)
for length, count in sorted_items:
pct = count / history_total * 100
hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%")
console.print(hist_table)
# ====== 3. 抽样还原 → CSV ======
num_samples = min(args.num_samples, len(dataset))
console.print(
f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
)
indices = list(range(num_samples))
csv_path = Path(args.output)
csv_path.parent.mkdir(parents=True, exist_ok=True)
@ -352,6 +417,7 @@ def main():
"context",
"suffix",
"history",
"history_slot_length",
"full_tokens",
]
@ -372,6 +438,7 @@ def main():
decoded["context"],
decoded["suffix"],
decoded["history"],
decoded["history_slot_length"],
decoded["full_tokens"],
]
)

87
src/model/merge_shards.py Normal file
View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片"""
import argparse
import gc
import json
from pathlib import Path
import numpy as np
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"):
with open(data_dir / "metadata.json") as f:
meta = json.load(f)
num_shards = meta["num_shards"]
if num_shards <= 1:
print(f"{data_dir}: already single shard, skip")
return
merged = {}
for field in FIELDS:
pieces = []
for i in tqdm(
range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard"
):
data = np.load(data_dir / f"shard_{i:06d}.npz")
pieces.append(data[field])
data.close()
merged[field] = np.concatenate(pieces, axis=0)
del pieces
gc.collect()
total = len(merged["labels"])
# 先删原分片(避免与新文件名冲突)
for i in range(num_shards):
(data_dir / f"shard_{i:06d}.npz").unlink()
# 再写合并文件
np.savez_compressed(data_dir / output_name, **merged)
del merged
gc.collect()
# 更新 metadata
meta["num_shards"] = 1
meta["shard_size"] = total
with open(data_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples")
def main():
parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件")
parser.add_argument(
"--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/"
)
parser.add_argument(
"--output-name", type=str, default="shard_000000.npz", help="合并后文件名"
)
parser.add_argument(
"--train-only", action="store_true", help="仅合并 train跳过 eval"
)
args = parser.parse_args()
root = Path(args.input_dir)
for split in ["train", "eval"]:
split_dir = root / split
if not split_dir.exists():
print(f"{split_dir}: not found, skip")
continue
if args.train_only and split == "eval":
continue
merge_split(split_dir, args.output_name)
if __name__ == "__main__":
main()

View File

@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module):
vocab_size: int = 10019,
pinyin_vocab_size: int = 28,
dim: int = 512,
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
n_layers: int = 4, # Transformer 层数
n_heads: int = 4, # 注意力头数
num_experts: int = 10, # MoE 专家数量
max_seq_len: int = 128, # 最大上下文长度
compile: bool = False, # 是否开启 torch.compile 优化
num_slots: int = 8,
n_layers: int = 4,
n_heads: int = 4,
num_experts: int = 10,
max_seq_len: int = 128,
compile: bool = False,
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
):
super().__init__()
self.dim = dim
@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module):
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
# 4. 混合专家层 (MoE)
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
self.moe = MoELayer(
dim=dim,
num_experts=num_experts,
top_k=3,
num_resblocks=12,
moe_mode=moe_mode,
)
# 5. 槽位注意力池化
self.slot_attention = nn.Linear(dim, 1)

429
src/model/subsample.py Normal file
View File

@ -0,0 +1,429 @@
#!/usr/bin/env python3
"""
子采样脚本从预处理的 .npz 分片中抽取子集
策略
1. 第1遍扫描只读 labels 字段统计每个 label ID 的样本数记录分片大小
2. 中间计算按每 ID 最多 N 个样本封顶硬封顶不足目标总量则从超额池补足
3. 第2遍扫描读取全部字段按精确保留配额抽取训练样本同时抽取评估集
用法
python -m model.subsample \
--input-dir ./preprocessed \
--output-dir ./subsampled \
--cap-per-label 300000 \
--target-total 100000000 \
--num-eval 2560
"""
import argparse
import gc
import json
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
from loguru import logger
from rich.console import Console
from tqdm import tqdm
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _global_to_shard(
positions: np.ndarray, shard_sizes: List[int]
) -> Dict[int, List[int]]:
"""将全局位置映射为 {shard_idx: [local_indices]}"""
cumsum = np.cumsum([0] + shard_sizes)
mapping: Dict[int, List[int]] = {}
for pos in positions:
shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1)
local_idx = pos - cumsum[shard_idx]
mapping.setdefault(shard_idx, []).append(local_idx)
return mapping
def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]:
"""
第1遍扫描只读 labels统计每个 label ID 的总样本数记录各分片大小
返回 (label_counts, shard_sizes)
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
label_counts: Dict[int, int] = {}
shard_sizes: List[int] = []
pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard")
for shard_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz"
data = np.load(shard_path)
labels = data["labels"]
n = len(labels)
shard_sizes.append(n)
unique, counts = np.unique(labels, return_counts=True)
for uid, cnt in zip(unique, counts):
uid = int(uid)
label_counts[uid] = label_counts.get(uid, 0) + int(cnt)
data.close()
pbar.update(1)
pbar.close()
return label_counts, shard_sizes
def compute_quotas(
label_counts: Dict[int, int],
cap: int = 300_000,
target_total: int = 100_000_000,
) -> Dict[int, int]:
"""
计算每个 label 的精确保留配额硬封顶
策略
- ID 保留 min(count, cap)
- 若封顶后总量 >= target_total直接用封顶策略
- 若封顶后总量 < target_total从超额池count > cap 的部分等比抽取补足
"""
capped_total = sum(min(cnt, cap) for cnt in label_counts.values())
if capped_total >= target_total:
return {label: min(cnt, cap) for label, cnt in label_counts.items()}
# 封顶不足,从超额池补足
deficit = target_total - capped_total
excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()}
excess_total = sum(excess_per_label.values())
quotas: Dict[int, int] = {}
if excess_total > 0:
ratio = min(1.0, deficit / excess_total)
for label, cnt in label_counts.items():
base = min(cnt, cap)
extra = int(excess_per_label[label] * ratio)
quotas[label] = base + extra
else:
quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()}
return quotas
def pass2_subsample(
input_dir: Path,
output_dir: Path,
split: str,
quotas: Dict[int, int],
eval_map: Dict[int, List[int]],
shard_sizes: List[int],
shard_size: int = 5_000_000,
) -> Tuple[int, int]:
"""
第2遍扫描读取全部字段抽取评估样本 + 按精确保留配额子采样训练样本
quotas: {label_id: exact_number_to_keep}
返回 (train_kept, eval_kept)
"""
metadata_path = input_dir / split / "metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
num_shards = metadata["num_shards"]
max_seq_length = metadata["max_seq_length"]
output_train_dir = output_dir / "train"
output_eval_dir = output_dir / "eval"
output_train_dir.mkdir(parents=True, exist_ok=True)
output_eval_dir.mkdir(parents=True, exist_ok=True)
rng = np.random.RandomState()
remaining = dict(quotas) # {label_id: remaining_to_keep}
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
train_buf_count = 0
train_shard_idx = 0
total_train_kept = 0
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
total_eval_kept = 0
pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard")
for src_shard_idx in range(num_shards):
shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz"
data = np.load(shard_path)
n = shard_sizes[src_shard_idx]
eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64)
is_eval = np.zeros(n, dtype=bool)
if len(eval_local) > 0:
is_eval[eval_local] = True
# ── 第一步:只加载 labels计算掩码 ──
labels = data["labels"]
# 抽取评估样本的 labels
if len(eval_local) > 0:
eval_buffers["labels"].append(labels[eval_local].copy())
# 计算训练保留位置
train_candidates = labels[~is_eval]
train_n = len(train_candidates)
n_kept = 0
keep_original_indices = np.array([], dtype=np.int64)
if train_n > 0:
sort_idx = np.argsort(train_candidates, kind="stable")
sorted_labels = train_candidates[sort_idx]
unique_vals, starts = np.unique(sorted_labels, return_index=True)
ends = np.append(starts[1:], train_n)
train_keep_mask = np.zeros(train_n, dtype=bool)
for i in range(len(unique_vals)):
label_val = int(unique_vals[i])
start, end = starts[i], ends[i]
cnt_in_shard = end - start
need = remaining.get(label_val, 0)
select = min(cnt_in_shard, need)
if select >= cnt_in_shard:
train_keep_mask[sort_idx[start:end]] = True
elif select > 0:
chosen = rng.choice(cnt_in_shard, size=select, replace=False)
train_keep_mask[sort_idx[start + chosen]] = True
remaining[label_val] = max(0, remaining.get(label_val, 0) - select)
original_train_indices = np.where(~is_eval)[0]
keep_original_indices = original_train_indices[train_keep_mask]
n_kept = len(keep_original_indices)
# 释放标签相关的临时数组
del train_candidates, sort_idx, sorted_labels
del unique_vals, starts, ends, train_keep_mask
del original_train_indices
gc.collect()
del labels, is_eval
gc.collect()
# ── 第二步:逐字段加载,抽取评估 + 训练 ──
# 评估:跳过 labels已在第一步抽取
if len(eval_local) > 0:
for f in FIELDS:
if f == "labels":
continue
arr = data[f]
eval_buffers[f].append(arr[eval_local].copy())
del arr
gc.collect()
total_eval_kept += len(eval_local)
# 训练:逐字段加载,立即删除
if n_kept > 0:
for f in FIELDS:
arr = data[f]
train_buffers[f].append(arr[keep_original_indices].copy())
del arr
gc.collect()
train_buf_count += n_kept
total_train_kept += n_kept
del keep_original_indices
data.close()
gc.collect()
if train_buf_count >= shard_size:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
logger.debug(
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
)
train_shard_idx += 1
train_buffers = {f: [] for f in FIELDS}
train_buf_count = 0
del merged
gc.collect()
pbar.update(1)
pbar.close()
# 剩余缓冲
if train_buf_count > 0:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(train_buffers[f], axis=0)
np.savez_compressed(
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
)
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
train_shard_idx += 1
# 评估数据
if total_eval_kept > 0:
eval_merged = {}
for f in FIELDS:
eval_merged[f] = np.concatenate(eval_buffers[f], axis=0)
np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged)
else:
logger.warning("No eval samples extracted!")
# 写 metadata
train_metadata = {
"num_samples": total_train_kept,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": train_shard_idx,
}
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
eval_metadata = {
"num_samples": total_eval_kept,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": total_eval_kept,
"num_shards": 1,
}
with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(eval_metadata, f, indent=2, ensure_ascii=False)
return total_train_kept, total_eval_kept
def main():
console = Console()
parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集")
parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录")
parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录")
parser.add_argument(
"--cap-per-label",
type=int,
default=300_000,
help="每个 label ID 最大保留样本数",
)
parser.add_argument(
"--target-total",
type=int,
default=100_000_000,
help="目标训练集样本总数",
)
parser.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="输出分片大小(样本数)",
)
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]")
console.print(f"输入目录: {input_dir}")
console.print(f"输出目录: {output_dir}")
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
console.print(f"目标训练集: {args.target_total:,}")
console.print(f"评估集: {args.num_eval}")
console.print()
# ── 第 1 遍:统计 ──
console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]")
label_counts, shard_sizes = pass1_count(input_dir, "train")
total_samples = sum(shard_sizes)
num_labels = len(label_counts)
capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values())
console.print(f"分片数: {len(shard_sizes)}")
console.print(f"总样本数: {total_samples:,}")
console.print(f"标签种类: {num_labels}")
console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}")
console.print()
# ── 计算保留配额 ──
quotas = compute_quotas(
label_counts,
cap=args.cap_per_label,
target_total=args.target_total,
)
expected_train = sum(quotas.values())
console.print(f"期望训练集大小: {expected_train:,}")
# 标签分布统计
stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label)
stats_n_below = num_labels - stats_n_capped
stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label)
console.print(
f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, "
f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}"
)
console.print()
# ── 抽取评估集位置 ──
rng = np.random.RandomState()
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
eval_positions.sort()
eval_map = _global_to_shard(eval_positions, shard_sizes)
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
console.print()
# ── 第 2 遍:子采样 ──
console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]")
train_count, eval_count = pass2_subsample(
input_dir,
output_dir,
"train",
quotas,
eval_map,
shard_sizes,
args.shard_size,
)
# ── 输出总结 ──
console.print()
console.print("[bold green]=== 完成 ===[/bold green]")
console.print(f"训练集: {train_count:,} 样本")
console.print(f"评估集: {eval_count:,} 样本")
for split in ["train", "eval"]:
sdir = output_dir / split
if sdir.exists():
total_size = sum(
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
)
meta_path = sdir / "metadata.json"
if meta_path.exists():
with open(meta_path) as mf:
meta = json.load(mf)
else:
meta = {}
console.print(
f" {split}/: {total_size / (1024**3):.2f} GB, "
f"{meta.get('num_shards', '?')} shards"
)
if __name__ == "__main__":
main()

View File

@ -1157,6 +1157,11 @@ def train(
"--compile/--no-compile",
help="是否开启 torch.compile 优化(需 PyTorch 2.0+",
),
moe_mode: str = typer.Option(
"all",
"--moe-mode",
help="MoE 计算策略: all全量计算, sparse稀疏计算, sparse_allow_graph稀疏+allow_in_graph",
),
):
"""
训练输入法模型
@ -1211,6 +1216,7 @@ def train(
config_table.add_row("模型", "MoE专家数", str(num_experts))
config_table.add_row("模型", "使用拼音", str(use_pinyin))
config_table.add_row("模型", "编译优化", str(compile))
config_table.add_row("模型", "MoE策略", moe_mode)
config_table.add_row("训练", "训练轮数", str(num_epochs))
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
@ -1331,6 +1337,7 @@ def train(
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
"moe_mode": moe_mode,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
"total_steps": total_steps,
@ -1353,6 +1360,7 @@ def train(
num_experts=num_experts,
max_seq_len=max_seq_len,
compile=compile,
moe_mode=moe_mode,
)
console.print(