# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT
"""
Loss Function Handling
MuyGPyS includes predefined loss functions and convenience functions for
indicating them to optimization.
"""
from typing import Callable
import MuyGPyS._src.math as mm
from MuyGPyS._src.optimize.loss import (
_mse_fn,
_cross_entropy_fn,
_lool_fn,
_lool_fn_unscaled,
_pseudo_huber_fn,
_looph_fn,
)
from MuyGPyS.optimize.utils import _switch_on_loss_method
[docs]def get_loss_func(loss_method: str) -> Callable:
"""
Select a loss function based upon string key.
Currently supports strings `"log"` or `"cross-entropy"` for
:func:`MuyGPyS.optimize.objective.cross_entropy_fn` and `"mse"` for
:func:`MuyGPyS.optimize.objective.mse_fn`.
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
Returns:
The loss function Callable.
Raises:
NotImplementedError:
Unrecognized strings will result in an error.
"""
return _switch_on_loss_method(
loss_method,
lambda: cross_entropy_fn,
lambda: mse_fn,
lambda: lool_fn,
lambda: pseudo_huber_fn,
lambda: looph_fn,
)
[docs]def cross_entropy_fn(
predictions: mm.ndarray,
targets: mm.ndarray,
) -> float:
"""
Cross entropy function.
Computes the cross entropy loss the predicted versus known response.
Transforms `predictions` to be row-stochastic, and ensures that `targets`
contains no negative elements.
@NOTE[bwp] I don't remember why we hard-coded eps=1e-6. Might need to
revisit.
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
Returns:
The cross-entropy loss of the prediction.
"""
return _cross_entropy_fn(predictions, targets, ll_eps=1e-6)
[docs]def mse_fn(
predictions: mm.ndarray,
targets: mm.ndarray,
) -> float:
"""
Mean squared error function.
Computes mean squared error loss of the predicted versus known response.
Treats multivariate outputs as interchangeable in terms of loss penalty. The
function computes
.. math::
l(f(x), y \\mid \\sigma) = \\frac{1}{b} \\sum_{i=1}^b (f(x_i) - y)^2}
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
Returns:
The mse loss of the prediction.
"""
return _mse_fn(predictions, targets)
[docs]def lool_fn(
predictions: mm.ndarray,
targets: mm.ndarray,
variances: mm.ndarray,
sigma_sq: mm.ndarray,
) -> float:
"""
Leave-one-out likelihood function.
Computes leave-one-out likelihood (LOOL) loss of the predicted versus known
response. Treats multivariate outputs as interchangeable in terms of loss
penalty. The function computes
.. math::
l(f(x), y \\mid \\sigma) = \\sum_{i=1}^b \\sum_{j=1}^s
\\frac{(f(x_i) - y)^2}{\\sigma_j} + \\log \\sigma_j
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
variances:
The unscaled variance of the predicted responses of shape
`(batch_count, response_count)`.
sigma_sq:
The sigma_sq variance scaling parameter of shape
`(response_count,)`.
Returns:
The LOOL loss of the prediction.
"""
return _lool_fn(predictions, targets, variances, sigma_sq)
[docs]def lool_fn_unscaled(
predictions: mm.ndarray, targets: mm.ndarray, variances: mm.ndarray
) -> float:
"""
Leave-one-out likelihood function.
Computes leave-one-out likelihood (LOOL) loss of the predicted versus known
response. Treats multivariate outputs as interchangeable in terms of loss
penalty. Unlike lool_fn, does not require sigma_sq as an argument. The
function computes
.. math::
l(f(x), y \\mid \\sigma) = \\sum_{i=1}^b
\\frac{(f(x_i) - y)^2}{\\sigma} + \\log \\sigma
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
variances:
The unscaled variance of the predicted responses of shape
`(batch_count, response_count)`.
Returns:
The LOOL loss of the prediction.
"""
return _lool_fn_unscaled(predictions, targets, variances)
[docs]def pseudo_huber_fn(
predictions: mm.ndarray, targets: mm.ndarray, boundary_scale: float = 1.5
) -> float:
"""
Pseudo-Huber loss function.
Computes a smooth approximation to the Huber loss function, which balances
sensitive squared-error loss for relatively small errors and
robust-to-outliers absolute loss for larger errors, so that the loss is not
overly sensitive to outliers. Used the form from
[wikipedia](https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function).
The function computes
.. math::
l(f(x), y \\mid \\delta) = \\delta^2 \\sum_{i=1}^b \\left ( \\sqrt{
\\left ( 1 + \\frac{y_i - f(x_i)}{\\delta} \\right )^2
} - 1 \\right )
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
boundary_scale:
The boundary value for the residual beyond which the loss becomes
approximately linear. Useful values depend on the scale of the
response.
Returns:
The sum of pseudo-Huber losses of the predictions.
"""
return _pseudo_huber_fn(predictions, targets, boundary_scale=boundary_scale)
[docs]def looph_fn(
predictions: mm.ndarray,
targets: mm.ndarray,
variances: mm.ndarray,
sigma_sq: mm.ndarray,
boundary_scale: float = 1.5,
) -> float:
"""
Variance-regularized pseudo-Huber loss function.
Computes a smooth approximation to the Huber loss function, similar to
:func:`pseudo_huber_fn`, with the addition of both a variance scaling and a
additive logarithmic variance regularization term to avoid exploding the
variance. The function computes
.. math::
l(f(x), y \\mid \\delta) = \\delta^2 \\sum_{i=1}^b \\left ( \\sqrt{
\\left ( 1 + \\frac{y_i - f(x_i)}{\\sigma_i \\delta} \\right )^2
} - 1 \\right ) + \\log \\sigma_i
Args:
predictions:
The predicted response of shape `(batch_count, response_count)`.
targets:
The expected response of shape `(batch_count, response_count)`.
variances:
The unscaled variance of the predicted responses of shape
`(batch_count, response_count)`.
sigma_sq:
The sigma_sq variance scaling parameter of shape
`(response_count,)`.
boundary_scale:
The boundary value for the residual beyond which the loss becomes
approximately linear. Useful values depend on the scale of the
response.
Returns:
The sum of pseudo-Huber losses of the predictions.
"""
return _looph_fn(
predictions, targets, variances, sigma_sq, boundary_scale=boundary_scale
)