Source code for pytorch_utils.utils

from dataclasses import asdict, dataclass, field
from typing import Mapping, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torchmetrics

from pytorch_utils.exceptions import (
    NotMonotone,
    NotNonDecreasing,
    NotNonIncreasing,
)
from pytorch_utils.logging.loggers import (
    Logger,
    VoidLogger,
    use_loggers,
)

NamedTorchTensors = Mapping[str, torch.Tensor]
NamedTorchModules = Mapping[str, nn.Module]
NamedTorchMetrics = Mapping[str, torchmetrics.Metric]
BatchTorchTensors = Tuple[NamedTorchTensors, torch.Tensor, Optional[torch.Tensor]]


[docs] @dataclass(frozen=True, order=True) class CategoricalFeatureEmbedding: """ Collection that simplifies the constructor parameters. """ feature_name: str nb_distinct_values: int embedding_size: int logger: Logger = field(default=VoidLogger(), repr=False, compare=False) def __post_init__(self): # Check attributes # ------------------------ if self.nb_distinct_values < 0: raise ValueError("nb_distinct_values must be non-negative") if self.embedding_size < 0: raise ValueError("embedding_size must be non-negative") # Make this class comply with Loggable Protocol object.__setattr__(self, "loggers", {"logger": self.logger})
[docs] @use_loggers("logger") def log(self, logger: Logger) -> None: params = asdict(self) del params["logger"] logger.log_params(params)
[docs] def get_embedding_size( nb_categories: int, multiplicative_factor: float = 1.6, power_exponent: float = 0.56, max_size: int = 600, ) -> int: """ Determine empirically good embedding sizes (formula taken from fastai: https://docs.fast.ai/tabular.model.html). Parameters ---------- nb_categories: int number of categories max_size: (int, optional) maximum embedding size. Defaults to 600. Returns ------- int embedding size """ if nb_categories > 2: return min(round(multiplicative_factor * nb_categories**power_exponent), max_size) else: return 1
[docs] def assert_monotone( inputs: np.ndarray, outputs: np.ndarray, non_decreasing: Optional[bool] = None, error_message: str = "", tol: float = 1e-5, ) -> None: """ Asserts if the `outputs` are a monontone function of the `inputs`. `inputs` should be a 1-dimensional. `outputs` should be a 1 or 2-dimensional array. If 2-dimensional, then every row is tested to be a monotone function of the inputs. If `non_decreasing` is `None` then an exception is raised only if the mapping is neither non-decreasing, nor non-increasing. If `non_decreasing` is `True` then an exception is raised only if the mapping is not non-decreasing. If `non_decreasing` is `False` then an exception is raised only if the mapping is not non-increasing. """ if len(inputs.shape) != 1: raise ValueError("inputs should be a 1-dimensional array") if len(outputs.shape) not in [1, 2]: raise ValueError("outputs should be a 1 or 2-dimensional array") if inputs.shape[0] != outputs.shape[-1]: raise ValueError("the last dimension of outputs should match the dimension of inputs") argsort_indices = np.argsort(inputs) is_non_decreasing = np.all(np.diff(outputs[..., argsort_indices], n=1, axis=1) >= -tol) is_non_increasing = np.all(np.diff(outputs[..., argsort_indices], n=1, axis=1) <= tol) if (non_decreasing is None) and not (is_non_decreasing or is_non_increasing): raise NotMonotone(error_message) elif (non_decreasing is True) and not is_non_decreasing: raise NotNonIncreasing(error_message) elif (non_decreasing is False) and not is_non_increasing: raise NotNonDecreasing(error_message)