jax_utils package
Submodules
jax_utils.common_tensor module
Wrappers around numpy/jax arrays
- check_ndim_in(array: Array | ndarray, allowed_ndims: Iterable[int])[source]
Checks if the number of dimensions matches some allowed values.
- Parameters:
array (Array) – numpy or JAX array
allowed_ndims (Iterable[int]) – number of dimensions allowed for
array
- Raises:
ValueError – when the number of dimensions of
array
is not present inallowed_shapes
.
- is_broadcastable(shape_1: Tuple[int, ...], shape_2: Tuple[int, ...]) bool [source]
Whether the shapes of 2 arrays/tensors can be broadcasted.
- Parameters:
shape_1 (Tuple[int, ...]) – shape of the 1st array
shape_2 (Tuple[int, ...]) – shape of the 2nd array
- Returns:
True
if and only if the 2 arrays/tensors can be broadcasted, False otherwise.- Return type:
bool
- class TensorAxes(initial: AbstractSet[T] | Sequence[T] | Iterable[T] | None = None, tensor_min_dim: int = 0)[source]
Bases:
OrderedSet
[T
]An ordered set of “named” array/tensor axes.
The position in the set defines the axis index in the tensor.
The first
len(self) - tensor_min_dim
axes are optional: if absent, they are assumed to be “flattened”.The reason for not making the last
tensor_min_dim
axes optional instead of the firstlen(self) - tensor_min_dim
axes is to follow the logic of array broadcasting (first dimension can be omitted).- reverse_index(key: T) int [source]
Computes the index of element
key
but rather than return a non-negative integer like methodindex
, it returns a negative integer e.g., -1 if for the last element, -2 for the penultimate element, …- Parameters:
key (T) – any element present in
self
(which is an ordered set)- Returns:
- negative integer corresponding to the position of
key
inself
(starting from the last and decrementing by 1 for each element)
- negative integer corresponding to the position of
- Return type:
int
- property mandatory: OrderedSet[T]
Returns: OrderedSet[T]: ordered set of non-optional axes.
- property optional: OrderedSet[T]
Returns: OrderedSet[T]: ordered set of optional axes.
- expand_dims_axis(tensor_axes: TensorAxes[AxisType], missing_axes: AbstractSet[AxisType]) Tuple[int, ...] [source]
Compute
axis
argument to pass tonumpy.expand_dims(a, axis)
function when some axes intensor_axes
are not present in the numpy arraya
(they are “flattened”).- Parameters:
tensor_axes (TensorAxes[AxisType]) – some array/tensor axes
missing_axes (AbstractSet[AxisType]) – missing axes in the array/tensor
- Raises:
ValueError – raises an error when
missing_axes
is not a subset oftensor_axes
- Returns:
the
axis
argument to pass tonumpy.expand_dims(a, axis)
- Return type:
Tuple[int, …]
- class AverageableArrayLike(*args, **kwargs)[source]
Bases:
Protocol
[ArrayType_co
]Shared interface of all classes with a
mean
method return a scalar array (corresponding to the mean values of the initial array).Example of classes implementing this interface: (jax) numpy arrays, …
- class Tensor(*args, **kwargs)[source]
Bases:
DataclassInstance
,AverageableArrayLike
,Protocol
[ArrayType
,AxisType
]A wrapper for numpy/jax arrays with the following additional features:
axes are “named” to facilitate manipulation and debugging (an axis “name” could be a string or any other hashable
AxisType
)the returned
values
of the tensor can be different from the actualarray
given as input at construction, this allows to easily implement “change of variables” (a.k.a., “substitution”)
- Parameters:
_tensor_axes (ClassVar[TensorAxes[AxisType]]) – class attribute defining the axes of the
array
array (ArrayType) – a (jax) numpy array containing all relevant data
- array: ArrayType
- check_array()[source]
Check the validity of the
array
attribute at construction. To be overriden if needed.
- getitem_from_axes(axes_keys: Dict[AxisType, Any]) Self [source]
Analogue of method
__getitem__
but where array slicing/indexing is explicitly applied to named axis.- Parameters:
axes_keys (Dict[AxisType, Any]) – a mapping between axes names and slices/list of indices/…
- Returns:
a new
Tensor
of the same type with restricted data- Return type:
Self
- property values: ArrayType
Override to apply transformations on
self.array
(“change of variable”).For example, if one wants to manipulate (jax) tensors constrained to always have non-negative values in any situation (even if the tensor gets updated by some gradient descent procedure for instance), then one can override
values
by returningjax.nn.relu(self.array)
.By default
values
implements the identity i.e., returnsself.array
.- Returns:
the transformed
array
- Return type:
ArrayType
- reverse_values(values: ArrayType) ArrayType [source]
Inverse of the transformation implemented in
values
property.By default,
self.array = self.values
so callingvalues
is equivalent to callingarray
.When the transformation mapping
array
tovalues
is not the identity,reverse_values
should be overriden accordingly.- Parameters:
values (ArrayType) – the output of property
values
- Returns:
- the array obtained by apply the inverse transformation of
values
- Return type:
ArrayType
- property dtype: dtype
Get the type of
self.values
- property shape: Tuple[int, ...]
Get the shape of
self.values
- property ndim: int
Get the number of dimensions of
self.values
- property axes: TensorAxes[AxisType]
Possible “named” axes of the tensor
- property actual_axes: TensorAxes[AxisType]
Actual axes of the tensor
- mean_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any [source]
Averages
values
along one or several named axes.- Parameters:
axes (AbstractSet[AxisType]) – axes along which the means are computed
- Returns:
new array containing the mean
values
- Return type:
Union[ArrayType, Any]
- sum_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any [source]
Sums
values
along one or several named axes.- Parameters:
axes (AbstractSet[AxisType]) – axes along which the sums are computed
- Returns:
new array containing the summed
values
- Return type:
Union[ArrayType, Any]
- has(axis: AxisType) bool [source]
Whether
axis
is one of the possible axes of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
True
ifaxis
is one of the possible axes of the tensor,False
otherwise.- Return type:
bool
- has_actual(axis: AxisType) bool [source]
Whether
axis
is one of the actual axes of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
True
ifaxis
is one of the actual axes of the tensor,False
otherwise.- Return type:
bool
- index(axis: AxisType) int [source]
Index of
axis
in tensor i.e., returnsi
if and only ifaxis
is thei
-th dimension of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
the corresponding index
- Return type:
int
- reverse_index(axis: AxisType) int [source]
Index of
axis
in tensor but in reverse order i.e., returns-i
if and only ifaxis
is the(n - i)
-th dimension of the tensor (withn
the total number of dimension). dimension of the tensor.- Parameters:
axis (AxisType) – a “named” axis
- Returns:
the corresponding index in reverse order
- Return type:
int
- size(axis: AxisType) int [source]
The size of a given
axis
. Returns 0 when theaxis
is “flattened”.- Parameters:
axis (AxisType) – a “named” axis
- Raises:
ValueError – if
axis
is not a possible axis- Returns:
the size of
axis
(0 when theaxis
is “flattened”)- Return type:
int
- is_broadcastable_with(other_tensor_or_shape: Tensor[ArrayType, AxisType] | Tuple[int, ...]) bool [source]
Whether this tensor can be broadcasted with
other_tensor_or_shape
.- Parameters:
other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor
- Returns:
True
if and only if this tensor can be broadcasted withother_tensor_or_shape
,False
otherwise.
- Return type:
bool
- class RegularizedArrayLikeCost(*args, **kwargs)[source]
Bases:
AverageableArrayLike
[ArrayType
],Protocol
[ArrayType
]Interface for “regularized” costs i.e., costs of the form:
cost + lagrangian_coefficient * regularization
- Parameters:
cost (AverageableArrayLike[ArrayType]) – an array-like collection of costs
regularization (AverageableArrayLike[ArrayType]) – an array-like collection regularization costs
lagrangian_coefficient (Number) – a non-negative number quantifying the regularization weight
- cost: AverageableArrayLike[ArrayType]
- regularization: AverageableArrayLike[ArrayType]
- lagrangian_coefficient: Number = 1
jax_utils.compilation module
Classes and methods to to jit-compile functions involving jax arrays transformations
- class JaxCompilableProtocol(*args, **kwargs)[source]
Bases:
Protocol
All classes implementing this interface should implement property
is_compilation_enabled
indicating whether the methods involving JAX arrays should be jit-compiled.- property is_compilation_enabled: bool
- class BaseJaxCompilable(*args, **kwargs)[source]
Bases:
JaxCompilableProtocol
,Protocol
Subclassing
BaseCompilableJax
allows to easily enable/disable jit-compilation of methods involving JAX arrays.Use
with_optional_jax_jit
decorator to compile a method only whenis_compilation_enabled
isTrue
(False
by default).To enable (resp. disable) jit-compilation, one only needs to call method
enable_compilation
(resp.disable_compilation
). By default, jit-compilation is disabled.- property is_compilation_enabled: bool
- jit_when_compilation_enabled(**jax_jit_args) Callable[[Callable], Callable] [source]
Parametrized decorator for methods of classes implementing
CompilableJaxProtocol
interface. Allows to jit-compile some methods only when compilation is enabled.- Returns:
decorator with parameters
jax_jit_args
specified- Return type:
Callable[[Callable], Callable]
jax_utils.dynamics module
High-level abstractions for decision problems (Markov Decision processes, etc…) involving jax arrays transformations
- class JaxDynamics(*args, **kwargs)[source]
Bases:
JaxDataclassNestedConvertibleToAxes
[AxisType_contra
],Dynamics
[State
,Action_contra
,Cost_co
,Observation_co
],Protocol
[AxisType_contra
,State
,Action_contra
,Cost_co
,Observation_co
]A
jax_utils.markov_decision_process.Dynamics
involving Jax arrays transformations. The cost of such dynamics can be differentiated w.r.t. the action (see methodjax_utils.dynamics.JaxDynamics.compute_gradient()
).- scalar_cost(state: State, action: Action_contra) Array [source]
Averages the cost associated to the dynamics of an MDP.
- Parameters:
state (State) – a state of the MDP
action (Action_contra) – an action of the MDP
- Returns:
a scalar Jax array corresponding to the mean cost
- Return type:
jnp.ndarray
- compute_gradient(state: State, action: Action_contra) Tuple[Cost_co, Action_contra] [source]
The cost of a
jax_utils.dynamics.JaxDynamics
can always be differentiated w.r.t. the action.- Parameters:
state (State) – a state of the MDP
action (Action_contra) – an action of the MDP
- Returns:
the associated cost and the gradient of the action
- Return type:
Tuple[Cost_co, Action_contra]
- class VectorizedJaxDynamics(dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co], vectorized_axis: AxisType)[source]
Bases:
JaxDynamics
[AxisType
,State
,Action_contra
,Cost_co
,Observation_co
]A
jax_utils.dynamics.JaxDynamics
that can be vectorized over some “named” axis.- Parameters:
dynamics (JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]) – the
jax_utils.dynamics.JaxDynamics
to be vectorizedvectorized_axis (jdc.Static[AxisType]) – “named” axis to map over
- dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]
- vectorized_axis: AxisType
- class RegularizedJaxDynamics(dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation], cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost])[source]
Bases:
RegularizedDynamics
[State
,Action
,Cost
,RegularizedCost
,Observation
],JaxDynamics
[AxisType_contra
,State
,Action
,Cost
,Observation
]A
jax_utils.dynamics.JaxDynamics
with cost regularization. See alsojax_utils.markov_decision_process.RegularizedDynamics
.- Parameters:
dynamics (JaxDynamics[AxisType_contra, State, Action, Cost, Observation]) – the
jax_utils.dynamics.JaxDynamics
to be vectorizedcost_regularizer (CostRegularizer[State, Action, Cost, Observation, RegularizedCost]) – a cost regularizer
- dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation]
- cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]
jax_utils.gradient module
Classes and methods to compute the gradient of functions involving jax arrays transformations
- class BaseGradientStep(*args, **kwargs)[source]
Bases:
JaxDataclassNestedConvertibleToAxes
[AxisType_contra
],Protocol
[AxisType_contra
,State_contra
,Action
,Cost_co
,OptimizerState
]Interface for gradient steps on the
jax_utils.markov_decision_process.Dynamics
cost of a Markov Decision Process.
- class GradientStep(optimizer: GradientTransformation, dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co])[source]
Bases:
BaseGradientStep
[AxisType
,State
,Action
,Cost_co
,OptimizerState
],Generic
[AxisType
,State
,Action
,Cost_co
,Observation_co
,OptimizerState
]The gradient step on the
jax_utils.markov_decision_process.Dynamics
cos of a Markov Decision Process.- Parameters:
optimizer (jdc.Static[optax.GradientTransformation]) – an optimizer (defining the stochastic gradient descent variant: RMSProp, Adam, etc..)
dynamics (JaxDynamics[Any, State, Action, Cost_co, Observation_co]) – an MDP dynamics
- optimizer: jdc.Static[optax.GradientTransformation]
- dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co]
- init_optimizer(init_action: Action) OptimizerState [source]
Initialize optimizer
- Parameters:
init_action (Action) – initial action before optimization starts
- Returns:
initialized optimizer state
- Return type:
OptimizerState
- compute_gradient(state: State, action: Action) Tuple[Cost_co, Action] [source]
Given a state-action pair of the MDP, computes the cost and gradient of the cost of a
dynamics
w.r.t. the action.- Parameters:
state (State) – state of the MDP
action (Action) – action of the MDP
- Returns:
cost and gradient of the action
- Return type:
Tuple[Cost_co, Action]
- update(action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState) Tuple[Action, OptimizerState] [source]
Update the action and optimizer state
- Parameters:
action (Action) – action of the MDP
gradient_value (GradientUpdates) – gradient of the action
opt_state (OptimizerState) – optimizer state
- Returns:
the new action and optimizer state
- Return type:
Tuple[Action, OptimizerState]
- class VectorizedGradientStep(gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState], vectorized_axis: AxisType)[source]
Bases:
BaseGradientStep
[AxisType
,State
,Action
,Cost_co
,OptimizerState
]A gradient step that is vectorized on a specific axis
- Parameters:
gradient_step (BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]) – the gradient step to be vectorized
vectorized_axis (jdc.Static[AxisType]) – “named” axis to be vectorized
- gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]
- vectorized_axis: jdc.Static[AxisType]
jax_utils.jax_tensor module
Wrappers around jax arrays
- class JaxTensor(*args, **kwargs)[source]
Bases:
Tensor
[Array
,AxisType
],ConvertibleToAxes
[AxisType
],Protocol
[AxisType
]Interface representing a
Tensor
with a JAXarray
.A
JaxTensor
is always “convertible to axes” (seejax_utils.pytree.ConvertibleToAxes
interface) usingjax_utils.pytree.ConvertibleToAxes.convert_to_axes()
method. When aTensor
is “converted to axes”, the associatedarray
attribute is no longer a JAX array but rather an integer indicating the dimension index of the specifiedaxis
(argument of methodconvert_to_axes
). Ifaxis
is flattened or not present,self.array
is set toNone
when converted to axes.Converting a
jax_utils.jax_tensor.JaxTensor
to axes is useful to determine thein_axes
andout_axes
arguments of jax.vmap and jax.pmap.- Parameters:
array (jnp.ndarray) – a JAX array (or an optional integer when “converted to axes”)
- array: Array
- is_broadcastable_with(other_tensor_or_shape: Tensor[Array, AxisType] | Tuple[int, ...]) bool [source]
Whether this tensor can be broadcasted with
other_tensor_or_shape
.- Parameters:
other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor
- Returns:
True
if and only if this tensor can be broadcasted withother_tensor_or_shape
,False
otherwise.
- Return type:
bool
- classmethod from_flattened_axes(array: Array, flattened_axes: AbstractSet[AxisType], **kwargs) Self [source]
Constructor when some axes are “flattened” in
array
.This method simply adds missing dimensions accordingly.
- Parameters:
array (jnp.ndarray) – a JAX array
flattened_axes (AbstractSet[AxisType]) – the named axes that are “flattened”
- Returns:
an instance of class
cls
- Return type:
Self
- convert_to_axes(axis: AxisType | None) Self [source]
Convert the
jax_utils.jax_tensor.JaxTensor
object so that it can be passed as thein_axes
orout_axes
arguments of jax.vmap and jax.pmap.- Parameters:
axis (Optional[AxisType]) – a “named” axis
- Returns:
- same object as
self
but witharray
replaced by the index ofaxis
i.e., if
axis
corresponds to thei
-th dimension ofarray
, thenarray
is replaced byi
, ifaxis
is a flattened axis or simply not present, thenarray
is replaced byNone
- same object as
- Return type:
Self
- class NonNegativeValues(*args, **kwargs)[source]
Bases:
Protocol
Interface to be used in combination with
jax_utils.jax_tensor.JaxTensor
in order to implement JAX tensors that are constrained to take non-negativevalues
.- Parameters:
array (jnp.ndarray) – a JAX array
- array: Array
- property values: Array
- class NonNegativeBudgetedValues(*args, **kwargs)[source]
Bases:
NonNegativeValues
,Protocol
Interface to be used in combination with
jax_utils.jax_tensor.JaxTensor
in order to implement JAX tensors that are constrained to both take non-negativevalues
and have their sum over the last axis bounded by a maximal “budget” (max_budget
).- Parameters:
array (jnp.ndarray) – a JAX array
max_budget (Number) – a non-negative number representing the maximum value that the sum over the last axis can take
- array: Array
- max_budget: Number
- property values: Array
- class RegularizedJaxCost(cost: JaxCostType, regularization: JaxRegularizedCostType, lagrangian_coefficient: Number = 1)[source]
Bases:
RegularizedArrayLikeCost
[Array
],BaseJaxCompilable
,Generic
[JaxCostType
,JaxRegularizedCostType
]Interface for “regularized” JAX costs i.e., costs of the form:
cost + lagrangian_coefficient * regularization
- Parameters:
cost (JaxCostType) – an JAX array of costs
regularization (JaxRegularizedCostType) – a JAX array of regularization costs
lagrangian_coefficient (jdc.Static[Number]) – a non-negative number quantifying the regularization weight
- cost: JaxCostType
- regularization: JaxRegularizedCostType
- lagrangian_coefficient: Number = 1
jax_utils.markov_decision_process module
High-level abstractions for decision problems (Markov Decision processes, etc…)
- class Dynamics(*args, **kwargs)[source]
Bases:
Protocol
[State
,Action_contra
,Cost_co
,Observation_co
]Interface defining the dynamics of a (Partially Observable) Markov Decision Process.
When an “agent” interacting with the (PO)MDP plays an “action” (a.k.a. “control”) in a given “state”, the (PO)MDP transitions to a new state and the agent observes some signal/feedback in the form of a “cost”/”reward” as well as additional “observations”.
A Dynamics is therefore a callable that maps a state-action pair to a state-cost-observation tuple.
- class CostRegularizer(*args, **kwargs)[source]
Bases:
Protocol
[State_contra
,Action_contra
,Cost_contra
,Observation_contra
,Cost_co
]Interface for callables that map any state-action-cost-observation tuple to a new “regularized” cost.
Example: one may want to penalize action with high norms, etc…
- class RegularizedDynamics(*args, **kwargs)[source]
Bases:
Dynamics
[State
,Action
,RegularizedCost
,Observation
],Protocol
[State
,Action
,Cost
,RegularizedCost
,Observation
]Interface defining a wrapper around class
jax_utils.markov_decision_process.Dynamics
that allows to add a regularization to the cost.A
RegularizedDynamics
is itself ajax_utils.markov_decision_process.Dynamics`
.- Parameters:
dynamics (Dynamics[State, Action, Cost, Observation]) – callable defining the dynamics of a (PO)MDP
cost_regularizer (CostRegularizer[State, Action, Cost, Observation, RegularizedCost]) – callable
transformation (defining a cost)
- cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]
jax_utils.optim module
Classes and methods to optimize functions involving jax arrays transformations via gradient descent
- class OptimizationState(iteration: int, state: State, action: Action, cost: Cost | None = None, optimizer_state: OptimizerState | None = None)[source]
Bases:
Generic
[State
,Action
,Cost
,OptimizerState
]The current state of an iterative optimization procedure involving the
jax_utils.markov_decision_process.Dynamics
of a Markov Decision Process (typically a cost minimization).- Parameters:
iteration (int) – the current iteration step
state (State) – the MDP state
action (Action) – the MDP current action
cost (Optional[Cost], optional) – the current cost associated to the state-action pair. Defaults to None.
optimizer_state (Optional[OptimizerState], optional) – The current
optax.OptState
. Defaults to None.
- iteration: int
- state: State
- action: Action
- cost: Cost | None = None
- optimizer_state: OptimizerState | None = None
- property scalar_cost: float
Returns the average
self.cost
- Returns:
average cost
- Return type:
float
- class OptimStoppingCondition(*args, **kwargs)[source]
Bases:
BaseJaxCompilable
,Protocol
[State
,Action
,Cost
,OptimizerState
]Interface for all stopping conditions of an iterative optimization procedure involving the
jax_utils.markov_decision_process.Dynamics
of a Markov Decision Process (typically a cost minimization).The stopping condition depends on an
OptimizationState
provided as input but may also collect information over steps.- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class OptimStoppingConditionsCombination(*args, **kwargs)[source]
Bases:
OptimStoppingCondition
[State
,Action
,Cost
,OptimizerState
],Protocol
This interface represents the combination of 2 stopping conditions.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- class OptimStoppingConditionIntersection(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
OptimStoppingConditionsCombination
[State
,Action
,Cost
,OptimizerState
]A combination of 2 stopping conditions using an “AND” operation.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() OptimStoppingConditionIntersection[State, Action, Cost, OptimizerState] [source]
Resets the stopping condition to an initial configuration.
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class OptimStoppingConditionUnion(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
OptimStoppingConditionsCombination
[State
,Action
,Cost
,OptimizerState
]A combination of 2 stopping conditions using an “OR” operation.
- Parameters:
stopping_condition_1 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
stopping_condition_2 (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() OptimStoppingConditionUnion[State, Action, Cost, OptimizerState] [source]
Resets the stopping condition to an initial configuration.
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class MaxIterationsStoppingCondition(max_iterations: int)[source]
Bases:
OptimStoppingCondition
[State
,Action
,Cost
,OptimizerState
]A stopping condition that stops when the number of iterations exceeds a given threshold.
- Parameters:
max_iterations (int) – maximal number of iterations before the stopping condition is raised
- max_iterations: int
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- property nb_iterations_upper_bound: int
An upper bound on the maximal number of iterations. Default is 1e20.
- class MinIterationsStoppingCondition(min_iterations: int)[source]
Bases:
OptimStoppingCondition
[State
,Action
,Cost
,OptimizerState
]A stopping condition that continues as long as the number of iterations is below a given threshold.
- Parameters:
min_iterations (int) – minimal number of iterations before the stopping condition is raised
- min_iterations: int
- stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- class MinDeltaActionStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True))[source]
Bases:
OptimStoppingCondition
[State
,JaxTensorType
,Cost
,OptimizerState
]Stops when the action stops significantly changing (as defined by relative & absolute tolerance). The previous action is saved in memory to compare it to the new action.
The delta between previous and new action is then compared to absolute and relative tolerance i.e, the stopping condition is raised when:
all(abs(new_action - previous_action) <= absolute tolerance + relative_tolerance * maximum(abs(new_action), abs(previous_action)))
- Parameters:
relative_tolerance (jnp.ndarray) – relative tolerance for action variations. Default is 1e-6.
absolute_tolerance (jnp.ndarray) – absolute tolerance for action variations. Default is 1e-6.
- relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- stop(optimization_state: OptimizationState[State, JaxTensorType, Cost, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() MinDeltaActionStoppingCondition[State, JaxTensorType, Cost, OptimizerState] [source]
Resets the stopping condition to an initial configuration.
- class MinDeltaCostStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), window_length: int = 2)[source]
Bases:
OptimStoppingCondition
[State
,Action
,Array
,OptimizerState
]Stops when the cost stops significantly decreasing (as defined by relative & absolute tolerance).
The
window_length
last cost values are saved in memory and the stopping condition is raised when the delta between the min and max of these values is below a threshold i.e., when:all(maximum(last_window_length_costs) - minimum(last_window_length_costs) < absolute_tolerance + relative_tolerance * minimum(last_window_length_costs))
- Parameters:
relative_tolerance (jnp.ndarray) – relative tolerance for cost variations. Default is 1e-6.
absolute_tolerance (jnp.ndarray) – absolute tolerance for cost variations. Default is 1e-6.
window_length (int) – maximal length of the queue storing the past history of costs (costs older than
window_length
time steps are discarded)
- relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
- window_length: int = 2
- stop(optimization_state: OptimizationState[State, Action, Array, OptimizerState]) bool [source]
Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.
- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_
- Returns:
True` if the optimization procedure should be stopped, False` otherwise.
- Return type:
bool
- reset() MinDeltaCostStoppingCondition[State, Action, OptimizerState] [source]
Resets the stopping condition to an initial configuration.
- class CostHistory(initlist=None)[source]
Bases:
UserList
[AverageableArrayLike
[Array
]]List of costs (typically jax arrays)
- class GradientDescentOptimizationLoop(gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState], stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]
Bases:
Generic
[AxisType
,State
,Action
,Cost
,Observation
,OptimizerState
]Gradient descent optimization procedure involving the
jax_utils.markov_decision_process.Dynamics
of a Markov Decision Process. This class is a callable that recursively appliesgradient_step
until astopping_condition
is met.- Parameters:
gradient_step (BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]) – a gradient step
stopping_condition (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition
- gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]
- stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState]
- resume(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) Tuple[OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost, OptimizerState], CostHistory] [source]
Resume optimization loop from
optimization_state
(cold start).- Parameters:
optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – state from which to resume the optimization procedure.
- Returns:
same outputs as
GradientDescentOptimizationLoop.__call__()
# pylint: disable=line-too-long- Return type:
OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost,OptimizerState], CostHistory,]
jax_utils.pytree module
Classes and methods to easily map pytrees to axes (e.g., for vectorization)
- class ConvertibleToAxes(*args, **kwargs)[source]
Bases:
Protocol
[AxisType_contra
]- convert_to_axes(axis: AxisType_contra | None) Self [source]
Returns an object that can be used in argument
in_axes
orout_axes
of jax.vmap or jax.pmap- Parameters:
axis (Optional[AxisType_contra]) – a “named” axis over which to apply vectorization.
- Returns:
same object as
self
but with all array-like objects replaced by axes- Return type:
Self
- pytree_to_axes(pytree: Any, vectorized_axis: AxisType, default_axis: int | None = None) Any [source]
Transform all the
ConvertibleToAxes
leafs of a given pytree to axes by applying methodConvertibleToAxes.convert_to_axes()
. This is useful for vectorizing functions involvingConvertibleToAxes
objects.- Parameters:
pytree (Any) – any Python pytree containing
ConvertibleToAxes
leafsvectorized_axis (AxisType) –
a “named” axis over which to apply vectorization
default_axis (Optional[int], optional) – A default axis value for leafs that are not
ConvertibleToAxes
. Defaults to None.
- Returns:
same pytree as given in input but where all
ConvertibleToAxes
are converted to axes.- Return type:
Any
jax_utils.tranform module
Jax arrays transformations (scaling, …)
- scale_jax_tensor(tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType] [source]
Scale the
values
of aJaxTensor
i.e., multiplyvalues
byscaling_factor
. Note that propertyvalues
inJaxTensor
differ from attributearray
.
- unscale_jax_tensor(scaled_tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType] [source]
Reverse transformation of function
scale_jax_tensor
.
- class JaxScaler(scaling_factors: ~jax.Array = Array(1., dtype=float32, weak_type=True), scaling_axes: ~typing.Set[~jax_utils.tranform.AxisType] = <factory>, tensor_types_to_scale: ~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]] = (<class 'jax_utils.jax_tensor.JaxTensor'>, ), tensor_types_not_to_scale: ~typing.Tuple[~jax_utils.jax_tensor.JaxTensor] = ())[source]
Bases:
BaseJaxCompilable
,Generic
[AxisType
]When applying gradient descent optimization algorithms, it is often helpful to scale all the arrays/tensors involved in the computations.
This class allows to easily scale/unscale
jax_utils.jax_tensor.JaxTensor
’s (even when the tensors are stored in a nested pytree structure).- Parameters:
scaling_factors (jnp.ndarray) – Multiplicative factors for scaling. Default to
jnp.array(1.0)
.scaling_axes (jdc.Static[Set[AxisType]]) – “Named” axes of
scaling_factors
. Thus, the number of axes should matchscaling_factors.ndim
. Default tojdc.field(default_factory=set)
.tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to
jax_utils.jax_tensor.JaxTensor
.tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to
tuple()
.
- scaling_factors: Array = Array(1., dtype=float32, weak_type=True)
- scaling_axes: Set[AxisType]
- classmethod from_tensor(tensor: ~jax_utils.jax_tensor.JaxTensor[~jax_utils.tranform.AxisType], scaling_axes: ~typing.Set[~jax_utils.tranform.AxisType], tensor_types_to_scale: ~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]] = (<class 'jax_utils.jax_tensor.JaxTensor'>,), tensor_types_not_to_scale: ~typing.Tuple[~jax_utils.jax_tensor.JaxTensor] = (), factor: float = 1.0) JaxScaler [source]
Alternative constructor that automatically computes the
scaling_factors
based on ajax_utils.jax_tensor.JaxTensor
, thescaling_axes
and a targetfactor
.If
my_scaler = JaxScaler.from_tensor(tensor=my_tensor, scaling_axes=my_axes, factor=my_factor)
thenmy_scaler.scale(my_tensor)
will be ajax_utils.jax_tensor.JaxTensor
withmy_tensor.mean_over_axes(axes=my_axes)
being equal to a “constant” Jax array with all elements equal tomy_factor
.- Parameters:
tensor (JaxTensor[AxisType]) – a jax tensor
scaling_axes (Set[AxisType]) – a set of “named” axes present in tensor
tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to
jax_utils.jax_tensor.JaxTensor
.tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to
tuple()
.
- Returns:
a scaler instance
- Return type:
- scale(pytree: Any) Any [source]
Scales any pytree containing
jax_utils.jax_tensor.JaxTensor
’s- Parameters:
pytree (Any) – pytree to be scaled
- Returns:
the scaled pytree
- Return type:
Any
- unscale(scaled_pytree: Any) Any [source]
Unscales any pytree containing
jax_utils.jax_tensor.JaxTensor
’s (reverse transformation of methodscale
).- Parameters:
scaled_pytree (Any) – pytree to be unscaled
- Returns:
the unscaled pytree
- Return type:
Any
- expand_scaling_factors(tensor_axes: TensorAxes[AxisType], missing_axes: Set[AxisType]) Array [source]
Expands the dimensions of
self.scaling_factors
to match thetensor_axes
given as inputs.This method only adds new dimensions of size 1.
- Parameters:
tensor_axes (TensorAxes[AxisType]) – a set of “named” axes
missing_axes (Set[AxisType])
- Returns:
- same as
self.scaling_factors
but with (empty) additional dimensions. The total number of dimensions of the returned array is equal to
len(tensor_axes)
.
- same as
- Return type:
jnp.ndarray
jax_utils.typing module
Define useful types for package
- class DataclassInstance(*args, **kwargs)[source]
Bases:
Protocol
Type for Python dataclasses (see
dataclasses.dataclass
).This code is copy-pasted from
_typeshed
(but is simpler to import).
- class HashableIndexingOrSlicing(*args, **kwargs)[source]
Bases:
DataclassInstance
,Hashable
,Protocol
Python slices and lists (of indices) are not hashable in Python (slices are hashable for Python version >= 3.12). This is a shared interface for hashable slices and indices.
- property values: slice | int | List[int]
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
- class HashableSlicing(start: int | None = None, stop: int | None = None, step: int | None = None)[source]
Bases:
HashableIndexingOrSlicing
Python slices are not hashable for Python version < 3.12. This class allows to define hashable slices.
- Parameters:
start (Optional[int]) – initial value of the slice (included)
stop (Optional[int]) – last value of the slice (excluded)
step (Optional[int]) – step between values
- start: int | None = None
- stop: int | None = None
- step: int | None = None
- property values: slice
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
- class HashableIndexing(*indices: int)[source]
Bases:
HashableIndexingOrSlicing
Python lists are not hashable. This class allows to define hashable list of indices.
- Parameters:
indices (Union[int, Tuple[int]]) – index or list of (integer-valued) indices
- indices: int | Tuple[int]
- property values: int | List[int]
Property returning index values in the form of a slice, an int (for singletons) or a list of int.
- Returns:
index values
- Return type:
Union[slice, int, List[int]]
jax_utils.vectorization module
Classes and methods to easily vectorize jax tensors
- class JaxVectorizableProtocol(*args, **kwargs)[source]
Bases:
JaxCompilableProtocol
,ConvertibleToAxes
[AxisType
],Protocol
Interface for classes manipulating JAX arrays. It defines a special attribute
vectorized_axis
corresponding to the “named” axis over which to apply vectorization.“Vectorizable” classes are also “compilable” (see interface
jax_utils.compilation.JaxCompilableProtocol
) and convertible to axes (seejax_utils.pytree.ConvertibleToAxes
interface).- Parameters:
vectorized_axis (AxisType) – “named” axis to be vectorized
- vectorized_axis: AxisType
- vectorize(in_default_axis: int | None = None, out_default_axis: int | None = -1) Callable[[Callable], Callable] [source]
Parametrized decorator for methods of classes implementing
JaxVectorizableProtocol
interface. It allows to vectorize the decorated methods according to a specified axis, namely attributeJaxVectorizableProtocol.vectorized_axis
(seeJaxVectorizableProtocol
).The pytrees of all inputs and outputs of the decorated method are explored and every JAX array as well as every class implementing the
jax_utils.pytree.ConvertibleToAxes
interface is considered a leaf of the pytree. When ajax_utils.pytree.ConvertibleToAxes
leaf is encountered, methodjax_utils.pytree.ConvertibleToAxes.convert_to_axes()
is used to determine the dimension to vectorize. When a JAX array leaf or any other leaf is encountered instead, the default valuein_default_axis
(for input arguments) orout_default_axis
(for output arguments) is used to determine the dimension to vectorize.- Parameters:
in_default_axis (Optional[int], optional) –
default index used in
in_axes
argument of method jax.vmap when a leaf that is notjax_utils.pytree.ConvertibleToAxes
is encountered in an input pytree. Defaults to None.out_default_axis (Optional[int], optional) –
default index used in
out_axes
argument of method jax.vmap when a leaf that is notjax_utils.pytree.ConvertibleToAxes
is encountered in an output pytree. Defaults to -1.
- Returns:
decorator with parameters
in_default_axis
andout_default_axis
specified- Return type:
Callable[[Callable], Callable]
- class JaxDataclassNestedConvertibleToAxes(*args, **kwargs)[source]
Bases:
DataclassInstance
,ConvertibleToAxes
[AxisType_contra
],BaseJaxCompilable
,Protocol
[AxisType_contra
]Interface for dataclasses that can be jit-compiled and are “convertible to axes”, with a concrete implementation of
convert_to_axes
method that is specific to dataclasses.- convert_to_axes(axis: AxisType_contra | None) Self [source]
Concrete implementation of
convert_to_axes
for dataclasses. All the fields of the dataclass are inspected and methodconvert_to_axes
is applied whenever the field isConvertibleToAxes
.- Parameters:
axis (Optional[AxisType_contra]) – a “named” axis
- Returns:
same object as
self
but with array-like fields converted to axes- Return type:
Self