diff --git a/src/model/components.py b/src/model/components.py index 1bb6173..b5b67bf 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -341,15 +341,37 @@ class CrossAttentionFusion(nn.Module): # 4. 专家混合层 (MoE Layer) # 对应 README: 20个专家 [1], 使用 components.py 中的 Expert 类 # ------------------------------------------------------------------ + + +@torch.compiler.allow_in_graph +def _sparse_moe_dispatch(x_flat, experts, topk_indices, topk_weights, num_experts): + output = torch.zeros_like(x_flat) + for e in range(num_experts): + mask = topk_indices == e + idx, k_idx = mask.nonzero(as_tuple=True) + if idx.numel() > 0: + w = topk_weights[idx, k_idx].unsqueeze(-1) + output.index_add_(0, idx, w * experts[e](x_flat[idx])) + return output + + class MoELayer(nn.Module): - def __init__(self, dim=512, num_experts=10, top_k=3, num_resblocks=8): + """ + moe_mode 支持三种策略: + - "all": 计算全部专家,torch.compile 不断裂 (当前默认) + - "sparse": 只计算被路由到的专家 (产生 graph break) + - "sparse_allow_graph": 稀疏 MoE,通过 allow_in_graph 避免 graph break + """ + + def __init__( + self, dim=512, num_experts=10, top_k=3, num_resblocks=8, moe_mode="all" + ): super().__init__() self.num_experts = num_experts self.top_k = top_k self.dim = dim + self.moe_mode = moe_mode - # Import Expert from your existing components - # Assuming Expert class is defined as in components.py [2] self.experts = nn.ModuleList( [ Expert( @@ -362,48 +384,41 @@ class MoELayer(nn.Module): ] ) - # Gating Network [2] self.gate = nn.Linear(dim, num_experts) def forward(self, x): - """ - 并行化 MoE 前向传播,完全兼容 torch.compile 和 AMP。 - - Args: - x: [batch, seq_len, dim] - Returns: - out: [batch, seq_len, dim] - """ B, L, D = x.shape num_tokens = B * L + x_flat = x.view(num_tokens, D) - # 展平输入以便处理 - x_flat = x.view(num_tokens, D) # [B*L, D] + gates = self.gate(x_flat) + topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) + topk_weights = F.softmax(topk_weights, dim=-1) - # 1. 计算门控分数 - gates = self.gate(x_flat) # [B*L, num_experts] + if self.moe_mode == "all": + expert_outputs = torch.stack( + [expert(x_flat) for expert in self.experts], dim=1 + ) + indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) + selected_outputs = torch.gather(expert_outputs, 1, indices_expanded) + weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) + out_flat = weighted_outputs.sum(dim=1) - # 2. 选择 Top-K 专家 - topk_weights, topk_indices = torch.topk(gates, self.top_k, dim=-1) # [B*L, K] + elif self.moe_mode == "sparse": + out_flat = torch.zeros_like(x_flat) + for e in range(self.num_experts): + mask = topk_indices == e + idx, k_idx = mask.nonzero(as_tuple=True) + if idx.numel() > 0: + w = topk_weights[idx, k_idx].unsqueeze(-1) + out_flat.index_add_(0, idx, w * self.experts[e](x_flat[idx])) - # 归一化权重 - topk_weights = F.softmax(topk_weights, dim=-1) # [B*L, K] + elif self.moe_mode == "sparse_allow_graph": + out_flat = _sparse_moe_dispatch( + x_flat, self.experts, topk_indices, topk_weights, self.num_experts + ) - # 3. 并行计算所有专家(消除 Python 循环中的动态控制流) - # torch.compile 会展开此列表推导式,因为 num_experts 是编译时常量 - expert_outputs = torch.stack( - [expert(x_flat) for expert in self.experts], dim=1 - ) # [B*L, num_experts, D] + else: + raise ValueError(f"Unknown moe_mode: {self.moe_mode}") - # 4. 使用 gather 选择对应专家的输出 - # 扩展索引以匹配 expert_outputs 的维度 [B*L, num_experts, D] - indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, D) # [B*L, K, D] - selected_outputs = torch.gather( - expert_outputs, 1, indices_expanded - ) # [B*L, K, D] - # 5. 加权求和 - weighted_outputs = selected_outputs * topk_weights.unsqueeze(-1) # [B*L, K, D] - out_flat = weighted_outputs.sum(dim=1) # [B*L, D] - - # 恢复原始形状 return out_flat.view(B, L, D) diff --git a/src/model/inspect_preprocessed.py b/src/model/inspect_preprocessed.py index 312cfe7..8c0cb5f 100644 --- a/src/model/inspect_preprocessed.py +++ b/src/model/inspect_preprocessed.py @@ -4,7 +4,8 @@ 功能: 1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数) -2. 随机抽样还原为人类可读文本,导出为 CSV 文件 +2. 统计 history_slot_ids 的有效长度分布 +3. 抽样还原为人类可读文本,导出为 CSV 文件 用法: python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train @@ -14,7 +15,6 @@ import argparse import csv import json -import random from collections import Counter from pathlib import Path @@ -86,6 +86,36 @@ def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0): return counter, total +def analyze_history_slots(dataset: PreProcessedDataset, max_shards: int = 0): + """统计 history_slot_ids 的有效长度分布(非零元素个数)""" + logger.info("正在统计 history_slot_ids 长度分布...") + counter = Counter() + total = 0 + + num_shards = dataset._num_shards if dataset._is_sharded else 1 + effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards + + pbar = tqdm(range(effective_shards), desc="统计 history slots", unit="shard") + + for shard_idx in pbar: + if dataset._is_sharded: + shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz")) + history_slots = shard_data["history_slot_ids"].astype(np.int64) + else: + history_slots = dataset.history_slot_ids[:].astype(np.int64) + + lengths = np.count_nonzero(history_slots, axis=1) + unique, counts = np.unique(lengths, return_counts=True) + for uid, cnt in zip(unique, counts): + counter[int(uid)] += cnt + total += len(history_slots) + + if dataset._is_sharded: + del shard_data + + return counter, total + + def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict: """将一个样本还原为人类可读格式""" input_ids = ( @@ -139,6 +169,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict: pinyin_str = decode_pinyin_ids(pinyin_ids) label_info = query_engine.query_by_id(labels) history_str = decode_history(history_ids, query_engine) + history_slot_length = sum(1 for h in history_ids if h != 0) return { "context": context_text, @@ -150,6 +181,7 @@ def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict: else f"", "label_count": label_info.count if label_info else 0, "history": history_str, + "history_slot_length": history_slot_length, "full_tokens": token_text, } @@ -168,7 +200,7 @@ def main(): "--num-samples", type=int, default=50, - help="随机抽样的样本数量(默认50)", + help="抽样的样本数量(默认50,取前 N 个样本)", ) parser.add_argument( "--output", @@ -182,24 +214,14 @@ def main(): default=0, help="统计 labels 时最多读取的分片数(0=全部)", ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="随机种子", - ) parser.add_argument( "--top-k", type=int, default=30, help="显示出现次数最多和最少的标签数量", ) - args = parser.parse_args() - random.seed(args.seed) - np.random.seed(args.seed) - if args.output is None: args.output = str(Path(args.data_dir) / "samples.csv") @@ -332,13 +354,56 @@ def main(): ) console.print(table_dist) - # ====== 2. 随机抽样还原 → CSV ====== - num_samples = min(args.num_samples, len(dataset)) - console.print( - f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]" + # ====== 2. 历史槽位长度分析 ====== + console.print("\n[bold yellow]====== 历史槽位长度分析 ======[/bold yellow]") + history_counter, history_total = analyze_history_slots( + dataset, max_shards=args.max_shards ) - indices = random.sample(range(len(dataset)), num_samples) + if not history_counter: + console.print("[yellow] 无历史槽位数据[/yellow]") + else: + sorted_items = sorted(history_counter.items()) + lengths_arr = [l for l, _ in sorted_items] + counts_arr = [c for _, c in sorted_items] + weighted_sum = sum(l * c for l, c in zip(lengths_arr, counts_arr)) + avg_length = weighted_sum / history_total if history_total > 0 else 0 + + cumsum = 0 + median_length = 0 + for length, count in sorted_items: + cumsum += count + if cumsum >= history_total / 2: + median_length = length + break + + console.print(f"\n总样本数: {history_total:,}") + console.print(f"最大历史槽位数: {max(lengths_arr)}") + console.print(f"最小历史槽位数: {min(lengths_arr)}") + console.print(f"平均历史槽位数: {avg_length:.2f}") + console.print(f"中位数历史槽位数: {median_length}") + + hist_table = Table( + title="历史槽位有效长度分布", + show_header=True, + header_style="bold magenta", + ) + hist_table.add_column("槽位数", style="cyan", width=10) + hist_table.add_column("样本数", style="green", width=12) + hist_table.add_column("占比", style="yellow", width=10) + + for length, count in sorted_items: + pct = count / history_total * 100 + hist_table.add_row(str(length), f"{count:,}", f"{pct:.2f}%") + console.print(hist_table) + + # ====== 3. 抽样还原 → CSV ====== + num_samples = min(args.num_samples, len(dataset)) + console.print( + f"\n[bold yellow]====== 抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]" + ) + + indices = list(range(num_samples)) csv_path = Path(args.output) csv_path.parent.mkdir(parents=True, exist_ok=True) @@ -352,6 +417,7 @@ def main(): "context", "suffix", "history", + "history_slot_length", "full_tokens", ] @@ -372,6 +438,7 @@ def main(): decoded["context"], decoded["suffix"], decoded["history"], + decoded["history_slot_length"], decoded["full_tokens"], ] ) diff --git a/src/model/merge_shards.py b/src/model/merge_shards.py new file mode 100644 index 0000000..e6ed5de --- /dev/null +++ b/src/model/merge_shards.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""将分片 .npz 合并为单个 .npz 文件,合并后删除原分片""" + +import argparse +import gc +import json +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +FIELDS = [ + "input_ids", + "token_type_ids", + "attention_mask", + "labels", + "history_slot_ids", + "pinyin_ids", +] + + +def merge_split(data_dir: Path, output_name: str = "shard_000000.npz"): + with open(data_dir / "metadata.json") as f: + meta = json.load(f) + num_shards = meta["num_shards"] + if num_shards <= 1: + print(f"{data_dir}: already single shard, skip") + return + + merged = {} + for field in FIELDS: + pieces = [] + for i in tqdm( + range(num_shards), desc=f"Merging {data_dir.name}/{field}", unit="shard" + ): + data = np.load(data_dir / f"shard_{i:06d}.npz") + pieces.append(data[field]) + data.close() + merged[field] = np.concatenate(pieces, axis=0) + del pieces + gc.collect() + + total = len(merged["labels"]) + # 先删原分片(避免与新文件名冲突) + for i in range(num_shards): + (data_dir / f"shard_{i:06d}.npz").unlink() + # 再写合并文件 + np.savez_compressed(data_dir / output_name, **merged) + del merged + gc.collect() + + # 更新 metadata + meta["num_shards"] = 1 + meta["shard_size"] = total + with open(data_dir / "metadata.json", "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2, ensure_ascii=False) + + print(f"{data_dir.name}: {num_shards} shards → 1 shard, {total:,} samples") + + +def main(): + parser = argparse.ArgumentParser(description="合并分片 .npz 为单个文件") + parser.add_argument( + "--input-dir", type=str, required=True, help="数据集目录(含 train/ 和 eval/)" + ) + parser.add_argument( + "--output-name", type=str, default="shard_000000.npz", help="合并后文件名" + ) + parser.add_argument( + "--train-only", action="store_true", help="仅合并 train,跳过 eval" + ) + + args = parser.parse_args() + root = Path(args.input_dir) + + for split in ["train", "eval"]: + split_dir = root / split + if not split_dir.exists(): + print(f"{split_dir}: not found, skip") + continue + if args.train_only and split == "eval": + continue + merge_split(split_dir, args.output_name) + + +if __name__ == "__main__": + main() diff --git a/src/model/model.py b/src/model/model.py index 6591170..a5ac2c4 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -35,12 +35,13 @@ class InputMethodEngine(nn.Module): vocab_size: int = 10019, pinyin_vocab_size: int = 28, dim: int = 512, - num_slots: int = 8, # 历史槽位数量 (对应 README 中的 8 个槽位) - n_layers: int = 4, # Transformer 层数 - n_heads: int = 4, # 注意力头数 - num_experts: int = 10, # MoE 专家数量 - max_seq_len: int = 128, # 最大上下文长度 - compile: bool = False, # 是否开启 torch.compile 优化 + num_slots: int = 8, + n_layers: int = 4, + n_heads: int = 4, + num_experts: int = 10, + max_seq_len: int = 128, + compile: bool = False, + moe_mode: str = "all", # "all" / "sparse" / "sparse_allow_graph" ): super().__init__() self.dim = dim @@ -72,7 +73,13 @@ class InputMethodEngine(nn.Module): self.cross_attn = CrossAttentionFusion(dim=dim, n_heads=n_heads) # 4. 混合专家层 (MoE) - self.moe = MoELayer(dim=dim, num_experts=num_experts, top_k=3, num_resblocks=12) + self.moe = MoELayer( + dim=dim, + num_experts=num_experts, + top_k=3, + num_resblocks=12, + moe_mode=moe_mode, + ) # 5. 槽位注意力池化 self.slot_attention = nn.Linear(dim, 1) diff --git a/src/model/subsample.py b/src/model/subsample.py new file mode 100644 index 0000000..16e5829 --- /dev/null +++ b/src/model/subsample.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +""" +子采样脚本:从预处理的 .npz 分片中抽取子集。 + +策略: + 1. 第1遍扫描:只读 labels 字段,统计每个 label ID 的样本数,记录分片大小 + 2. 中间计算:按每 ID 最多 N 个样本封顶(硬封顶),不足目标总量则从超额池补足 + 3. 第2遍扫描:读取全部字段,按精确保留配额抽取训练样本,同时抽取评估集 + +用法: + python -m model.subsample \ + --input-dir ./preprocessed \ + --output-dir ./subsampled \ + --cap-per-label 300000 \ + --target-total 100000000 \ + --num-eval 2560 +""" + +import argparse +import gc +import json +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +from loguru import logger +from rich.console import Console +from tqdm import tqdm + +FIELDS = [ + "input_ids", + "token_type_ids", + "attention_mask", + "labels", + "history_slot_ids", + "pinyin_ids", +] + + +def _global_to_shard( + positions: np.ndarray, shard_sizes: List[int] +) -> Dict[int, List[int]]: + """将全局位置映射为 {shard_idx: [local_indices]}""" + cumsum = np.cumsum([0] + shard_sizes) + mapping: Dict[int, List[int]] = {} + for pos in positions: + shard_idx = int(np.searchsorted(cumsum, pos, side="right") - 1) + local_idx = pos - cumsum[shard_idx] + mapping.setdefault(shard_idx, []).append(local_idx) + return mapping + + +def pass1_count(input_dir: Path, split: str) -> Tuple[Dict[int, int], List[int]]: + """ + 第1遍扫描:只读 labels,统计每个 label ID 的总样本数,记录各分片大小。 + 返回 (label_counts, shard_sizes)。 + """ + metadata_path = input_dir / split / "metadata.json" + with open(metadata_path) as f: + metadata = json.load(f) + + num_shards = metadata["num_shards"] + label_counts: Dict[int, int] = {} + shard_sizes: List[int] = [] + + pbar = tqdm(total=num_shards, desc="Pass 1: counting labels", unit="shard") + + for shard_idx in range(num_shards): + shard_path = input_dir / split / f"shard_{shard_idx:06d}.npz" + data = np.load(shard_path) + labels = data["labels"] + n = len(labels) + shard_sizes.append(n) + + unique, counts = np.unique(labels, return_counts=True) + for uid, cnt in zip(unique, counts): + uid = int(uid) + label_counts[uid] = label_counts.get(uid, 0) + int(cnt) + + data.close() + pbar.update(1) + + pbar.close() + return label_counts, shard_sizes + + +def compute_quotas( + label_counts: Dict[int, int], + cap: int = 300_000, + target_total: int = 100_000_000, +) -> Dict[int, int]: + """ + 计算每个 label 的精确保留配额(硬封顶)。 + + 策略: + - 每 ID 保留 min(count, cap) + - 若封顶后总量 >= target_total,直接用封顶策略 + - 若封顶后总量 < target_total,从超额池(count > cap 的部分)等比抽取补足 + """ + capped_total = sum(min(cnt, cap) for cnt in label_counts.values()) + + if capped_total >= target_total: + return {label: min(cnt, cap) for label, cnt in label_counts.items()} + + # 封顶不足,从超额池补足 + deficit = target_total - capped_total + excess_per_label = {lbl: max(0, cnt - cap) for lbl, cnt in label_counts.items()} + excess_total = sum(excess_per_label.values()) + + quotas: Dict[int, int] = {} + if excess_total > 0: + ratio = min(1.0, deficit / excess_total) + for label, cnt in label_counts.items(): + base = min(cnt, cap) + extra = int(excess_per_label[label] * ratio) + quotas[label] = base + extra + else: + quotas = {label: min(cnt, cap) for label, cnt in label_counts.items()} + + return quotas + + +def pass2_subsample( + input_dir: Path, + output_dir: Path, + split: str, + quotas: Dict[int, int], + eval_map: Dict[int, List[int]], + shard_sizes: List[int], + shard_size: int = 5_000_000, +) -> Tuple[int, int]: + """ + 第2遍扫描:读取全部字段,抽取评估样本 + 按精确保留配额子采样训练样本。 + + quotas: {label_id: exact_number_to_keep} + 返回 (train_kept, eval_kept)。 + """ + metadata_path = input_dir / split / "metadata.json" + with open(metadata_path) as f: + metadata = json.load(f) + + num_shards = metadata["num_shards"] + max_seq_length = metadata["max_seq_length"] + + output_train_dir = output_dir / "train" + output_eval_dir = output_dir / "eval" + output_train_dir.mkdir(parents=True, exist_ok=True) + output_eval_dir.mkdir(parents=True, exist_ok=True) + + rng = np.random.RandomState() + remaining = dict(quotas) # {label_id: remaining_to_keep} + + train_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} + train_buf_count = 0 + train_shard_idx = 0 + total_train_kept = 0 + + eval_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS} + total_eval_kept = 0 + + pbar = tqdm(total=num_shards, desc="Pass 2: subsampling", unit="shard") + + for src_shard_idx in range(num_shards): + shard_path = input_dir / split / f"shard_{src_shard_idx:06d}.npz" + data = np.load(shard_path) + n = shard_sizes[src_shard_idx] + + eval_local = np.array(eval_map.get(src_shard_idx, []), dtype=np.int64) + + is_eval = np.zeros(n, dtype=bool) + if len(eval_local) > 0: + is_eval[eval_local] = True + + # ── 第一步:只加载 labels,计算掩码 ── + labels = data["labels"] + + # 抽取评估样本的 labels + if len(eval_local) > 0: + eval_buffers["labels"].append(labels[eval_local].copy()) + + # 计算训练保留位置 + train_candidates = labels[~is_eval] + train_n = len(train_candidates) + n_kept = 0 + keep_original_indices = np.array([], dtype=np.int64) + + if train_n > 0: + sort_idx = np.argsort(train_candidates, kind="stable") + sorted_labels = train_candidates[sort_idx] + unique_vals, starts = np.unique(sorted_labels, return_index=True) + ends = np.append(starts[1:], train_n) + + train_keep_mask = np.zeros(train_n, dtype=bool) + for i in range(len(unique_vals)): + label_val = int(unique_vals[i]) + start, end = starts[i], ends[i] + cnt_in_shard = end - start + need = remaining.get(label_val, 0) + select = min(cnt_in_shard, need) + if select >= cnt_in_shard: + train_keep_mask[sort_idx[start:end]] = True + elif select > 0: + chosen = rng.choice(cnt_in_shard, size=select, replace=False) + train_keep_mask[sort_idx[start + chosen]] = True + remaining[label_val] = max(0, remaining.get(label_val, 0) - select) + + original_train_indices = np.where(~is_eval)[0] + keep_original_indices = original_train_indices[train_keep_mask] + n_kept = len(keep_original_indices) + + # 释放标签相关的临时数组 + del train_candidates, sort_idx, sorted_labels + del unique_vals, starts, ends, train_keep_mask + del original_train_indices + gc.collect() + + del labels, is_eval + gc.collect() + + # ── 第二步:逐字段加载,抽取评估 + 训练 ── + # 评估:跳过 labels(已在第一步抽取) + if len(eval_local) > 0: + for f in FIELDS: + if f == "labels": + continue + arr = data[f] + eval_buffers[f].append(arr[eval_local].copy()) + del arr + gc.collect() + total_eval_kept += len(eval_local) + + # 训练:逐字段加载,立即删除 + if n_kept > 0: + for f in FIELDS: + arr = data[f] + train_buffers[f].append(arr[keep_original_indices].copy()) + del arr + gc.collect() + train_buf_count += n_kept + total_train_kept += n_kept + + del keep_original_indices + data.close() + gc.collect() + + if train_buf_count >= shard_size: + merged = {} + for f in FIELDS: + merged[f] = np.concatenate(train_buffers[f], axis=0) + np.savez_compressed( + output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged + ) + logger.debug( + f"Saved train shard {train_shard_idx}: {train_buf_count} samples" + ) + train_shard_idx += 1 + train_buffers = {f: [] for f in FIELDS} + train_buf_count = 0 + del merged + gc.collect() + + pbar.update(1) + + pbar.close() + + # 剩余缓冲 + if train_buf_count > 0: + merged = {} + for f in FIELDS: + merged[f] = np.concatenate(train_buffers[f], axis=0) + np.savez_compressed( + output_train_dir / f"shard_{train_shard_idx:06d}.npz", **merged + ) + logger.debug(f"Saved train shard {train_shard_idx}: {train_buf_count} samples") + train_shard_idx += 1 + + # 评估数据 + if total_eval_kept > 0: + eval_merged = {} + for f in FIELDS: + eval_merged[f] = np.concatenate(eval_buffers[f], axis=0) + np.savez_compressed(output_eval_dir / "shard_000000.npz", **eval_merged) + else: + logger.warning("No eval samples extracted!") + + # 写 metadata + train_metadata = { + "num_samples": total_train_kept, + "max_seq_length": max_seq_length, + "dtype": "int16", + "fields": FIELDS, + "shard_size": shard_size, + "num_shards": train_shard_idx, + } + with open(output_train_dir / "metadata.json", "w", encoding="utf-8") as f: + json.dump(train_metadata, f, indent=2, ensure_ascii=False) + + eval_metadata = { + "num_samples": total_eval_kept, + "max_seq_length": max_seq_length, + "dtype": "int16", + "fields": FIELDS, + "shard_size": total_eval_kept, + "num_shards": 1, + } + with open(output_eval_dir / "metadata.json", "w", encoding="utf-8") as f: + json.dump(eval_metadata, f, indent=2, ensure_ascii=False) + + return total_train_kept, total_eval_kept + + +def main(): + console = Console() + + parser = argparse.ArgumentParser(description="从预处理 .npz 分片中抽取子集") + parser.add_argument("--input-dir", type=str, required=True, help="输入预处理目录") + parser.add_argument("--output-dir", type=str, required=True, help="输出子采样目录") + parser.add_argument( + "--cap-per-label", + type=int, + default=300_000, + help="每个 label ID 最大保留样本数", + ) + parser.add_argument( + "--target-total", + type=int, + default=100_000_000, + help="目标训练集样本总数", + ) + parser.add_argument( + "--shard-size", + type=int, + default=5_000_000, + help="输出分片大小(样本数)", + ) + parser.add_argument("--num-eval", type=int, default=2560, help="评估集样本数") + + args = parser.parse_args() + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + console.print("[bold cyan]=== 子采样预处理数据 ===[/bold cyan]") + console.print(f"输入目录: {input_dir}") + console.print(f"输出目录: {output_dir}") + console.print(f"每 ID 封顶: {args.cap_per_label:,}") + console.print(f"目标训练集: {args.target_total:,}") + console.print(f"评估集: {args.num_eval}") + console.print() + + # ── 第 1 遍:统计 ── + console.print("[bold]第 1 遍扫描:统计标签分布...[/bold]") + label_counts, shard_sizes = pass1_count(input_dir, "train") + + total_samples = sum(shard_sizes) + num_labels = len(label_counts) + capped_total = sum(min(cnt, args.cap_per_label) for cnt in label_counts.values()) + + console.print(f"分片数: {len(shard_sizes)}") + console.print(f"总样本数: {total_samples:,}") + console.print(f"标签种类: {num_labels}") + console.print(f"封顶后样本数 (每ID≤{args.cap_per_label:,}): {capped_total:,}") + console.print() + + # ── 计算保留配额 ── + quotas = compute_quotas( + label_counts, + cap=args.cap_per_label, + target_total=args.target_total, + ) + expected_train = sum(quotas.values()) + console.print(f"期望训练集大小: {expected_train:,}") + + # 标签分布统计 + stats_n_capped = sum(1 for cnt in label_counts.values() if cnt > args.cap_per_label) + stats_n_below = num_labels - stats_n_capped + stats_n_quota_full = sum(1 for lbl, q in quotas.items() if q >= args.cap_per_label) + console.print( + f"超过封顶的标签数: {stats_n_capped}, 不足封顶的标签数: {stats_n_below}, " + f"配额打满(≥{args.cap_per_label // 10000}万): {stats_n_quota_full}" + ) + console.print() + + # ── 抽取评估集位置 ── + rng = np.random.RandomState() + eval_positions = rng.choice(total_samples, size=args.num_eval, replace=False) + eval_positions.sort() + eval_map = _global_to_shard(eval_positions, shard_sizes) + console.print(f"评估集: {args.num_eval} 个位置已分配到 {len(eval_map)} 个分片中") + console.print() + + # ── 第 2 遍:子采样 ── + console.print("[bold]第 2 遍扫描:子采样 + 抽取评估集...[/bold]") + train_count, eval_count = pass2_subsample( + input_dir, + output_dir, + "train", + quotas, + eval_map, + shard_sizes, + args.shard_size, + ) + + # ── 输出总结 ── + console.print() + console.print("[bold green]=== 完成 ===[/bold green]") + console.print(f"训练集: {train_count:,} 样本") + console.print(f"评估集: {eval_count:,} 样本") + + for split in ["train", "eval"]: + sdir = output_dir / split + if sdir.exists(): + total_size = sum( + f.stat().st_size for f in sdir.iterdir() if f.suffix == ".npz" + ) + meta_path = sdir / "metadata.json" + if meta_path.exists(): + with open(meta_path) as mf: + meta = json.load(mf) + else: + meta = {} + console.print( + f" {split}/: {total_size / (1024**3):.2f} GB, " + f"{meta.get('num_shards', '?')} shards" + ) + + +if __name__ == "__main__": + main() diff --git a/src/model/trainer.py b/src/model/trainer.py index 165c996..6f7f8b9 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1157,6 +1157,11 @@ def train( "--compile/--no-compile", help="是否开启 torch.compile 优化(需 PyTorch 2.0+)", ), + moe_mode: str = typer.Option( + "all", + "--moe-mode", + help="MoE 计算策略: all(全量计算), sparse(稀疏计算), sparse_allow_graph(稀疏+allow_in_graph)", + ), ): """ 训练输入法模型 @@ -1211,6 +1216,7 @@ def train( config_table.add_row("模型", "MoE专家数", str(num_experts)) config_table.add_row("模型", "使用拼音", str(use_pinyin)) config_table.add_row("模型", "编译优化", str(compile)) + config_table.add_row("模型", "MoE策略", moe_mode) config_table.add_row("训练", "训练轮数", str(num_epochs)) config_table.add_row("训练", "学习率", f"{learning_rate:.2e}") @@ -1331,6 +1337,7 @@ def train( "auto_resume": auto_resume, "max_iter_length": max_iter_length, "compile": compile, + "moe_mode": moe_mode, "is_train_preprocessed": is_train_preprocessed, "is_eval_preprocessed": is_eval_preprocessed, "total_steps": total_steps, @@ -1353,6 +1360,7 @@ def train( num_experts=num_experts, max_seq_len=max_seq_len, compile=compile, + moe_mode=moe_mode, ) console.print(