fix(dataset): 修复分片数据集时未正确计算样本数的问题

This commit is contained in:
songsenand 2026-05-10 23:34:08 +08:00
parent 483e4d4f98
commit d0f1534086
1 changed files with 9 additions and 3 deletions

View File

@ -81,12 +81,17 @@ class PreProcessedDataset(Dataset):
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.num_samples = self.metadata["num_samples"]
self.max_seq_length = self.metadata["max_seq_length"]
self._shard_size: Optional[int] = self.metadata.get("shard_size")
self._num_shards: Optional[int] = self.metadata.get("num_shards")
if self._shard_size is not None and self._num_shards is not None:
if self._shard_size is not None:
shard_files = sorted(self.data_dir.glob("shard_*.npz"))
self._num_shards = len(shard_files)
if self._num_shards == 0:
raise FileNotFoundError(
f"No shard_*.npz files found in {self.data_dir}"
)
self.num_samples = self._num_shards * self._shard_size
self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards)
logger.info(
@ -94,6 +99,7 @@ class PreProcessedDataset(Dataset):
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
)
else:
self.num_samples = self.metadata["num_samples"]
self._is_sharded = False
self._load_single_files()
logger.info(