Inference

class InferenceHandler(prior_path, nuisance_pars=None)[source]

High-level interface for NPE training and posterior sampling.

The three distinct lifecycles (fresh train / resume / infer-only) are each expressed as a clean call sequence — see module docstring.

Parameters:
  • prior_path (Path)

  • nuisance_pars (list[str] | None)

set_dataset(data_folder)[source]

Point the handler at a folder of .feather simulation files.

Parameters:

data_folder (Path) – Directory containing .feather files.

Return type:

None

load_training_data(verbose=True)[source]

Return type:

None

Parameters:

verbose (bool)

create_posterior(config)[source]

Build the NPE inference object and density estimator network.

Only the kwargs that the chosen model family actually accepts are forwarded to posterior_nn; unsupported kwargs (e.g. num_blocks for zuko-backed flows) are dropped with a DEBUG log line rather than raising a TypeError at runtime.

Parameters:

config (PosteriorConfig) – Architecture and hyperparameter settings.

Return type:

None

train_posterior(config, model_config=None)[source]

Train the density estimator from scratch using PyTorch Lightning.

Requires load_training_data() and create_posterior() to have been called first.

Parameters:
  • config (TrainingConfig) – Training loop settings.

  • model_config (PosteriorConfig | None) – Architecture config embedded in every checkpoint.

Raises:

ValueError – If training data or the NPE object are missing.

Return type:

None

resume_training(checkpoint_path, config)[source]

Resume training from an existing checkpoint.

Single entry-point for the --resume_checkpoint flow. The model architecture is read directly from the checkpoint — create_posterior() does not need to be called beforehand, and any --model / --hidden / etc. flags passed on the CLI are intentionally ignored to prevent silent architecture mismatches.

_build_posterior_nn_kwargs filtering applies here exactly as it does for a fresh run, so zuko models resume without error regardless of which kwargs were stored in the checkpoint config.

Requires load_training_data() to have been called.

Parameters:
  • checkpoint_path (Path) – Path to a .ckpt produced by a previous run.

  • config (TrainingConfig) – Training loop settings for the resumed run.

Raises:
  • FileNotFoundError – If checkpoint_path does not exist.

  • ValueError – If training data has not been loaded.

Return type:

None

load_posterior(checkpoint_path, posterior_config=None)[source]

Load a trained density estimator from a checkpoint for inference only.

The PosteriorConfig is read from the checkpoint’s "model_config" key. _build_posterior_nn_kwargs filtering applies, so loading a zuko checkpoint works even if num_blocks is present in the stored config (it will simply be dropped).

Parameters:
  • checkpoint_path (Path) – Path to a .pt / .ckpt checkpoint.

  • posterior_config (PosteriorConfig | None) – Backwards-compat only; ignored when the checkpoint is self-contained.

Raises:
  • FileNotFoundError – If checkpoint_path does not exist.

  • ValueError – If no model config can be determined.

Return type:

None

build_posterior()[source]

Wrap the trained density estimator in an sbi posterior object.

Raises:

ValueError – If no density estimator or NPE object is present.

Return type:

None

sample_posterior(num_samples, x, **kwargs)[source]

Draw samples from the posterior conditioned on x.

Parameters:
  • num_samples (int) – Number of posterior samples to draw.

  • x (list[float] | ndarray) – Observed data vector x_o.

Return type:

Tensor

Returns:

Tensor of shape (num_samples, n_params).

get_log_likelihood(theta, x, **kwargs)[source]

Evaluate the log-likelihood of theta given observed data x.

Parameters:
  • theta (ndarray[tuple[Any, ...], dtype[float32]]) – Parameter array of shape (n_samples, n_params).

  • x (list[float] | ndarray) – Observed data vector x_o.

Return type:

Tensor

Returns:

Log-probability tensor of shape (n_samples,).

Trainer

lightning_module

alias of <module ‘mach3sbitools.inference.lightning_module’ from ‘/home/runner/work/MaCh3SbiTools/MaCh3SbiTools/src/mach3sbitools/inference/lightning_module.py’>