Source code for MuyGPyS.gp.kernels.kernel_fn

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Hyperparameters and kernel functors

Defines kernel functors (inheriting
:class:`~MuyGPyS.gp.kernels.kernel_fn.KernelFn`) that transform crosswise
difference tensors into cross-covariance matrices and pairwise difference
matrices into covariance or kernel tensors.

See the following example to initialize an :class:`MuyGPyS.gp.kernels.Matern`
object. Other kernel functors are similar, but require different
hyperparameters.

Example:
    >>> from MuyGPyS.gp.kernels import Matern
    >>> kern = Matern(
    ...     smoothness=Parameter("log_sample", (0.1, 2.5)),
    ...     deformation=Isotropy(
    ...         metric=l2,
    ...         length_scale=Parameter(1.0),
    ...     ),
    ... )

One uses a previously computed `pairwise_diffs` tensor (see
:func:`MuyGPyS.gp.tensor.pairwise_tensor`) to compute a kernel tensor whose
second two dimensions contain square kernel matrices. Similarly, one uses a
previously computed `crosswise_diffs` matrix (see
:func:`MuyGPyS.gp.tensor.crosswise_diffs`) to compute a cross-covariance
matrix. See the following example, which assumes that you have already
constructed the difference `numpy.ndarrays` and the kernel `kern` as shown
above.

Example:
    >>> K = kern(pairwise_diffs)
    >>> Kcross = kern(crosswise_diffs)
"""

from typing import Callable, Dict, List, Tuple

import MuyGPyS._src.math as mm
from MuyGPyS.gp.deformation import DeformationFn


[docs]class KernelFn: """ A kernel functor. Base class for kernel functors that include a hyperparameter Dict and a call mechanism. Args: kwargs: Ignored (by this base class) keyword arguments. """ def __init__( self, deformation: DeformationFn, ): """ Initialize dict holding hyperparameters. """ self._hyperparameters: Dict = dict() self.deformation = deformation self._make_base() def _make_base(self): self.deformation.populate_length_scale(self._hyperparameters) def _make(self): raise NotImplementedError("_make is not implemented for base KernelFn")
[docs] def set_params(self, **kwargs) -> None: """ Reset hyperparameters using hyperparameter dict(s). Args: kwargs: Hyperparameter kwargs. """ for name in kwargs: self._hyperparameters[name]._set(kwargs[name])
[docs] def __call__(self, diffs: mm.ndarray, **kwargs) -> mm.ndarray: raise NotImplementedError( "__call__ is not implemented for base KernelFn" )
def get_opt_fn(self) -> Callable: raise NotImplementedError( "get_opt_fn is not implemented for base KernelFn" )
[docs] def get_opt_params( self, ) -> Tuple[List[str], List[float], List[Tuple[float, float]]]: """ Report lists of unfixed hyperparameter names, values, and bounds. Returns ------- names: A list of unfixed hyperparameter names. params: A list of unfixed hyperparameter values. bounds: A list of unfixed hyperparameter bound tuples. """ return self.deformation.get_opt_params()
def __str__(self) -> str: """ Print state of hyperparameter dict. Intended only for testing purposes. """ ret = "" for p in self._hyperparameters: param = self._hyperparameters[p] ret += f"{p} : {param()} - {param.get_bounds()}\n" return ret[:-1]