Source code for qp.parameterizations.spline.spline

"""This module implements a PDT distribution sub-class using splines"""

from __future__ import annotations
import numpy as np

from scipy.interpolate import splev, splint, splrep, interp1d
from scipy.special import errstate  # pylint: disable=no-name-in-module
from scipy.stats import rv_continuous
from typing import Mapping, Optional
from numpy.typing import ArrayLike

from .spline_utils import (
    extract_samples,
    spline_extract_xy_vals,
    build_kdes,
    evaluate_kdes,
    normalize_spline,
    build_splines,
)
from ...core.factory import add_class
from ...core.ensemble import Ensemble
from ..base import Pdf_rows_gen
from ...plotting import get_axes_and_xlims, plot_pdf_on_axes
from ...utils.array import reshape_to_pdf_size


[docs] class spline_gen(Pdf_rows_gen): """Spline based distribution Notes ----- This implements PDFs using a set of splines The relevant data members are: - `splx`: (npdf, n) spline-knot x-values - `sply`: (npdf, n) spline-knot y-values - `spln`: (npdf) spline-knot order parameters The pdf() for the ith pdf will return the result of `scipy.interpolate.splev(x, splx[i], sply[i], spln[i))` The cdf() for the ith pdf will return the result of `scipy.interpolate.splint(x, splx[i], sply[i], spln[i))` The ppf() will use the default scipy implementation, which inverts the cdf() as evaluated on an adaptive grid. """ # pylint: disable=protected-access name = "spline" version = 0 _support_mask = rv_continuous._support_mask def __init__( self, *args, **kwargs, ) -> None: """ Create a new distribution using the given histogram Parameters -------- splx : ArrayLike The x-values of the spline knots sply : ArrayLike The y-values of the spline knots spln : ArrayLike, optional The order of the spline knots, by default None Notes ----- Either (splx, sply and spln) must be provided. """ splx = kwargs.pop("splx", None) sply = kwargs.pop("sply", None) spln = kwargs.pop("spln", None) if splx is None: # pragma: no cover raise ValueError("splx must be provided") if splx.shape != sply.shape: # pragma: no cover raise ValueError( "Shape of xvals (%s) != shape of yvals (%s)" % (splx.shape, sply.shape) ) # kwargs['a'] = self.a = np.min(splx) # kwargs['b'] = self.b = np.max(splx) self._xmin = np.min(splx) self._xmax = np.max(splx) # kwargs["shape"] = splx.shape[:-1] self._splx = reshape_to_pdf_size(splx, -1) self._sply = reshape_to_pdf_size(sply, -1) self._spln = reshape_to_pdf_size(spln, -1) kwargs["shape"] = self._splx.shape super().__init__(*args, **kwargs) self._addobjdata("splx", self._splx) self._addobjdata("sply", self._sply) self._addobjdata("spln", self._spln)
[docs] @staticmethod def build_normed_splines(xvals, yvals, **kwargs): """ Build a set of normalized splines using the x and y values Parameters ---------- xvals : ArrayLike The x-values used to do the interpolation yvals : ArrayLike The y-values used to do the interpolation Returns ------- splx : ArrayLike The x-values of the spline knots sply : ArrayLike The y-values of the spline knots spln : ArrayLike The order of the spline knots """ if xvals.shape != yvals.shape: # pragma: no cover raise ValueError( "Shape of xvals (%s) != shape of yvals (%s)" % (xvals.shape, yvals.shape) ) xmin = np.min(xvals) xmax = np.max(xvals) # make sure xvals and yvals are 2d if np.ndim(xvals) == 1: xvals = np.expand_dims(xvals, axis=0) if np.ndim(yvals) == 1: yvals = np.expand_dims(yvals, axis=0) yvals = normalize_spline(xvals, yvals, limits=(xmin, xmax), **kwargs) return build_splines(xvals, yvals)
[docs] @classmethod def create_from_xy_vals(cls, xvals, yvals, **kwargs): """ Create a new distribution using the given x and y values Parameters ---------- xvals : ArrayLike The x-values used to do the interpolation yvals : ArrayLike The y-values used to do the interpolation Returns ------- pdf_obj : spline_gen The requested PDF """ splx, sply, spln = spline_gen.build_normed_splines(xvals, yvals, **kwargs) gen_obj = cls(splx=splx, sply=sply, spln=spln) return gen_obj(**kwargs)
[docs] @classmethod def create_from_samples(cls, xvals, samples, **kwargs): """ Create a new distribution using the given x and y values Parameters ---------- xvals : ArrayLike The x-values used to do the interpolation samples : ArrayLike The sample values used to build the KDE Returns ------- pdf_obj : spline_gen The requested PDF """ kdes = build_kdes(samples) kwargs.pop("yvals", None) yvals = evaluate_kdes(xvals, kdes) xvals_expand = (np.expand_dims(xvals, -1) * np.ones(samples.shape[0])).T return cls.create_from_xy_vals(xvals_expand, yvals, **kwargs)
@property def splx(self) -> np.ndarray: """Return x-values of the spline knots""" return self._splx @property def sply(self) -> np.ndarray: """Return y-values of the spline knots""" return self._sply @property def spln(self) -> np.ndarray: """Return order of the spline knots""" return self._spln def _pdf(self, x, row): # pylint: disable=arguments-differ def pdf_row(xv, irow): return splev( xv, (self._splx[irow], self._sply[irow], self._spln[irow].item()) ) with errstate(all="ignore"): vv = np.vectorize(pdf_row) return vv(x, row).ravel() def _cdf(self, x, row): # pylint: disable=arguments-differ def cdf_row(xv, irow): return splint( self._xmin, xv, (self._splx[irow], self._sply[irow], self._spln[irow].item()), ) with errstate(all="ignore"): vv = np.vectorize(cdf_row) return vv(x, row).ravel()
[docs] def ppf(self, quants): # FIXME: remove this function once the issue with spline ppf is fixed raise NotImplementedError( "This function is buggy and currently not working properly, and will be restored once it's been fixed." )
def _ppf(self, quants, row): # pylint: disable=arguments-differ # get the cdfs on a grid n_pts = 1001 grid = np.linspace(self._xmin, self._xmax, n_pts) unique_rows = np.unique(row) cdf_vals = self._cdf(np.expand_dims(grid, -1), unique_rows).reshape( len(unique_rows), n_pts ) def ppf_row(quantsv, irow): cdf_row = cdf_vals[irow] # Filter out the bits where it fluctuations down arg_sorted = np.argsort(cdf_row) sorted_vals = cdf_row[arg_sorted] sorted_grid = grid[arg_sorted] mask = np.zeros((len(sorted_grid)), dtype=bool) mask[1:] = sorted_grid[1:] > sorted_grid[0:-1] mask[1:] &= sorted_vals[1:] > sorted_vals[0:-1] sorted_masked_vals = sorted_vals[mask] sorted_masked_grid = sorted_grid[mask] sorted_masked_vals /= sorted_masked_vals[-1] # Build an interpolater, but reverse x and y to get the inverse function interp = interp1d( np.squeeze(sorted_vals[mask]), sorted_grid[mask], bounds_error=False, fill_value=(sorted_grid[0], sorted_grid[-1]), ) return interp(quantsv) with errstate(all="ignore"): vv = np.vectorize(ppf_row) ret_vals = vv(quants, row).ravel() return ret_vals def _updated_ctor_param(self): """ Set the bins as additional constructor argument """ dct = super()._updated_ctor_param() dct["splx"] = self._splx dct["sply"] = self._sply dct["spln"] = self._spln return dct
[docs] @classmethod def get_allocation_kwds( cls, npdf: int, **kwargs ) -> dict[str, tuple[tuple[int, int], str]]: """ Return the keywords necessary to create an 'empty' hdf5 file with npdf entries for iterative file writeout. We only need to allocate the objdata columns, as the metadata can be written when we finalize the file. Parameters ---------- npdf : int number of *total* PDFs that will be written out kwargs dictionary of kwargs needed to create the ensemble Returns ------- dict[str, tuple[tuple[int, int], str]] """ if "splx" not in kwargs: # pragma: no cover raise ValueError("required argument splx not included in kwargs") shape = np.shape(kwargs["splx"]) return dict(splx=(shape, "f4"), sply=(shape, "f4"), spln=((shape[0], 1), "i4"))
[docs] @classmethod def plot_native(cls, pdf, **kwargs): """Plot the PDF in a way that is particular to this type of distibution For a spline this shows the spline knots """ axes, _, kw = get_axes_and_xlims(**kwargs) xvals = pdf.dist.splx[pdf.kwds["row"]] return plot_pdf_on_axes(axes, pdf, xvals, **kw)
[docs] @classmethod def add_mappings(cls) -> None: """ Add this classes mappings to the conversion dictionary """ cls._add_creation_method(cls.create, None) cls._add_creation_method(cls.create_from_xy_vals, "xy") cls._add_creation_method(cls.create_from_samples, "samples") cls._add_extraction_method(spline_extract_xy_vals, "xy") cls._add_extraction_method(extract_samples, "samples")
[docs] @classmethod def create_ensemble( self, splx: ArrayLike, sply: ArrayLike, spln: Optional[ArrayLike] = None, ancil: Optional[Mapping] = None, method: Optional[str] = None, ) -> Ensemble: """Creates an Ensemble of distributions parameterized as via a set of splines. Parameters ---------- splx : ArrayLike The x-values of the spline knots sply : ArrayLike The y-values of the spline knots spln : ArrayLike, optional The order of the spline knots, by default None ancil : Optional[Mapping], optional A dictionary of metadata for the distributions, where any arrays have the same length as the number of distributions, by default None method : Optional[str], optional The string of the creation method to use, by default None. Returns ------- Ensemble An Ensemble object containing all of the given distributions. """ data = {"splx": splx, "sply": sply, "spln": spln} return Ensemble(self, data, ancil, method)
spline = spline_gen spline_from_xy = spline_gen.create_from_xy_vals spline_from_samples = spline_gen.create_from_samples add_class(spline_gen)