fix(dataset): 修复分片数据集时未正确计算样本数的问题
This commit is contained in:
parent
483e4d4f98
commit
d0f1534086
|
|
@ -81,12 +81,17 @@ class PreProcessedDataset(Dataset):
|
|||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
self.max_seq_length = self.metadata["max_seq_length"]
|
||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
||||
self._num_shards: Optional[int] = self.metadata.get("num_shards")
|
||||
|
||||
if self._shard_size is not None and self._num_shards is not None:
|
||||
if self._shard_size is not None:
|
||||
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||
self._num_shards = len(shard_files)
|
||||
if self._num_shards == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No shard_*.npz files found in {self.data_dir}"
|
||||
)
|
||||
self.num_samples = self._num_shards * self._shard_size
|
||||
self._is_sharded = True
|
||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||
logger.info(
|
||||
|
|
@ -94,6 +99,7 @@ class PreProcessedDataset(Dataset):
|
|||
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
||||
)
|
||||
else:
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
self._is_sharded = False
|
||||
self._load_single_files()
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Reference in New Issue