Source code for MuyGPyS.optimize.sigma_sq

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Convenience functions for optimizing the `sigma_sq` parameter of
:class:`~MuyGPyS.gp.muygps.MuyGPS` objects.

Currently only supports an analytic approximation, but will support other
methods in the future.
"""

from copy import deepcopy
from typing import Callable, Optional

import MuyGPyS._src.math as mm
from MuyGPyS.gp import MuyGPS, MultivariateMuyGPS as MMuyGPS
from MuyGPyS._src.optimize.sigma_sq import _analytic_sigma_sq_optim
from MuyGPyS.optimize.utils import _switch_on_sigma_method
from MuyGPyS.gp.noise.perturbation import select_perturb_fn


[docs]def muygps_sigma_sq_optim( muygps: MuyGPS, pairwise_diffs: mm.ndarray, nn_targets: mm.ndarray, sigma_method: Optional[str] = "analytic", ) -> MuyGPS: """ Optimize the value of the :math:`\\sigma^2` scale parameter for each response dimension. The optimization to be applied depends upon the value of `sigma_method`. Args: muygps: The model to be optimized. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. sigma_method: The optimization method to apply. Currently only supports `"analytic"` and `None`. Returns: A new MuyGPs model whose sigma_sq parameter has been optimized. """ return _switch_on_sigma_method( sigma_method, muygps_analytic_sigma_sq_optim, lambda muygps, pairwise_diffs, nn_targets: muygps, muygps, pairwise_diffs, nn_targets, )
[docs]def mmuygps_sigma_sq_optim( mmuygps: MMuyGPS, pairwise_diffs: mm.ndarray, nn_targets: mm.ndarray, sigma_method: Optional[str] = "analytic", ) -> MMuyGPS: """ Optimize the value of the :math:`\\sigma^2` scale parameter for each response dimension of a MultivariateMuyGPS object. The optimization to be applied depends upon the value of `sigma_method`. Args: mmuygps: The model to be optimized. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. sigma_method: The optimization method to apply. Currently only supports `"analytic"` and `None`. Returns: A new MultivariateMuyGPs model whose sigma_sq parameter has been optimized. """ return _switch_on_sigma_method( sigma_method, mmuygps_analytic_sigma_sq_optim, lambda mmuygps, pairwise_diffs, nn_targets: mmuygps, mmuygps, pairwise_diffs, nn_targets, )
def make_sigma_sq_optim( sigma_method: Optional[str], muygps: MuyGPS ) -> Callable: return _switch_on_sigma_method( sigma_method, make_analytic_sigma_sq_optim, make_none_sigma_sq_optim, muygps, _analytic_sigma_sq_optim, select_perturb_fn(muygps.eps), ) def make_none_sigma_sq_optim(muygps: MuyGPS, *args) -> Callable: return lambda: lambda K, nn_targets, **kwargs: muygps.sigma_sq() def make_analytic_sigma_sq_optim( muygps: MuyGPS, analytic_optim_fn, perturb_fn: Callable ) -> Callable: if not muygps.eps.fixed(): def ss_opt_fn(K, nn_targets, **kwargs): return analytic_optim_fn(perturb_fn(K, kwargs["eps"]), nn_targets) else: def ss_opt_fn(K, nn_targets, **kwargs): return analytic_optim_fn(perturb_fn(K, muygps.eps()), nn_targets) return ss_opt_fn
[docs]def muygps_analytic_sigma_sq_optim( muygps: MuyGPS, pairwise_diffs: mm.ndarray, nn_targets: mm.ndarray ) -> MuyGPS: """ Optimize the value of the :math:`\\sigma^2` scale parameter for each response dimension. We approximate :math:`\\sigma^2` by way of averaging over the analytic solution from each local kernel. .. math:: \\sigma^2 = \\frac{1}{bk} * \\sum_{i \\in B} Y_{nn_i}^T K_{nn_i}^{-1} Y_{nn_i} Here :math:`Y_{nn_i}` and :math:`K_{nn_i}` are the target and kernel matrices with respect to the nearest neighbor set in scope, where :math:`k` is the number of nearest neighbors and :math:`b = |B|` is the number of batch elements considered. Args: muygps: The model to be optimized. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. Returns: A new MuyGPs model whose sigma_sq parameter has been optimized. """ ret = deepcopy(muygps) perturb_fn = select_perturb_fn(ret.eps) K = ret.kernel(pairwise_diffs) ss = _analytic_sigma_sq_optim(perturb_fn(K, ret.eps()), nn_targets) ret.sigma_sq._set(ss) return ret
[docs]def mmuygps_analytic_sigma_sq_optim( mmuygps: MMuyGPS, pairwise_diffs: mm.ndarray, nn_targets: mm.ndarray ) -> MMuyGPS: """ Optimize the value of the :math:`\\sigma^2` scale parameter for each response dimension. We approximate :math:`\\sigma^2` by way of averaging over the analytic solution from each local kernel. .. math:: \\sigma^2 = \\frac{1}{bk} * \\sum_{i \\in B} Y_{nn_i}^T K_{nn_i}^{-1} Y_{nn_i} Here :math:`Y_{nn_i}` and :math:`K_{nn_i}` are the target and kernel matrices with respect to the nearest neighbor set in scope, where :math:`k` is the number of nearest neighbors and :math:`b = |B|` is the number of batch elements considered. Args: muygps: The model to be optimized. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. Returns: A new MuyGPs model whose sigma_sq parameter has been optimized. """ ret = deepcopy(mmuygps) batch_count, nn_count, response_count = nn_targets.shape if response_count != len(ret.models): raise ValueError( f"Response count ({response_count}) does not match the number " f"of models ({len(ret.models)})." ) sigma_sqs = mm.zeros((response_count,)) for i, model in enumerate(ret.models): perturb_fn = select_perturb_fn(model.eps) K = model.kernel(pairwise_diffs) new_sigma_val = _analytic_sigma_sq_optim( perturb_fn(K, model.eps()), nn_targets[:, :, i].reshape(batch_count, nn_count, 1), )[0] sigma_sqs = mm.assign(sigma_sqs, new_sigma_val, i) model.sigma_sq._set(mm.atleast_1d(sigma_sqs[i])) ret.sigma_sq._set(sigma_sqs) return ret