From 0862b5b8fc32660b84b45a35143303fc6ef99a10 Mon Sep 17 00:00:00 2001 From: songsenand Date: Mon, 11 May 2026 22:05:05 +0800 Subject: [PATCH] =?UTF-8?q?fix(PreProcessedDataset):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E8=BD=AC=E6=8D=A2=EF=BC=8C?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E5=86=85=E5=AD=98=E5=A4=8D=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/preprocessed_dataset.py | 19 ++++++++----------- src/model/trainer.py | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/model/preprocessed_dataset.py b/src/model/preprocessed_dataset.py index 13faf36..5f1eb51 100644 --- a/src/model/preprocessed_dataset.py +++ b/src/model/preprocessed_dataset.py @@ -93,7 +93,7 @@ class PreProcessedDataset(Dataset): 所有数据以 int16 存储,读取时转为 torch.long (int64)。 """ - def __init__(self, data_dir: str, max_cache_shards: int = 2): + def __init__(self, data_dir: str, max_cache_shards: int = 1): self.data_dir = Path(data_dir) with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f: @@ -144,12 +144,9 @@ class PreProcessedDataset(Dataset): self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r") def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]: - """加载一个 .npz 分片到内存""" + """加载一个 .npz 分片到内存(保持原始 int16,不转换)""" shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz" - data = dict(np.load(shard_path)) - for key in data: - data[key] = data[key].astype(np.int64) - return data + return dict(np.load(shard_path)) def __len__(self) -> int: return self.num_samples @@ -173,22 +170,22 @@ class PreProcessedDataset(Dataset): shard_data = self._cache.get(shard_idx, self._load_shard) return { "input_ids": torch.from_numpy( - shard_data["input_ids"][local_idx].copy() + shard_data["input_ids"][local_idx].astype(np.int64) ), "token_type_ids": torch.from_numpy( - shard_data["token_type_ids"][local_idx].copy() + shard_data["token_type_ids"][local_idx].astype(np.int64) ), "attention_mask": torch.from_numpy( - shard_data["attention_mask"][local_idx].copy() + shard_data["attention_mask"][local_idx].astype(np.int64) ), "labels": torch.tensor( shard_data["labels"][local_idx], dtype=torch.long ), "history_slot_ids": torch.from_numpy( - shard_data["history_slot_ids"][local_idx].copy() + shard_data["history_slot_ids"][local_idx].astype(np.int64) ), "pinyin_ids": torch.from_numpy( - shard_data["pinyin_ids"][local_idx].copy() + shard_data["pinyin_ids"][local_idx].astype(np.int64) ), } else: diff --git a/src/model/trainer.py b/src/model/trainer.py index 2fde42e..ffe65b7 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1260,7 +1260,7 @@ def train( is_eval_preprocessed = is_preprocessed_data(eval_data_path) if is_train_preprocessed: - train_dataset = PreProcessedDataset(train_data_path) + train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1) total_steps = (len(train_dataset) // batch_size) * num_epochs train_dataloader = create_dataloader( dataset=train_dataset, @@ -1292,7 +1292,7 @@ def train( config_table.add_row("数据", "训练数据类型", "流式数据") if is_eval_preprocessed: - eval_dataset = PreProcessedDataset(eval_data_path) + eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1) eval_dataloader = create_dataloader( dataset=eval_dataset, batch_size=batch_size,