fix(PreProcessedDataset): 修复数据类型转换,避免内存复制
This commit is contained in:
parent
27beb7f0b1
commit
0862b5b8fc
|
|
@ -93,7 +93,7 @@ class PreProcessedDataset(Dataset):
|
||||||
所有数据以 int16 存储,读取时转为 torch.long (int64)。
|
所有数据以 int16 存储,读取时转为 torch.long (int64)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir: str, max_cache_shards: int = 2):
|
def __init__(self, data_dir: str, max_cache_shards: int = 1):
|
||||||
self.data_dir = Path(data_dir)
|
self.data_dir = Path(data_dir)
|
||||||
|
|
||||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||||
|
|
@ -144,12 +144,9 @@ class PreProcessedDataset(Dataset):
|
||||||
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
|
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
|
||||||
|
|
||||||
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
|
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
|
||||||
"""加载一个 .npz 分片到内存"""
|
"""加载一个 .npz 分片到内存(保持原始 int16,不转换)"""
|
||||||
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
|
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
|
||||||
data = dict(np.load(shard_path))
|
return dict(np.load(shard_path))
|
||||||
for key in data:
|
|
||||||
data[key] = data[key].astype(np.int64)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
@ -173,22 +170,22 @@ class PreProcessedDataset(Dataset):
|
||||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.from_numpy(
|
"input_ids": torch.from_numpy(
|
||||||
shard_data["input_ids"][local_idx].copy()
|
shard_data["input_ids"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
"token_type_ids": torch.from_numpy(
|
"token_type_ids": torch.from_numpy(
|
||||||
shard_data["token_type_ids"][local_idx].copy()
|
shard_data["token_type_ids"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
"attention_mask": torch.from_numpy(
|
"attention_mask": torch.from_numpy(
|
||||||
shard_data["attention_mask"][local_idx].copy()
|
shard_data["attention_mask"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
"labels": torch.tensor(
|
"labels": torch.tensor(
|
||||||
shard_data["labels"][local_idx], dtype=torch.long
|
shard_data["labels"][local_idx], dtype=torch.long
|
||||||
),
|
),
|
||||||
"history_slot_ids": torch.from_numpy(
|
"history_slot_ids": torch.from_numpy(
|
||||||
shard_data["history_slot_ids"][local_idx].copy()
|
shard_data["history_slot_ids"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
"pinyin_ids": torch.from_numpy(
|
"pinyin_ids": torch.from_numpy(
|
||||||
shard_data["pinyin_ids"][local_idx].copy()
|
shard_data["pinyin_ids"][local_idx].astype(np.int64)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1260,7 +1260,7 @@ def train(
|
||||||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
||||||
|
|
||||||
if is_train_preprocessed:
|
if is_train_preprocessed:
|
||||||
train_dataset = PreProcessedDataset(train_data_path)
|
train_dataset = PreProcessedDataset(train_data_path, max_cache_shards=1)
|
||||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||||
train_dataloader = create_dataloader(
|
train_dataloader = create_dataloader(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
|
|
@ -1292,7 +1292,7 @@ def train(
|
||||||
config_table.add_row("数据", "训练数据类型", "流式数据")
|
config_table.add_row("数据", "训练数据类型", "流式数据")
|
||||||
|
|
||||||
if is_eval_preprocessed:
|
if is_eval_preprocessed:
|
||||||
eval_dataset = PreProcessedDataset(eval_data_path)
|
eval_dataset = PreProcessedDataset(eval_data_path, max_cache_shards=1)
|
||||||
eval_dataloader = create_dataloader(
|
eval_dataloader = create_dataloader(
|
||||||
dataset=eval_dataset,
|
dataset=eval_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue