Source code for qp.mixmod_pdf

"""This module implements a PDT distribution sub-class using a Gaussian mixture model
"""

import numpy as np
from scipy import stats as sps
from scipy.stats import rv_continuous

from qp.conversion_funcs import extract_mixmod_fit_samples
from qp.factory import add_class
from qp.pdf_gen import Pdf_rows_gen
from qp.test_data import MEAN_MIXMOD, STD_MIXMOD, TEST_XVALS, WEIGHT_MIXMOD
from qp.utils import get_eval_case, interpolate_multi_x_y, reshape_to_pdf_size


[docs]class mixmod_gen(Pdf_rows_gen): """Mixture model based distribution Notes ----- This implements a PDF using a Gaussian Mixture model The relevant data members are: means: (npdf, ncomp) means of the Gaussians stds: (npdf, ncomp) standard deviations of the Gaussians weights: (npdf, ncomp) weights for the Gaussians The pdf() and cdf() are exact, and are computed as a weighted sum of the pdf() and cdf() of the component Gaussians. The ppf() is computed by computing the cdf() values on a fixed grid and interpolating the inverse function. """ # pylint: disable=protected-access name = "mixmod" version = 0 _support_mask = rv_continuous._support_mask def __init__(self, means, stds, weights, *args, **kwargs): """ Create a new distribution using the given histogram Parameters ---------- means : array_like The means of the Gaussians stds: array_like The standard deviations of the Gaussians weights : array_like The weights to attach to the Gaussians. Weights should sum up to one. If not, the weights are interpreted as relative weights. """ self._scipy_version_warning() self._means = reshape_to_pdf_size(means, -1) self._stds = reshape_to_pdf_size(stds, -1) self._weights = reshape_to_pdf_size(weights, -1) kwargs["shape"] = means.shape[:-1] self._ncomps = means.shape[-1] super().__init__(*args, **kwargs) if np.any(self._weights < 0): raise ValueError("All weights need to be larger than zero") self._weights = self._weights / self._weights.sum(axis=1)[:, None] self._addobjdata("weights", self._weights) self._addobjdata("stds", self._stds) self._addobjdata("means", self._means) def _scipy_version_warning(self): import scipy # pylint: disable=import-outside-toplevel scipy_version = scipy.__version__ vtuple = scipy_version.split(".") if int(vtuple[0]) > 1 or int(vtuple[1]) > 7: return raise DeprecationWarning( f"Mixmod_gen will not work correctly with scipy version < 1.8.0, you have {scipy_version}" ) # pragma: no cover @property def weights(self): """Return weights to attach to the Gaussians""" return self._weights @property def means(self): """Return means of the Gaussians""" return self._means @property def stds(self): """Return standard deviations of the Gaussians""" return self._stds def _pdf(self, x, row): # pylint: disable=arguments-differ if np.ndim(x) > 1: # pragma: no cover x = np.expand_dims(x, -2) return ( self.weights[row].swapaxes(-2, -1) * sps.norm( loc=self._means[row].swapaxes(-2, -1), scale=self._stds[row].swapaxes(-2, -1), ).pdf(x) ).sum(axis=0) def _cdf(self, x, row): # pylint: disable=arguments-differ if np.ndim(x) > 1: # pragma: no cover x = np.expand_dims(x, -2) return ( self.weights[row].swapaxes(-2, -1) * sps.norm( loc=self._means[row].swapaxes(-2, -1), scale=self._stds[row].swapaxes(-2, -1), ).cdf(x) ).sum(axis=0) def _ppf(self, x, row): # pylint: disable=arguments-differ min_val = np.min(self._means - 6 * self._stds) max_val = np.max(self._means + 6 * self._stds) grid = np.linspace(min_val, max_val, 201) case_idx, _, rr = get_eval_case(x, row) if case_idx == 1: cdf_vals = self.cdf(grid, rr) elif case_idx == 3: cdf_vals = self.cdf(grid, np.expand_dims(rr, -1)) else: # pragma: no cover raise ValueError( f"Opps, we handle this kind of input to mixmod._ppf {case_idx}" ) return interpolate_multi_x_y( x, row, cdf_vals, grid, bounds_error=False, fill_value=(min_val, max_val) ).ravel() def _updated_ctor_param(self): """ Set the bins as additional constructor argument """ dct = super()._updated_ctor_param() dct["means"] = self._means dct["stds"] = self._stds dct["weights"] = self._weights return dct
[docs] @classmethod def get_allocation_kwds(cls, npdf, **kwargs): """Return the keywords necessary to create an 'empty' hdf5 file with npdf entries for iterative file writeout. We only need to allocate the objdata columns, as the metadata can be written when we finalize the file. Parameters ---------- npdf: int number of *total* PDFs that will be written out kwargs: dict dictionary of kwargs needed to create the ensemble """ if "means" not in kwargs: # pragma: no cover raise ValueError("required argument means not included in kwargs") ncomp = np.shape(kwargs["means"])[-1] return dict( means=((npdf, ncomp), "f4"), stds=((npdf, ncomp), "f4"), weights=((npdf, ncomp), "f4"), )
[docs] @classmethod def add_mappings(cls): """ Add this classes mappings to the conversion dictionary """ cls._add_creation_method(cls.create, None) cls._add_extraction_method(extract_mixmod_fit_samples, None)
[docs] @classmethod def make_test_data(cls): """Make data for unit tests""" cls.test_data = dict( mixmod=dict( gen_func=mixmod, ctor_data=dict( weights=WEIGHT_MIXMOD, means=MEAN_MIXMOD, stds=STD_MIXMOD ), convert_data={}, test_xvals=TEST_XVALS, atol_diff2=1.0, ) )
mixmod = mixmod_gen.create add_class(mixmod_gen)