jax_utils package
Submodules
jax_utils.common_tensor module
Wrappers around numpy/jax arrays
- check_ndim_in(array: Array | ndarray, allowed_ndims: Iterable[int])[source]
Checks if the number of dimensions matches some allowed values.
- Parameters:
array (Array) – numpy or JAX array
allowed_ndims (Iterable[int]) – number of dimensions allowed for
array
- Raises:
ValueError – when the number of dimensions of
arrayis not present inallowed_shapes.
- is_broadcastable(shape_1: Tuple[int, ...], shape_2: Tuple[int, ...]) bool[source]
Whether the shapes of 2 arrays/tensors can be broadcasted.
- Parameters:
shape_1 (Tuple[int, ...]) – shape of the 1st array
shape_2 (Tuple[int, ...]) – shape of the 2nd array
- Returns:
Trueif and only if the 2 arrays/tensors can be broadcasted, False otherwise.- Return type:
bool
- class TensorAxes(initial: AbstractSet[T] | Sequence[T] | Iterable[T] | None = None, tensor_min_dim: int = 0)[source]
Bases:
OrderedSet[T]An ordered set of “named” array/tensor axes.
The position in the set defines the axis index in the tensor.
The first
len(self) - tensor_min_dimaxes are optional: if absent, they are assumed to be “flattened”.The reason for not making the last
tensor_min_dimaxes optional instead of the firstlen(self) - tensor_min_dimaxes is to follow the logic of array broadcasting (first dimension can be omitted).- reverse_index(key: T) int[source]
Computes the index of element
keybut rather than return a non-negative integer like methodindex, it returns a negative integer e.g., -1 if for the last element, -2 for the penultimate element, …- Parameters:
key (T) – any element present in
self(which is an ordered set)- Returns:
- negative integer corresponding to the position of
keyinself (starting from the last and decrementing by 1 for each element)
- negative integer corresponding to the position of
- Return type:
int
- property mandatory: OrderedSet[T]
Returns: OrderedSet[T]: ordered set of non-optional axes.
- property optional: OrderedSet[T]
Returns: OrderedSet[T]: ordered set of optional axes.
- expand_dims_axis(tensor_axes: TensorAxes[AxisType], missing_axes: AbstractSet[AxisType]) Tuple[int, ...][source]
Compute
axisargument to pass tonumpy.expand_dims(a, axis)function when some axes intensor_axesare not present in the numpy arraya(they are “flattened”).- Parameters:
tensor_axes (TensorAxes[AxisType]) – some array/tensor axes
missing_axes (AbstractSet[AxisType]) – missing axes in the array/tensor
- Raises:
ValueError – raises an error when
missing_axesis not a subset oftensor_axes- Returns:
the
axisargument to pass tonumpy.expand_dims(a, axis)- Return type:
Tuple[int, …]
- class AverageableArrayLike(*args, **kwargs)[source]
Bases:
Protocol[ArrayType_co]Shared interface of all classes with a
meanmethod return a scalar array (corresponding to the mean values of the initial array).Example of classes implementing this interface: (jax) numpy arrays, …
- class Tensor(*args, **kwargs)[source]
Bases:
DataclassInstance,AverageableArrayLike,Protocol[ArrayType,AxisType]A wrapper for numpy/jax arrays with the following additional features:
axes are “named” to facilitate manipulation and debugging (an axis “name” could be a string or any other hashable
AxisType)the returned
valuesof the tensor can be different from the actualarraygiven as input at construction, this allows to easily implement “change of variables” (a.k.a., “substitution”)
- Parameters:
_tensor_axes (ClassVar[TensorAxes[AxisType]]) – class attribute defining the axes of the
arrayarray (ArrayType) – a (jax) numpy array containing all relevant data
- array: ArrayType
- check_array()[source]
Check the validity of the
arrayattribute at construction. To be overriden if needed.
- getitem_from_axes(axes_keys: Dict[AxisType, Any]) Self[source]
Analogue of method
__getitem__but where array slicing/indexing is explicitly applied to named axis.- Parameters:
axes_keys (Dict[AxisType, Any]) – a mapping between axes names and slices/list of indices/…
- Returns:
a new
Tensorof the same type with restricted data- Return type:
Self
- property values: ArrayType
Override to apply transformations on
self.array(“change of variable”).For example, if one wants to manipulate (jax) tensors constrained to always have non-negative values in any situation (even if the tensor gets updated by some gradient descent procedure for instance), then one can override
valuesby returningjax.nn.relu(self.array).By default
valuesimplements the identity i.e., returnsself.array.- Returns:
the transformed
array- Return type:
ArrayType
- reverse_values(values: ArrayType) ArrayType[source]
Inverse of the transformation implemented in
valuesproperty.By default,
self.array = self.valuesso callingvaluesis equivalent to callingarray.When the transformation mapping
arraytovaluesis not the identity,reverse_valuesshould be overriden accordingly.- Parameters:
values (ArrayType) – the output of property
values- Returns:
- the array obtained by apply the inverse transformation of
values
- Return type:
ArrayType
- property dtype: dtype
Get the type of
self.values
- property shape: Tuple[int, ...]
Get the shape of
self.values
- property ndim: int
Get the number of dimensions of
self.values
- property axes: TensorAxes[AxisType]
Possible “named” axes of the tensor
- property actual_axes: TensorAxes[AxisType]
Actual axes of the tensor
- mean_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any[source]
Averages
valuesalong one or several named axes.- Parameters:
axes (AbstractSet[AxisType]) – axes along which the means are computed
- Returns:
new array containing the mean
values- Return type:
Union[ArrayType, Any]
- sum_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any[source]
Sums
valuesalong one or several named axes.- Parameters:
axes (AbstractSet[AxisType]) – axes along which the sums are computed
- Returns:
new array containing the summed
values- Return type:
Union[ArrayType, Any]
- has(axis: AxisType) bool[source]
Whether
axisis one of the possible axes of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
Trueifaxisis one of the possible axes of the tensor,Falseotherwise.- Return type:
bool
- has_actual(axis: AxisType) bool[source]
Whether
axisis one of the actual axes of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
Trueifaxisis one of the actual axes of the tensor,Falseotherwise.- Return type:
bool
- index(axis: AxisType) int[source]
Index of
axisin tensor i.e., returnsiif and only ifaxisis thei-th dimension of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
the corresponding index
- Return type:
int
- reverse_index(axis: AxisType) int[source]
Index of
axisin tensor but in reverse order i.e., returns-iif and only ifaxisis the(n - i)-th dimension of the tensor (withnthe total number of dimension). dimension of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
the corresponding index in reverse order
- Return type:
int
- size(axis: AxisType) int[source]
The size of a given
axis. Returns 0 when theaxisis “flattened”.- Parameters:
axis (AxisType) – a “named” axis
- Raises:
ValueError – if
axisis not a possible axis- Returns:
the size of
axis(0 when theaxisis “flattened”)- Return type:
int
- is_broadcastable_with(other_tensor_or_shape: Tensor[ArrayType, AxisType] | Tuple[int, ...]) bool[source]
Whether this tensor can be broadcasted with
other_tensor_or_shape.- Parameters:
other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor
- Returns:
Trueif and only if this tensor can be broadcasted withother_tensor_or_shape,Falseotherwise.
- Return type:
bool
- class RegularizedArrayLikeCost(*args, **kwargs)[source]
Bases:
AverageableArrayLike[ArrayType],Protocol[ArrayType]Interface for “regularized” costs i.e., costs of the form:
cost + lagrangian_coefficient * regularization- Parameters:
cost (AverageableArrayLike[ArrayType]) – an array-like collection of costs
regularization (AverageableArrayLike[ArrayType]) – an array-like collection regularization costs
lagrangian_coefficient (Number) – a non-negative number quantifying the regularization weight
- cost: AverageableArrayLike[ArrayType]
- regularization: AverageableArrayLike[ArrayType]
- lagrangian_coefficient: Number = 1
jax_utils.compilation module
Classes and methods to to jit-compile functions involving jax arrays transformations
- class JaxCompilableProtocol(*args, **kwargs)[source]
Bases:
ProtocolAll classes implementing this interface should implement property
is_compilation_enabledindicating whether the methods involving JAX arrays should be jit-compiled.- property is_compilation_enabled: bool
- class BaseJaxCompilable(*args, **kwargs)[source]
Bases:
JaxCompilableProtocol,ProtocolSubclassing
BaseCompilableJaxallows to easily enable/disable jit-compilation of methods involving JAX arrays.Use
with_optional_jax_jitdecorator to compile a method only whenis_compilation_enabledisTrue(Falseby default).To enable (resp. disable) jit-compilation, one only needs to call method
enable_compilation(resp.disable_compilation). By default, jit-compilation is disabled.- property is_compilation_enabled: bool
- jit_when_compilation_enabled(**jax_jit_args) Callable[[Callable], Callable][source]
Parametrized decorator for methods of classes implementing
CompilableJaxProtocolinterface. Allows to jit-compile some methods only when compilation is enabled.- Returns:
decorator with parameters
jax_jit_argsspecified- Return type:
Callable[[Callable], Callable]
jax_utils.dynamics module
High-level abstractions for decision problems (Markov Decision processes, etc…) involving jax arrays transformations
- class JaxDynamics(*args, **kwargs)[source]
Bases:
JaxDataclassNestedConvertibleToAxes[AxisType_contra],Dynamics[State,Action_contra,Cost_co,Observation_co],Protocol[AxisType_contra,State,Action_contra,Cost_co,Observation_co]A
jax_utils.markov_decision_process.Dynamicsinvolving Jax arrays transformations. The cost of such dynamics can be differentiated w.r.t. the action (see methodjax_utils.dynamics.JaxDynamics.compute_gradient()).- scalar_cost(state: State, action: Action_contra) Array[source]
Averages the cost associated to the dynamics of an MDP.
- Parameters:
state (State) – a state of the MDP
action (Action_contra) – an action of the MDP
- Returns:
a scalar Jax array corresponding to the mean cost
- Return type:
jnp.ndarray
- compute_gradient(state: State, action: Action_contra) Tuple[Cost_co, Action_contra][source]
The cost of a
jax_utils.dynamics.JaxDynamicscan always be differentiated w.r.t. the action.- Parameters:
state (State) – a state of the MDP
action (Action_contra) – an action of the MDP
- Returns:
the associated cost and the gradient of the action
- Return type:
Tuple[Cost_co, Action_contra]
- class VectorizedJaxDynamics(dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co], vectorized_axis: Annotated[AxisType, '__jax_dataclasses_static_field__'])[source]
Bases:
JaxDynamics[AxisType,State,Action_contra,Cost_co,Observation_co]A
jax_utils.dynamics.JaxDynamicsthat can be vectorized over some “named” axis.- Parameters:
dynamics (JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]) – the
jax_utils.dynamics.JaxDynamicsto be vectorizedvectorized_axis (jdc.Static[AxisType]) – “named” axis to map over
- dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]
- vectorized_axis: Annotated[AxisType, '__jax_dataclasses_static_field__']
- class RegularizedJaxDynamics(dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation], cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost])[source]
Bases:
RegularizedDynamics[State,Action,Cost,RegularizedCost,Observation],JaxDynamics[AxisType_contra,State,Action,Cost,Observation]A
jax_utils.dynamics.JaxDynamicswith cost regularization. See alsojax_utils.markov_decision_process.RegularizedDynamics.- Parameters:
dynamics (JaxDynamics[AxisType_contra, State, Action, Cost, Observation]) – the
jax_utils.dynamics.JaxDynamicsto be vectorizedcost_regularizer (CostRegularizer[State, Action, Cost, Observation, RegularizedCost]) – a cost regularizer
- dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation]
- cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]
jax_utils.gradient module
Classes and methods to compute the gradient of functions involving jax arrays transformations
- class BaseGradientStep(*args, **kwargs)[source]
Bases:
JaxDataclassNestedConvertibleToAxes[AxisType_contra],Protocol[AxisType_contra,State_contra,Action,Cost_co,OptimizerState]Interface for gradient steps on the
jax_utils.markov_decision_process.Dynamicscost of a Markov Decision Process.
- class GradientStep(optimizer: Annotated[GradientTransformation, '__jax_dataclasses_static_field__'], dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co])[source]
Bases:
BaseGradientStep[AxisType,State,Action,Cost_co,OptimizerState],Generic[AxisType,State,Action,Cost_co,Observation_co,OptimizerState]The gradient step on the
jax_utils.markov_decision_process.Dynamicscos of a Markov Decision Process.- Parameters:
optimizer (jdc.Static[optax.GradientTransformation]) – an optimizer (defining the stochastic gradient descent variant: RMSProp, Adam, etc..)
dynamics (JaxDynamics[Any, State, Action, Cost_co, Observation_co]) – an MDP dynamics
- optimizer: jdc.Static[optax.GradientTransformation]
- dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co]
- init_optimizer(init_action: Action) OptimizerState[source]
Initialize optimizer
- Parameters:
init_action (Action) – initial action before optimization starts
- Returns:
initialized optimizer state
- Return type:
OptimizerState
- compute_gradient(state: State, action: Action) Tuple[Cost_co, Action][source]
Given a state-action pair of the MDP, computes the cost and gradient of the cost of a
dynamicsw.r.t. the action.- Parameters:
state (State) – state of the MDP
action (Action) – action of the MDP
- Returns:
cost and gradient of the action
- Return type:
Tuple[Cost_co, Action]
- update(action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState) Tuple[Action, OptimizerState][source]
Update the action and optimizer state
- Parameters:
action (Action) – action of the MDP
gradient_value (GradientUpdates) – gradient of the action
opt_state (OptimizerState) – optimizer state
- Returns:
the new action and optimizer state
- Return type:
Tuple[Action, OptimizerState]
- class VectorizedGradientStep(gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState], vectorized_axis: Annotated[AxisType, '__jax_dataclasses_static_field__'])[source]
Bases:
BaseGradientStep[AxisType,State,Action,Cost_co,OptimizerState]A gradient step that is vectorized on a specific axis
- Parameters:
gradient_step (BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]) – the gradient step to be vectorized
vectorized_axis (jdc.Static[AxisType]) – “named” axis to be vectorized
- gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]
- vectorized_axis: jdc.Static[AxisType]
jax_utils.jax_tensor module
Wrappers around jax arrays
- class JaxTensor(*args, **kwargs)[source]
Bases:
Tensor[Array,AxisType],ConvertibleToAxes[AxisType],Protocol[AxisType]Interface representing a
Tensorwith a JAXarray.A
JaxTensoris always “convertible to axes” (seejax_utils.pytree.ConvertibleToAxesinterface) usingjax_utils.pytree.ConvertibleToAxes.convert_to_axes()method. When aTensoris “converted to axes”, the associatedarrayattribute is no longer a JAX array but rather an integer indicating the dimension index of the specifiedaxis(argument of methodconvert_to_axes). Ifaxisis flattened or not present,self.arrayis set toNonewhen converted to axes.Converting a
jax_utils.jax_tensor.JaxTensorto axes is useful to determine thein_axesandout_axesarguments of jax.vmap and jax.pmap.- Parameters:
array (jnp.ndarray) – a JAX array (or an optional integer when “converted to axes”)
- array: Array
- is_broadcastable_with(other_tensor_or_shape: Tensor[Array, AxisType] | Tuple[int, ...]) bool[source]
Whether this tensor can be broadcasted with
other_tensor_or_shape.- Parameters:
other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor
- Returns:
Trueif and only if this tensor can be broadcasted withother_tensor_or_shape,Falseotherwise.
- Return type:
bool
- classmethod from_flattened_axes(array: Array, flattened_axes: AbstractSet[AxisType], **kwargs) Self[source]
Constructor when some axes are “flattened” in
array.This method simply adds missing dimensions accordingly.
- Parameters:
array (jnp.ndarray) – a JAX array
flattened_axes (AbstractSet[AxisType]) – the named axes that are “flattened”
- Returns:
an instance of class
cls- Return type:
Self
- convert_to_axes(axis: AxisType | None) Self[source]
Convert the
jax_utils.jax_tensor.JaxTensorobject so that it can be passed as thein_axesorout_axesarguments of jax.vmap and jax.pmap.- Parameters:
axis (Optional[AxisType]) – a “named” axis
- Returns:
- same object as
selfbut witharrayreplaced by the index ofaxisi.e., if
axiscorresponds to thei-th dimension ofarray, thenarrayis replaced byi, ifaxisis a flattened axis or simply not present, thenarrayis replaced byNone
- same object as
- Return type:
Self
- class NonNegativeValues(*args, **kwargs)[source]
Bases:
ProtocolInterface to be used in combination with
jax_utils.jax_tensor.JaxTensorin order to implement JAX tensors that are constrained to take non-negativevalues.- Parameters:
array (jnp.ndarray) – a JAX array
- array: Array
- property values: Array
- class NonNegativeBudgetedValues(*args, **kwargs)[source]
Bases:
NonNegativeValues,ProtocolInterface to be used in combination with
jax_utils.jax_tensor.JaxTensorin order to implement JAX tensors that are constrained to both take non-negativevaluesand have their sum over the last axis bounded by a maximal “budget” (max_budget).- Parameters:
array (jnp.ndarray) – a JAX array
max_budget (Number) – a non-negative number representing the maximum value that the sum over the last axis can take
- array: Array
- max_budget: Number
- property values: Array
- class RegularizedJaxCost(cost: JaxCostType, regularization: JaxRegularizedCostType, lagrangian_coefficient: Annotated[Number, '__jax_dataclasses_static_field__'] = 1)[source]
Bases:
RegularizedArrayLikeCost[Array],BaseJaxCompilable,Generic[JaxCostType,JaxRegularizedCostType]Interface for “regularized” JAX costs i.e., costs of the form:
cost + lagrangian_coefficient * regularization- Parameters:
cost (JaxCostType) – an JAX array of costs
regularization (JaxRegularizedCostType) – a JAX array of regularization costs
lagrangian_coefficient (jdc.Static[Number]) – a non-negative number quantifying the regularization weight
- cost: JaxCostType
- regularization: JaxRegularizedCostType
- lagrangian_coefficient: Annotated[Number, '__jax_dataclasses_static_field__'] = 1
jax_utils.markov_decision_process module
High-level abstractions for decision problems (Markov Decision processes, etc…)
- class Dynamics(*args, **kwargs)[source]
Bases:
Protocol[State,Action_contra,Cost_co,Observation_co]Interface defining the dynamics of a (Partially Observable) Markov Decision Process.
When an “agent” interacting with the (PO)MDP plays an “action” (a.k.a. “control”) in a given “state”, the (PO)MDP transitions to a new state and the agent observes some signal/feedback in the form of a “cost”/”reward” as well as additional “observations”.
A Dynamics is therefore a callable that maps a state-action pair to a state-cost-observation tuple.
- class CostRegularizer(*args, **kwargs)[source]
Bases:
Protocol[State_contra,Action_contra,Cost_contra,Observation_contra,Cost_co]Interface for callables that map any state-action-cost-observation tuple to a new “regularized” cost.
Example: one may want to penalize action with high norms, etc…
- class RegularizedDynamics(*args, **kwargs)[source]
Bases:
Dynamics[State,Action,RegularizedCost,Observation],Protocol[State,Action,Cost,RegularizedCost,Observation]Interface defining a wrapper around class
jax_utils.markov_decision_process.Dynamicsthat allows to add a regularization to the cost.A
RegularizedDynamicsis itself ajax_utils.markov_decision_process.Dynamics`.- Parameters:
dynamics (Dynamics[State, Action, Cost, Observation]) – callable defining the dynamics of a (PO)MDP
cost_regularizer (CostRegularizer[State, Action, Cost, Observation, RegularizedCost]) – callable
transformation (defining a cost)
- cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]
jax_utils.optim module
Classes and methods to optimize functions involving jax arrays transformations via gradient descent
- class OptimizationState(iteration: int, state: State, action: Action, cost: Cost | None = None, optimizer_state: OptimizerState | None = None)[source]
Bases:
Generic[State,Action,Cost,OptimizerState]The current state of an iterative optimization procedure involving the
jax_utils.markov_decision_process.Dynamicsof a Markov Decision Process (typically a cost minimization).- Parameters:
iteration (int) – the current iteration step
state (State) – the MDP state
action (Action) – the MDP current action
cost (Optional[Cost], optional) – the current cost associated to the state-action pair. Defaults to None.
optimizer_state (Optional[OptimizerState], optional) – The current
optax.OptState. Defaults to None.
- iteration: int
- state: State
- action: Action
- cost: Cost | None = None
- optimizer_state: OptimizerState | None = None
- property scalar_cost: float
Returns the average
self.cost- Returns:
average cost
- Return type:
float
- class OptimStoppingCondition(*args, **kwargs)[source]
Bases:
BaseJaxCompilable,Protocol[State,Action,Cost,OptimizerState]Interface for all stopping conditions of an iterative optimization procedure involving the
jax_utils.markov_decision_process.Dynamicsof a Markov Decision Process (typically a cost minimization).The stopping condition depends on an
OptimizationStateprovided as input but may also collect information over steps.- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class OptimStoppingConditionsCombination(*args, **kwargs)[source]
Bases:
OptimStoppingCondition[State,Action,Cost,OptimizerState],ProtocolThis interface represents the combination of 2 stopping conditions.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- class OptimStoppingConditionIntersection(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
OptimStoppingConditionsCombination[State,Action,Cost,OptimizerState]A combination of 2 stopping conditions using an “AND” operation.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() OptimStoppingConditionIntersection[State, Action, Cost, OptimizerState][source]
Resets the stopping condition to an initial configuration.
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class OptimStoppingConditionUnion(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
OptimStoppingConditionsCombination[State,Action,Cost,OptimizerState]A combination of 2 stopping conditions using an “OR” operation.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() OptimStoppingConditionUnion[State, Action, Cost, OptimizerState][source]
Resets the stopping condition to an initial configuration.
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class MaxIterationsStoppingCondition(max_iterations: int)[source]
Bases:
OptimStoppingCondition[State,Action,Cost,OptimizerState]A stopping condition that stops when the number of iterations exceeds a given threshold.
- Parameters:
max_iterations (int) – maximal number of iterations before the stopping condition is raised
- max_iterations: int
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class MinIterationsStoppingCondition(min_iterations: int)[source]
Bases:
OptimStoppingCondition[State,Action,Cost,OptimizerState]A stopping condition that continues as long as the number of iterations is below a given threshold.
- Parameters:
min_iterations (int) – minimal number of iterations before the stopping condition is raised
- min_iterations: int
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- class MinDeltaActionStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True))[source]
Bases:
OptimStoppingCondition[State,JaxTensorType,Cost,OptimizerState]Stops when the action stops significantly changing (as defined by relative & absolute tolerance). The previous action is saved in memory to compare it to the new action.
The delta between previous and new action is then compared to absolute and relative tolerance i.e, the stopping condition is raised when:
all(abs(new_action - previous_action) <= absolute tolerance + relative_tolerance * maximum(abs(new_action), abs(previous_action)))- Parameters:
relative_tolerance (jnp.ndarray) – relative tolerance for action variations. Default is 1e-6.
absolute_tolerance (jnp.ndarray) – absolute tolerance for action variations. Default is 1e-6.
- relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- stop(optimization_state: OptimizationState[State, JaxTensorType, Cost, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() MinDeltaActionStoppingCondition[State, JaxTensorType, Cost, OptimizerState][source]
Resets the stopping condition to an initial configuration.
- class MinDeltaCostStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), window_length: int = 2)[source]
Bases:
OptimStoppingCondition[State,Action,Array,OptimizerState]Stops when the cost stops significantly decreasing (as defined by relative & absolute tolerance).
The
window_lengthlast cost values are saved in memory and the stopping condition is raised when the delta between the min and max of these values is below a threshold i.e., when:all(maximum(last_window_length_costs) - minimum(last_window_length_costs) < absolute_tolerance + relative_tolerance * minimum(last_window_length_costs))- Parameters:
relative_tolerance (jnp.ndarray) – relative tolerance for cost variations. Default is 1e-6.
absolute_tolerance (jnp.ndarray) – absolute tolerance for cost variations. Default is 1e-6.
window_length (int) – maximal length of the queue storing the past history of costs (costs older than
window_lengthtime steps are discarded)
- relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- window_length: int = 2
- stop(optimization_state: OptimizationState[State, Action, Array, OptimizerState]) bool[source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() MinDeltaCostStoppingCondition[State, Action, OptimizerState][source]
Resets the stopping condition to an initial configuration.
- class CostHistory(initlist=None)[source]
Bases:
UserList[AverageableArrayLike[Array]]List of costs (typically jax arrays)
- class GradientDescentOptimizationLoop(gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState], stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
Generic[AxisType,State,Action,Cost,Observation,OptimizerState]Gradient descent optimization procedure involving the
jax_utils.markov_decision_process.Dynamicsof a Markov Decision Process. This class is a callable that recursively appliesgradient_stepuntil astopping_conditionis met.- Parameters:
gradient_step (BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]) – a gradient step
stopping_condition (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]
- stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- resume(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) Tuple[OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost, OptimizerState], CostHistory][source]
Resume optimization loop from
optimization_state(cold start).- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – state from which to resume the optimization procedure.
- Returns:
same outputs as
GradientDescentOptimizationLoop.__call__()# pylint: disable=line-too-long- Return type:
OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost,OptimizerState], CostHistory,]
jax_utils.pytree module
Classes and methods to easily map pytrees to axes (e.g., for vectorization)
- class ConvertibleToAxes(*args, **kwargs)[source]
Bases:
Protocol[AxisType_contra]- convert_to_axes(axis: AxisType_contra | None) Self[source]
Returns an object that can be used in argument
in_axesorout_axesof jax.vmap or jax.pmap- Parameters:
axis (Optional[AxisType_contra]) – a “named” axis over which to apply vectorization.
- Returns:
same object as
selfbut with all array-like objects replaced by axes- Return type:
Self
- pytree_to_axes(pytree: Any, vectorized_axis: AxisType, default_axis: int | None = None) Any[source]
Transform all the
ConvertibleToAxesleafs of a given pytree to axes by applying methodConvertibleToAxes.convert_to_axes(). This is useful for vectorizing functions involvingConvertibleToAxesobjects.- Parameters:
pytree (Any) – any Python pytree containing
ConvertibleToAxesleafsvectorized_axis (AxisType) –
a “named” axis over which to apply vectorization
default_axis (Optional[int], optional) – A default axis value for leafs that are not
ConvertibleToAxes. Defaults to None.
- Returns:
same pytree as given in input but where all
ConvertibleToAxesare converted to axes.- Return type:
Any
jax_utils.tranform module
Jax arrays transformations (scaling, …)
- scale_jax_tensor(tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType][source]
Scale the
valuesof aJaxTensori.e., multiplyvaluesbyscaling_factor. Note that propertyvaluesinJaxTensordiffer from attributearray.
- unscale_jax_tensor(scaled_tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType][source]
Reverse transformation of function
scale_jax_tensor.
- class JaxScaler(scaling_factors: ~jax.Array = Array(1., dtype=float32, weak_type=True), scaling_axes: ~typing.Annotated[~typing.Set[~jax_utils.tranform.AxisType], '__jax_dataclasses_static_field__'] = <factory>, tensor_types_to_scale: ~typing.Annotated[~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]], '__jax_dataclasses_static_field__'] = (<class 'jax_utils.jax_tensor.JaxTensor'>,), tensor_types_not_to_scale: ~typing.Annotated[~typing.Tuple[~jax_utils.jax_tensor.JaxTensor], '__jax_dataclasses_static_field__'] = ())[source]
Bases:
BaseJaxCompilable,Generic[AxisType]When applying gradient descent optimization algorithms, it is often helpful to scale all the arrays/tensors involved in the computations.
This class allows to easily scale/unscale
jax_utils.jax_tensor.JaxTensor’s (even when the tensors are stored in a nested pytree structure).- Parameters:
scaling_factors (jnp.ndarray) – Multiplicative factors for scaling. Default to
jnp.array(1.0).scaling_axes (jdc.Static[Set[AxisType]]) – “Named” axes of
scaling_factors. Thus, the number of axes should matchscaling_factors.ndim. Default tojdc.field(default_factory=set).tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to
jax_utils.jax_tensor.JaxTensor.tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to
tuple().
- scaling_factors: Array = Array(1., dtype=float32, weak_type=True)
- scaling_axes: Annotated[Set[AxisType], '__jax_dataclasses_static_field__']
- tensor_types_to_scale: Annotated[Tuple[Type[JaxTensor]], '__jax_dataclasses_static_field__'] = (<class 'jax_utils.jax_tensor.JaxTensor'>,)
- classmethod from_tensor(tensor: ~jax_utils.jax_tensor.JaxTensor[~jax_utils.tranform.AxisType], scaling_axes: ~typing.Set[~jax_utils.tranform.AxisType], tensor_types_to_scale: ~typing.Annotated[~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]], '__jax_dataclasses_static_field__'] = (<class 'jax_utils.jax_tensor.JaxTensor'>,), tensor_types_not_to_scale: ~typing.Annotated[~typing.Tuple[~jax_utils.jax_tensor.JaxTensor], '__jax_dataclasses_static_field__'] = (), factor: float = 1.0) JaxScaler[source]
Alternative constructor that automatically computes the
scaling_factorsbased on ajax_utils.jax_tensor.JaxTensor, thescaling_axesand a targetfactor.If
my_scaler = JaxScaler.from_tensor(tensor=my_tensor, scaling_axes=my_axes, factor=my_factor)thenmy_scaler.scale(my_tensor)will be ajax_utils.jax_tensor.JaxTensorwithmy_tensor.mean_over_axes(axes=my_axes)being equal to a “constant” Jax array with all elements equal tomy_factor.- Parameters:
tensor (JaxTensor[AxisType]) – a jax tensor
scaling_axes (Set[AxisType]) – a set of “named” axes present in tensor
tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to
jax_utils.jax_tensor.JaxTensor.tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to
tuple().
- Returns:
a scaler instance
- Return type:
- scale(pytree: Any) Any[source]
Scales any pytree containing
jax_utils.jax_tensor.JaxTensor’s- Parameters:
pytree (Any) – pytree to be scaled
- Returns:
the scaled pytree
- Return type:
Any
- unscale(scaled_pytree: Any) Any[source]
Unscales any pytree containing
jax_utils.jax_tensor.JaxTensor’s (reverse transformation of methodscale).- Parameters:
scaled_pytree (Any) – pytree to be unscaled
- Returns:
the unscaled pytree
- Return type:
Any
- expand_scaling_factors(tensor_axes: TensorAxes[AxisType], missing_axes: Set[AxisType]) Array[source]
Expands the dimensions of
self.scaling_factorsto match thetensor_axesgiven as inputs.This method only adds new dimensions of size 1.
- Parameters:
tensor_axes (TensorAxes[AxisType]) – a set of “named” axes
missing_axes (Set[AxisType])
- Returns:
- same as
self.scaling_factorsbut with (empty) additional dimensions. The total number of dimensions of the returned array is equal to
len(tensor_axes).
- same as
- Return type:
jnp.ndarray
jax_utils.typing module
Define useful types for package
- class DataclassInstance(*args, **kwargs)[source]
Bases:
ProtocolType for Python dataclasses (see
dataclasses.dataclass).This code is copy-pasted from
_typeshed(but is simpler to import).
- class HashableIndexingOrSlicing(*args, **kwargs)[source]
Bases:
DataclassInstance,Hashable,ProtocolPython slices and lists (of indices) are not hashable in Python (slices are hashable for Python version >= 3.12). This is a shared interface for hashable slices and indices.
- property values: slice | int | List[int]
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
- class HashableSlicing(start: int | None = None, stop: int | None = None, step: int | None = None)[source]
Bases:
HashableIndexingOrSlicingPython slices are not hashable for Python version < 3.12. This class allows to define hashable slices.
- Parameters:
start (Optional[int]) – initial value of the slice (included)
stop (Optional[int]) – last value of the slice (excluded)
step (Optional[int]) – step between values
- start: int | None = None
- stop: int | None = None
- step: int | None = None
- property values: slice
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
- class HashableIndexing(*indices: int)[source]
Bases:
HashableIndexingOrSlicingPython lists are not hashable. This class allows to define hashable list of indices.
- Parameters:
indices (Union[int, Tuple[int]]) – index or list of (integer-valued) indices
- indices: int | Tuple[int]
- property values: int | List[int]
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
jax_utils.vectorization module
Classes and methods to easily vectorize jax tensors
- class JaxVectorizableProtocol(*args, **kwargs)[source]
Bases:
JaxCompilableProtocol,ConvertibleToAxes[AxisType],ProtocolInterface for classes manipulating JAX arrays. It defines a special attribute
vectorized_axiscorresponding to the “named” axis over which to apply vectorization.“Vectorizable” classes are also “compilable” (see interface
jax_utils.compilation.JaxCompilableProtocol) and convertible to axes (seejax_utils.pytree.ConvertibleToAxesinterface).- Parameters:
vectorized_axis (AxisType) – “named” axis to be vectorized
- vectorized_axis: AxisType
- vectorize(in_default_axis: int | None = None, out_default_axis: int | None = -1) Callable[[Callable], Callable][source]
Parametrized decorator for methods of classes implementing
JaxVectorizableProtocolinterface. It allows to vectorize the decorated methods according to a specified axis, namely attributeJaxVectorizableProtocol.vectorized_axis(seeJaxVectorizableProtocol).The pytrees of all inputs and outputs of the decorated method are explored and every JAX array as well as every class implementing the
jax_utils.pytree.ConvertibleToAxesinterface is considered a leaf of the pytree. When ajax_utils.pytree.ConvertibleToAxesleaf is encountered, methodjax_utils.pytree.ConvertibleToAxes.convert_to_axes()is used to determine the dimension to vectorize. When a JAX array leaf or any other leaf is encountered instead, the default valuein_default_axis(for input arguments) orout_default_axis(for output arguments) is used to determine the dimension to vectorize.- Parameters:
in_default_axis (Optional[int], optional) –
default index used in
in_axesargument of method jax.vmap when a leaf that is notjax_utils.pytree.ConvertibleToAxesis encountered in an input pytree. Defaults to None.out_default_axis (Optional[int], optional) –
default index used in
out_axesargument of method jax.vmap when a leaf that is notjax_utils.pytree.ConvertibleToAxesis encountered in an output pytree. Defaults to -1.
- Returns:
decorator with parameters
in_default_axisandout_default_axisspecified- Return type:
Callable[[Callable], Callable]
- class JaxDataclassNestedConvertibleToAxes(*args, **kwargs)[source]
Bases:
DataclassInstance,ConvertibleToAxes[AxisType_contra],BaseJaxCompilable,Protocol[AxisType_contra]Interface for dataclasses that can be jit-compiled and are “convertible to axes”, with a concrete implementation of
convert_to_axesmethod that is specific to dataclasses.- convert_to_axes(axis: AxisType_contra | None) Self[source]
Concrete implementation of
convert_to_axesfor dataclasses. All the fields of the dataclass are inspected and methodconvert_to_axesis applied whenever the field isConvertibleToAxes.- Parameters:
axis (Optional[AxisType_contra]) – a “named” axis
- Returns:
same object as
selfbut with array-like fields converted to axes- Return type:
Self