Source code for MuyGPyS.optimize.objective

# Copyright 2021-2022 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Objective Handling

MuyGPyS includes predefined objective functions and convenience functions for
indicating them to optimization.
"""

import numpy as np

from typing import Callable, Tuple

from MuyGPyS import config

from MuyGPyS.optimize.utils import _switch_on_opt_method, _switch_on_loss_method
from MuyGPyS.optimize.loss import get_loss_func


[docs]def make_obj_fn( obj_method: str, opt_method: str, loss_method: str, *args ) -> Callable: """ Prepare an objective function as a function purely of the hyperparameters to be optimized. This function is designed for use with :func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format depends on the `opt_method` argument. Args: obj_method: The name of the objective function to be minimized. opt_method: The name of the optimization method to be utilized. loss_method: Indicates the loss function to be used. Returns: A Callable `objective_fn`, whose format depends on `opt_method`. """ if obj_method == "loo_crossval": return make_loo_crossval_fn(opt_method, loss_method, *args) else: raise ValueError(f"Unsupported objective method: {obj_method}")
[docs]def make_loo_crossval_fn( opt_method: str, loss_method: str, loss_fn: Callable, kernel_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, pairwise_dists: np.ndarray, crosswise_dists: np.ndarray, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: """ Prepare a leave-one-out cross validation function as a function purely of the hyperparameters to be optimized. This function is designed for use with :func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format depends on the `opt_method` argument. Args: opt_method: The name of the optimization method to be utilized. loss_method: Indicates the loss function to be used. kernel_fn: A function that realizes kernel tensors given a list of the free parameters. mean_fn: A function that realizes MuyGPs posterior mean prediction given an epsilon value. The given value is unused if epsilon is fixed. var_fn: A function that realizes MuyGPs posterior variance prediction given an epsilon value. The given value is unused if epsilon is fixed. sigma_sq_fn: A function that realizes `sigma_sq` optimization given an epsilon value. The given value is unused if epsilon is fixed. pairwise_dists: Distance tensor of floats of shape `(batch_count, nn_count, nn_count)` whose second two dimensions give the pairwise distances between the nearest neighbors of each batch element. crosswise_dists: Distance matrix of floats of shape `(batch_count, nn_count)` whose rows give the distances between each batch element and its nearest neighbors. batch_nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. batch_targets: Matrix of floats of shape `(batch_count, response_count)` whose rows give the expected response for each batch element. Returns: A Callable `objective_fn`, whose format depends on `opt_method`. """ kernels_fn = _switch_on_opt_method( opt_method, make_kwargs_kernels_fn, make_array_kernels_fn, kernel_fn, pairwise_dists, crosswise_dists, ) predict_and_loss_fn = _switch_on_loss_method( loss_method, make_raw_predict_and_loss_fn, make_raw_predict_and_loss_fn, make_var_predict_and_loss_fn, opt_method, loss_fn, mean_fn, var_fn, sigma_sq_fn, batch_nn_targets, batch_targets, ) return _switch_on_opt_method( opt_method, make_kwargs_obj_fn, make_array_obj_fn, kernels_fn, predict_and_loss_fn, )
def make_array_obj_fn( kernels_fn: Callable, predict_and_loss_fn: Callable ) -> Callable: def obj_fn(x0): K, Kcross = kernels_fn(x0) return predict_and_loss_fn(K, Kcross, x0) return obj_fn def make_kwargs_obj_fn( kernels_fn: Callable, predict_and_loss_fn: Callable, **kwargs ) -> Callable: def obj_fn(**kwargs): K, Kcross = kernels_fn(**kwargs) return predict_and_loss_fn(K, Kcross, **kwargs) return obj_fn def make_array_kernels_fn( kernel_fn: Callable, pairwise_dists: np.ndarray, crosswise_dists: np.ndarray, ) -> Callable: def kernels_fn(x0): K = kernel_fn(pairwise_dists, x0) Kcross = kernel_fn(crosswise_dists, x0) return K, Kcross return kernels_fn def make_kwargs_kernels_fn( kernel_fn: Callable, pairwise_dists: np.ndarray, crosswise_dists: np.ndarray, ) -> Callable: def kernels_fn(**kwargs): K = kernel_fn(pairwise_dists, **kwargs) Kcross = kernel_fn(crosswise_dists, **kwargs) return K, Kcross return kernels_fn def make_raw_predict_and_loss_fn( opt_method: str, loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: return _switch_on_opt_method( opt_method, make_raw_kwargs_predict_and_loss_fn, make_raw_array_predict_and_loss_fn, loss_fn, mean_fn, var_fn, sigma_sq_fn, batch_nn_targets, batch_targets, ) def make_raw_array_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: def predict_and_loss_fn(K, Kcross, x0): predictions = mean_fn( K, Kcross, batch_nn_targets, x0, ) return loss_fn(predictions, batch_targets) return predict_and_loss_fn def make_raw_kwargs_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: def predict_and_loss_fn(K, Kcross, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) return -loss_fn(predictions, batch_targets) return predict_and_loss_fn def make_var_predict_and_loss_fn( opt_method: str, loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: return _switch_on_opt_method( opt_method, make_var_kwargs_predict_and_loss_fn, make_var_array_predict_and_loss_fn, loss_fn, mean_fn, var_fn, sigma_sq_fn, batch_nn_targets, batch_targets, ) def make_var_array_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: def predict_and_loss_fn(K, Kcross, x0): predictions = mean_fn( K, Kcross, batch_nn_targets, x0, ) sigma_sq = sigma_sq_fn(K, batch_nn_targets, x0) variances = var_fn(K, Kcross, x0) return loss_fn(predictions, batch_targets, variances,sigma_sq) return predict_and_loss_fn def make_var_kwargs_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, sigma_sq_fn: Callable, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: def predict_and_loss_fn(K, Kcross, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) sigma_sq = sigma_sq_fn(K, batch_nn_targets, **kwargs) variances = var_fn(K, Kcross, **kwargs) return -loss_fn(predictions, batch_targets, variances, sigma_sq) return predict_and_loss_fn