feat(MoELayer): 添加 moe_mode 支持稀疏和图内计算策略
This commit is contained in:
parent
e8eab1f260
commit
432132a108
|
|
@ -341,15 +341,37 @@ class CrossAttentionFusion(nn.Module):
|
||||||
# 4. 专家混合层 (MoE Layer)
|
# 4. 专家混合层 (MoE Layer)
|
||||||
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
# 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compiler.allow_in_graph
|
||||||
|
def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts):
|
||||||
|
output = torch.zeros_like(x_flat)
|
||||||
|
for e in range(num_experts):
|
||||||
|
mask = topk_indices == e
|
||||||
|
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
|
if idx.numel() > 0:
|
||||||
|
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||||
|
output.index_add_(0, idx, w * experts[e](x_flat[idx]))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MoELayer(nn.Module):
|
class MoELayer(nn.Module):
|
||||||
def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8):
|
"""
|
||||||
|
moe_mode 支持三种策略:
|
||||||
|
- "all": 计算全部专家,torch.compile 不断裂 (当前默认)
|
||||||
|
- "sparse": 只计算被路由到的专家 (产生 graph break)
|
||||||
|
- "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all"
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
self.moe_mode = moe_mode
|
||||||
|
|
||||||
# Import Expert from your existing components
|
|
||||||
# Assuming Expert class is defined as in components.py [2]
|
|
||||||
self.experts = nn.ModuleList(
|
self.experts = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Expert(
|
Expert(
|
||||||
|
|
@ -362,48 +384,41 @@ class MoELayer(nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gating Network [2]
|
|
||||||
self.gate = nn.Linear(dim, num_experts)
|
self.gate = nn.Linear(dim, num_experts)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
|
||||||
并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: [batch, seq_len, dim]
|
|
||||||
Returns:
|
|
||||||
out: [batch, seq_len, dim]
|
|
||||||
"""
|
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
num_tokens = B * L
|
num_tokens = B * L
|
||||||
|
x_flat = x.view(num_tokens, D)
|
||||||
|
|
||||||
# 展平输入以便处理
|
gates = self.gate(x_flat)
|
||||||
x_flat = x.view(num_tokens, D) # [B*L, D]
|
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1)
|
||||||
|
topk_weights = F.softmax(topk_weights, dim=-1)
|
||||||
|
|
||||||
# 1. 计算门控分数
|
if self.moe_mode == "all":
|
||||||
gates = self.gate(x_flat) # [B*L, num_experts]
|
expert_outputs = torch.stack(
|
||||||
|
[expert(x_flat) for expert in self.experts], dim=1
|
||||||
|
)
|
||||||
|
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D)
|
||||||
|
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
|
||||||
|
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1)
|
||||||
|
out_flat = weighted_outputs.sum(dim=1)
|
||||||
|
|
||||||
# 2. 选择 Top-K 专家
|
elif self.moe_mode == "sparse":
|
||||||
topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K]
|
out_flat = torch.zeros_like(x_flat)
|
||||||
|
for e in range(self.num_experts):
|
||||||
|
mask = topk_indices == e
|
||||||
|
idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
|
if idx.numel() > 0:
|
||||||
|
w = topk_weights[idx, k_idx].unsqueeze(-1)
|
||||||
|
out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx]))
|
||||||
|
|
||||||
# 归一化权重
|
elif self.moe_mode == "sparse_allow_graph":
|
||||||
topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K]
|
out_flat = _sparse_moe_dispatch(
|
||||||
|
x_flat, self.experts, topk_indices, topk_weights, self.num_experts
|
||||||
|
)
|
||||||
|
|
||||||
# 3. 并行计算所有专家(消除 Python 循环中的动态控制流)
|
else:
|
||||||
# torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量
|
raise ValueError(f"Unknown moe_mode: {self.moe_mode}")
|
||||||
expert_outputs = torch.stack(
|
|
||||||
[expert(x_flat) for expert in self.experts], dim=1
|
|
||||||
) # [B*L, num_experts, D]
|
|
||||||
|
|
||||||
# 4. 使用 gather 选择对应专家的输出
|
|
||||||
# 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D]
|
|
||||||
indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D]
|
|
||||||
selected_outputs = torch.gather(
|
|
||||||
expert_outputs, 1, indices_expanded
|
|
||||||
) # [B*L, K, D]
|
|
||||||
# 5. 加权求和
|
|
||||||
weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D]
|
|
||||||
out_flat = weighted_outputs.sum(dim=1) # [B*L, D]
|
|
||||||
|
|
||||||
# 恢复原始形状
|
|
||||||
return out_flat.view(B, L, D)
|
return out_flat.view(B, L, D)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@
|
||||||
|
|
||||||
功能:
|
功能:
|
||||||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
||||||
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
|
2. 统计 history_slot_ids 的有效长度分布
|
||||||
|
3. 抽样还原为人类可读文本,导出为 CSV 文件
|
||||||
|
|
||||||
用法:
|
用法:
|
||||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
||||||
|
|
@ -14,7 +15,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -86,6 +86,36 @@ def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||||
return counter, total
|
return counter, total
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||||
|
"""统计 history_slot_ids 的有效长度分布(非零元素个数)"""
|
||||||
|
logger.info("正在统计 history_slot_ids 长度分布...")
|
||||||
|
counter = Counter()
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||||||
|
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||||||
|
|
||||||
|
pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard")
|
||||||
|
|
||||||
|
for shard_idx in pbar:
|
||||||
|
if dataset._is_sharded:
|
||||||
|
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||||||
|
history_slots = shard_data["history_slot_ids"].astype(np.int64)
|
||||||
|
else:
|
||||||
|
history_slots = dataset.history_slot_ids[:].astype(np.int64)
|
||||||
|
|
||||||
|
lengths = np.count_nonzero(history_slots, axis=1)
|
||||||
|
unique, counts = np.unique(lengths, return_counts=True)
|
||||||
|
for uid, cnt in zip(unique, counts):
|
||||||
|
counter[int(uid)] += cnt
|
||||||
|
total += len(history_slots)
|
||||||
|
|
||||||
|
if dataset._is_sharded:
|
||||||
|
del shard_data
|
||||||
|
|
||||||
|
return counter, total
|
||||||
|
|
||||||
|
|
||||||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||||
"""将一个样本还原为人类可读格式"""
|
"""将一个样本还原为人类可读格式"""
|
||||||
input_ids = (
|
input_ids = (
|
||||||
|
|
@ -139,6 +169,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
||||||
label_info = query_engine.query_by_id(labels)
|
label_info = query_engine.query_by_id(labels)
|
||||||
history_str = decode_history(history_ids, query_engine)
|
history_str = decode_history(history_ids, query_engine)
|
||||||
|
history_slot_length = sum(1 for h in history_ids if h != 0)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"context": context_text,
|
"context": context_text,
|
||||||
|
|
@ -150,6 +181,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||||
else f"<ID:{labels}>",
|
else f"<ID:{labels}>",
|
||||||
"label_count": label_info.count if label_info else 0,
|
"label_count": label_info.count if label_info else 0,
|
||||||
"history": history_str,
|
"history": history_str,
|
||||||
|
"history_slot_length": history_slot_length,
|
||||||
"full_tokens": token_text,
|
"full_tokens": token_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -168,7 +200,7 @@ def main():
|
||||||
"--num-samples",
|
"--num-samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="随机抽样的样本数量(默认50)",
|
help="抽样的样本数量(默认50,取前 N 个样本)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
|
|
@ -182,24 +214,14 @@ def main():
|
||||||
default=0,
|
default=0,
|
||||||
help="统计 labels 时最多读取的分片数(0=全部)",
|
help="统计 labels 时最多读取的分片数(0=全部)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="随机种子",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-k",
|
"--top-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=30,
|
||||||
help="显示出现次数最多和最少的标签数量",
|
help="显示出现次数最多和最少的标签数量",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
if args.output is None:
|
if args.output is None:
|
||||||
args.output = str(Path(args.data_dir) / "samples.csv")
|
args.output = str(Path(args.data_dir) / "samples.csv")
|
||||||
|
|
||||||
|
|
@ -332,13 +354,56 @@ def main():
|
||||||
)
|
)
|
||||||
console.print(table_dist)
|
console.print(table_dist)
|
||||||
|
|
||||||
# ====== 2. 随机抽样还原 → CSV ======
|
# ====== 2. 历史槽位长度分析 ======
|
||||||
num_samples = min(args.num_samples, len(dataset))
|
console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]")
|
||||||
console.print(
|
history_counter, history_total = analyze_history_slots(
|
||||||
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
dataset, max_shards=args.max_shards
|
||||||
)
|
)
|
||||||
|
|
||||||
indices = random.sample(range(len(dataset)), num_samples)
|
if not history_counter:
|
||||||
|
console.print("[yellow] 无历史槽位数据[/yellow]")
|
||||||
|
else:
|
||||||
|
sorted_items = sorted(history_counter.items())
|
||||||
|
lengths_arr = [l for l, _ in sorted_items]
|
||||||
|
counts_arr = [c for _, c in sorted_items]
|
||||||
|
weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr))
|
||||||
|
avg_length = weighted_sum / history_total if history_total > 0 else 0
|
||||||
|
|
||||||
|
cumsum = 0
|
||||||
|
median_length = 0
|
||||||
|
for length, count in sorted_items:
|
||||||
|
cumsum += count
|
||||||
|
if cumsum >= history_total / 2:
|
||||||
|
median_length = length
|
||||||
|
break
|
||||||
|
|
||||||
|
console.print(f"\n总样本数: {history_total:,}")
|
||||||
|
console.print(f"最大历史槽位数: {max(lengths_arr)}")
|
||||||
|
console.print(f"最小历史槽位数: {min(lengths_arr)}")
|
||||||
|
console.print(f"平均历史槽位数: {avg_length:.2f}")
|
||||||
|
console.print(f"中位数历史槽位数: {median_length}")
|
||||||
|
|
||||||
|
hist_table = Table(
|
||||||
|
title="历史槽位有效长度分布",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
hist_table.add_column("槽位数", style="cyan", width=10)
|
||||||
|
hist_table.add_column("样本数", style="green", width=12)
|
||||||
|
hist_table.add_column("占比", style="yellow", width=10)
|
||||||
|
|
||||||
|
for length, count in sorted_items:
|
||||||
|
pct = count / history_total * 100
|
||||||
|
hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%")
|
||||||
|
console.print(hist_table)
|
||||||
|
|
||||||
|
# ====== 3. 抽样还原 → CSV ======
|
||||||
|
num_samples = min(args.num_samples, len(dataset))
|
||||||
|
console.print(
|
||||||
|
f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
|
indices = list(range(num_samples))
|
||||||
|
|
||||||
csv_path = Path(args.output)
|
csv_path = Path(args.output)
|
||||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -352,6 +417,7 @@ def main():
|
||||||
"context",
|
"context",
|
||||||
"suffix",
|
"suffix",
|
||||||
"history",
|
"history",
|
||||||
|
"history_slot_length",
|
||||||
"full_tokens",
|
"full_tokens",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -372,6 +438,7 @@ def main():
|
||||||
decoded["context"],
|
decoded["context"],
|
||||||
decoded["suffix"],
|
decoded["suffix"],
|
||||||
decoded["history"],
|
decoded["history"],
|
||||||
|
decoded["history_slot_length"],
|
||||||
decoded["full_tokens"],
|
decoded["full_tokens"],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
FIELDS = [
|
||||||
|
"input_ids",
|
||||||
|
"token_type_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"labels",
|
||||||
|
"history_slot_ids",
|
||||||
|
"pinyin_ids",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"):
|
||||||
|
with open(data_dir / "metadata.json") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
num_shards = meta["num_shards"]
|
||||||
|
if num_shards <= 1:
|
||||||
|
print(f"{data_dir}: already single shard, skip")
|
||||||
|
return
|
||||||
|
|
||||||
|
merged = {}
|
||||||
|
for field in FIELDS:
|
||||||
|
pieces = []
|
||||||
|
for i in tqdm(
|
||||||
|
range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard"
|
||||||
|
):
|
||||||
|
data = np.load(data_dir / f"shard_{i:06d}.npz")
|
||||||
|
pieces.append(data[field])
|
||||||
|
data.close()
|
||||||
|
merged[field] = np.concatenate(pieces, axis=0)
|
||||||
|
del pieces
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
total = len(merged["labels"])
|
||||||
|
# 先删原分片(避免与新文件名冲突)
|
||||||
|
for i in range(num_shards):
|
||||||
|
(data_dir / f"shard_{i:06d}.npz").unlink()
|
||||||
|
# 再写合并文件
|
||||||
|
np.savez_compressed(data_dir / output_name, **merged)
|
||||||
|
del merged
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# 更新 metadata
|
||||||
|
meta["num_shards"] = 1
|
||||||
|
meta["shard_size"] = total
|
||||||
|
with open(data_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-name", type=str, default="shard_000000.npz", help="合并后文件名"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-only", action="store_true", help="仅合并 train,跳过 eval"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
root = Path(args.input_dir)
|
||||||
|
|
||||||
|
for split in ["train", "eval"]:
|
||||||
|
split_dir = root / split
|
||||||
|
if not split_dir.exists():
|
||||||
|
print(f"{split_dir}: not found, skip")
|
||||||
|
continue
|
||||||
|
if args.train_only and split == "eval":
|
||||||
|
continue
|
||||||
|
merge_split(split_dir, args.output_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module):
|
||||||
vocab_size: int = 10019,
|
vocab_size: int = 10019,
|
||||||
pinyin_vocab_size: int = 28,
|
pinyin_vocab_size: int = 28,
|
||||||
dim: int = 512,
|
dim: int = 512,
|
||||||
num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位)
|
num_slots: int = 8,
|
||||||
n_layers: int = 4, # Transformer 层数
|
n_layers: int = 4,
|
||||||
n_heads: int = 4, # 注意力头数
|
n_heads: int = 4,
|
||||||
num_experts: int = 10, # MoE 专家数量
|
num_experts: int = 10,
|
||||||
max_seq_len: int = 128, # 最大上下文长度
|
max_seq_len: int = 128,
|
||||||
compile: bool = False, # 是否开启 torch.compile 优化
|
compile: bool = False,
|
||||||
|
moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module):
|
||||||
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads)
|
||||||
|
|
||||||
# 4. 混合专家层 (MoE)
|
# 4. 混合专家层 (MoE)
|
||||||
self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12)
|
self.moe = MoELayer(
|
||||||
|
dim=dim,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=3,
|
||||||
|
num_resblocks=12,
|
||||||
|
moe_mode=moe_mode,
|
||||||
|
)
|
||||||
|
|
||||||
# 5. 槽位注意力池化
|
# 5. 槽位注意力池化
|
||||||
self.slot_attention = nn.Linear(dim, 1)
|
self.slot_attention = nn.Linear(dim, 1)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,429 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
子采样脚本:从预处理的 .npz 分片中抽取子集。
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 第1遍扫描:只读 labels 字段,统计每个 label ID 的样本数,记录分片大小
|
||||||
|
2. 中间计算:按每 ID 最多 N 个样本封顶(硬封顶),不足目标总量则从超额池补足
|
||||||
|
3. 第2遍扫描:读取全部字段,按精确保留配额抽取训练样本,同时抽取评估集
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python -m model.subsample \
|
||||||
|
--input-dir ./preprocessed \
|
||||||
|
--output-dir ./subsampled \
|
||||||
|
--cap-per-label 300000 \
|
||||||
|
--target-total 100000000 \
|
||||||
|
--num-eval 2560
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
from rich.console import Console
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
FIELDS = [
|
||||||
|
"input_ids",
|
||||||
|
"token_type_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"labels",
|
||||||
|
"history_slot_ids",
|
||||||
|
"pinyin_ids",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _global_to_shard(
|
||||||
|
positions: np.ndarray, shard_sizes: List[int]
|
||||||
|
) -> Dict[int, List[int]]:
|
||||||
|
"""将全局位置映射为 {shard_idx: [local_indices]}"""
|
||||||
|
cumsum = np.cumsum([0] + shard_sizes)
|
||||||
|
mapping: Dict[int, List[int]] = {}
|
||||||
|
for pos in positions:
|
||||||
|
shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1)
|
||||||
|
local_idx = pos - cumsum[shard_idx]
|
||||||
|
mapping.setdefault(shard_idx, []).append(local_idx)
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]:
|
||||||
|
"""
|
||||||
|
第1遍扫描:只读 labels,统计每个 label ID 的总样本数,记录各分片大小。
|
||||||
|
返回 (label_counts, shard_sizes)。
|
||||||
|
"""
|
||||||
|
metadata_path = input_dir / split / "metadata.json"
|
||||||
|
with open(metadata_path) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
num_shards = metadata["num_shards"]
|
||||||
|
label_counts: Dict[int, int] = {}
|
||||||
|
shard_sizes: List[int] = []
|
||||||
|
|
||||||
|
pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard")
|
||||||
|
|
||||||
|
for shard_idx in range(num_shards):
|
||||||
|
shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz"
|
||||||
|
data = np.load(shard_path)
|
||||||
|
labels = data["labels"]
|
||||||
|
n = len(labels)
|
||||||
|
shard_sizes.append(n)
|
||||||
|
|
||||||
|
unique, counts = np.unique(labels, return_counts=True)
|
||||||
|
for uid, cnt in zip(unique, counts):
|
||||||
|
uid = int(uid)
|
||||||
|
label_counts[uid] = label_counts.get(uid, 0) + int(cnt)
|
||||||
|
|
||||||
|
data.close()
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
return label_counts, shard_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def compute_quotas(
|
||||||
|
label_counts: Dict[int, int],
|
||||||
|
cap: int = 300_000,
|
||||||
|
target_total: int = 100_000_000,
|
||||||
|
) -> Dict[int, int]:
|
||||||
|
"""
|
||||||
|
计算每个 label 的精确保留配额(硬封顶)。
|
||||||
|
|
||||||
|
策略:
|
||||||
|
- 每 ID 保留 min(count, cap)
|
||||||
|
- 若封顶后总量 >= target_total,直接用封顶策略
|
||||||
|
- 若封顶后总量 < target_total,从超额池(count > cap 的部分)等比抽取补足
|
||||||
|
"""
|
||||||
|
capped_total = sum(min(cnt, cap) for cnt in label_counts.values())
|
||||||
|
|
||||||
|
if capped_total >= target_total:
|
||||||
|
return {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||||
|
|
||||||
|
# 封顶不足,从超额池补足
|
||||||
|
deficit = target_total - capped_total
|
||||||
|
excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()}
|
||||||
|
excess_total = sum(excess_per_label.values())
|
||||||
|
|
||||||
|
quotas: Dict[int, int] = {}
|
||||||
|
if excess_total > 0:
|
||||||
|
ratio = min(1.0, deficit / excess_total)
|
||||||
|
for label, cnt in label_counts.items():
|
||||||
|
base = min(cnt, cap)
|
||||||
|
extra = int(excess_per_label[label] * ratio)
|
||||||
|
quotas[label] = base + extra
|
||||||
|
else:
|
||||||
|
quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()}
|
||||||
|
|
||||||
|
return quotas
|
||||||
|
|
||||||
|
|
||||||
|
def pass2_subsample(
|
||||||
|
input_dir: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
split: str,
|
||||||
|
quotas: Dict[int, int],
|
||||||
|
eval_map: Dict[int, List[int]],
|
||||||
|
shard_sizes: List[int],
|
||||||
|
shard_size: int = 5_000_000,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。
|
||||||
|
|
||||||
|
quotas: {label_id: exact_number_to_keep}
|
||||||
|
返回 (train_kept, eval_kept)。
|
||||||
|
"""
|
||||||
|
metadata_path = input_dir / split / "metadata.json"
|
||||||
|
with open(metadata_path) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
num_shards = metadata["num_shards"]
|
||||||
|
max_seq_length = metadata["max_seq_length"]
|
||||||
|
|
||||||
|
output_train_dir = output_dir / "train"
|
||||||
|
output_eval_dir = output_dir / "eval"
|
||||||
|
output_train_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_eval_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
rng = np.random.RandomState()
|
||||||
|
remaining = dict(quotas) # {label_id: remaining_to_keep}
|
||||||
|
|
||||||
|
train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||||
|
train_buf_count = 0
|
||||||
|
train_shard_idx = 0
|
||||||
|
total_train_kept = 0
|
||||||
|
|
||||||
|
eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||||
|
total_eval_kept = 0
|
||||||
|
|
||||||
|
pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard")
|
||||||
|
|
||||||
|
for src_shard_idx in range(num_shards):
|
||||||
|
shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz"
|
||||||
|
data = np.load(shard_path)
|
||||||
|
n = shard_sizes[src_shard_idx]
|
||||||
|
|
||||||
|
eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64)
|
||||||
|
|
||||||
|
is_eval = np.zeros(n, dtype=bool)
|
||||||
|
if len(eval_local) > 0:
|
||||||
|
is_eval[eval_local] = True
|
||||||
|
|
||||||
|
# ── 第一步:只加载 labels,计算掩码 ──
|
||||||
|
labels = data["labels"]
|
||||||
|
|
||||||
|
# 抽取评估样本的 labels
|
||||||
|
if len(eval_local) > 0:
|
||||||
|
eval_buffers["labels"].append(labels[eval_local].copy())
|
||||||
|
|
||||||
|
# 计算训练保留位置
|
||||||
|
train_candidates = labels[~is_eval]
|
||||||
|
train_n = len(train_candidates)
|
||||||
|
n_kept = 0
|
||||||
|
keep_original_indices = np.array([], dtype=np.int64)
|
||||||
|
|
||||||
|
if train_n > 0:
|
||||||
|
sort_idx = np.argsort(train_candidates, kind="stable")
|
||||||
|
sorted_labels = train_candidates[sort_idx]
|
||||||
|
unique_vals, starts = np.unique(sorted_labels, return_index=True)
|
||||||
|
ends = np.append(starts[1:], train_n)
|
||||||
|
|
||||||
|
train_keep_mask = np.zeros(train_n, dtype=bool)
|
||||||
|
for i in range(len(unique_vals)):
|
||||||
|
label_val = int(unique_vals[i])
|
||||||
|
start, end = starts[i], ends[i]
|
||||||
|
cnt_in_shard = end - start
|
||||||
|
need = remaining.get(label_val, 0)
|
||||||
|
select = min(cnt_in_shard, need)
|
||||||
|
if select >= cnt_in_shard:
|
||||||
|
train_keep_mask[sort_idx[start:end]] = True
|
||||||
|
elif select > 0:
|
||||||
|
chosen = rng.choice(cnt_in_shard, size=select, replace=False)
|
||||||
|
train_keep_mask[sort_idx[start + chosen]] = True
|
||||||
|
remaining[label_val] = max(0, remaining.get(label_val, 0) - select)
|
||||||
|
|
||||||
|
original_train_indices = np.where(~is_eval)[0]
|
||||||
|
keep_original_indices = original_train_indices[train_keep_mask]
|
||||||
|
n_kept = len(keep_original_indices)
|
||||||
|
|
||||||
|
# 释放标签相关的临时数组
|
||||||
|
del train_candidates, sort_idx, sorted_labels
|
||||||
|
del unique_vals, starts, ends, train_keep_mask
|
||||||
|
del original_train_indices
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
del labels, is_eval
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# ── 第二步:逐字段加载,抽取评估 + 训练 ──
|
||||||
|
# 评估:跳过 labels(已在第一步抽取)
|
||||||
|
if len(eval_local) > 0:
|
||||||
|
for f in FIELDS:
|
||||||
|
if f == "labels":
|
||||||
|
continue
|
||||||
|
arr = data[f]
|
||||||
|
eval_buffers[f].append(arr[eval_local].copy())
|
||||||
|
del arr
|
||||||
|
gc.collect()
|
||||||
|
total_eval_kept += len(eval_local)
|
||||||
|
|
||||||
|
# 训练:逐字段加载,立即删除
|
||||||
|
if n_kept > 0:
|
||||||
|
for f in FIELDS:
|
||||||
|
arr = data[f]
|
||||||
|
train_buffers[f].append(arr[keep_original_indices].copy())
|
||||||
|
del arr
|
||||||
|
gc.collect()
|
||||||
|
train_buf_count += n_kept
|
||||||
|
total_train_kept += n_kept
|
||||||
|
|
||||||
|
del keep_original_indices
|
||||||
|
data.close()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if train_buf_count >= shard_size:
|
||||||
|
merged = {}
|
||||||
|
for f in FIELDS:
|
||||||
|
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||||
|
np.savez_compressed(
|
||||||
|
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Saved train shard {train_shard_idx}: {train_buf_count} samples"
|
||||||
|
)
|
||||||
|
train_shard_idx += 1
|
||||||
|
train_buffers = {f: [] for f in FIELDS}
|
||||||
|
train_buf_count = 0
|
||||||
|
del merged
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
# 剩余缓冲
|
||||||
|
if train_buf_count > 0:
|
||||||
|
merged = {}
|
||||||
|
for f in FIELDS:
|
||||||
|
merged[f] = np.concatenate(train_buffers[f], axis=0)
|
||||||
|
np.savez_compressed(
|
||||||
|
output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged
|
||||||
|
)
|
||||||
|
logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples")
|
||||||
|
train_shard_idx += 1
|
||||||
|
|
||||||
|
# 评估数据
|
||||||
|
if total_eval_kept > 0:
|
||||||
|
eval_merged = {}
|
||||||
|
for f in FIELDS:
|
||||||
|
eval_merged[f] = np.concatenate(eval_buffers[f], axis=0)
|
||||||
|
np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged)
|
||||||
|
else:
|
||||||
|
logger.warning("No eval samples extracted!")
|
||||||
|
|
||||||
|
# 写 metadata
|
||||||
|
train_metadata = {
|
||||||
|
"num_samples": total_train_kept,
|
||||||
|
"max_seq_length": max_seq_length,
|
||||||
|
"dtype": "int16",
|
||||||
|
"fields": FIELDS,
|
||||||
|
"shard_size": shard_size,
|
||||||
|
"num_shards": train_shard_idx,
|
||||||
|
}
|
||||||
|
with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(train_metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
eval_metadata = {
|
||||||
|
"num_samples": total_eval_kept,
|
||||||
|
"max_seq_length": max_seq_length,
|
||||||
|
"dtype": "int16",
|
||||||
|
"fields": FIELDS,
|
||||||
|
"shard_size": total_eval_kept,
|
||||||
|
"num_shards": 1,
|
||||||
|
}
|
||||||
|
with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(eval_metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
return total_train_kept, total_eval_kept
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集")
|
||||||
|
parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录")
|
||||||
|
parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-per-label",
|
||||||
|
type=int,
|
||||||
|
default=300_000,
|
||||||
|
help="每个 label ID 最大保留样本数",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-total",
|
||||||
|
type=int,
|
||||||
|
default=100_000_000,
|
||||||
|
help="目标训练集样本总数",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shard-size",
|
||||||
|
type=int,
|
||||||
|
default=5_000_000,
|
||||||
|
help="输出分片大小(样本数)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
input_dir = Path(args.input_dir)
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]")
|
||||||
|
console.print(f"输入目录: {input_dir}")
|
||||||
|
console.print(f"输出目录: {output_dir}")
|
||||||
|
console.print(f"每 ID 封顶: {args.cap_per_label:,}")
|
||||||
|
console.print(f"目标训练集: {args.target_total:,}")
|
||||||
|
console.print(f"评估集: {args.num_eval}")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# ── 第 1 遍:统计 ──
|
||||||
|
console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]")
|
||||||
|
label_counts, shard_sizes = pass1_count(input_dir, "train")
|
||||||
|
|
||||||
|
total_samples = sum(shard_sizes)
|
||||||
|
num_labels = len(label_counts)
|
||||||
|
capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values())
|
||||||
|
|
||||||
|
console.print(f"分片数: {len(shard_sizes)}")
|
||||||
|
console.print(f"总样本数: {total_samples:,}")
|
||||||
|
console.print(f"标签种类: {num_labels}")
|
||||||
|
console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# ── 计算保留配额 ──
|
||||||
|
quotas = compute_quotas(
|
||||||
|
label_counts,
|
||||||
|
cap=args.cap_per_label,
|
||||||
|
target_total=args.target_total,
|
||||||
|
)
|
||||||
|
expected_train = sum(quotas.values())
|
||||||
|
console.print(f"期望训练集大小: {expected_train:,}")
|
||||||
|
|
||||||
|
# 标签分布统计
|
||||||
|
stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label)
|
||||||
|
stats_n_below = num_labels - stats_n_capped
|
||||||
|
stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label)
|
||||||
|
console.print(
|
||||||
|
f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, "
|
||||||
|
f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}"
|
||||||
|
)
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# ── 抽取评估集位置 ──
|
||||||
|
rng = np.random.RandomState()
|
||||||
|
eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False)
|
||||||
|
eval_positions.sort()
|
||||||
|
eval_map = _global_to_shard(eval_positions, shard_sizes)
|
||||||
|
console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# ── 第 2 遍:子采样 ──
|
||||||
|
console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]")
|
||||||
|
train_count, eval_count = pass2_subsample(
|
||||||
|
input_dir,
|
||||||
|
output_dir,
|
||||||
|
"train",
|
||||||
|
quotas,
|
||||||
|
eval_map,
|
||||||
|
shard_sizes,
|
||||||
|
args.shard_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 输出总结 ──
|
||||||
|
console.print()
|
||||||
|
console.print("[bold green]=== 完成 ===[/bold green]")
|
||||||
|
console.print(f"训练集: {train_count:,} 样本")
|
||||||
|
console.print(f"评估集: {eval_count:,} 样本")
|
||||||
|
|
||||||
|
for split in ["train", "eval"]:
|
||||||
|
sdir = output_dir / split
|
||||||
|
if sdir.exists():
|
||||||
|
total_size = sum(
|
||||||
|
f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz"
|
||||||
|
)
|
||||||
|
meta_path = sdir / "metadata.json"
|
||||||
|
if meta_path.exists():
|
||||||
|
with open(meta_path) as mf:
|
||||||
|
meta = json.load(mf)
|
||||||
|
else:
|
||||||
|
meta = {}
|
||||||
|
console.print(
|
||||||
|
f" {split}/: {total_size / (1024**3):.2f} GB, "
|
||||||
|
f"{meta.get('num_shards', '?')} shards"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1157,6 +1157,11 @@ def train(
|
||||||
"--compile/--no-compile",
|
"--compile/--no-compile",
|
||||||
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
help="是否开启 torch.compile 优化(需 PyTorch 2.0+)",
|
||||||
),
|
),
|
||||||
|
moe_mode: str = typer.Option(
|
||||||
|
"all",
|
||||||
|
"--moe-mode",
|
||||||
|
help="MoE 计算策略: all(全量计算), sparse(稀疏计算), sparse_allow_graph(稀疏+allow_in_graph)",
|
||||||
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
训练输入法模型
|
训练输入法模型
|
||||||
|
|
@ -1211,6 +1216,7 @@ def train(
|
||||||
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
||||||
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||||||
config_table.add_row("模型", "编译优化", str(compile))
|
config_table.add_row("模型", "编译优化", str(compile))
|
||||||
|
config_table.add_row("模型", "MoE策略", moe_mode)
|
||||||
|
|
||||||
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||||||
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||||||
|
|
@ -1331,6 +1337,7 @@ def train(
|
||||||
"auto_resume": auto_resume,
|
"auto_resume": auto_resume,
|
||||||
"max_iter_length": max_iter_length,
|
"max_iter_length": max_iter_length,
|
||||||
"compile": compile,
|
"compile": compile,
|
||||||
|
"moe_mode": moe_mode,
|
||||||
"is_train_preprocessed": is_train_preprocessed,
|
"is_train_preprocessed": is_train_preprocessed,
|
||||||
"is_eval_preprocessed": is_eval_preprocessed,
|
"is_eval_preprocessed": is_eval_preprocessed,
|
||||||
"total_steps": total_steps,
|
"total_steps": total_steps,
|
||||||
|
|
@ -1353,6 +1360,7 @@ def train(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
compile=compile,
|
compile=compile,
|
||||||
|
moe_mode=moe_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue