Source code for mach3sbitools.data_loaders.sbi_data_module
"""
PyTorch Lightning data module for SBI simulation datasets.
Dataset sharing strategy
------------------------
The ``TensorDataset`` passed to this module lives in **CPU RAM** and is
**not copied per DDP rank**. Lightning's built-in ``DistributedSampler``
(activated automatically when ``strategy="ddp"``) gives each rank a
disjoint slice of indices, so every GPU reads only its own share from the
shared tensor without any inter-process data replication.
This is the correct pattern for in-memory datasets under DDP:
* Each rank receives the full ``TensorDataset`` reference (shared memory).
* ``DistributedSampler`` partitions the index space; ``pin_memory=True``
then pages only the required rows to GPU VRAM.
* There is zero redundant I/O or memory duplication.
``num_workers=0`` is kept throughout: spawning worker processes for an
already-RAM-resident tensor would only add IPC overhead.
"""
from __future__ import annotations
import warnings
import lightning as L
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from mach3sbitools.utils.config import TrainingConfig
warnings.filterwarnings(
"ignore",
message=".*num_workers.*bottleneck.*",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message=".*LeafSpec.*deprecated.*",
category=UserWarning,
)
[docs]
class SBIDataModule(L.LightningDataModule):
"""
Lightning data module over a pre-loaded ``(theta, x)`` dataset.
The dataset is expected to have been pre-loaded into CPU RAM via
:meth:`~mach3sbitools.data_loaders.ParaketDataset.to_tensor_dataset`
before this module is constructed.
Under DDP, Lightning automatically wraps each DataLoader's sampler in a
``DistributedSampler``, which partitions the index space across ranks.
Because the underlying ``TensorDataset`` tensors are kept in CPU shared
memory (no ``.to(device)`` call on the dataset itself), each rank reads
only its own slice — no data is copied between processes.
.. note::
The random split uses a fixed seed of ``42`` so that all DDP ranks
produce identical train / validation index sets. If you change this
seed, change it consistently across all ranks.
"""
def __init__(self, dataset: TensorDataset, config: TrainingConfig) -> None:
"""
:param dataset: Pre-loaded ``(theta, x)`` :class:`~torch.utils.data.TensorDataset`
in CPU RAM, produced by
:meth:`~mach3sbitools.data_loaders.ParaketDataset.to_tensor_dataset`.
:param config: Training configuration supplying ``validation_fraction``
and ``batch_size``.
"""
super().__init__()
self.dataset = dataset
self.config = config
# Specifically still save the batch size
self.batch_size = config.batch_size
self.train_dataset: Dataset | None = None
self.val_dataset: Dataset | None = None
[docs]
def setup(self, stage: str | None = None) -> None:
warnings.filterwarnings(
"ignore",
message=".*num_workers.*bottleneck.*",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message=".*LeafSpec.*",
category=UserWarning,
)
n_val = int(len(self.dataset) * self.config.validation_fraction)
n_train = len(self.dataset) - n_val
self.train_dataset, self.val_dataset = random_split(
self.dataset,
[n_train, n_val],
generator=torch.Generator().manual_seed(42),
)
[docs]
def train_dataloader(self) -> DataLoader:
"""
Return the training DataLoader.
Shuffling is handled by Lightning's ``DistributedSampler`` under DDP,
or by ``shuffle=True`` on single-device runs. ``drop_last=True``
keeps batch sizes uniform across ranks.
:raises RuntimeError: If :meth:`setup` has not been called.
"""
if self.train_dataset is None:
raise RuntimeError("Training set has not been set; call setup() first.")
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
num_workers=0,
pin_memory=True,
)
[docs]
def val_dataloader(self) -> DataLoader:
"""
Return the validation DataLoader.
:raises RuntimeError: If :meth:`setup` has not been called.
"""
if self.val_dataset is None:
raise RuntimeError("Validation set has not been set; call setup() first.")
return DataLoader(
self.val_dataset,
batch_size=self.config.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)