feat(MoELayer): 添加 moe_mode 支持稀疏和图内计算策略
This commit is contained in:
parent
e8eab1f260
commit
432132a108
|
|
@ -341,15 +341,37 @@ class CrossAttentionFusion(nn.Module):
|
|||
# 4. 专家混合层 (MoE Layer)
|
||||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.compiler.allow_in_graph
|
||||
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
|
||||
output = torch.zeros_like(x_flat)
|
||||
for e in range(num_experts):
|
||||
mask = topk_indices == e
|
||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
|
||||
return output
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
|
||||
"""
|
||||
moe_mode 支持三种策略:
|
||||
- "all": 计算全部专家,torch.compile 不断裂 (当前默认)
|
||||
- "sparse": 只计算被路由到的专家 (产生 graph break)
|
||||
- "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.dim = dim
|
||||
self.moe_mode = moe_mode
|
||||
|
||||
# Import Expert from your existing components
|
||||
# Assuming Expert class is defined as in components.py [2]
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(
|
||||
|
|
@ -362,48 +384,41 @@ class MoELayer(nn.Module):
|
|||
]
|
||||
)
|
||||
|
||||
# Gating Network [2]
|
||||
self.gate = nn.Linear(dim, num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
||||
|
||||
Args:
|
||||
x: [batch, seq_len, dim]
|
||||
Returns:
|
||||
out: [batch, seq_len, dim]
|
||||
"""
|
||||
B, L, D = x.shape
|
||||
num_tokens = B * L
|
||||
x_flat = x.view(num_tokens, D)
|
||||
|
||||
# 展平输入以便处理
|
||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
||||
gates = self.gate(x_flat)
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
|
||||
topk_weights = F.softmax(topk_weights, dim=-1)
|
||||
|
||||
# 1. 计算门控分数
|
||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
||||
if self.moe_mode == "all":
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
)
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
|
||||
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
|
||||
out_flat = weighted_outputs.sum(dim=1)
|
||||
|
||||
# 2. 选择 Top-K 专家
|
||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
||||
elif self.moe_mode == "sparse":
|
||||
out_flat = torch.zeros_like(x_flat)
|
||||
for e in range(self.num_experts):
|
||||
mask = topk_indices == e
|
||||
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if idx.numel() > 0:
|
||||
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
|
||||
|
||||
# 归一化权重
|
||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
||||
elif self.moe_mode == "sparse_allow_graph":
|
||||
out_flat = _sparse_moe_dispatch(
|
||||
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
|
||||
)
|
||||
|
||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
||||
expert_outputs = torch.stack(
|
||||
[expert(x_flat) for expert in self.experts], dim=1
|
||||
) # [B*L, num_experts, D]
|
||||
else:
|
||||
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
|
||||
|
||||
# 4. 使用 gather 选择对应专家的输出
|
||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
||||
selected_outputs = torch.gather(
|
||||
expert_outputs, 1, indices_expanded
|
||||
) # [B*L, K, D]
|
||||
# 5. 加权求和
|
||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
||||
|
||||
# 恢复原始形状
|
||||
return out_flat.view(B, L, D)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@
|
|||
|
||||
功能:
|
||||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
||||
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
|
||||
2. 统计 history_slot_ids 的有效长度分布
|
||||
3. 抽样还原为人类可读文本,导出为 CSV 文件
|
||||
|
||||
用法:
|
||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
||||
|
|
@ -14,7 +15,6 @@
|
|||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -86,6 +86,36 @@ def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
|||
return counter, total
|
||||
|
||||
|
||||
def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||
"""统计 history_slot_ids 的有效长度分布(非零元素个数)"""
|
||||
logger.info("正在统计 history_slot_ids 长度分布...")
|
||||
counter = Counter()
|
||||
total = 0
|
||||
|
||||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||||
|
||||
pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard")
|
||||
|
||||
for shard_idx in pbar:
|
||||
if dataset._is_sharded:
|
||||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||||
history_slots = shard_data["history_slot_ids"].astype(np.int64)
|
||||
else:
|
||||
history_slots = dataset.history_slot_ids[:].astype(np.int64)
|
||||
|
||||
lengths = np.count_nonzero(history_slots, axis=1)
|
||||
unique, counts = np.unique(lengths, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
counter[int(uid)] += cnt
|
||||
total += len(history_slots)
|
||||
|
||||
if dataset._is_sharded:
|
||||
del shard_data
|
||||
|
||||
return counter, total
|
||||
|
||||
|
||||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||
"""将一个样本还原为人类可读格式"""
|
||||
input_ids = (
|
||||
|
|
@ -139,6 +169,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
|||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
||||
label_info = query_engine.query_by_id(labels)
|
||||
history_str = decode_history(history_ids, query_engine)
|
||||
history_slot_length = sum(1 for h in history_ids if h != 0)
|
||||
|
||||
return {
|
||||
"context": context_text,
|
||||
|
|
@ -150,6 +181,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
|||
else f"<ID:{labels}>",
|
||||
"label_count": label_info.count if label_info else 0,
|
||||
"history": history_str,
|
||||
"history_slot_length": history_slot_length,
|
||||
"full_tokens": token_text,
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +200,7 @@ def main():
|
|||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="随机抽样的样本数量(默认50)",
|
||||
help="抽样的样本数量(默认50,取前 N 个样本)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
|
|
@ -182,24 +214,14 @@ def main():
|
|||
default=0,
|
||||
help="统计 labels 时最多读取的分片数(0=全部)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="随机种子",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="显示出现次数最多和最少的标签数量",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.output is None:
|
||||
args.output = str(Path(args.data_dir) / "samples.csv")
|
||||
|
||||
|
|
@ -332,13 +354,56 @@ def main():
|
|||
)
|
||||
console.print(table_dist)
|
||||
|
||||
# ====== 2. 随机抽样还原 → CSV ======
|
||||
num_samples = min(args.num_samples, len(dataset))
|
||||
console.print(
|
||||
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||||
# ====== 2. 历史槽位长度分析 ======
|
||||
console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]")
|
||||
history_counter, history_total = analyze_history_slots(
|
||||
dataset, max_shards=args.max_shards
|
||||
)
|
||||
|
||||
indices = random.sample(range(len(dataset)), num_samples)
|
||||
if not history_counter:
|
||||
console.print("[yellow] 无历史槽位数据[/yellow]")
|
||||
else:
|
||||
sorted_items = sorted(history_counter.items())
|
||||
lengths_arr = [l for l, _ in sorted_items]
|
||||
counts_arr = [c for _, c in sorted_items]
|
||||
weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr))
|
||||
avg_length = weighted_sum / history_total if history_total > 0 else 0
|
||||
|
||||
cumsum = 0
|
||||
median_length = 0
|
||||
for length, count in sorted_items:
|
||||
cumsum += count
|
||||
if cumsum >= history_total / 2:
|
||||
median_length = length
|
||||
break
|
||||
|
||||
console.print(f"\n总样本数: {history_total:,}")
|
||||
console.print(f"最大历史槽位数: {max(lengths_arr)}")
|
||||
console.print(f"最小历史槽位数: {min(lengths_arr)}")
|
||||
console.print(f"平均历史槽位数: {avg_length:.2f}")
|
||||
console.print(f"中位数历史槽位数: {median_length}")
|
||||
|
||||
hist_table = Table(
|
||||
title="历史槽位有效长度分布",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
hist_table.add_column("槽位数", style="cyan", width=10)
|
||||
hist_table.add_column("样本数", style="green", width=12)
|
||||
hist_table.add_column("占比", style="yellow", width=10)
|
||||
|
||||
for length, count in sorted_items:
|
||||
pct = count / history_total * 100
|
||||
hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%")
|
||||
console.print(hist_table)
|
||||
|
||||
# ====== 3. 抽样还原 → CSV ======
|
||||
num_samples = min(args.num_samples, len(dataset))
|
||||
console.print(
|
||||
f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||||
)
|
||||
|
||||
indices = list(range(num_samples))
|
||||
|
||||
csv_path = Path(args.output)
|
||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -352,6 +417,7 @@ def main():
|
|||
"context",
|
||||
"suffix",
|
||||
"history",
|
||||
"history_slot_length",
|
||||
"full_tokens",
|
||||
]
|
||||
|
||||
|
|
@ -372,6 +438,7 @@ def main():
|
|||
decoded["context"],
|
||||
decoded["suffix"],
|
||||
decoded["history"],
|
||||
decoded["history_slot_length"],
|
||||
decoded["full_tokens"],
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python3
|
||||
"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"):
|
||||
with open(data_dir / "metadata.json") as f:
|
||||
meta = json.load(f)
|
||||
num_shards = meta["num_shards"]
|
||||
if num_shards <= 1:
|
||||
print(f"{data_dir}: already single shard, skip")
|
||||
return
|
||||
|
||||
merged = {}
|
||||
for field in FIELDS:
|
||||
pieces = []
|
||||
for i in tqdm(
|
||||
range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard"
|
||||
):
|
||||
data = np.load(data_dir / f"shard_{i:06d}.npz")
|
||||
pieces.append(data[field])
|
||||
data.close()
|
||||
merged[field] = np.concatenate(pieces, axis=0)
|
||||
del pieces
|
||||
gc.collect()
|
||||
|
||||
total = len(merged["labels"])
|
||||
# 先删原分片(避免与新文件名冲突)
|
||||
for i in range(num_shards):
|
||||
(data_dir / f"shard_{i:06d}.npz").unlink()
|
||||
# 再写合并文件
|
||||
np.savez_compressed(data_dir / output_name, **merged)
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
# 更新 metadata
|
||||
meta["num_shards"] = 1
|
||||
meta["shard_size"] = total
|
||||
with open(data_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件")
|
||||
parser.add_argument(
|
||||
"--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name", type=str, default="shard_000000.npz", help="合并后文件名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-only", action="store_true", help="仅合并 train,跳过 eval"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.input_dir)
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
split_dir = root / split
|
||||
if not split_dir.exists():
|
||||
print(f"{split_dir}: not found, skip")
|
||||
continue
|
||||
if args.train_only and split == "eval":
|
||||
continue
|
||||
merge_split(split_dir, args.output_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module):
|
|||
vocab_size: int = 10019,
|
||||
pinyin_vocab_size: int = 28,
|
||||
dim: int = 512,
|
||||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
||||
n_layers: int = 4, # Transformer 层数
|
||||
n_heads: int = 4, # 注意力头数
|
||||
num_experts: int = 10, # MoE 专家数量
|
||||
max_seq_len: int = 128, # 最大上下文长度
|
||||
compile: bool = False, # 是否开启 torch.compile 优化
|
||||
num_slots: int = 8,
|
||||
n_layers: int = 4,
|
||||
n_heads: int = 4,
|
||||
num_experts: int = 10,
|
||||
max_seq_len: int = 128,
|
||||
compile: bool = False,
|
||||
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module):
|
|||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||
|
||||
# 4. 混合专家层 (MoE)
|
||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
|
||||
self.moe = MoELayer(
|
||||
dim=dim,
|
||||
num_experts=num_experts,
|
||||
top_k=3,
|
||||
num_resblocks=12,
|
||||
moe_mode=moe_mode,
|
||||
)
|
||||
|
||||
# 5. 槽位注意力池化
|
||||
self.slot_attention = nn.Linear(dim, 1)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,429 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
子采样脚本:从预处理的 .npz 分片中抽取子集。
|
||||
|
||||
策略:
|
||||
1. 第1遍扫描:只读 labels 字段,统计每个 label ID 的样本数,记录分片大小
|
||||
2. 中间计算:按每 ID 最多 N 个样本封顶(硬封顶),不足目标总量则从超额池补足
|
||||
3. 第2遍扫描:读取全部字段,按精确保留配额抽取训练样本,同时抽取评估集
|
||||
|
||||
用法:
|
||||
python -m model.subsample \
|
||||
--input-dir ./preprocessed \
|
||||
--output-dir ./subsampled \
|
||||
--cap-per-label 300000 \
|
||||
--target-total 100000000 \
|
||||
--num-eval 2560
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from tqdm import tqdm
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _global_to_shard(
|
||||
positions: np.ndarray, shard_sizes: List[int]
|
||||
) -> Dict[int, List[int]]:
|
||||
"""将全局位置映射为 {shard_idx: [local_indices]}"""
|
||||
cumsum = np.cumsum([0] + shard_sizes)
|
||||
mapping: Dict[int, List[int]] = {}
|
||||
for pos in positions:
|
||||
shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1)
|
||||
local_idx = pos - cumsum[shard_idx]
|
||||
mapping.setdefault(shard_idx, []).append(local_idx)
|
||||
return mapping
|
||||
|
||||
|
||||
def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]:
|
||||
"""
|
||||
第1遍扫描:只读 labels,统计每个 label ID 的总样本数,记录各分片大小。
|
||||
返回 (label_counts, shard_sizes)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
label_counts: Dict[int, int] = {}
|
||||
shard_sizes: List[int] = []
|
||||
|
||||
pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard")
|
||||
|
||||
for shard_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
labels = data["labels"]
|
||||
n = len(labels)
|
||||
shard_sizes.append(n)
|
||||
|
||||
unique, counts = np.unique(labels, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
uid = int(uid)
|
||||
label_counts[uid] = label_counts.get(uid, 0) + int(cnt)
|
||||
|
||||
data.close()
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
return label_counts, shard_sizes
|
||||
|
||||
|
||||
def compute_quotas(
|
||||
label_counts: Dict[int, int],
|
||||
cap: int = 300_000,
|
||||
target_total: int = 100_000_000,
|
||||
) -> Dict[int, int]:
|
||||
"""
|
||||
计算每个 label 的精确保留配额(硬封顶)。
|
||||
|
||||
策略:
|
||||
- 每 ID 保留 min(count, cap)
|
||||
- 若封顶后总量 >= target_total,直接用封顶策略
|
||||
- 若封顶后总量 < target_total,从超额池(count > cap 的部分)等比抽取补足
|
||||
"""
|
||||
capped_total = sum(min(cnt, cap) for cnt in label_counts.values())
|
||||
|
||||
if capped_total >= target_total:
|
||||
return {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||
|
||||
# 封顶不足,从超额池补足
|
||||
deficit = target_total - capped_total
|
||||
excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()}
|
||||
excess_total = sum(excess_per_label.values())
|
||||
|
||||
quotas: Dict[int, int] = {}
|
||||
if excess_total > 0:
|
||||
ratio = min(1.0, deficit / excess_total)
|
||||
for label, cnt in label_counts.items():
|
||||
base = min(cnt, cap)
|
||||
extra = int(excess_per_label[label] * ratio)
|
||||
quotas[label] = base + extra
|
||||
else:
|
||||
quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||
|
||||
return quotas
|
||||
|
||||
|
||||
def pass2_subsample(
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
split: str,
|
||||
quotas: Dict[int, int],
|
||||
eval_map: Dict[int, List[int]],
|
||||
shard_sizes: List[int],
|
||||
shard_size: int = 5_000_000,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
||||
|
||||
quotas: {label_id: exact_number_to_keep}
|
||||
返回 (train_kept, eval_kept)。
|
||||
"""
|
||||
metadata_path = input_dir / split / "metadata.json"
|
||||
with open(metadata_path) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
num_shards = metadata["num_shards"]
|
||||
max_seq_length = metadata["max_seq_length"]
|
||||
|
||||
output_train_dir = output_dir / "train"
|
||||
output_eval_dir = output_dir / "eval"
|
||||
output_train_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rng = np.random.RandomState()
|
||||
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
||||
|
||||
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
train_buf_count = 0
|
||||
train_shard_idx = 0
|
||||
total_train_kept = 0
|
||||
|
||||
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
total_eval_kept = 0
|
||||
|
||||
pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard")
|
||||
|
||||
for src_shard_idx in range(num_shards):
|
||||
shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz"
|
||||
data = np.load(shard_path)
|
||||
n = shard_sizes[src_shard_idx]
|
||||
|
||||
eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64)
|
||||
|
||||
is_eval = np.zeros(n, dtype=bool)
|
||||
if len(eval_local) > 0:
|
||||
is_eval[eval_local] = True
|
||||
|
||||
# ── 第一步:只加载 labels,计算掩码 ──
|
||||
labels = data["labels"]
|
||||
|
||||
# 抽取评估样本的 labels
|
||||
if len(eval_local) > 0:
|
||||
eval_buffers["labels"].append(labels[eval_local].copy())
|
||||
|
||||
# 计算训练保留位置
|
||||
train_candidates = labels[~is_eval]
|
||||
train_n = len(train_candidates)
|
||||
n_kept = 0
|
||||
keep_original_indices = np.array([], dtype=np.int64)
|
||||
|
||||
if train_n > 0:
|
||||
sort_idx = np.argsort(train_candidates, kind="stable")
|
||||
sorted_labels = train_candidates[sort_idx]
|
||||
unique_vals, starts = np.unique(sorted_labels, return_index=True)
|
||||
ends = np.append(starts[1:], train_n)
|
||||
|
||||
train_keep_mask = np.zeros(train_n, dtype=bool)
|
||||
for i in range(len(unique_vals)):
|
||||
label_val = int(unique_vals[i])
|
||||
start, end = starts[i], ends[i]
|
||||
cnt_in_shard = end - start
|
||||
need = remaining.get(label_val, 0)
|
||||
select = min(cnt_in_shard, need)
|
||||
if select >= cnt_in_shard:
|
||||
train_keep_mask[sort_idx[start:end]] = True
|
||||
elif select > 0:
|
||||
chosen = rng.choice(cnt_in_shard, size=select, replace=False)
|
||||
train_keep_mask[sort_idx[start + chosen]] = True
|
||||
remaining[label_val] = max(0, remaining.get(label_val, 0) - select)
|
||||
|
||||
original_train_indices = np.where(~is_eval)[0]
|
||||
keep_original_indices = original_train_indices[train_keep_mask]
|
||||
n_kept = len(keep_original_indices)
|
||||
|
||||
# 释放标签相关的临时数组
|
||||
del train_candidates, sort_idx, sorted_labels
|
||||
del unique_vals, starts, ends, train_keep_mask
|
||||
del original_train_indices
|
||||
gc.collect()
|
||||
|
||||
del labels, is_eval
|
||||
gc.collect()
|
||||
|
||||
# ── 第二步:逐字段加载,抽取评估 + 训练 ──
|
||||
# 评估:跳过 labels(已在第一步抽取)
|
||||
if len(eval_local) > 0:
|
||||
for f in FIELDS:
|
||||
if f == "labels":
|
||||
continue
|
||||
arr = data[f]
|
||||
eval_buffers[f].append(arr[eval_local].copy())
|
||||
del arr
|
||||
gc.collect()
|
||||
total_eval_kept += len(eval_local)
|
||||
|
||||
# 训练:逐字段加载,立即删除
|
||||
if n_kept > 0:
|
||||
for f in FIELDS:
|
||||
arr = data[f]
|
||||
train_buffers[f].append(arr[keep_original_indices].copy())
|
||||
del arr
|
||||
gc.collect()
|
||||
train_buf_count += n_kept
|
||||
total_train_kept += n_kept
|
||||
|
||||
del keep_original_indices
|
||||
data.close()
|
||||
gc.collect()
|
||||
|
||||
if train_buf_count >= shard_size:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
logger.debug(
|
||||
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
||||
)
|
||||
train_shard_idx += 1
|
||||
train_buffers = {f: [] for f in FIELDS}
|
||||
train_buf_count = 0
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# 剩余缓冲
|
||||
if train_buf_count > 0:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||
np.savez_compressed(
|
||||
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||
)
|
||||
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
||||
train_shard_idx += 1
|
||||
|
||||
# 评估数据
|
||||
if total_eval_kept > 0:
|
||||
eval_merged = {}
|
||||
for f in FIELDS:
|
||||
eval_merged[f] = np.concatenate(eval_buffers[f], axis=0)
|
||||
np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged)
|
||||
else:
|
||||
logger.warning("No eval samples extracted!")
|
||||
|
||||
# 写 metadata
|
||||
train_metadata = {
|
||||
"num_samples": total_train_kept,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": train_shard_idx,
|
||||
}
|
||||
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
eval_metadata = {
|
||||
"num_samples": total_eval_kept,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": total_eval_kept,
|
||||
"num_shards": 1,
|
||||
}
|
||||
with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(eval_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return total_train_kept, total_eval_kept
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集")
|
||||
parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录")
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录")
|
||||
parser.add_argument(
|
||||
"--cap-per-label",
|
||||
type=int,
|
||||
default=300_000,
|
||||
help="每个 label ID 最大保留样本数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-total",
|
||||
type=int,
|
||||
default=100_000_000,
|
||||
help="目标训练集样本总数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="输出分片大小(样本数)",
|
||||
)
|
||||
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
||||
|
||||
args = parser.parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]")
|
||||
console.print(f"输入目录: {input_dir}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
||||
console.print(f"目标训练集: {args.target_total:,}")
|
||||
console.print(f"评估集: {args.num_eval}")
|
||||
console.print()
|
||||
|
||||
# ── 第 1 遍:统计 ──
|
||||
console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]")
|
||||
label_counts, shard_sizes = pass1_count(input_dir, "train")
|
||||
|
||||
total_samples = sum(shard_sizes)
|
||||
num_labels = len(label_counts)
|
||||
capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values())
|
||||
|
||||
console.print(f"分片数: {len(shard_sizes)}")
|
||||
console.print(f"总样本数: {total_samples:,}")
|
||||
console.print(f"标签种类: {num_labels}")
|
||||
console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}")
|
||||
console.print()
|
||||
|
||||
# ── 计算保留配额 ──
|
||||
quotas = compute_quotas(
|
||||
label_counts,
|
||||
cap=args.cap_per_label,
|
||||
target_total=args.target_total,
|
||||
)
|
||||
expected_train = sum(quotas.values())
|
||||
console.print(f"期望训练集大小: {expected_train:,}")
|
||||
|
||||
# 标签分布统计
|
||||
stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label)
|
||||
stats_n_below = num_labels - stats_n_capped
|
||||
stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label)
|
||||
console.print(
|
||||
f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, "
|
||||
f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}"
|
||||
)
|
||||
console.print()
|
||||
|
||||
# ── 抽取评估集位置 ──
|
||||
rng = np.random.RandomState()
|
||||
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||
eval_positions.sort()
|
||||
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
||||
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
||||
console.print()
|
||||
|
||||
# ── 第 2 遍:子采样 ──
|
||||
console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]")
|
||||
train_count, eval_count = pass2_subsample(
|
||||
input_dir,
|
||||
output_dir,
|
||||
"train",
|
||||
quotas,
|
||||
eval_map,
|
||||
shard_sizes,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
# ── 输出总结 ──
|
||||
console.print()
|
||||
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||
console.print(f"训练集: {train_count:,} 样本")
|
||||
console.print(f"评估集: {eval_count:,} 样本")
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
sdir = output_dir / split
|
||||
if sdir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
meta_path = sdir / "metadata.json"
|
||||
if meta_path.exists():
|
||||
with open(meta_path) as mf:
|
||||
meta = json.load(mf)
|
||||
else:
|
||||
meta = {}
|
||||
console.print(
|
||||
f" {split}/: {total_size / (1024**3):.2f} GB, "
|
||||
f"{meta.get('num_shards', '?')} shards"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1157,6 +1157,11 @@ def train(
|
|||
"--compile/--no-compile",
|
||||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
||||
),
|
||||
moe_mode: str = typer.Option(
|
||||
"all",
|
||||
"--moe-mode",
|
||||
help="MoE 计算策略: all(全量计算), sparse(稀疏计算), sparse_allow_graph(稀疏+allow_in_graph)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
训练输入法模型
|
||||
|
|
@ -1211,6 +1216,7 @@ def train(
|
|||
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
||||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||||
config_table.add_row("模型", "编译优化", str(compile))
|
||||
config_table.add_row("模型", "MoE策略", moe_mode)
|
||||
|
||||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||||
|
|
@ -1331,6 +1337,7 @@ def train(
|
|||
"auto_resume": auto_resume,
|
||||
"max_iter_length": max_iter_length,
|
||||
"compile": compile,
|
||||
"moe_mode": moe_mode,
|
||||
"is_train_preprocessed": is_train_preprocessed,
|
||||
"is_eval_preprocessed": is_eval_preprocessed,
|
||||
"total_steps": total_steps,
|
||||
|
|
@ -1353,6 +1360,7 @@ def train(
|
|||
num_experts=num_experts,
|
||||
max_seq_len=max_seq_len,
|
||||
compile=compile,
|
||||
moe_mode=moe_mode,
|
||||
)
|
||||
|
||||
console.print(
|
||||
|
|
|
|||
Loading…
Reference in New Issue