Source code for MuyGPyS.optimize.objective

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: MIT

Objective Handling

MuyGPyS includes predefined objective functions and convenience functions for
indicating them to optimization.

from typing import Callable, Dict, Optional

import MuyGPyS._src.math as mm

from MuyGPyS.optimize.loss import LossFn

[docs]def make_loo_crossval_fn( loss_fn: LossFn, kernel_fn: Callable, mean_fn: Callable, var_fn: Callable, scale_fn: Callable, pairwise_diffs: mm.ndarray, crosswise_diffs: mm.ndarray, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, batch_features: Optional[mm.ndarray] = None, target_mask: Optional[mm.ndarray] = None, loss_kwargs: Dict = dict(), ) -> Callable: """ Prepare a leave-one-out cross validation function as a function purely of the hyperparameters to be optimized. This function is designed for use with :class:`MuyGPyS.optimize.chassis.OptimizeFn`. Args: loss_fn: The loss functor used to evaluate model performance. kernel_fn: A function that realizes kernel tensors given a list of the free parameters. mean_fn: A function that realizes MuyGPs posterior mean prediction given a noise model. var_fn: A function that realizes MuyGPs posterior variance prediction given a noise model. scale_fn: A function that realizes variance scale parameter optimization given a noise model. pairwise_diffs: A tensor of shape `(batch_count, nn_count, nn_count, feature_count)` containing the `(nn_count, nn_count, feature_count)`-shaped pairwise nearest neighbor difference tensors corresponding to each of the batch elements. crosswise_diffs: A tensor of shape `(batch_count, nn_count, feature_count)` whose last two dimensions list the difference between each feature of each batch element element and its nearest neighbors. batch_nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. batch_targets: Matrix of floats of shape `(batch_count, response_count)` whose rows give the expected response for each batch element. batch_features: Matrix of floats of shape `(batch_count, feature_count)` whose rows give the features for each batch element. target_mask: An array of indices, listing the output dimensions of the prediction to be used for optimization. loss_kwargs: A dict listing any additional kwargs to pass to the loss function. Returns: A Callable `objective_fn`. """ kernels_fn = make_kernels_fn(kernel_fn, pairwise_diffs, crosswise_diffs) # This is ad-hoc, and might need to be revisited. predict_and_loss_fn = loss_fn.make_predict_and_loss_fn( mean_fn, var_fn, scale_fn, batch_nn_targets, batch_targets, target_mask=target_mask, **loss_kwargs, ) def obj_fn(*args, **kwargs): Kin, Kcross = kernels_fn(*args, batch_features=batch_features, **kwargs) return predict_and_loss_fn(Kin, Kcross, *args, **kwargs) return obj_fn
def make_kernels_fn( kernel_fn: Callable, pairwise_diffs: mm.ndarray, crosswise_diffs: mm.ndarray, ) -> Callable: def kernels_fn(*args, **kwargs): Kin = kernel_fn(pairwise_diffs, *args, **kwargs) Kcross = kernel_fn(crosswise_diffs, *args, **kwargs) return Kin, Kcross return kernels_fn