Data Loading¶
- class ParaketDataset(data_folder, parameter_names, nuisance_params=None)[source]¶
File-level PyTorch dataset over a folder of
.feathersimulation files.Each
__getitem__call loads one feather file and returns a(theta, x)pair. Callto_tensor_dataset()before training to pre-load everything into RAM as a flatTensorDataset.- Parameters:
data_folder (Path)
parameter_names (list[str])
nuisance_params (list[str] | None)
- to_tensor_dataset(device='cpu', verbose=True)[source]¶
Pre-load all feather files into a flat
TensorDataset.Concatenates all
(theta, x)pairs along the sample dimension. This avoids repeated disk reads per epoch during training.- Parameters:
device (
str) – Target device for the output tensors.verbose (bool)
- Return type:
TensorDataset- Returns:
A
TensorDatasetof(theta_tensor, x_tensor)with shapes(n_total_samples, n_params)and(n_total_samples, n_bins).
Utilities¶
- from_feather(file_name, parameter_names, nuisance_pars=None)[source]¶
Load a
(theta, x)pair from a feather file.- Parameters:
file_name (
Path) – Path to the.featherfile.parameter_names (
list[str]) – Ordered parameter names used for nuisance filtering.nuisance_pars (
list[str] |None) – fnmatch patterns for parameters to exclude from theta.Nonereturns all parameters.
- Return type:
tuple[ndarray[tuple[Any,...],dtype[float32]],ndarray[tuple[Any,...],dtype[float32]]]- Returns:
Tuple of
(theta, x)asfloat32numpy arrays.- Raises:
FileNotFoundError – If file_name does not exist.
- to_feather(file_name, theta_values, x_values)[source]¶
Write a
(theta, x)pair to a feather file.- Parameters:
file_name (
Path) – Destination path. Must end in.feather.theta_values (
ndarray[tuple[Any,...],dtype[float32]]) – Parameter array of shape(n_samples, n_params).x_values (
ndarray[tuple[Any,...],dtype[float32]]) – Observable array of shape(n_samples, n_bins).
- Raises:
ValueError – If file_name does not have a
.feathersuffix.- Return type:
None
- filter_nuisance(parameter_names, nuisance_pars, theta)[source]¶
Remove nuisance parameters from a theta array by name pattern.
- Parameters:
parameter_names (
list[str]) – Ordered parameter names, length must matchtheta.shape[1].nuisance_pars (
list[str]) – fnmatch patterns for parameters to exclude (e.g.["syst_*"]).theta (
ndarray[tuple[Any,...],dtype[float32]]) – Parameter array of shape(n_samples, n_params).
- Return type:
ndarray[tuple[Any,...],dtype[float32]]- Returns:
Filtered array with nuisance columns removed.
- Raises:
ValueError – If
len(parameter_names) != theta.shape[1].
Lightning Module¶
- class SBIDataModule(dataset, config)[source]¶
Lightning data module over a pre-loaded
(theta, x)dataset.The dataset is expected to have been pre-loaded into CPU RAM via
to_tensor_dataset()before this module is constructed.Under DDP, Lightning automatically wraps each DataLoader’s sampler in a
DistributedSampler, which partitions the index space across ranks. Because the underlyingTensorDatasettensors are kept in CPU shared memory (no.to(device)call on the dataset itself), each rank reads only its own slice — no data is copied between processes.Note
The random split uses a fixed seed of
42so that all DDP ranks produce identical train / validation index sets. If you change this seed, change it consistently across all ranks.- Parameters:
dataset (TensorDataset)
config (TrainingConfig)
- setup(stage=None)[source]¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Return type:
None- Parameters:
stage (str | None)
- Args:
stage: either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader()[source]¶
Return the training DataLoader.
Shuffling is handled by Lightning’s
DistributedSamplerunder DDP, or byshuffle=Trueon single-device runs.drop_last=Truekeeps batch sizes uniform across ranks.- Raises:
RuntimeError – If
setup()has not been called.- Return type:
DataLoader