Source code for MuyGPyS.optimize.loss

# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

"""
Loss Function Handling

MuyGPyS includes predefined loss functions and convenience functions for
indicating them to optimization.
"""

from typing import Callable

import MuyGPyS._src.math as mm
from MuyGPyS._src.optimize.loss import (
    _mse_fn,
    _cross_entropy_fn,
    _lool_fn,
    _lool_fn_unscaled,
    _pseudo_huber_fn,
    _looph_fn,
)


[docs]def make_raw_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, scale_fn: Callable, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, **loss_kwargs, ) -> Callable: """ Make a predict_and_loss function that depends only on the posterior mean. Assembles a new function with signature `(K, Kcross, *args, **kwargs)` that computes the posterior mean and uses the passed `loss_fn` to score it against the batch targets. Args: loss_fn: A loss function Callable with signature `(predictions, responses, **kwargs)`, where `predictions` and `targets` are matrices of shape `(batch_count, response_count)`. mean_fn: A MuyGPS posterior mean function Callable with signature `(K, Kcross, batch_nn_targets)`, which are tensors of shape `(batch_count, nn_count, nn_count)`, `(batch_count, nn_count)`, and `(batch_count, nn_count, response_count)`, respectively. var_fn: A MuyGPS posterior variance function Callable with signature `(K, Kcross)`, which are tensors of shape `(batch_count, nn_count, nn_count)` and `(batch_count, nn_count)`, respectively. Unused by this function, but still required by the signature. scale_fn: A MuyGPS `scale` optimization function Callable with signature `(K, batch_nn_targets)`, which are tensors of shape `(batch_count, nn_count, nn_count)` and `(batch_count, nn_count, response_count)`, respectively. Unused by this function, but still required by the signature. batch_nn_targets: A tensor of shape `(batch_count, nn_count, response_count)` containing the expected response of the nearest neighbors of each batch element. batch_targets: A matrix of shape `(batch_count, response_count)` containing the expected response of each batch element. loss_kwargs: Additionall keyword arguments used by the loss function. Returns: A Callable with signature `(K, Kcross, *args, **kwargs) -> float` that computes the posterior mean and applies the loss function to it and the `batch_targets`. """ def predict_and_loss_fn(K, Kcross, *args, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) return -loss_fn(predictions, batch_targets, **loss_kwargs) return predict_and_loss_fn
[docs]def make_var_predict_and_loss_fn( loss_fn: Callable, mean_fn: Callable, var_fn: Callable, scale_fn: Callable, batch_nn_targets: mm.ndarray, batch_targets: mm.ndarray, **loss_kwargs, ) -> Callable: """ Make a predict_and_loss function that depends on the posterior mean and variance. Assembles a new function with signature `(K, Kcross, *args, **kwargs)` that computes the posterior mean and variance and uses the passed `loss_fn` to score them against the batch targets. Args: loss_fn: A loss function Callable with signature `(predictions, responses, **kwargs)`, where `predictions` and `targets` are matrices of shape `(batch_count, response_count)`. mean_fn: A MuyGPS posterior mean function Callable with signature `(K, Kcross, batch_nn_targets)`, which are tensors of shape `(batch_count, nn_count, nn_count)`, `(batch_count, nn_count)`, and `(batch_count, nn_count, response_count)`, respectively. var_fn: A MuyGPS posterior variance function Callable with signature `(K, Kcross)`, which are tensors of shape `(batch_count, nn_count, nn_count)` and `(batch_count, nn_count)`, respectively. scale_fn: A MuyGPS `scale` optimization function Callable with signature `(K, batch_nn_targets)`, which are tensors of shape `(batch_count, nn_count, nn_count)` and `(batch_count, nn_count, response_count)`, respectively. batch_nn_targets: A tensor of shape `(batch_count, nn_count, response_count)` containing the expected response of the nearest neighbors of each batch element. batch_targets: A matrix of shape `(batch_count, response_count)` containing the expected response of each batch element. loss_kwargs: Additionall keyword arguments used by the loss function. Returns: A Callable with signature `(K, Kcross, *args, **kwargs) -> float` that computes the posterior mean and applies the loss function to it and the `batch_targets`. """ def predict_and_loss_fn(K, Kcross, *args, **kwargs): predictions = mean_fn( K, Kcross, batch_nn_targets, **kwargs, ) scale = scale_fn(K, batch_nn_targets, **kwargs) variances = var_fn(K, Kcross, **kwargs) return -loss_fn( predictions, batch_targets, variances, scale, **loss_kwargs ) return predict_and_loss_fn
[docs]class LossFn: """ Loss functor class. MuyGPyS-compatible loss functions are objects of this class. Creating a new loss function is as simple as instantiation a new `LossFn` object. Args: loss_fn: A Callable with signature `(predictions, targets, **kwargs)` or `(predictions, targets, variances, scale, **kwargs)` tha computes a floating-point loss score of a set of predictions given posterior means and possibly posterior variances. Individual loss functions can implement different `kwargs` as needed. make_precit_and_loss_fn: A Callable with signature `(loss_fn, mean_fn, var_fn, scale_fn, batch_nn_targets, batch_targets, **loss_kwargs)` that produces a function that computes posterior predictions and scores them using the loss function. :func:~MuyGPyS.optimize.loss._make_raw_predict_and_loss_fn` and :func:~MuyGPyS.optimize.loss._make_var_predict_and_loss_fn` are two candidates. """ def __init__(self, loss_fn: Callable, make_predict_and_loss_fn: Callable): self._fn = loss_fn self._make_predict_and_loss_fn = make_predict_and_loss_fn def __call__(self, *args, **kwargs): return self._fn(*args, **kwargs) def make_predict_and_loss_fn(self, *args, **kwargs): return self._make_predict_and_loss_fn(self._fn, *args, **kwargs)
cross_entropy_fn = LossFn(_cross_entropy_fn, make_raw_predict_and_loss_fn) """ Cross entropy function. Computes the cross entropy loss the predicted versus known response. Transforms `predictions` to be row-stochastic, and ensures that `targets` contains no negative elements. Only defined for two or more labels. For a sample with true labels :math:`y_i \\in \\{0, 1\\}` and estimates :math:`\\bar{\\mu(x_i)} = \\textrm{Pr}(y = 1)`, the function computes .. math:: \\ell_\\textrm{cross-entropy}(\\bar{\\mu}, y) = \\sum_{i=1}^{b} y_i \\log(\\bar{\\mu}_i) - (1 - y_i) \\log(1 - \\bar{\\mu}_i). The numpy backend uses `sklearn's implementation <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html>`_. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. eps: Probabilities are clipped to the range `[eps, 1 - eps]`. Returns: The cross-entropy loss of the prediction. """ mse_fn = LossFn(_mse_fn, make_raw_predict_and_loss_fn) """ Mean squared error function. Computes mean squared error loss of the predicted versus known response. Treats multivariate outputs as interchangeable in terms of loss penalty. The function computes .. math:: \\ell_\\textrm{MSE}(\\bar{\\mu}, y) = \\frac{1}{b} \\sum_{i=1}^b (\\bar{\\mu}_i - y)^2. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. Returns: The mse loss of the prediction. """ lool_fn = LossFn(_lool_fn, make_var_predict_and_loss_fn) """ Leave-one-out likelihood function. Computes leave-one-out likelihood (LOOL) loss of the predicted versus known response. Treats multivariate outputs as interchangeable in terms of loss penalty. The function computes .. math:: \\ell_\\textrm{lool}(\\bar{\\mu}, y \\mid \\bar{\\Sigma}) = \\sum_{i=1}^b \\sum_{j=1}^s \\left ( \\frac{\\bar{\\mu}_i - y}{\\bar{\\Sigma}_{ii}} \\right )_j^2 + \\left ( \\log \\bar{\\Sigma}_{ii} \\right )_j Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. variances: The unscaled variance of the predicted responses of shape `(batch_count, response_count)`. scale: The scale variance scaling parameter of shape `(response_count,)`. Returns: The LOOL loss of the prediction. """ lool_fn_unscaled = LossFn(_lool_fn_unscaled, make_var_predict_and_loss_fn) """ Leave-one-out likelihood function. Computes leave-one-out likelihood (LOOL) loss of the predicted versus known response. Treats multivariate outputs as interchangeable in terms of loss penalty. Unlike lool_fn, does not require scale as an argument. The function computes .. math:: \\ell_\\textrm{lool}(\\bar{\\mu}, y \\mid \\bar{\\Sigma}) = \\sum_{i=1}^b \\frac{(\\bar{\\mu}_i - y)^2}{\\bar{\\Sigma}_{ii}} + \\log \\bar{\\Sigma}_{ii}. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. variances: The unscaled variance of the predicted responses of shape `(batch_count, response_count)`. Returns: The LOOL loss of the prediction. """ pseudo_huber_fn = LossFn(_pseudo_huber_fn, make_raw_predict_and_loss_fn) """ Pseudo-Huber loss function. Computes a smooth approximation to the Huber loss function, which balances sensitive squared-error loss for relatively small errors and robust-to-outliers absolute loss for larger errors, so that the loss is not overly sensitive to outliers. Uses the form from `wikipedia <https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function>`_. The function computes .. math:: \\ell_\\textrm{Pseudo-Huber}(\\bar{\\mu}, y \\mid \\delta) = \\sum_{i=1}^b \\delta^2 \\left ( \\sqrt{ 1 + \\left ( \\frac{y_i - \\bar{\\mu}_i}{\\delta} \\right )^2 } - 1 \\right ). Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. boundary_scale: The boundary value for the residual beyond which the loss becomes approximately linear. Useful values depend on the scale of the response. Returns: The sum of pseudo-Huber losses of the predictions. """ looph_fn = LossFn(_looph_fn, make_var_predict_and_loss_fn) """ Variance-regularized pseudo-Huber loss function. Computes a smooth approximation to the Huber loss function, similar to :func:`pseudo_huber_fn`, with the addition of both a variance scaling and a additive logarithmic variance regularization term to avoid exploding the variance. The function computes .. math:: \\ell_\\textrm{lool}(\\bar{\\mu}, y \\mid \\delta, \\bar{\\Sigma}) = \\sum_{i=1}^b \\sum_{i=1}^s \\delta^2 \\left ( \\sqrt{ 1 + \\left ( \\frac{y_i - \\bar{\\mu}_i}{\\delta \\bar{\\Sigma}_{ii}} \\right )_j^2 } - 1 \\right ) + \\left ( \\log \\bar{\\Sigma}_{ii} \\right )_j. Args: predictions: The predicted response of shape `(batch_count, response_count)`. targets: The expected response of shape `(batch_count, response_count)`. variances: The unscaled variance of the predicted responses of shape `(batch_count, response_count)`. scale: The scale variance scaling parameter of shape `(response_count,)`. boundary_scale: The boundary value for the residual beyond which the loss becomes approximately linear. Useful values depend on the scale of the response. Returns: The sum of leave-one-out pseudo-Huber losses of the predictions. """