fix(dataset): 修复分片数据集时未正确计算样本数的问题
This commit is contained in:
parent
483e4d4f98
commit
d0f1534086
|
|
@ -81,12 +81,17 @@ class PreProcessedDataset(Dataset):
|
||||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||||
self.metadata = json.load(f)
|
self.metadata = json.load(f)
|
||||||
|
|
||||||
self.num_samples = self.metadata["num_samples"]
|
|
||||||
self.max_seq_length = self.metadata["max_seq_length"]
|
self.max_seq_length = self.metadata["max_seq_length"]
|
||||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
||||||
self._num_shards: Optional[int] = self.metadata.get("num_shards")
|
|
||||||
|
|
||||||
if self._shard_size is not None and self._num_shards is not None:
|
if self._shard_size is not None:
|
||||||
|
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
|
||||||
|
self._num_shards = len(shard_files)
|
||||||
|
if self._num_shards == 0:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No shard_*.npz files found in {self.data_dir}"
|
||||||
|
)
|
||||||
|
self.num_samples = self._num_shards * self._shard_size
|
||||||
self._is_sharded = True
|
self._is_sharded = True
|
||||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -94,6 +99,7 @@ class PreProcessedDataset(Dataset):
|
||||||
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.num_samples = self.metadata["num_samples"]
|
||||||
self._is_sharded = False
|
self._is_sharded = False
|
||||||
self._load_single_files()
|
self._load_single_files()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue