Data Loading

class ParaketDataset(data_folder, parameter_names, nuisance_params=None)[source]

File-level PyTorch dataset over a folder of .feather simulation files.

Each __getitem__ call loads one feather file and returns a (theta, x) pair. Call to_tensor_dataset() before training to pre-load everything into RAM as a flat TensorDataset.

Parameters:
  • data_folder (Path)

  • parameter_names (list[str])

  • nuisance_params (list[str] | None)

to_tensor_dataset(device='cpu', verbose=True)[source]

Pre-load all feather files into a flat TensorDataset.

Concatenates all (theta, x) pairs along the sample dimension. This avoids repeated disk reads per epoch during training.

Parameters:
  • device (str) – Target device for the output tensors.

  • verbose (bool)

Return type:

TensorDataset

Returns:

A TensorDataset of (theta_tensor, x_tensor) with shapes (n_total_samples, n_params) and (n_total_samples, n_bins).

Utilities

from_feather(file_name, parameter_names, nuisance_pars=None)[source]

Load a (theta, x) pair from a feather file.

Parameters:
  • file_name (Path) – Path to the .feather file.

  • parameter_names (list[str]) – Ordered parameter names used for nuisance filtering.

  • nuisance_pars (list[str] | None) – fnmatch patterns for parameters to exclude from theta. None returns all parameters.

Return type:

tuple[ndarray[tuple[Any, ...], dtype[float32]], ndarray[tuple[Any, ...], dtype[float32]]]

Returns:

Tuple of (theta, x) as float32 numpy arrays.

Raises:

FileNotFoundError – If file_name does not exist.

to_feather(file_name, theta_values, x_values)[source]

Write a (theta, x) pair to a feather file.

Parameters:
  • file_name (Path) – Destination path. Must end in .feather.

  • theta_values (ndarray[tuple[Any, ...], dtype[float32]]) – Parameter array of shape (n_samples, n_params).

  • x_values (ndarray[tuple[Any, ...], dtype[float32]]) – Observable array of shape (n_samples, n_bins).

Raises:

ValueError – If file_name does not have a .feather suffix.

Return type:

None

filter_nuisance(parameter_names, nuisance_pars, theta)[source]

Remove nuisance parameters from a theta array by name pattern.

Parameters:
  • parameter_names (list[str]) – Ordered parameter names, length must match theta.shape[1].

  • nuisance_pars (list[str]) – fnmatch patterns for parameters to exclude (e.g. ["syst_*"]).

  • theta (ndarray[tuple[Any, ...], dtype[float32]]) – Parameter array of shape (n_samples, n_params).

Return type:

ndarray[tuple[Any, ...], dtype[float32]]

Returns:

Filtered array with nuisance columns removed.

Raises:

ValueError – If len(parameter_names) != theta.shape[1].

Lightning Module

class SBIDataModule(dataset, config)[source]

Lightning data module over a pre-loaded (theta, x) dataset.

The dataset is expected to have been pre-loaded into CPU RAM via to_tensor_dataset() before this module is constructed.

Under DDP, Lightning automatically wraps each DataLoader’s sampler in a DistributedSampler, which partitions the index space across ranks. Because the underlying TensorDataset tensors are kept in CPU shared memory (no .to(device) call on the dataset itself), each rank reads only its own slice — no data is copied between processes.

Note

The random split uses a fixed seed of 42 so that all DDP ranks produce identical train / validation index sets. If you change this seed, change it consistently across all ranks.

Parameters:
setup(stage=None)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Return type:

None

Parameters:

stage (str | None)

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
train_dataloader()[source]

Return the training DataLoader.

Shuffling is handled by Lightning’s DistributedSampler under DDP, or by shuffle=True on single-device runs. drop_last=True keeps batch sizes uniform across ranks.

Raises:

RuntimeError – If setup() has not been called.

Return type:

DataLoader

val_dataloader()[source]

Return the validation DataLoader.

Raises:

RuntimeError – If setup() has not been called.

Return type:

DataLoader