Source code for jax_utils.common_tensor

"""Wrappers around numpy/jax arrays"""

from __future__ import annotations

from numbers import Number
from typing import (
    AbstractSet,
    Any,
    ClassVar,
    Dict,
    Hashable,
    Iterable,
    Optional,
    Protocol,
    Tuple,
    TypeAlias,
    TypeVar,
    Union,
    runtime_checkable,
)

import jax.numpy as jnp
import numpy as np
from ordered_set import OrderedSet, OrderedSetInitializer
from typing_extensions import Self

from jax_utils.typing import DataclassInstance

Array: TypeAlias = Union[jnp.ndarray, np.ndarray]
Scalar: TypeAlias = Union[float, int]


[docs] def check_ndim_in(array: Array, allowed_ndims: Iterable[int]): """Checks if the number of dimensions matches some allowed values. Args: array (Array): numpy or JAX array allowed_ndims (Iterable[int]): number of dimensions allowed for ``array`` Raises: ValueError: when the number of dimensions of ``array`` is not present in ``allowed_shapes``. """ if array.ndim not in allowed_ndims: raise ValueError(f"Dimension of delivery should be in {allowed_ndims}")
[docs] def is_broadcastable(shape_1: Tuple[int, ...], shape_2: Tuple[int, ...]) -> bool: """Whether the shapes of 2 arrays/tensors can be `broadcasted <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_. Args: shape_1 (Tuple[int, ...]): shape of the 1st array shape_2 (Tuple[int, ...]): shape of the 2nd array Returns: bool: ``True`` if and only if the 2 arrays/tensors can be broadcasted, False otherwise. """ for dim1, dim2 in zip(shape_1[::-1], shape_2[::-1]): if not (dim1 == 1 or dim2 == 1 or dim1 == dim2): return False return True
T = TypeVar("T", bound=Hashable)
[docs] class TensorAxes(OrderedSet[T]): """An ordered set of "named" array/tensor axes. The position in the set defines the axis index in the tensor. The first ``len(self) - tensor_min_dim`` axes are optional: if absent, they are assumed to be "flattened". The reason for not making the last ``tensor_min_dim`` axes optional instead of the first ``len(self) - tensor_min_dim`` axes is to follow the logic of array broadcasting (first dimension can be omitted). """ def __init__( self, initial: Optional[OrderedSetInitializer[T]] = None, tensor_min_dim: int = 0, ): super().__init__(initial) # type: ignore if tensor_min_dim > len(self): raise ValueError( "tensor_min_shape_length should be smaller than the number of specified axes" ) if tensor_min_dim < 0: raise ValueError("tensor_min_shape_length should be greater than or equal to 0") self.tensor_min_nb_dim = tensor_min_dim
[docs] def reverse_index(self, key: T) -> int: """Computes the index of element ``key`` but rather than return a non-negative integer like method ``index``, it returns a negative integer e.g., -1 if for the last element, -2 for the penultimate element, ... Args: key (T): any element present in ``self`` (which is an ordered set) Returns: int: negative integer corresponding to the position of ``key`` in ``self`` (starting from the last and decrementing by 1 for each element) """ return -(len(self) - self.index(key))
@property def _first_mandatory_idx(self) -> int: return -self.tensor_min_nb_dim if self.tensor_min_nb_dim > 0 else len(self) @property def mandatory(self) -> OrderedSet[T]: """ Returns: OrderedSet[T]: ordered set of non-optional axes. """ return OrderedSet(list(self[self._first_mandatory_idx :])) @property def optional(self) -> OrderedSet[T]: """ Returns: OrderedSet[T]: ordered set of optional axes. """ return OrderedSet(list(self[: self._first_mandatory_idx])) def __repr__(self) -> str: return type(self).__name__ + "\n| ".join( [""] + [f"{self.reverse_index(axis)} (optional): {axis}" for axis in self.optional] + [f"{self.reverse_index(axis)}: {axis}" for axis in self.mandatory] )
ArrayType = TypeVar("ArrayType", bound=Array) AxisType = TypeVar("AxisType", bound=Hashable)
[docs] def expand_dims_axis( tensor_axes: TensorAxes[AxisType], missing_axes: AbstractSet[AxisType] ) -> Tuple[int, ...]: """Compute ``axis`` argument to pass to ``numpy.expand_dims(a, axis)`` function when some axes in ``tensor_axes`` are not present in the numpy array ``a`` (they are "flattened"). Args: tensor_axes (TensorAxes[AxisType]): some array/tensor axes missing_axes (AbstractSet[AxisType]): missing axes in the array/tensor Raises: ValueError: raises an error when ``missing_axes`` is not a subset of ``tensor_axes`` Returns: Tuple[int, ...]: the ``axis`` argument to pass to ``numpy.expand_dims(a, axis)`` """ if not set(missing_axes).issubset(tensor_axes): raise ValueError( f"The following missing axes are invalid: {', '.join(str(axis) for axis in missing_axes - tensor_axes)}. " f"Only the following axes are valid: {', '.join(str(axis) for axis in tensor_axes)}" ) return tuple( tensor_axes.reverse_index(axis) for axis in tensor_axes[::-1] if axis in missing_axes )
ArrayType_co = TypeVar("ArrayType_co", bound=Array, covariant=True)
[docs] class AverageableArrayLike(Protocol[ArrayType_co]): """Shared interface of all classes with a ``mean`` method return a scalar array (corresponding to the mean values of the initial array). Example of classes implementing this interface: (jax) numpy arrays, ... """
[docs] def mean(self, *args, **kwargs) -> ArrayType_co: """ Returns: ArrayType_co: Should return a scalar array """
[docs] @runtime_checkable class Tensor(DataclassInstance, AverageableArrayLike, Protocol[ArrayType, AxisType]): """ A wrapper for numpy/jax arrays with the following additional features: - axes are "named" to facilitate manipulation and debugging (an axis "name" could be a string or any other hashable ``AxisType``) - the returned ``values`` of the tensor can be different from the actual ``array`` given as input at construction, this allows to easily implement "change of variables" (a.k.a., "substitution") Args: _tensor_axes (ClassVar[TensorAxes[AxisType]]): class attribute defining the axes of the ``array`` array (ArrayType): a (jax) numpy array containing all relevant data """ _tensor_axes: ClassVar[TensorAxes[AxisType]] # type: ignore[misc] array: ArrayType # TODO: to facilitate vectorization with vmap, it could be convenient # pylint: disable=W0511 # for the user to define an optional `values_axes` attribute (None by default). # When not None, `values` property would basically perform a reshape of `array` attribute # (permutation of axes). This would allow to put the `vectorized_axis` as the first axis in # all input tensors of a vectorized method (see vectorization.py) def __post_init__(self): self.check_array()
[docs] def check_array(self): """Check the validity of the ``array`` attribute at construction. To be overriden if needed. """ check_ndim_in( self.array, range(self.axes.tensor_min_nb_dim, len(type(self)._tensor_axes) + 1) )
def __getitem__(self, key): return type(self)(array=self.array[key]) # type: ignore[call-arg]
[docs] def getitem_from_axes(self, axes_keys: Dict[AxisType, Any]) -> Self: """Analogue of method ``__getitem__`` but where array slicing/indexing is explicitly applied to named axis. Args: axes_keys (Dict[AxisType, Any]): a mapping between axes names and slices/list of indices/... Returns: Self: a new ``Tensor`` of the same type with restricted data """ array_key = tuple( axes_keys[axis] if axis in axes_keys else slice(None, None, None) for axis in self.actual_axes ) return self[array_key]
def __neg__(self) -> Self: return self.__class__(array=-self.array) # type: ignore[call-arg] def __abs__(self) -> Self: return self.__class__(array=abs(self.array)) # type: ignore[call-arg] def __add__(self, other: Union[Scalar, Array, Tensor[ArrayType, AxisType]]) -> Self: if isinstance(other, Tensor): return self.__class__(array=self.array + other.array) # type: ignore[call-arg] return self.__class__(array=self.array + other) # type: ignore[call-arg] def __sub__(self, other: Union[Scalar, Array, Tensor[ArrayType, AxisType]]) -> Self: if isinstance(other, Tensor): return self.__class__(array=self.array - other.array) # type: ignore[call-arg] return self.__class__(array=self.array - other) # type: ignore[call-arg] def __mul__(self, other: Union[Scalar, Array, Tensor[ArrayType, AxisType]]) -> Self: if isinstance(other, Tensor): return self.__class__(array=self.array * other.array) # type: ignore[call-arg] return self.__class__(array=self.array * other) # type: ignore[call-arg] def __truediv__(self, other: Union[Scalar, Array, Tensor[ArrayType, AxisType]]) -> Self: if isinstance(other, Tensor): return self.__class__(array=self.array / other.array) # type: ignore[call-arg] return self.__class__(array=self.array / other) # type: ignore[call-arg] @classmethod def _expand_dims_axis(cls, missing_axes: AbstractSet[AxisType]): return expand_dims_axis(cls._tensor_axes, missing_axes) @property def values(self) -> ArrayType: """Override to apply transformations on ``self.array`` ("change of variable"). For example, if one wants to manipulate (jax) tensors constrained to always have non-negative values in any situation (even if the tensor gets updated by some gradient descent procedure for instance), then one can override ``values`` by returning ``jax.nn.relu(self.array)``. By default ``values`` implements the identity i.e., returns ``self.array``. Returns: ArrayType: the transformed ``array`` """ return self.array
[docs] def reverse_values(self, values: ArrayType) -> ArrayType: """Inverse of the transformation implemented in ``values`` property. By default, ``self.array = self.values`` so calling ``values`` is equivalent to calling ``array``. When the transformation mapping ``array`` to ``values`` is not the identity, ``reverse_values`` should be overriden accordingly. Args: values (ArrayType): the output of property ``values`` Returns: ArrayType: the array obtained by apply the inverse transformation of ``values`` """ return values
@property def dtype(self) -> np.dtype: """Get the type of ``self.values``""" return self.values.dtype @property def shape(self) -> Tuple[int, ...]: """Get the shape of ``self.values``""" return self.values.shape @property def ndim(self) -> int: """Get the number of dimensions of ``self.values``""" return self.values.ndim @property def axes(self) -> TensorAxes[AxisType]: """Possible "named" axes of the tensor""" return type(self)._tensor_axes @property def actual_axes(self) -> TensorAxes[AxisType]: """Actual axes of the tensor""" return TensorAxes( [self.axes[i] for i in range(-len(self.shape), 0)], tensor_min_dim=self.axes.tensor_min_nb_dim, )
[docs] def mean(self, *args, **kwargs) -> Union[ArrayType, Any]: """Compute the mean on ``self.values`` on specified axes""" return self.values.mean(*args, **kwargs)
[docs] def mean_over_axes(self, axes: AbstractSet[AxisType]) -> Union[ArrayType, Any]: """Averages ``values`` along one or several named axes. Args: axes (AbstractSet[AxisType]): axes along which the means are computed Returns: Union[ArrayType, Any]: new array containing the mean ``values`` """ axes_to_mean_over = tuple(self.index(axis) for axis in self.actual_axes.intersection(axes)) return self.mean(axis=axes_to_mean_over)
[docs] def sum(self, *args, **kwargs) -> Union[ArrayType, Any]: """Get the sum of the array""" return self.values.sum(*args, **kwargs)
[docs] def sum_over_axes(self, axes: AbstractSet[AxisType]) -> Union[ArrayType, Any]: """Sums ``values`` along one or several named axes. Args: axes (AbstractSet[AxisType]): axes along which the sums are computed Returns: Union[ArrayType, Any]: new array containing the summed ``values`` """ axes_to_mean_over = tuple(self.index(axis) for axis in self.actual_axes.intersection(axes)) return self.sum(axis=axes_to_mean_over)
[docs] def has(self, axis: AxisType) -> bool: """Whether ``axis`` is one of the possible axes of the tensor. Args: axis (AxisType): a "named" axis Returns: bool: ``True`` if ``axis`` is one of the possible axes of the tensor, ``False`` otherwise. """ return axis in self.axes
[docs] def has_actual(self, axis: AxisType) -> bool: """Whether ``axis`` is one of the actual axes of the tensor. Args: axis (AxisType): a "named" axis Returns: bool: ``True`` if ``axis`` is one of the actual axes of the tensor, ``False`` otherwise. """ return axis in self.actual_axes
[docs] def index(self, axis: AxisType) -> int: """Index of ``axis`` in tensor i.e., returns ``i`` if and only if ``axis`` is the ``i``-th dimension of the tensor. Args: axis (AxisType): a "named" axis Returns: int: the corresponding index """ return self.actual_axes.index(axis)
[docs] def reverse_index(self, axis: AxisType) -> int: """Index of ``axis`` in tensor but in reverse order i.e., returns ``-i`` if and only if ``axis`` is the ``(n - i)``-th dimension of the tensor (with ``n`` the total number of dimension). dimension of the tensor. Args: axis (AxisType): a "named" axis Returns: int: the corresponding index in reverse order """ return self.actual_axes.reverse_index(axis)
[docs] def size( self, axis: AxisType, ) -> int: """The size of a given ``axis``. Returns 0 when the ``axis`` is "flattened". Args: axis (AxisType): a "named" axis Raises: ValueError: if ``axis`` is not a possible axis Returns: int: the size of ``axis`` (0 when the ``axis`` is "flattened") """ if not self.has(axis): raise ValueError( f"{axis} is not a valid axis." f"Valid axes are: {', '.join(str(a) for a in self.axes)}" ) if not self.has_actual(axis): return 0 reverse_idx = self.reverse_index(axis) return self.shape[reverse_idx]
[docs] def is_broadcastable_with( self, other_tensor_or_shape: Union[Tensor[ArrayType, AxisType], Tuple[int, ...]], ) -> bool: """Whether this tensor can be `broadcasted <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_ with ``other_tensor_or_shape``. Args: other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]): another tensor Returns: bool: ``True`` if and only if this tensor can be broadcasted with ``other_tensor_or_shape``, ``False`` otherwise. """ return is_broadcastable( self.shape, other_tensor_or_shape if isinstance(other_tensor_or_shape, tuple) else other_tensor_or_shape.shape, )
[docs] class RegularizedArrayLikeCost(AverageableArrayLike[ArrayType], Protocol[ArrayType]): """Interface for "regularized" costs i.e., costs of the form: ``cost + lagrangian_coefficient * regularization`` `More about regularization <https://en.wikipedia.org/wiki/Regularization_(mathematics)>`_. Args: cost (AverageableArrayLike[ArrayType]): an array-like collection of costs regularization (AverageableArrayLike[ArrayType]): an array-like collection regularization costs lagrangian_coefficient (Number): a non-negative number quantifying the regularization weight """ cost: AverageableArrayLike[ArrayType] regularization: AverageableArrayLike[ArrayType] lagrangian_coefficient: Number = 1 # type: ignore[assignment]
[docs] def mean(self, *args, **kwargs) -> Union[ArrayType, Any]: return ( self.cost.mean(*args, **kwargs) + self.regularization.mean(*args, **kwargs) * self.lagrangian_coefficient )