from typing import Any, Literal, Optional, Tuple, Union

import torch
from sklearn.calibration import CalibrationDisplay
from torchmetrics import MeanAbsoluteError, MeanSquaredError
from torchmetrics.classification import BinaryCalibrationError
from import dim_zero_cat

[docs] class WeightedMeanSquaredError(MeanSquaredError): """ Analogue of `torchmetrics.MeanSquaredError` but with (optional) sample weights. """ def __init__( self, squared: bool = True, **kwargs: Any, ) -> None: super(MeanSquaredError, self).__init__(**kwargs) self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") self.squared = squared
[docs] def update( self, preds: torch.Tensor, target: torch.Tensor, sample_weights: Optional[torch.Tensor] = None, ) -> None: # type: ignore """Update state with predictions and targets.""" preds = preds if preds.is_floating_point is not None else preds.float() target = target if target.is_floating_point is not None else target.float() diff = preds - target sum_squared_error = torch.sum( diff * diff * (sample_weights if sample_weights is not None else 1) ) self.sum_squared_error += sum_squared_error += torch.sum(sample_weights) if sample_weights is not None else target.numel()
[docs] class WeightedMeanAbsoluteError(MeanAbsoluteError): """ Analogue of `torchmetrics.MeanAbsoluteError` but with (optional) sample weights. """ def __init__( self, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
[docs] def update( self, preds: torch.Tensor, target: torch.Tensor, sample_weights: Optional[torch.Tensor] = None, ) -> None: # type: ignore """Update state with predictions and targets.""" preds = preds if preds.is_floating_point is not None else preds.float() target = target if target.is_floating_point is not None else target.float() diff = preds - target sum_abs_error = torch.sum( torch.abs(diff) * (sample_weights if sample_weights is not None else 1) ) self.sum_abs_error += sum_abs_error += torch.sum(sample_weights) if sample_weights is not None else target.numel()
[docs] class WeightedBinaryCalibrationError(BinaryCalibrationError): """ Analogue of `torchmetrics.classification import BinaryCalibrationError` but with (optional) sample weights. """ def __init__( self, n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__( n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args, **kwargs, ) self.add_state("weights", [], dist_reduce_fx="cat")
[docs] def update( self, preds: torch.Tensor, target: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> None: # type: ignore super().update(preds=preds, target=target) if weights is not None: self.weights.append(weights) # type: ignore
[docs] def compute(self) -> torch.Tensor: confidences = dim_zero_cat(self.confidences) # type: ignore accuracies = dim_zero_cat(self.accuracies) # type: ignore weights = dim_zero_cat(self.weights) if self.weights else None # type: ignore return self._ce_compute( confidences=confidences, accuracies=accuracies, weights=weights, bin_boundaries=self.n_bins, norm=self.norm, )
def _ce_compute( self, confidences: torch.Tensor, accuracies: torch.Tensor, weights: Optional[torch.Tensor], bin_boundaries: Union[torch.Tensor, int], norm: str = "l1", debias: bool = False, ) -> torch.Tensor: """ Analogue of `torchmetrics.functional.classification.calibration_error._ce_compute` but with weights. Computes the calibration error given the provided bin boundaries and norm. Args: confidences: The confidence (i.e. predicted prob) of the top1 prediction. accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. bin_boundaries: Bin boundaries separating the ``linspace`` from 0 to 1. norm: Norm function to use when computing calibration error. Defaults to "l1". debias: Apply debiasing to L2 norm computation as in `Verified Uncertainty Calibration`_. Defaults to False. Raises: ValueError: If an unsupported norm function is provided. Returns: Tensor: Calibration error scalar. """ if isinstance(bin_boundaries, int): bin_boundaries = torch.linspace( 0, 1, bin_boundaries + 1, dtype=torch.float, device=confidences.device ) if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") # nosec B608 with torch.no_grad(): acc_bin, conf_bin, prop_bin = self._binning_bucketize( confidences, accuracies, weights, bin_boundaries ) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) elif norm == "max": ce = torch.max(torch.abs(acc_bin - conf_bin)) elif norm == "l2": ce = torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin) # NOTE: debiasing is disabled in the wrapper functions. This implementation differs from that in sklearn. if debias: # the order here (acc_bin - 1 ) vs (1 - acc_bin) is flipped from # the equation in Verified Uncertainty Prediction (Kumar et al 2019)/ debias_bins = (acc_bin * (acc_bin - 1) * prop_bin) / ( prop_bin * accuracies.size()[0] - 1 ) ce += torch.sum( torch.nan_to_num(debias_bins) ) # replace nans with zeros if nothing appeared in a bin ce = torch.sqrt(ce) if ce > 0 else torch.tensor(0) return ce def _binning_bucketize( self, confidences: torch.Tensor, accuracies: torch.Tensor, weights: Optional[torch.Tensor], bin_boundaries: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Analogue of `torchmetrics.functional.classification.calibration_error._binning_bucketize` but with weights. Compute calibration bins using ``torch.bucketize``. Use for pytorch >= 1.6. Args: confidences: The confidence (i.e. predicted prob) of the top1 prediction. accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. bin_boundaries: Bin boundaries separating the ``linspace`` from 0 to 1. Returns: tuple with binned accuracy, binned confidence and binned probabilities """ accuracies = acc_bin = torch.zeros( len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype ) conf_bin = torch.zeros( len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype ) # count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) weights_bin = torch.zeros( len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype if weights is None else weights.dtype, ) indices = torch.clamp( torch.bucketize(confidences, bin_boundaries, right=True) - 1, min=0, max=len(weights_bin) - 1, ) # count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) weights_bin.scatter_add_( dim=0, index=indices, src=torch.ones_like(confidences) if weights is None else weights, ) conf_bin.scatter_add_( dim=0, index=indices, src=confidences * (1 if weights is None else weights) ) # conf_bin = torch.nan_to_num(conf_bin / count_bin) conf_bin = torch.nan_to_num(conf_bin / weights_bin) acc_bin.scatter_add_( dim=0, index=indices, src=accuracies * (1 if weights is None else weights) ) # acc_bin = torch.nan_to_num(acc_bin / count_bin) acc_bin = torch.nan_to_num(acc_bin / weights_bin) # prop_bin = count_bin / count_bin.sum() prop_bin = weights_bin / weights_bin.sum() # Backup acc_bin, conf_bin, prop_bin to be used in calibration_curve method self.acc_bin_backup = acc_bin self.conf_bin_backup = conf_bin self.prop_bin = prop_bin return acc_bin, conf_bin, prop_bin
[docs] def calibration_curve(self, estimator_name: str, pos_label: int = 1) -> CalibrationDisplay: return ( CalibrationDisplay( prob_true=self.acc_bin_backup[self.prop_bin > 0].numpy(force=True), prob_pred=self.conf_bin_backup[self.prop_bin > 0].numpy(force=True), y_prob=None, estimator_name=estimator_name, pos_label=pos_label, ) .plot() .figure_ )