Source code for MuyGPyS.gp.deformation.isotropy

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


from typing import Callable, Dict, List, Optional, Tuple

import MuyGPyS._src.math as mm
from MuyGPyS._src.util import auto_str
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn
from MuyGPyS.gp.hyperparameter import ScalarParam


[docs]@auto_str class Isotropy(DeformationFn): """ An isotropic deformation model. Isotropy defines a scaled elementwise distance function :math:`d_ell(\\cdot, \\cdot)`, and is paramterized by a scalar :math:`\\ell>0` length scale hyperparameter. .. math:: d_\\ell(\\mathbf{x}, \\mathbf{y}) = \\sum_{i=0}^d \\frac{d(\\mathbf{x}_i, \\mathbf{y}_i)}{\\ell} Args: metric: A callable metric function that takes a tensor of shape `(..., feature_count)` whose last dimension lists the elementwise differences between a pair of feature vectors and returns a tensor of shape `(...)`, having collapsed the last dimension into a scalar difference. length_scale: Some scalar nonnegative hyperparameter object. """ def __init__( self, metric: Callable, length_scale: ScalarParam, ): if not isinstance(length_scale, ScalarParam): raise ValueError( "Expected ScalarParam type for length_scale, not " f"{type(length_scale)}" ) self.length_scale = length_scale self._dist_fn = metric
[docs] def __call__( self, diffs: mm.ndarray, length_scale: Optional[float] = None, **kwargs ) -> mm.ndarray: """ Apply isotropic deformation to an elementwise difference tensor. This function is not intended to be invoked directly by a user. It is instead functionally incorporated into some :class:`MuyGPyS.gp.kernels.KernelFn` in its constructor. Args: diffs: A tensor of pairwise differences of shape `(..., feature_count)`. length_scale: A floating point length scale. Returns: A crosswise distance matrix of shape `(data_count, nn_count)` or a pairwise distance tensor of shape `(data_count, nn_count, nn_count)` whose last two dimensions are pairwise distance matrices. """ if length_scale is None: length_scale = self.length_scale() return self._dist_fn(diffs / length_scale)
[docs] def get_opt_params( self, ) -> Tuple[List[str], List[float], List[Tuple[float, float]]]: """ Report lists of unfixed hyperparameter names, values, and bounds. Returns ------- names: A list of unfixed hyperparameter names. params: A list of unfixed hyperparameter values. bounds: A list of unfixed hyperparameter bound tuples. """ names: List[str] = [] params: List[float] = [] bounds: List[Tuple[float, float]] = [] self.length_scale.append_lists("length_scale", names, params, bounds) return names, params, bounds
[docs] def populate_length_scale(self, hyperparameters: Dict) -> None: """ Populates the hyperparameter dictionary of a KernelFn object with `self.length_scale` of the Isotropy object. Args: hyperparameters: A dict containing the hyperparameters of a KernelFn object. """ hyperparameters["length_scale"] = self.length_scale
[docs] def embed_fn(self, fn: Callable) -> Callable: """ Augments a function to automatically apply the deformation to a difference tensor. Args: fn: A Callable with signature `(diffs, *args, **kwargs) -> mm.ndarray` taking a difference tensor `diffs` with shape `(..., feature_count)`. Returns: A new Callable that applies the deformation to `diffs`, removing the last tensor dimension by collapsing the feature-wise differences into scalar distances. Also adds a `length_scale` kwarg, making the function drivable by keyword optimization. """ def embedded_fn(diffs, *args, length_scale=None, **kwargs): return fn(self(diffs, length_scale=length_scale), *args, **kwargs) return embedded_fn