From d0f153408646ac4e181fef8c6dc276b42c0133d1 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sun, 10 May 2026 23:34:08 +0800 Subject: [PATCH] =?UTF-8?q?fix(dataset):=20=E4=BF=AE=E5=A4=8D=E5=88=86?= =?UTF-8?q?=E7=89=87=E6=95=B0=E6=8D=AE=E9=9B=86=E6=97=B6=E6=9C=AA=E6=AD=A3?= =?UTF-8?q?=E7=A1=AE=E8=AE=A1=E7=AE=97=E6=A0=B7=E6=9C=AC=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model/preprocessed_dataset.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/model/preprocessed_dataset.py b/src/model/preprocessed_dataset.py index c61ac9d..4320bb0 100644 --- a/src/model/preprocessed_dataset.py +++ b/src/model/preprocessed_dataset.py @@ -81,12 +81,17 @@ class PreProcessedDataset(Dataset): with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f: self.metadata = json.load(f) - self.num_samples = self.metadata["num_samples"] self.max_seq_length = self.metadata["max_seq_length"] self._shard_size: Optional[int] = self.metadata.get("shard_size") - self._num_shards: Optional[int] = self.metadata.get("num_shards") - if self._shard_size is not None and self._num_shards is not None: + if self._shard_size is not None: + shard_files = sorted(self.data_dir.glob("shard_*.npz")) + self._num_shards = len(shard_files) + if self._num_shards == 0: + raise FileNotFoundError( + f"No shard_*.npz files found in {self.data_dir}" + ) + self.num_samples = self._num_shards * self._shard_size self._is_sharded = True self._cache = _ShardCache(max_size=max_cache_shards) logger.info( @@ -94,6 +99,7 @@ class PreProcessedDataset(Dataset): f"{self._num_shards} shards, shard_size={self._shard_size:,}" ) else: + self.num_samples = self.metadata["num_samples"] self._is_sharded = False self._load_single_files() logger.info(