Source code for jax_utils.vectorization

"""Classes and methods to easily vectorize jax tensors"""

from __future__ import annotations

from functools import wraps
from typing import Callable, Hashable, Optional, Protocol, TypeVar, runtime_checkable

import jax_dataclasses as jdc
import optax
from jax import vmap
from typing_extensions import Self

from jax_utils.compilation import (
    BaseJaxCompilable,
    JaxCompilableProtocol,
    jit_when_compilation_enabled,
)
from jax_utils.pytree import ConvertibleToAxes, pytree_to_axes
from jax_utils.jax_tensor import AverageableJaxArrayLike
from jax_utils.typing import DataclassInstance

State = TypeVar("State")
Action_contra = TypeVar("Action_contra", contravariant=True)
Action = TypeVar("Action")
Observation_co = TypeVar("Observation_co", covariant=True)
Cost_co = TypeVar("Cost_co", covariant=True, bound=AverageableJaxArrayLike)
OptimizerState = TypeVar("OptimizerState", bound=optax.OptState)
AxisType = TypeVar("AxisType", bound=Hashable)
AxisType_contra = TypeVar("AxisType_contra", contravariant=True, bound=Hashable)


[docs] @runtime_checkable class JaxVectorizableProtocol(JaxCompilableProtocol, ConvertibleToAxes[AxisType], Protocol): """Interface for classes manipulating JAX arrays. It defines a special attribute ``vectorized_axis`` corresponding to the "named" axis over which to apply `vectorization <https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html>`_. "Vectorizable" classes are also "compilable" (see interface :class:`jax_utils.compilation.JaxCompilableProtocol`) and convertible to axes (see :class:`jax_utils.pytree.ConvertibleToAxes` interface). Args: vectorized_axis (AxisType): "named" axis to be vectorized """ vectorized_axis: AxisType
[docs] def vectorize( in_default_axis: Optional[int] = None, out_default_axis: Optional[int] = -1, ) -> Callable[[Callable], Callable]: """Parametrized decorator for methods of classes implementing :class:`JaxVectorizableProtocol` interface. It allows to `vectorize <https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html>`_ the decorated methods according to a specified axis, namely attribute :attr:`JaxVectorizableProtocol.vectorized_axis` (see :class:`JaxVectorizableProtocol`). The `pytrees <https://jax.readthedocs.io/en/latest/pytrees.html>`_ of all inputs and outputs of the decorated method are explored and every JAX array as well as every class implementing the :class:`jax_utils.pytree.ConvertibleToAxes` interface is considered a leaf of the pytree. When a :class:`jax_utils.pytree.ConvertibleToAxes` leaf is encountered, method :meth:`jax_utils.pytree.ConvertibleToAxes.convert_to_axes` is used to determine the dimension to vectorize. When a JAX array leaf or any other leaf is encountered instead, the default value ``in_default_axis`` (for input arguments) or ``out_default_axis`` (for output arguments) is used to determine the dimension to vectorize. Args: in_default_axis (Optional[int], optional): default index used in ``in_axes`` argument of method `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`_ when a leaf that is not :class:`jax_utils.pytree.ConvertibleToAxes` is encountered in an input pytree. Defaults to None. out_default_axis (Optional[int], optional): default index used in ``out_axes`` argument of method `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`_ when a leaf that is not :class:`jax_utils.pytree.ConvertibleToAxes` is encountered in an output pytree. Defaults to -1. Returns: Callable[[Callable], Callable]: decorator with parameters ``in_default_axis`` and ``out_default_axis`` specified """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(self: JaxVectorizableProtocol, *args, **kwargs): if not hasattr(self, "_vmap_functions"): object.__setattr__(self, "_vmap_functions", {}) vmap_functions = getattr(self, "_vmap_functions") in_axes = pytree_to_axes( (self,) + tuple(args) + tuple(kwargs.values()), self.vectorized_axis, default_axis=in_default_axis, ) if in_axes not in vmap_functions: outputs = func(self, *args, **kwargs) out_axes = pytree_to_axes( outputs, self.vectorized_axis, default_axis=out_default_axis ) vmap_functions[in_axes] = vmap(func, in_axes=in_axes, out_axes=out_axes) return vmap_functions[in_axes](self, *args, **kwargs) return jit_when_compilation_enabled()(wrapper) return decorator
[docs] @runtime_checkable class JaxDataclassNestedConvertibleToAxes( DataclassInstance, ConvertibleToAxes[AxisType_contra], BaseJaxCompilable, Protocol[AxisType_contra], ): """Interface for dataclasses that can be jit-compiled and are "convertible to axes", with a concrete implementation of ``convert_to_axes`` method that is specific to dataclasses. """
[docs] def convert_to_axes(self, axis: Optional[AxisType_contra]) -> Self: """Concrete implementation of ``convert_to_axes`` for dataclasses. All the fields of the dataclass are inspected and method ``convert_to_axes`` is applied whenever the field is ``ConvertibleToAxes``. Args: axis (Optional[AxisType_contra]): a "named" axis Returns: Self: same object as ``self`` but with array-like fields converted to axes """ with jdc.copy_and_mutate(self, validate=False) as dc: for field in jdc.fields(self): field_value = getattr(self, field.name) if isinstance(field_value, ConvertibleToAxes): setattr(dc, field.name, field_value.convert_to_axes(axis)) return dc