Source code for mach3sbitools.simulator.priors.prior

"""
Prior distribution for MaCh3 SBI.

Constructs a composite prior from four distribution types, checked in order:

1. **Cyclical** — parameters matching *cyclical_parameters* patterns, forced
   to bounds of ``[-2π, 2π]``.
2. **Flipped Uniform** — parameters matching *flipped_parameters* patterns.
   Support is ``[lower, upper] + [-upper, -lower]`` where ``lower``/``upper``
   are read from the parameter's existing bounds (which must be positive and
   satisfy ``0 < lower < upper``).
3. **Flat (Uniform)** — parameters flagged via *flat_msk* and not cyclical or
   flipped.
4. **Gaussian** — all remaining parameters, modelled as a
   :class:`~torch.distributions.MultivariateNormal`.
"""

import fnmatch
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import TypeAlias

import numpy as np
import torch
from torch.distributions import Uniform, constraints

from mach3sbitools.utils import TorchDeviceHandler, get_logger

from ..simulator_injector import SimulatorProtocol
from .cyclical_distribution import CyclicalDistribution
from .dataclasses import PriorData
from .flipped_uniform_distribution import FlippedUniformDistribution
from .truncated_gaussian_distribution import TruncatedGaussianDistribution

size_: TypeAlias = torch.Size | list[int] | tuple[int, ...]


class PriorNotFound(Exception):
    """Raised when a prior file cannot be found or deserialised."""


logger = get_logger()


@dataclass(frozen=True)
class MaskDistributionMap:
    """
    Associates a boolean parameter mask with its distribution.

    :param mask: Boolean tensor of shape ``(n_params,)`` selecting the
        parameters governed by *distribution*.
    :param distribution: The :class:`torch.distributions.Distribution`
        for the selected parameters.
    """

    mask: torch.Tensor
    distribution: torch.distributions.Distribution

    def to(self, device: torch.device | str) -> "MaskDistributionMap":
        """
        Move *mask* to *device* (distribution tensors are not moved).

        :param device: Target PyTorch device.
        :returns: New :class:`MaskDistributionMap` with mask on *device*.
        """
        return MaskDistributionMap(
            mask=self.mask.to(device), distribution=self.distribution
        )


[docs] class Prior(torch.distributions.Distribution): """ Composite MaCh3 prior combining cyclical, flipped-uniform, flat, and Gaussian components. Designed to replicate MaCh3's prior construction (https://github.com/mach3-software/MaCh3) and satisfy the ``sbi`` :class:`~torch.distributions.Distribution` interface. Parameters are assigned to distributions in the following order: - **Cyclical** — matched by *cyclical_parameters* (fnmatch patterns). - **Flipped Uniform** — matched by *flipped_parameters* (fnmatch patterns). Each matched parameter must have positive bounds ``0 < lower_bound < upper_bound``; its support becomes ``[lower, upper] + [-upper, -lower]``. - **Flat** — flagged by *flat_msk* and not cyclical or flipped. - **Gaussian** — everything else. .. warning:: The nuisance filter is fixed at construction time. Calling :meth:`set_nuisance_filter` after construction is not supported — the distribution masks are built once against the filtered parameter set and cannot be safely remapped afterwards. To change the nuisance filter, construct a new :class:`Prior`. """ def __init__( self, prior_data: PriorData, flat_msk: list[bool] | None = None, cyclical_parameters: list[str] | None = None, nuisance_parameters: list[str] | None = None, flipped_parameters: list[str] | None = None, ): """ Construct the composite prior. :param prior_data: Raw prior arrays (names, nominals, bounds, covariance). :param flat_msk: Per-parameter flat flags (index-aligned with *prior_data*). Defaults to all ``False`` if not provided. :param cyclical_parameters: fnmatch patterns selecting cyclical parameters. Matched parameters receive bounds of ``±2π``. :param nuisance_parameters: fnmatch patterns selecting parameters to exclude. Fixed at construction time — cannot be changed later. :param flipped_parameters: fnmatch patterns selecting parameters that use a bimodal uniform prior over ``[lower, upper] + [-upper, -lower]``. The bounds are read from *prior_data* and must satisfy ``0 < lower < upper``. """ self.device_handler = TorchDeviceHandler() self._prior_data = prior_data # Apply nuisance filter once — masks are built against the filtered set # and cannot be safely remapped if the filter changes afterwards. self._nuisance_filter = self._build_nuisance_filter(nuisance_parameters) self._priors: list[MaskDistributionMap] = [] n_params = len(self.prior_data.nominals) # ── Cyclical mask ────────────────────────────────────────────────── if cyclical_parameters: cyclical_mask_ = [ any(fnmatch.fnmatch(p, c) for c in cyclical_parameters) for p in self.prior_data.parameter_names ] cyclical_mask = self.device_handler.to_tensor(cyclical_mask_) else: cyclical_mask = torch.zeros( n_params, dtype=torch.bool, device=self.device_handler.device ) if any(cyclical_mask): self._prior_data[self._nuisance_filter].lower_bounds[cyclical_mask] = ( -2 * torch.pi ) self._prior_data[self._nuisance_filter].upper_bounds[cyclical_mask] = ( 2 * torch.pi ) self._priors.append(self._get_cyclical_map(cyclical_mask)) # ── Flipped-uniform mask ─────────────────────────────────────────── if flipped_parameters: flipped_mask_ = [ any(fnmatch.fnmatch(p, f) for f in flipped_parameters) for p in self.prior_data.parameter_names ] flipped_mask = self.device_handler.to_tensor(flipped_mask_).bool() # Flipped params must not also be cyclical flipped_mask = flipped_mask & ~cyclical_mask else: flipped_mask = torch.zeros( n_params, dtype=torch.bool, device=self.device_handler.device ) # Store for use in check_bounds — flipped params are valid in EITHER # region so the standard lower/upper bounds check would incorrectly # reject samples drawn from the negative region. self._flipped_mask = flipped_mask if any(flipped_mask): self._priors.extend(self._get_flipped_maps(flipped_mask)) # ── Flat mask ────────────────────────────────────────────────────── # Guard against None so the tensor conversion doesn't crash. flat_msk = flat_msk if flat_msk is not None else [False] * n_params flat_mask = ( self.device_handler.to_tensor(flat_msk).bool() & ~cyclical_mask & ~flipped_mask ) if any(flat_mask): self._priors.append(self._get_flat_map(flat_mask)) # ── Gaussian mask ────────────────────────────────────────────────── gaussian_mask = ~cyclical_mask & ~flipped_mask & ~flat_mask if any(gaussian_mask): self._priors.append(self._get_gaussian_map(gaussian_mask)) super().__init__( batch_shape=torch.Size(), event_shape=torch.Size([n_params]), validate_args=False, ) # ── Private helpers ──────────────────────────────────────────────────────── def _build_nuisance_filter( self, nuisance_patterns: list[str] | None ) -> torch.Tensor: """ Build a boolean keep-mask from *nuisance_patterns*. :param nuisance_patterns: fnmatch patterns, or ``None`` to keep all. :returns: Boolean tensor of shape ``(n_all_params,)``. """ if nuisance_patterns is None: n_pars = len(self._prior_data.parameter_names) return torch.ones( n_pars, dtype=torch.bool, device=self.device_handler.device ) keep = [ not any(fnmatch.fnmatch(p, n) for n in nuisance_patterns) for p in self._prior_data.parameter_names ] return self.device_handler.to_tensor(keep) # ── Private distribution builders ────────────────────────────────────────── def _get_cyclical_map(self, cyclical_mask: torch.Tensor) -> MaskDistributionMap: cyclical_data = self.prior_data[cyclical_mask] cyclical_dist = CyclicalDistribution(cyclical_data.nominals) return MaskDistributionMap(cyclical_mask, cyclical_dist) def _get_flipped_maps( self, flipped_mask: torch.Tensor ) -> list[MaskDistributionMap]: """ Build one :class:`MaskDistributionMap` per flipped parameter. Each flipped parameter may have *different* ``lower``/``upper`` bounds, so a separate :class:`~.FlippedUniformDistribution` is created for each one. The individual single-parameter masks are disjoint and together cover exactly the bits set in *flipped_mask*. :param flipped_mask: Boolean tensor of shape ``(n_params,)`` with ``True`` for every flipped parameter. :returns: List of :class:`MaskDistributionMap` objects, one per flipped parameter. :raises ValueError: If any matched parameter has bounds that violate ``0 < lower < upper``. """ maps: list[MaskDistributionMap] = [] flipped_indices = flipped_mask.nonzero(as_tuple=True)[0] for idx in flipped_indices: # Build a single-parameter mask single_mask = torch.zeros( len(flipped_mask), dtype=torch.bool, device=self.device_handler.device ) single_mask[idx] = True lower_val = float(self.prior_data.lower_bounds[idx].item()) upper_val = float(self.prior_data.upper_bounds[idx].item()) param_name = self.prior_data.parameter_names[idx.item()] if lower_val <= 0 or upper_val <= lower_val: raise ValueError( f"Flipped parameter '{param_name}' has bounds " f"[{lower_val}, {upper_val}] but FlippedUniformDistribution " f"requires 0 < lower < upper. " f"Set positive bounds in your simulator config." ) single_data = self.prior_data[single_mask] dist = FlippedUniformDistribution( nominals=single_data.nominals, lower=lower_val, upper=upper_val, ) maps.append(MaskDistributionMap(single_mask, dist)) return maps def _get_flat_map(self, flat_mask: torch.Tensor) -> MaskDistributionMap: flat_data = self.prior_data[flat_mask] flat_dist = Uniform(flat_data.lower_bounds, flat_data.upper_bounds) return MaskDistributionMap(flat_mask, flat_dist) def _get_gaussian_map(self, gaussian_mask: torch.Tensor) -> MaskDistributionMap: gaussian_data = self.prior_data[gaussian_mask] dist = TruncatedGaussianDistribution( mean=gaussian_data.nominals, covariance=gaussian_data.covariance_matrix, lower_bounds=gaussian_data.lower_bounds, upper_bounds=gaussian_data.upper_bounds, ) return MaskDistributionMap(gaussian_mask, dist) # ── Properties ───────────────────────────────────────────────────────────── @property def prior_data(self) -> PriorData: """Active :class:`PriorData` after applying the nuisance filter.""" return self._prior_data[self._nuisance_filter] @property def mean(self) -> torch.Tensor: """Prior mean — the nominal parameter values.""" return self.device_handler.to_tensor(self.prior_data.nominals) @property def n_params(self) -> int: """Number of active (non-nuisance) parameters.""" return len(self.prior_data.nominals) @property def variance(self) -> torch.Tensor: """ Per-parameter prior variance, assembled from all sub-distributions. The mask for each sub-distribution is sized against the filtered parameter set (same as the tensor being filled), so shapes are always consistent. """ variance = torch.zeros(self.n_params, device=self.device_handler.device) for mask_map in self._priors: variance[mask_map.mask] = mask_map.distribution.variance return variance @property def support(self): """Independent interval support defined by the parameter bounds.""" return constraints.independent( constraints.interval( self.prior_data.lower_bounds, self.prior_data.upper_bounds ), 1, ) # ── Distribution interface ─────────────────────────────────────────────────
[docs] def sample(self, sample_shape=torch.Size([])) -> torch.Tensor: """ Draw samples from the composite prior. Each component delegates to its own distribution's ``sample()`` method. Gaussian parameters use :class:`~mach3sbitools.simulator.priors.truncated_gaussian_distribution.TruncatedGaussianDistribution` which draws exact, rejection-free samples via the inverse-CDF method. :param sample_shape: Batch shape. Pass ``torch.Size([n])`` for *n* independent samples. :returns: Tensor of shape ``(*sample_shape, n_params)``. """ sample_shape = torch.Size(sample_shape) samples = torch.empty( (*sample_shape, self.n_params), dtype=torch.double, device=self.device_handler.device, ) for mask_map in self._priors: samples[..., mask_map.mask] = mask_map.distribution.sample(sample_shape).to( torch.double ) return samples
[docs] def rsample(self, sample_shape=torch.Size([])) -> torch.Tensor: """ Draw reparameterised samples (where supported by sub-distributions). :param sample_shape: Batch shape. :returns: Tensor of shape ``(*sample_shape, n_params)``. """ sample_shape = torch.Size(sample_shape) samples = torch.empty(*sample_shape, self.n_params, dtype=torch.double) for mask_map in self._priors: samples[..., mask_map.mask] = mask_map.distribution.rsample( sample_shape ).to(torch.double) return samples
[docs] def check_bounds(self, params: torch.Tensor) -> torch.Tensor: """ Check whether each sample in *params* lies within the prior support. For most parameters this is the standard interval ``[lower_bound, upper_bound]``. For flipped parameters the support is the union of two symmetric regions ``[lower, upper] + [-upper, -lower]``, so a sample is valid when its *absolute value* lies in ``[lower, upper]``. :param params: Tensor of shape ``(n_samples, n_params)``. :returns: Boolean tensor of shape ``(n_samples,)``. """ lb = self.prior_data.lower_bounds.to(params.device) ub = self.prior_data.upper_bounds.to(params.device) # Standard interval check for all parameters in_bounds = (params >= lb) & (params <= ub) # Override flipped parameters: valid iff |value| in [lower, upper] if self._flipped_mask.any(): abs_params = params.abs() in_flipped = (abs_params >= lb) & (abs_params <= ub) in_bounds[:, self._flipped_mask] = in_flipped[:, self._flipped_mask] return self.device_handler.to_tensor(in_bounds.all(dim=-1))
# ── Persistence ────────────────────────────────────────────────────────────
[docs] def save(self, output_path: Path) -> None: """ Pickle the prior to *output_path*. :param output_path: Destination file path. Parent directories are created automatically. """ output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("wb") as f: pickle.dump(self, f)
[docs] def to(self, device: torch.device | str) -> "Prior": """ Move all tensors to *device* in-place. :param device: Target PyTorch device. :returns: ``self``, for chaining. """ self._prior_data = self._prior_data.to(device) for i, mask_map in enumerate(self._priors): self._priors[i] = mask_map.to(device) self._nuisance_filter = self._nuisance_filter.to(device) self._flipped_mask = self._flipped_mask.to(device) return self
# ── Module-level helpers ─────────────────────────────────────────────────────── def _check_boundary( nominal: torch.Tensor, error: torch.Tensor, lower_bound: torch.Tensor, upper_bound: torch.Tensor, parameter_names: np.ndarray, ) -> None: """ Warn if any parameter has bounds further than 10σ from its nominal. :param nominal: Nominal values, shape ``(n_params,)``. :param error: 1σ errors, shape ``(n_params,)``. :param lower_bound: Lower bounds, shape ``(n_params,)``. :param upper_bound: Upper bounds, shape ``(n_params,)``. :param parameter_names: Parameter name strings, shape ``(n_params,)``. """ warning_thresh = 10 warning_ub = nominal + error * warning_thresh warning_lb = nominal - error * warning_thresh mask = (lower_bound < warning_lb) | (upper_bound > warning_ub) if not any(mask): return logger.warning( f"The following parameters have boundaries > {warning_thresh:d}σ from their prior nominal" ) for param_info in zip( parameter_names[mask.cpu().numpy()], nominal[mask], error[mask], lower_bound[mask], upper_bound[mask], ): logger.warning( " '{:s}' | Nominal: {:4f}, Error {:4f} | Lower Bnd {:4f}, Upper Bnd {:4f}".format( *param_info ) )
[docs] def create_prior( simulator_instance: SimulatorProtocol, nuisance_pars: list[str] | None = None, cyclical_pars: list[str] | None = None, flipped_pars: list[str] | None = None, ) -> Prior: """ Convenience function to build a :class:`Prior` from a simulator instance. Reads all parameter metadata from *simulator_instance* and constructs the appropriate composite prior. Warns about parameters with unusually wide bounds (>10σ). .. code-block:: console prior = create_prior( simulator, nuisance_pars=["syst_*"], cyclical_pars=["angle"], flipped_pars=["delta_cp"], ) prior.save(Path("prior.pkl")) :param simulator_instance: An object implementing :class:`SimulatorProtocol`. :param nuisance_pars: fnmatch patterns for parameters to exclude from the prior (e.g. ``['syst_*']``). :param cyclical_pars: fnmatch patterns for parameters that should use a cyclical sinusoidal prior over ``[-2π, 2π]``. :param flipped_pars: fnmatch patterns for parameters that should use a bimodal uniform prior over ``[lower, upper] + [-upper, -lower]``, where ``lower``/``upper`` come from the simulator's parameter bounds. :returns: Configured :class:`Prior` ready for use with ``sbi``. """ logger.info("Creating Prior") dh = TorchDeviceHandler() nominals = dh.to_tensor(simulator_instance.get_parameter_nominals()) errors = dh.to_tensor(simulator_instance.get_parameter_errors()) lower_arr, upper_arr = simulator_instance.get_parameter_bounds() lower = dh.to_tensor(lower_arr) upper = dh.to_tensor(upper_arr) names = np.array(simulator_instance.get_parameter_names(), dtype=str) _check_boundary(nominals, errors, lower, upper, names) covariance = dh.to_tensor(simulator_instance.get_covariance_matrix()) flat_pars = [simulator_instance.get_is_flat(i) for i in range(len(names))] data = PriorData( parameter_names=names, nominals=nominals, lower_bounds=lower, upper_bounds=upper, covariance_matrix=covariance, ) prior = Prior( prior_data=data, flat_msk=flat_pars, nuisance_parameters=nuisance_pars, cyclical_parameters=cyclical_pars, flipped_parameters=flipped_pars, ) get_logger().info("Prior constructed") return prior
[docs] def load_prior(prior_path: Path, device=torch.device("cpu")) -> Prior: """ Load a pickled :class:`Prior` from disk. .. code-block:: console prior = load_prior(Path("prior.pkl")) :param prior_path: Path to a ``.pkl`` file produced by :meth:`Prior.save`. :param device: Device to move the prior to after loading. Defaults to CPU. :returns: The loaded :class:`Prior`. :raises PriorNotFound: If *prior_path* does not exist or does not contain a valid :class:`Prior`. """ if not isinstance(prior_path, Path): prior_path = Path(prior_path) if not prior_path.is_file() or not prior_path.exists(): raise PriorNotFound("Could not find prior %s", prior_path) with prior_path.open("rb") as f: prior = pickle.load(f) if not isinstance(prior, Prior): raise PriorNotFound( "No valid prior in %s. Instead found %s", prior_path, type(prior) ) return prior.to(device)