Source code for MuyGPyS.optimize.objective

# Copyright 2021 Lawrence Livermore National Security, LLC and other MuyGPyS
# Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""Objective and Loss Function Handling

MuyGPyS includes predefined loss functions and convenience functions for 
indicating them to optimization.
"""

import numpy as np

from typing import Callable

from scipy.special import softmax
from sklearn.metrics import log_loss


[docs]def get_loss_func(loss_method: str) -> Callable: """ Select a loss function based upon string key. Currently supports strings `"log"` or `"cross-entropy"` for :func:`MuyGPyS.optimize.objective.cross_entropy_fn` and `"mse"` for :func:`MuyGPyS.optimize.objective.mse_fn`. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. Returns: The loss function Callable. Raises: NotImplementedError: Unrecognized strings will result in an error. """ loss_method = loss_method.lower() if loss_method == "cross-entropy" or loss_method == "log": return cross_entropy_fn elif loss_method == "mse": return mse_fn else: raise NotImplementedError( f"Loss function {loss_method} is not implemented." )
[docs]def cross_entropy_fn( predictions: np.ndarray, targets: np.ndarray, ) -> float: """ Cross entropy function. Computes the cross entropy loss the predicted versus known response. Transforms `predictions` to be row-stochastic, and ensures that `targets` contains no negative elements. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. Returns: The cross-entropy loss of the prediction. """ one_hot_targets = np.zeros(targets.shape) one_hot_targets[targets > 0.0] = 1.0 return log_loss( one_hot_targets, softmax(predictions, axis=1), eps=1e-6, normalize=False )
[docs]def mse_fn( predictions: np.ndarray, targets: np.ndarray, ) -> float: """ Mean squared error function. Computes mean squared error loss of the predicted versus known response. Treats multivariate outputs as interchangeable in terms of loss penalty. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. Returns: The mse loss of the prediction. """ batch_count = predictions.shape[0] response_count = predictions.shape[1] squared_errors = np.sum((predictions - targets) ** 2) return squared_errors / (batch_count * response_count)
[docs]def loo_crossval( x0: np.ndarray, objective_fn: Callable, kernel_fn: Callable, predict_fn: Callable, pairwise_dists: np.ndarray, crosswise_dists: np.ndarray, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> float: """ Leave-one-out cross validation. Returns leave-one-out cross validation performance for a set `MuyGPS` object. Predicts on all of the training data at once. Args: x0: Current guess for hyperparameter values of shape `(opt_count,)`. objective_fn: The function to be optimized. Can be any function that accepts two `numpy.ndarray` objects indicating the prediction and target values, in that order. kernel_fn: A function that realizes kernel tensors given a list of the free parameters. predict_fn: A function that realizes MuyGPs prediction given an epsilon value. The given value is unused if epsilon is fixed. pairwise_dists: Distance tensor of floats of shape `(batch_count, nn_count, nn_count)` whose second two dimensions give the pairwise distances between the nearest neighbors of each batch element. crosswise_dists: Distance matrix of floats of shape `(batch_count, nn_count)` whose rows give the distances between each batch element and its nearest neighbors. batch_nn_targets: Tensor of floats of shape `(batch_count, nn_count, response_count)` containing the expected response for each nearest neighbor of each batch element. batch_targets: Matrix of floats of shape `(batch_count, response_count)` whose rows give the expected response for each batch element. Returns: The evaluation of `objective_fn` on the predicted versus expected response. """ K = kernel_fn(pairwise_dists, x0) Kcross = kernel_fn(crosswise_dists, x0) predictions = predict_fn( K, Kcross, batch_nn_targets, x0[-1], ) return objective_fn(predictions, batch_targets)
def make_loo_crossval_fn( loss_fn: Callable, kernel_fn: Callable, predict_fn: Callable, pairwise_dists: np.ndarray, crosswise_dists: np.ndarray, batch_nn_targets: np.ndarray, batch_targets: np.ndarray, ) -> Callable: def caller_fn(x0): return loo_crossval( x0, loss_fn, kernel_fn, predict_fn, pairwise_dists, crosswise_dists, batch_nn_targets, batch_targets, ) return caller_fn