Source code for jax_utils.tranform

"""Jax arrays transformations (scaling, ...)"""

from __future__ import annotations

from typing import Any, Generic, Hashable, Set, Tuple, Type, TypeVar

import jax.numpy as jnp
import jax_dataclasses as jdc
from jax.tree_util import tree_map

from jax_utils.compilation import BaseJaxCompilable, jit_when_compilation_enabled
from jax_utils.jax_tensor import JaxTensor
from jax_utils.common_tensor import TensorAxes, expand_dims_axis

AxisType = TypeVar("AxisType", bound=Hashable)


[docs] def scale_jax_tensor( tensor: JaxTensor[AxisType], scaling_factor: jnp.ndarray ) -> JaxTensor[AxisType]: """Scale the ``values`` of a ``JaxTensor`` i.e., multiply ``values`` by ``scaling_factor``. Note that property ``values`` in ``JaxTensor`` differ from attribute ``array``. Args: tensor (JaxTensor[AxisType]): jax tensor to be scaled scaling_factor (jnp.ndarray): the scaling factors to apply for scaling Returns: JaxTensor[AxisType]: scaled jax tensor """ with jdc.copy_and_mutate(tensor, validate=True) as scaled_tensor: scaled_tensor.array = scaled_tensor.reverse_values(scaled_tensor.values * scaling_factor) return scaled_tensor
[docs] def unscale_jax_tensor( scaled_tensor: JaxTensor[AxisType], scaling_factor: jnp.ndarray ) -> JaxTensor[AxisType]: """Reverse transformation of function ``scale_jax_tensor``. Args: tensor (JaxTensor[AxisType]): jax tensor to be unscaled scaling_factor (jnp.ndarray): the scaling factors applied when scaling the original jax tensor Returns: JaxTensor[AxisType]: unscaled jax tensor """ with jdc.copy_and_mutate(scaled_tensor, validate=True) as tensor: tensor.array = scaled_tensor.reverse_values(tensor.values / scaling_factor) return tensor
[docs] @jdc.pytree_dataclass(frozen=True) class JaxScaler(BaseJaxCompilable, Generic[AxisType]): """When applying gradient descent optimization algorithms, it is often helpful to scale all the arrays/tensors involved in the computations. This class allows to easily scale/unscale :class:`jax_utils.jax_tensor.JaxTensor`'s (even when the tensors are stored in a nested pytree structure). Args: scaling_factors (jnp.ndarray): Multiplicative factors for scaling. Default to ``jnp.array(1.0)``. scaling_axes (jdc.Static[Set[AxisType]]): "Named" axes of ``scaling_factors``. Thus, the number of axes should match ``scaling_factors.ndim``. Default to ``jdc.field(default_factory=set)``. tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]): Object types that should be scaled in a pytree. Default to :class:`jax_utils.jax_tensor.JaxTensor`. tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]): Object types that should not be scaled in a pytree. Default to ``tuple()``. """ scaling_factors: jnp.ndarray = jnp.array(1.0) scaling_axes: jdc.Static[Set[AxisType]] = jdc.field(default_factory=set) tensor_types_to_scale: jdc.Static[Tuple[Type[JaxTensor]]] = (JaxTensor,) tensor_types_not_to_scale: jdc.Static[Tuple[JaxTensor]] = tuple() # type: ignore[assignment] def __post_init__(self): if self.scaling_factors.ndim != len(self.scaling_axes): raise ValueError( "the number of scaling_axes should math the number of dimensions of scaling_factors" ) if len(set(self.tensor_types_to_scale) & set(self.tensor_types_not_to_scale)) > 0: raise ValueError( "tensor_types_to_scale and tensor_types_not_to_scale should be disjoint" )
[docs] @classmethod def from_tensor( cls, tensor: JaxTensor[AxisType], scaling_axes: Set[AxisType], tensor_types_to_scale: jdc.Static[Tuple[Type[JaxTensor]]] = (JaxTensor,), tensor_types_not_to_scale: jdc.Static[Tuple[JaxTensor]] = tuple(), # type: ignore[assignment] factor: float = 1.0, ) -> JaxScaler: """Alternative constructor that automatically computes the ``scaling_factors`` based on a :class:`jax_utils.jax_tensor.JaxTensor`, the ``scaling_axes`` and a target ``factor``. If ``my_scaler = JaxScaler.from_tensor(tensor=my_tensor, scaling_axes=my_axes, factor=my_factor)`` then ``my_scaler.scale(my_tensor)`` will be a :class:`jax_utils.jax_tensor.JaxTensor` with ``my_tensor.mean_over_axes(axes=my_axes)`` being equal to a "constant" Jax array with all elements equal to ``my_factor``. Args: tensor (JaxTensor[AxisType]): a jax tensor scaling_axes (Set[AxisType]): a set of "named" axes present in `tensor` tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]): Object types that should be scaled in a pytree. Default to :class:`jax_utils.jax_tensor.JaxTensor`. tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]): Object types that should not be scaled in a pytree. Default to ``tuple()``. Returns: JaxScaler: a scaler instance """ axes_to_mean_over = tensor.actual_axes - scaling_axes return cls( scaling_factors=factor / jnp.maximum(tensor.mean_over_axes(axes=axes_to_mean_over), 1), scaling_axes=scaling_axes, tensor_types_to_scale=tensor_types_to_scale, tensor_types_not_to_scale=tensor_types_not_to_scale, )
[docs] @jit_when_compilation_enabled() def scale(self, pytree: Any) -> Any: """Scales any pytree containing :class:`jax_utils.jax_tensor.JaxTensor`'s Args: pytree (Any): pytree to be scaled Returns: Any: the scaled pytree """ return tree_map( lambda x: scale_jax_tensor( x, self.expand_scaling_factors( tensor_axes=x.actual_axes, missing_axes=x.actual_axes - self.scaling_axes ), ) if ( isinstance(x, self.tensor_types_to_scale) and not isinstance(x, self.tensor_types_not_to_scale) # type: ignore[arg-type] ) else x, pytree, is_leaf=lambda x: isinstance(x, JaxTensor), )
[docs] @jit_when_compilation_enabled() def unscale(self, scaled_pytree: Any) -> Any: """Unscales any pytree containing :class:`jax_utils.jax_tensor.JaxTensor`'s (reverse transformation of method ``scale``). Args: scaled_pytree (Any): pytree to be unscaled Returns: Any: the unscaled pytree """ return tree_map( lambda x: unscale_jax_tensor( x, self.expand_scaling_factors( tensor_axes=x.actual_axes, missing_axes=x.actual_axes - self.scaling_axes ), ) if ( isinstance(x, self.tensor_types_to_scale) and not isinstance(x, self.tensor_types_not_to_scale) # type: ignore[arg-type] ) else x, scaled_pytree, is_leaf=lambda x: isinstance(x, JaxTensor), )
[docs] @jit_when_compilation_enabled() def expand_scaling_factors( self, tensor_axes: TensorAxes[AxisType], missing_axes: Set[AxisType] ) -> jnp.ndarray: """Expands the dimensions of ``self.scaling_factors`` to match the ``tensor_axes`` given as inputs. This method only adds new dimensions of size 1. Args: tensor_axes (TensorAxes[AxisType]): a set of "named" axes missing_axes (Set[AxisType]): Returns: jnp.ndarray: same as ``self.scaling_factors`` but with (empty) additional dimensions. The total number of dimensions of the returned array is equal to ``len(tensor_axes)``. """ return jnp.expand_dims( self.scaling_factors, axis=expand_dims_axis(tensor_axes=tensor_axes, missing_axes=missing_axes), )