Source code for jax_utils.gradient

"""Classes and methods to compute the gradient of functions involving jax arrays transformations"""

from __future__ import annotations

from typing import Any, Generic, Hashable, Protocol, Tuple, TypeVar

import jax_dataclasses as jdc
import optax

from jax_utils.compilation import jit_when_compilation_enabled
from jax_utils.dynamics import JaxDynamics
from jax_utils.jax_tensor import AverageableJaxArrayLike
from jax_utils.vectorization import JaxDataclassNestedConvertibleToAxes, vectorize

State = TypeVar("State")
State_contra = TypeVar("State_contra", contravariant=True)
Action = TypeVar("Action")
Observation_co = TypeVar("Observation_co", covariant=True)
Cost_co = TypeVar("Cost_co", covariant=True, bound=AverageableJaxArrayLike)
OptimizerState = TypeVar("OptimizerState", bound=optax.OptState)
GradientUpdates = TypeVar("GradientUpdates", bound=optax.Updates)
AxisType = TypeVar("AxisType", bound=Hashable)
AxisType_contra = TypeVar("AxisType_contra", contravariant=True, bound=Hashable)


[docs] class BaseGradientStep( JaxDataclassNestedConvertibleToAxes[AxisType_contra], Protocol[AxisType_contra, State_contra, Action, Cost_co, OptimizerState], ): """Interface for gradient steps on the :class:`jax_utils.markov_decision_process.Dynamics` cost of a Markov Decision Process.""" # pylint: disable=C0116
[docs] def init_optimizer(self, init_action: Action) -> OptimizerState: ...
# pylint: disable=C0116
[docs] def compute_gradient(
self, state: State_contra, action: Action, ) -> Tuple[Cost_co, Action]: ... # pylint: disable=C0116
[docs] def update(
self, action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState ) -> Tuple[Action, OptimizerState]: ... def __call__( self, state: State_contra, action: Action, opt_state: OptimizerState, ) -> Tuple[Action, OptimizerState, Cost_co]: ...
[docs] @jdc.pytree_dataclass(frozen=True) class GradientStep( BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState], Generic[AxisType, State, Action, Cost_co, Observation_co, OptimizerState], ): """The gradient step on the :class:`jax_utils.markov_decision_process.Dynamics` cos of a Markov Decision Process. Args: optimizer (jdc.Static[optax.GradientTransformation]): an optimizer (defining the stochastic gradient descent variant: RMSProp, Adam, etc..) dynamics (JaxDynamics[Any, State, Action, Cost_co, Observation_co]): an MDP dynamics """ optimizer: jdc.Static[optax.GradientTransformation] dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co]
[docs] @jit_when_compilation_enabled() def init_optimizer(self, init_action: Action) -> OptimizerState: """Initialize optimizer Args: init_action (Action): initial action before optimization starts Returns: OptimizerState: initialized optimizer state """ return self.optimizer.init(init_action)
[docs] @jit_when_compilation_enabled() def compute_gradient( self, state: State, action: Action, ) -> Tuple[Cost_co, Action]: """Given a state-action pair of the MDP, computes the cost and gradient of the cost of a ``dynamics`` w.r.t. the action. Args: state (State): state of the MDP action (Action): action of the MDP Returns: Tuple[Cost_co, Action]: cost and gradient of the action """ return self.dynamics.compute_gradient( state=state, action=action, )
[docs] @jit_when_compilation_enabled() def update( self, action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState ) -> Tuple[Action, OptimizerState]: """Update the action and optimizer state Args: action (Action): action of the MDP gradient_value (GradientUpdates): gradient of the action opt_state (OptimizerState): optimizer state Returns: Tuple[Action, OptimizerState]: the new action and optimizer state """ updates, updated_opt_state = self.optimizer.update(gradient_value, opt_state, action) updated_action = optax.apply_updates(action, updates) return updated_action, updated_opt_state
@jit_when_compilation_enabled() def __call__( self, state: State, action: Action, opt_state: OptimizerState, ) -> Tuple[Action, OptimizerState, Cost_co]: """Computes the gradient of the cost w.r.t. the action and updates the action and optimizer state. Args: state (State): state of the MDP action (Action): action of the MDP opt_state (OptimizerState): optimizer state Returns: Tuple[Action, OptimizerState, Cost_co]: a new action, optimizer state and the value of the cost associated to the previous action """ cost_value, gradient_value = self.compute_gradient( state=state, action=action, ) updated_action, updated_opt_state = self.update(action, gradient_value, opt_state) return updated_action, updated_opt_state, cost_value
[docs] @jdc.pytree_dataclass(frozen=True) class VectorizedGradientStep(BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]): """A gradient step that is `vectorized <https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html>`_ on a specific axis Args: gradient_step (BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]): the gradient step to be vectorized vectorized_axis (jdc.Static[AxisType]): "named" axis to be vectorized """ gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState] vectorized_axis: jdc.Static[AxisType]
[docs] @jit_when_compilation_enabled() def init_optimizer(self, init_action: Action) -> OptimizerState: return self.gradient_step.init_optimizer(init_action)
[docs] @vectorize() def compute_gradient( self, state: State, action: Action, ) -> Tuple[Cost_co, Action]: return self.gradient_step.compute_gradient( state, action, )
[docs] @jit_when_compilation_enabled() def update( self, action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState ) -> Tuple[Action, OptimizerState]: return self.gradient_step.update(action, gradient_value, opt_state)
@jit_when_compilation_enabled() def __call__( self, state: State, action: Action, opt_state: OptimizerState, ) -> Tuple[Action, OptimizerState, Cost_co]: cost_value, gradient_value = self.compute_gradient(state, action) updated_action, updated_opt_state = self.update(action, gradient_value, opt_state) return updated_action, updated_opt_state, cost_value