diff --git a/src/model/preprocessed_dataset.py b/src/model/preprocessed_dataset.py index c61ac9d..4320bb0 100644 --- a/src/model/preprocessed_dataset.py +++ b/src/model/preprocessed_dataset.py @@ -81,12 +81,17 @@ class PreProcessedDataset(Dataset): with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f: self.metadata = json.load(f) - self.num_samples = self.metadata["num_samples"] self.max_seq_length = self.metadata["max_seq_length"] self._shard_size: Optional[int] = self.metadata.get("shard_size") - self._num_shards: Optional[int] = self.metadata.get("num_shards") - if self._shard_size is not None and self._num_shards is not None: + if self._shard_size is not None: + shard_files = sorted(self.data_dir.glob("shard_*.npz")) + self._num_shards = len(shard_files) + if self._num_shards == 0: + raise FileNotFoundError( + f"No shard_*.npz files found in {self.data_dir}" + ) + self.num_samples = self._num_shards * self._shard_size self._is_sharded = True self._cache = _ShardCache(max_size=max_cache_shards) logger.info( @@ -94,6 +99,7 @@ class PreProcessedDataset(Dataset): f"{self._num_shards} shards, shard_size={self._shard_size:,}" ) else: + self.num_samples = self.metadata["num_samples"] self._is_sharded = False self._load_single_files() logger.info(