Source code for MuyGPyS.gp.deformation.isotropy

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


from typing import Optional, Union

import MuyGPyS._src.math as mm
from MuyGPyS._src.mpi_utils import mpi_chunk
from MuyGPyS._src.util import auto_str
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn
from MuyGPyS.gp.deformation.metric import MetricFn
from MuyGPyS.gp.hyperparameter import ScalarParam, NamedParam
from MuyGPyS.gp.hyperparameter.experimental import (
    HierarchicalParam,
    NamedHierarchicalParam,
)


[docs]@auto_str class Isotropy(DeformationFn): """ An isotropic deformation model. Isotropy defines a scaled elementwise distance function :math:`d_ell(\\cdot, \\cdot)`, and is paramterized by a scalar :math:`\\ell>0` length scale hyperparameter. .. math:: d_\\ell(\\mathbf{x}, \\mathbf{y}) = \\sum_{i=0}^d \\frac{d(\\mathbf{x}_i, \\mathbf{y}_i)}{\\ell} Args: metric: A MetricFn object defining the behavior of the feature metric space. length_scale: Some scalar nonnegative hyperparameter object. """ def __init__( self, metric: MetricFn, length_scale: ScalarParam, ): # This is brittle and should be refactored if isinstance(length_scale, ScalarParam): self.length_scale = NamedParam("length_scale", length_scale) elif isinstance(length_scale, HierarchicalParam): self.length_scale = NamedHierarchicalParam( "length_scale", length_scale ) else: raise ValueError( "Expected ScalarParam type for length_scale, not " f"{type(length_scale)}" ) self.metric = metric
[docs] def __call__( self, dists: mm.ndarray, length_scale: Optional[Union[float, mm.ndarray]] = None, **kwargs, ) -> mm.ndarray: """ Apply isotropic deformation to an elementwise difference tensor. This function is not intended to be invoked directly by a user. It is instead functionally incorporated into some :class:`MuyGPyS.gp.kernels.KernelFn` in its constructor. Args: dists: A tensor of distances between sets of observables. length_scale: A floating point length scale. Returns: A scaled distance matrix of the same shape as shape `(data_count, nn_count)` or a pairwise distance tensor of shape `(data_count, nn_count, nn_count)` whose last two dimensions are pairwise distance matrices. """ if length_scale is None: length_scale = self.length_scale(**kwargs) # This is brittle and I hate it. I'm not sure where to put this logic. if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0: shape = [None] * dists.ndim shape[0] = slice(None) length_scale = length_scale[tuple(shape)] return self.metric.apply_length_scale(dists, length_scale)
[docs] @mpi_chunk(return_count=1) def pairwise_tensor( self, data: mm.ndarray, nn_indices: mm.ndarray, **kwargs, ) -> mm.ndarray: """ Compute a pairwise distance tensor among sets of nearest neighbors. Takes a full dataset of records of interest `data` and produces the pairwise distances between the elements indicated by each row of `nn_indices`. Args: data: The data matrix of shape `(batch_count, feature_count)` containing batch elements. nn_indices: An integral matrix of shape (batch_count, nn_count) listing the nearest neighbor indices for the batch of data points. Returns: A tensor of shape `(batch_count, nn_count, nn_count)` containing the `(nn_count, nn_count)`-shaped pairwise nearest neighbor distance tensors corresponding to each of the batch elements. """ return self.metric.pairwise_distances(data, nn_indices)
[docs] @mpi_chunk(return_count=1) def crosswise_tensor( self, data: mm.ndarray, nn_data: mm.ndarray, data_indices: mm.ndarray, nn_indices: mm.ndarray, **kwargs, ) -> mm.ndarray: """ Compute a crosswise distance tensor between data and their nearest neighbors. Takes full datasets of records of interest `data` and neighbor candidates `nn_data` and produces a scalar distance between each element of `data` indicated by `data_indices` and each of the nearest neighbors in `nn_data` as indicated by the corresponding rows of `nn_indices`. `data` and `nn_data` can refer to the same dataset. Args: data: The data matrix of shape `(data_count, feature_count)` containing batch elements. nn_data: The data matrix of shape `(candidate_count, feature_count)` containing the universe of candidate neighbors for the batch elements. Might be the same as `data`. indices: An integral vector of shape `(batch_count,)` containing the indices of the batch. nn_indices: An integral matrix of shape (batch_count, nn_count) listing the nearest neighbor indices for the batch of data points. Returns: A tensor of shape `(batch_count, nn_count)` whose second dimension indicates distance vectors between each batch element and its nearest neighbors. """ return self.metric.crosswise_distances( data, nn_data, data_indices, nn_indices )
@auto_str class DifferenceIsotropy(Isotropy): """ An isotropic deformation model that reasons about differences rather than distances. Isotropy defines a scaled elementwise distance function :math:`d_ell(\\cdot, \\cdot)`, and is paramterized by a scalar :math:`\\ell>0` length scale hyperparameter. .. math:: d_\\ell(\\mathbf{x}, \\mathbf{y}) = \\sum_{i=0}^d \\frac{d(\\mathbf{x}_i, \\mathbf{y}_i)}{\\ell} Args: metric: A MetricFn object defining the behavior of the feature metric space. length_scale: Some scalar nonnegative hyperparameter object. """ def __call__( self, dists: mm.ndarray, length_scale: Optional[float] = None, **kwargs ) -> mm.ndarray: """ Apply isotropic deformation to an elementwise difference tensor. This function is not intended to be invoked directly by a user. It is instead functionally incorporated into some :class:`MuyGPyS.gp.kernels.KernelFn` in its constructor. Args: dists: A tensor of distances between sets of observables. length_scale: A floating point length scale. Returns: A scaled distance matrix of the same shape as shape `(data_count, nn_count)` or a pairwise distance tensor of shape `(data_count, nn_count, nn_count)` whose last two dimensions are pairwise distance matrices. """ if length_scale is None: length_scale = self.length_scale() return self.metric(dists / length_scale) @mpi_chunk(return_count=1) def pairwise_tensor( self, data: mm.ndarray, nn_indices: mm.ndarray, **kwargs, ) -> mm.ndarray: """ Compute a pairwise distance tensor among sets of nearest neighbors. Takes a full dataset of records of interest `data` and produces the pairwise distances between the elements indicated by each row of `nn_indices`. Args: data: The data matrix of shape `(batch_count, feature_count)` containing batch elements. nn_indices: An integral matrix of shape (batch_count, nn_count) listing the nearest neighbor indices for the batch of data points. Returns: A tensor of shape `(batch_count, nn_count, nn_count)` containing the `(nn_count, nn_count)`-shaped pairwise nearest neighbor distance tensors corresponding to each of the batch elements. """ return self.metric.pairwise_differences(data, nn_indices) @mpi_chunk(return_count=1) def crosswise_tensor( self, data: mm.ndarray, nn_data: mm.ndarray, data_indices: mm.ndarray, nn_indices: mm.ndarray, **kwargs, ) -> mm.ndarray: """ Compute a crosswise distance tensor between data and their nearest neighbors. Takes full datasets of records of interest `data` and neighbor candidates `nn_data` and produces a scalar distance between each element of `data` indicated by `data_indices` and each of the nearest neighbors in `nn_data` as indicated by the corresponding rows of `nn_indices`. `data` and `nn_data` can refer to the same dataset. Args: data: The data matrix of shape `(data_count, feature_count)` containing batch elements. nn_data: The data matrix of shape `(candidate_count, feature_count)` containing the universe of candidate neighbors for the batch elements. Might be the same as `data`. indices: An integral vector of shape `(batch_count,)` containing the indices of the batch. nn_indices: An integral matrix of shape (batch_count, nn_count) listing the nearest neighbor indices for the batch of data points. Returns: A tensor of shape `(batch_count, nn_count)` whose second dimension indicates distance vectors between each batch element and its nearest neighbors. """ return self.metric.crosswise_differences( data, nn_data, data_indices, nn_indices )