# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT
"""
Objective Handling
MuyGPyS includes predefined objective functions and convenience functions for
indicating them to optimization.
"""
from typing import Callable, Dict, Optional
import MuyGPyS._src.math as mm
from MuyGPyS.optimize.utils import _switch_on_loss_method
[docs]def make_obj_fn(obj_method: str, loss_method: str, *args, **kwargs) -> Callable:
"""
Prepare an objective function as a function purely of the hyperparameters
to be optimized.
This function is designed for use with
:func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format
depends on the `opt_method` argument.
Args:
obj_method:
The name of the objective function to be minimized.
opt_method:
The name of the optimization method to be utilized.
loss_method:
Indicates the loss function to be used.
Returns:
A Callable `objective_fn`, whose format depends on `opt_method`.
"""
if obj_method == "loo_crossval":
return make_loo_crossval_fn(loss_method, *args, **kwargs)
else:
raise ValueError(f"Unsupported objective method: {obj_method}")
[docs]def make_loo_crossval_fn(
loss_method: str,
loss_fn: Callable,
kernel_fn: Callable,
mean_fn: Callable,
var_fn: Callable,
sigma_sq_fn: Callable,
pairwise_diffs: mm.ndarray,
crosswise_diffs: mm.ndarray,
batch_nn_targets: mm.ndarray,
batch_targets: mm.ndarray,
batch_features: Optional[mm.ndarray] = None,
loss_kwargs: Dict = dict(),
) -> Callable:
"""
Prepare a leave-one-out cross validation function as a function purely of
the hyperparameters to be optimized.
This function is designed for use with
:func:`MuyGPyS.optimize.chassis.optimize_from_tensors()`, and the format
depends on the `opt_method` argument.
Args:
loss_method:
Indicates the loss function to be used.
kernel_fn:
A function that realizes kernel tensors given a list of the free
parameters.
mean_fn:
A function that realizes MuyGPs posterior mean prediction given an
epsilon value. The given value is unused if epsilon is fixed.
var_fn:
A function that realizes MuyGPs posterior variance prediction given
an epsilon value. The given value is unused if epsilon is fixed.
sigma_sq_fn:
A function that realizes `sigma_sq` optimization given an epsilon
value. The given value is unused if epsilon is fixed.
pairwise_diffs:
A tensor of shape `(batch_count, nn_count, nn_count, feature_count)`
containing the `(nn_count, nn_count, feature_count)`-shaped pairwise
nearest neighbor difference tensors corresponding to each of the
batch elements.
crosswise_diffs:
A tensor of shape `(batch_count, nn_count, feature_count)` whose
last two dimensions list the difference between each feature of each
batch element element and its nearest neighbors.
batch_nn_targets:
Tensor of floats of shape `(batch_count, nn_count, response_count)`
containing the expected response for each nearest neighbor of each
batch element.
batch_targets:
Matrix of floats of shape `(batch_count, response_count)` whose rows
give the expected response for each batch element.
batch_features:
Matrix of floats of shape `(batch_count, feature_count)` whose rows
give the features for each batch element.
loss_kwargs:
A dict listing any additional kwargs to pass to the loss function.
Returns:
A Callable `objective_fn`, whose format depends on `opt_method`.
"""
kernels_fn = make_kernels_fn(kernel_fn, pairwise_diffs, crosswise_diffs)
predict_and_loss_fn = _switch_on_loss_method(
loss_method,
make_raw_predict_and_loss_fn,
make_raw_predict_and_loss_fn,
make_var_predict_and_loss_fn,
make_raw_predict_and_loss_fn,
make_var_predict_and_loss_fn,
loss_fn,
mean_fn,
var_fn,
sigma_sq_fn,
batch_nn_targets,
batch_targets,
**loss_kwargs,
)
def obj_fn(*args, **kwargs):
K, Kcross = kernels_fn(*args, batch_features=batch_features, **kwargs)
return predict_and_loss_fn(K, Kcross, *args, **kwargs)
return obj_fn
def make_kernels_fn(
kernel_fn: Callable,
pairwise_diffs: mm.ndarray,
crosswise_diffs: mm.ndarray,
) -> Callable:
def kernels_fn(*args, **kwargs):
K = kernel_fn(pairwise_diffs, *args, **kwargs)
Kcross = kernel_fn(crosswise_diffs, *args, **kwargs)
return K, Kcross
return kernels_fn
def make_raw_predict_and_loss_fn(
loss_fn: Callable,
mean_fn: Callable,
var_fn: Callable,
sigma_sq_fn: Callable,
batch_nn_targets: mm.ndarray,
batch_targets: mm.ndarray,
**loss_kwargs,
) -> Callable:
def predict_and_loss_fn(K, Kcross, *args, **kwargs):
predictions = mean_fn(
K,
Kcross,
batch_nn_targets,
**kwargs,
)
return -loss_fn(predictions, batch_targets, **loss_kwargs)
return predict_and_loss_fn
def make_var_predict_and_loss_fn(
loss_fn: Callable,
mean_fn: Callable,
var_fn: Callable,
sigma_sq_fn: Callable,
batch_nn_targets: mm.ndarray,
batch_targets: mm.ndarray,
**loss_kwargs,
) -> Callable:
def predict_and_loss_fn(K, Kcross, *args, **kwargs):
predictions = mean_fn(
K,
Kcross,
batch_nn_targets,
**kwargs,
)
sigma_sq = sigma_sq_fn(K, batch_nn_targets, **kwargs)
variances = var_fn(K, Kcross, **kwargs)
return -loss_fn(
predictions, batch_targets, variances, sigma_sq, **loss_kwargs
)
return predict_and_loss_fn