Source code for MuyGPyS.optimize.objective

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Objective Handling

MuyGPyS includes predefined objective functions and convenience functions for
indicating them to optimization.
"""

from typing import Callable, Dict, Optional

import MuyGPyS._src.math as mm
from MuyGPyS.optimize.utils import _switch_on_loss_method


[docs]def make_obj_fn(obj_method: str, loss_method: str, *args, **kwargs) -> Callable: """ Prepare an objective function as a function purely of the hyperparameters to be optimized. This function is designed for use with :func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format depends on the `opt_method` argument. Args: obj_method: The name of the objective function to be minimized. opt_method: The name of the optimization method to be utilized. loss_method: Indicates the loss function to be used. Returns: A Callable `objective_fn`, whose format depends on `opt_method`. """ if obj_method == "loo_crossval": return make_loo_crossval_fn(loss_method, *args, **kwargs) else: raise ValueError(f"Unsupported objective method: {obj_method}")
[docs]def make_loo_crossval_fn( loss_method: str, loss_fn: Callable, kernel_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, pairwise_diffs: mm.ndarray, crosswise_diffs: mm.ndarray, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, batch_features: Optional[mm.ndarray] = None, loss_kwargs: Dict = dict(), ) -> Callable: """ Prepare a leave-one-out cross validation function as a function purely of the hyperparameters to be optimized. This function is designed for use with :func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format depends on the `opt_method` argument. Args: loss_method: Indicates the loss function to be used. kernel_fn: A function that realizes kernel tensors given a list of the free parameters. mean_fn: A function that realizes MuyGPs posterior mean prediction given an epsilon value. The given value is unused if epsilon is fixed. var_fn: A function that realizes MuyGPs posterior variance prediction given an epsilon value. The given value is unused if epsilon is fixed. sigma_sq_fn: A function that realizes `sigma_sq` optimization given an epsilon value. The given value is unused if epsilon is fixed. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. crosswise_diffs: A tensor of shape `(batch_count, nn_count, feature_count)` whose last two dimensions list the difference between each feature of each batch element element and its nearest neighbors. batch_nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. batch_targets: Matrix of floats of shape `(batch_count, response_count)` whose rows give the expected response for each batch element. batch_features: Matrix of floats of shape `(batch_count, feature_count)` whose rows give the features for each batch element. loss_kwargs: A dict listing any additional kwargs to pass to the loss function. Returns: A Callable `objective_fn`, whose format depends on `opt_method`. """ kernels_fn = make_kernels_fn(kernel_fn, pairwise_diffs, crosswise_diffs) predict_and_loss_fn = _switch_on_loss_method( loss_method, make_raw_predict_and_loss_fn, make_raw_predict_and_loss_fn, make_var_predict_and_loss_fn, make_raw_predict_and_loss_fn, make_var_predict_and_loss_fn, loss_fn, mean_fn, var_fn, sigma_sq_fn, batch_nn_targets, batch_targets, **loss_kwargs, ) def obj_fn(*args, **kwargs): K, Kcross = kernels_fn(*args, batch_features=batch_features, **kwargs) return predict_and_loss_fn(K, Kcross, *args, **kwargs) return obj_fn
def make_kernels_fn( kernel_fn: Callable, pairwise_diffs: mm.ndarray, crosswise_diffs: mm.ndarray, ) -> Callable: def kernels_fn(*args, **kwargs): K = kernel_fn(pairwise_diffs, *args, **kwargs) Kcross = kernel_fn(crosswise_diffs, *args, **kwargs) return K, Kcross return kernels_fn def make_raw_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, **loss_kwargs, ) -> Callable: def predict_and_loss_fn(K, Kcross, *args, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) return -loss_fn(predictions, batch_targets, **loss_kwargs) return predict_and_loss_fn def make_var_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, **loss_kwargs, ) -> Callable: def predict_and_loss_fn(K, Kcross, *args, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) sigma_sq = sigma_sq_fn(K, batch_nn_targets, **kwargs) variances = var_fn(K, Kcross, **kwargs) return -loss_fn( predictions, batch_targets, variances, sigma_sq, **loss_kwargs ) return predict_and_loss_fn