Source code for MuyGPyS.gp.hyperparameter.scale

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Variance scale hyperparameter
"""

from typing import Callable, Tuple, Type

import MuyGPyS._src.math as mm
from MuyGPyS._src.util import _fullname
from MuyGPyS._src.optimize.scale import _analytic_scale_optim


[docs]class ScaleFn: """ A :math:`\\sigma^2` covariance scale parameter base functor. :math:`\\sigma^2` is a scaling parameter that one multiplies with the found diagonal variances of a :class:`MuyGPyS.gp.muygps.MuyGPS` or :class:`MuyGPyS.gp.muygps.MultivariateMuyGPS` regression in order to obtain the predicted posterior variance. Trained values assume a number of dimensions equal to the number of response dimensions, and correspond to scalar scaling parameters along the corresponding dimensions. Args: response_count: The integer number of response dimensions. """ def __init__( self, response_count: int = 1, _backend_ones: Callable = mm.ones, _backend_ndarray: Type = mm.ndarray, _backend_ftype: Type = mm.ftype, _backend_farray: Callable = mm.farray, _backend_outer: Callable = mm.outer, **kwargs, ): self.val = _backend_ones(response_count) self._trained = False self._backend_ndarray = _backend_ndarray self._backend_ftype = _backend_ftype self._backend_farray = _backend_farray self._backend_outer = _backend_outer def __str__(self, **kwargs): return f"{type(self).__name__}({self.val})" def _set(self, val: mm.ndarray) -> None: """ Value setter. Args: val: The new value of the hyperparameter. """ if not isinstance(val, self._backend_ndarray): raise ValueError( f"Expected {_fullname(self._backend_ndarray)} for variance " f"scale value update, not {_fullname(val.__class__)}" ) if self.val.shape != val.shape: raise ValueError( "Bad attempt to assign variance scale of shape " f"{self.val.shape} a value of shape {val.shape}" ) if val.dtype != self._backend_ftype: val = self._backend_farray(val) self.val = val self._trained = True
[docs] def __call__(self) -> mm.ndarray: """ Value accessor. Returns: The current value of the hyperparameter. """ return self.val
@property def trained(self) -> bool: """ Report whether the value has been set. Returns: `True` if trained, `False` otherwise. """ return self._trained @property def shape(self) -> Tuple[int, ...]: """ Report the shape of the scale parameter. Returns: The shape of the scale parameter. """ return self.val.shape
[docs] def scale_fn(self, fn: Callable) -> Callable: """ Modify a function to outer product its output with `scale`. Args: fn: A function. Returns: A function that returns the outer product of the output of `fn` """ def scaled_fn(*args, scale=self(), **kwargs): return self._backend_outer(fn(*args, **kwargs), scale) return scaled_fn
def get_opt_fn(self, muygps) -> Callable: def noop_scale_opt_fn(K, nn_targets, *args, **kwargs): return muygps.scale() return noop_scale_opt_fn
[docs]class FixedScale(ScaleFn): """ A :math:`\\sigma^2` covariance scale parameter. A `Scale` parameter with a null optimization method. This parameter is therefore insensitive to optimization. Args: response_count: The integer number of response dimensions. """
[docs] def get_opt_fn(self, muygps) -> Callable: """ Return a function that optimizes the value of the variance scale. Args: muygps: A model to be ignored. Returns: A function that always returns the value of this scale parameter. """ def noop_scale_opt_fn(K, nn_targets, *args, **kwargs): return muygps.scale() return noop_scale_opt_fn
[docs]class AnalyticScale(ScaleFn): """ An optimizable :math:`\\sigma^2` covariance scale parameter. Identical to :class:`~MuyGPyS.gp.scale.FixedScale`, save that its `get_opt_fn` method performs an analytic optimization. Args: response_count: The integer number of response dimensions. """
[docs] def get_opt_fn(self, muygps) -> Callable: """ Get a function to optimize the value of the :math:`\\sigma^2` scale parameter for each response dimension. We approximate :math:`\\sigma^2` by way of averaging over the analytic solution from each local kernel. .. math:: \\sigma^2 = \\frac{1}{bk} * \\sum_{i \\in B} Y_{nn_i}^T K_{nn_i}^{-1} Y_{nn_i} Here :math:`Y_{nn_i}` and :math:`K_{nn_i}` are the target and kernel matrices with respect to the nearest neighbor set in scope, where :math:`k` is the number of nearest neighbors and :math:`b = |B|` is the number of batch elements considered. Args: muygps: The model to used to create and perturb the kernel. Returns: A function with signature `(K, nn_targets, *args, **kwargs) -> mm.ndarray` that perturbs the `(batch_count, nn_count, nn_count)` tensor `K` with `muygps`'s noise model before solving it against the `(batch_count, nn_count, response_count)` tensor `nn_targets`. """ def analytic_scale_opt_fn(K, nn_targets, *args, **kwargs): return _analytic_scale_optim(muygps.noise.perturb(K), nn_targets) return analytic_scale_opt_fn