jax_utils package

Submodules

jax_utils.common_tensor module

Wrappers around numpy/jax arrays

check_ndim_in(array: Array | ndarray, allowed_ndims: Iterable[int])[source]

Checks if the number of dimensions matches some allowed values.

Parameters:
  • array (Array) – numpy or JAX array

  • allowed_ndims (Iterable[int]) – number of dimensions allowed for array

Raises:

ValueError – when the number of dimensions of array is not present in allowed_shapes.

is_broadcastable(shape_1: Tuple[int, ...], shape_2: Tuple[int, ...]) bool[source]

Whether the shapes of 2 arrays/tensors can be broadcasted.

Parameters:
  • shape_1 (Tuple[int, ...]) – shape of the 1st array

  • shape_2 (Tuple[int, ...]) – shape of the 2nd array

Returns:

True if and only if the 2 arrays/tensors can be broadcasted, False otherwise.

Return type:

bool

class TensorAxes(initial: AbstractSet[T] | Sequence[T] | Iterable[T] | None = None, tensor_min_dim: int = 0)[source]

Bases: OrderedSet[T]

An ordered set of “named” array/tensor axes.

The position in the set defines the axis index in the tensor.

The first len(self) - tensor_min_dim axes are optional: if absent, they are assumed to be “flattened”.

The reason for not making the last tensor_min_dim axes optional instead of the first len(self) - tensor_min_dim axes is to follow the logic of array broadcasting (first dimension can be omitted).

reverse_index(key: T) int[source]

Computes the index of element key but rather than return a non-negative integer like method index, it returns a negative integer e.g., -1 if for the last element, -2 for the penultimate element, …

Parameters:

key (T) – any element present in self (which is an ordered set)

Returns:

negative integer corresponding to the position of key in self

(starting from the last and decrementing by 1 for each element)

Return type:

int

property mandatory: OrderedSet[T]

Returns: OrderedSet[T]: ordered set of non-optional axes.

property optional: OrderedSet[T]

Returns: OrderedSet[T]: ordered set of optional axes.

expand_dims_axis(tensor_axes: TensorAxes[AxisType], missing_axes: AbstractSet[AxisType]) Tuple[int, ...][source]

Compute axis argument to pass to numpy.expand_dims(a, axis) function when some axes in tensor_axes are not present in the numpy array a (they are “flattened”).

Parameters:
  • tensor_axes (TensorAxes[AxisType]) – some array/tensor axes

  • missing_axes (AbstractSet[AxisType]) – missing axes in the array/tensor

Raises:

ValueError – raises an error when missing_axes is not a subset of tensor_axes

Returns:

the axis argument to pass to numpy.expand_dims(a, axis)

Return type:

Tuple[int, …]

class AverageableArrayLike(*args, **kwargs)[source]

Bases: Protocol[ArrayType_co]

Shared interface of all classes with a mean method return a scalar array (corresponding to the mean values of the initial array).

Example of classes implementing this interface: (jax) numpy arrays, …

mean(*args, **kwargs) ArrayType_co[source]
Returns:

Should return a scalar array

Return type:

ArrayType_co

class Tensor(*args, **kwargs)[source]

Bases: DataclassInstance, AverageableArrayLike, Protocol[ArrayType, AxisType]

A wrapper for numpy/jax arrays with the following additional features:

  • axes are “named” to facilitate manipulation and debugging (an axis “name” could be a string or any other hashable AxisType)

  • the returned values of the tensor can be different from the actual array given as input at construction, this allows to easily implement “change of variables” (a.k.a., “substitution”)

Parameters:
  • _tensor_axes (ClassVar[TensorAxes[AxisType]]) – class attribute defining the axes of the array

  • array (ArrayType) – a (jax) numpy array containing all relevant data

array: ArrayType
check_array()[source]

Check the validity of the array attribute at construction. To be overriden if needed.

getitem_from_axes(axes_keys: Dict[AxisType, Any]) Self[source]

Analogue of method __getitem__ but where array slicing/indexing is explicitly applied to named axis.

Parameters:

axes_keys (Dict[AxisType, Any]) – a mapping between axes names and slices/list of indices/…

Returns:

a new Tensor of the same type with restricted data

Return type:

Self

property values: ArrayType

Override to apply transformations on self.array (“change of variable”).

For example, if one wants to manipulate (jax) tensors constrained to always have non-negative values in any situation (even if the tensor gets updated by some gradient descent procedure for instance), then one can override values by returning jax.nn.relu(self.array).

By default values implements the identity i.e., returns self.array.

Returns:

the transformed array

Return type:

ArrayType

reverse_values(values: ArrayType) ArrayType[source]

Inverse of the transformation implemented in values property.

By default, self.array = self.values so calling values is equivalent to calling array.

When the transformation mapping array to values is not the identity, reverse_values should be overriden accordingly.

Parameters:

values (ArrayType) – the output of property values

Returns:

the array obtained by apply the inverse transformation of

values

Return type:

ArrayType

property dtype: dtype

Get the type of self.values

property shape: Tuple[int, ...]

Get the shape of self.values

property ndim: int

Get the number of dimensions of self.values

property axes: TensorAxes[AxisType]

Possible “named” axes of the tensor

property actual_axes: TensorAxes[AxisType]

Actual axes of the tensor

mean(*args, **kwargs) ArrayType | Any[source]

Compute the mean on self.values on specified axes

mean_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any[source]

Averages values along one or several named axes.

Parameters:

axes (AbstractSet[AxisType]) – axes along which the means are computed

Returns:

new array containing the mean values

Return type:

Union[ArrayType, Any]

sum(*args, **kwargs) ArrayType | Any[source]

Get the sum of the array

sum_over_axes(axes: AbstractSet[AxisType]) ArrayType | Any[source]

Sums values along one or several named axes.

Parameters:

axes (AbstractSet[AxisType]) – axes along which the sums are computed

Returns:

new array containing the summed values

Return type:

Union[ArrayType, Any]

has(axis: AxisType) bool[source]

Whether axis is one of the possible axes of the tensor.

Parameters:

axis (AxisType) – a “named” axis

Returns:

True if axis is one of the possible axes of the tensor, False otherwise.

Return type:

bool

has_actual(axis: AxisType) bool[source]

Whether axis is one of the actual axes of the tensor.

Parameters:

axis (AxisType) – a “named” axis

Returns:

True if axis is one of the actual axes of the tensor, False otherwise.

Return type:

bool

index(axis: AxisType) int[source]

Index of axis in tensor i.e., returns i if and only if axis is the i-th dimension of the tensor.

Parameters:

axis (AxisType) – a “named” axis

Returns:

the corresponding index

Return type:

int

reverse_index(axis: AxisType) int[source]

Index of axis in tensor but in reverse order i.e., returns -i if and only if axis is the (n - i)-th dimension of the tensor (with n the total number of dimension). dimension of the tensor.

Parameters:

axis (AxisType) – a “named” axis

Returns:

the corresponding index in reverse order

Return type:

int

size(axis: AxisType) int[source]

The size of a given axis. Returns 0 when the axis is “flattened”.

Parameters:

axis (AxisType) – a “named” axis

Raises:

ValueError – if axis is not a possible axis

Returns:

the size of axis (0 when the axis is “flattened”)

Return type:

int

is_broadcastable_with(other_tensor_or_shape: Tensor[ArrayType, AxisType] | Tuple[int, ...]) bool[source]

Whether this tensor can be broadcasted with other_tensor_or_shape.

Parameters:

other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor

Returns:

True if and only if this tensor can be broadcasted with other_tensor_or_shape,

False otherwise.

Return type:

bool

class RegularizedArrayLikeCost(*args, **kwargs)[source]

Bases: AverageableArrayLike[ArrayType], Protocol[ArrayType]

Interface for “regularized” costs i.e., costs of the form: cost + lagrangian_coefficient * regularization

More about regularization.

Parameters:
  • cost (AverageableArrayLike[ArrayType]) – an array-like collection of costs

  • regularization (AverageableArrayLike[ArrayType]) – an array-like collection regularization costs

  • lagrangian_coefficient (Number) – a non-negative number quantifying the regularization weight

cost: AverageableArrayLike[ArrayType]
regularization: AverageableArrayLike[ArrayType]
lagrangian_coefficient: Number = 1
mean(*args, **kwargs) ArrayType | Any[source]
Returns:

Should return a scalar array

Return type:

ArrayType_co

jax_utils.compilation module

Classes and methods to to jit-compile functions involving jax arrays transformations

class JaxCompilableProtocol(*args, **kwargs)[source]

Bases: Protocol

All classes implementing this interface should implement property is_compilation_enabled indicating whether the methods involving JAX arrays should be jit-compiled.

property is_compilation_enabled: bool
class BaseJaxCompilable(*args, **kwargs)[source]

Bases: JaxCompilableProtocol, Protocol

Subclassing BaseCompilableJax allows to easily enable/disable jit-compilation of methods involving JAX arrays.

Use with_optional_jax_jit decorator to compile a method only when is_compilation_enabled is True (False by default).

To enable (resp. disable) jit-compilation, one only needs to call method enable_compilation (resp. disable_compilation). By default, jit-compilation is disabled.

property is_compilation_enabled: bool
enable_compilation() Self[source]
disable_compilation() Self[source]
jit_when_compilation_enabled(**jax_jit_args) Callable[[Callable], Callable][source]

Parametrized decorator for methods of classes implementing CompilableJaxProtocol interface. Allows to jit-compile some methods only when compilation is enabled.

Returns:

decorator with parameters jax_jit_args specified

Return type:

Callable[[Callable], Callable]

jax_utils.dynamics module

High-level abstractions for decision problems (Markov Decision processes, etc…) involving jax arrays transformations

class JaxDynamics(*args, **kwargs)[source]

Bases: JaxDataclassNestedConvertibleToAxes[AxisType_contra], Dynamics[State, Action_contra, Cost_co, Observation_co], Protocol[AxisType_contra, State, Action_contra, Cost_co, Observation_co]

A jax_utils.markov_decision_process.Dynamics involving Jax arrays transformations. The cost of such dynamics can be differentiated w.r.t. the action (see method jax_utils.dynamics.JaxDynamics.compute_gradient()).

scalar_cost(state: State, action: Action_contra) Array[source]

Averages the cost associated to the dynamics of an MDP.

Parameters:
  • state (State) – a state of the MDP

  • action (Action_contra) – an action of the MDP

Returns:

a scalar Jax array corresponding to the mean cost

Return type:

jnp.ndarray

compute_gradient(state: State, action: Action_contra) Tuple[Cost_co, Action_contra][source]

The cost of a jax_utils.dynamics.JaxDynamics can always be differentiated w.r.t. the action.

Parameters:
  • state (State) – a state of the MDP

  • action (Action_contra) – an action of the MDP

Returns:

the associated cost and the gradient of the action

Return type:

Tuple[Cost_co, Action_contra]

class VectorizedJaxDynamics(dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co], vectorized_axis: AxisType)[source]

Bases: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]

A jax_utils.dynamics.JaxDynamics that can be vectorized over some “named” axis.

Parameters:
  • dynamics (JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]) – the jax_utils.dynamics.JaxDynamics to be vectorized

  • vectorized_axis (jdc.Static[AxisType]) – “named” axis to map over

dynamics: JaxDynamics[AxisType, State, Action_contra, Cost_co, Observation_co]
vectorized_axis: AxisType
class RegularizedJaxDynamics(dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation], cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost])[source]

Bases: RegularizedDynamics[State, Action, Cost, RegularizedCost, Observation], JaxDynamics[AxisType_contra, State, Action, Cost, Observation]

A jax_utils.dynamics.JaxDynamics with cost regularization. See also jax_utils.markov_decision_process.RegularizedDynamics.

Parameters:
dynamics: JaxDynamics[AxisType_contra, State, Action, Cost, Observation]
cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]

jax_utils.gradient module

Classes and methods to compute the gradient of functions involving jax arrays transformations

class BaseGradientStep(*args, **kwargs)[source]

Bases: JaxDataclassNestedConvertibleToAxes[AxisType_contra], Protocol[AxisType_contra, State_contra, Action, Cost_co, OptimizerState]

Interface for gradient steps on the jax_utils.markov_decision_process.Dynamics cost of a Markov Decision Process.

init_optimizer(init_action: Action) OptimizerState[source]
compute_gradient(state: State_contra, action: Action) Tuple[Cost_co, Action][source]
update(action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState) Tuple[Action, OptimizerState][source]
class GradientStep(optimizer: GradientTransformation, dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co])[source]

Bases: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState], Generic[AxisType, State, Action, Cost_co, Observation_co, OptimizerState]

The gradient step on the jax_utils.markov_decision_process.Dynamics cos of a Markov Decision Process.

Parameters:
  • optimizer (jdc.Static[optax.GradientTransformation]) – an optimizer (defining the stochastic gradient descent variant: RMSProp, Adam, etc..)

  • dynamics (JaxDynamics[Any, State, Action, Cost_co, Observation_co]) – an MDP dynamics

optimizer: jdc.Static[optax.GradientTransformation]
dynamics: JaxDynamics[Any, State, Action, Cost_co, Observation_co]
init_optimizer(init_action: Action) OptimizerState[source]

Initialize optimizer

Parameters:

init_action (Action) – initial action before optimization starts

Returns:

initialized optimizer state

Return type:

OptimizerState

compute_gradient(state: State, action: Action) Tuple[Cost_co, Action][source]

Given a state-action pair of the MDP, computes the cost and gradient of the cost of a dynamics w.r.t. the action.

Parameters:
  • state (State) – state of the MDP

  • action (Action) – action of the MDP

Returns:

cost and gradient of the action

Return type:

Tuple[Cost_co, Action]

update(action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState) Tuple[Action, OptimizerState][source]

Update the action and optimizer state

Parameters:
  • action (Action) – action of the MDP

  • gradient_value (GradientUpdates) – gradient of the action

  • opt_state (OptimizerState) – optimizer state

Returns:

the new action and optimizer state

Return type:

Tuple[Action, OptimizerState]

class VectorizedGradientStep(gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState], vectorized_axis: AxisType)[source]

Bases: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]

A gradient step that is vectorized on a specific axis

Parameters:
  • gradient_step (BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]) – the gradient step to be vectorized

  • vectorized_axis (jdc.Static[AxisType]) – “named” axis to be vectorized

gradient_step: BaseGradientStep[AxisType, State, Action, Cost_co, OptimizerState]
vectorized_axis: jdc.Static[AxisType]
init_optimizer(init_action: Action) OptimizerState[source]
compute_gradient(state: State, action: Action) Tuple[Cost_co, Action][source]
update(action: Action, gradient_value: GradientUpdates, opt_state: OptimizerState) Tuple[Action, OptimizerState][source]

jax_utils.jax_tensor module

Wrappers around jax arrays

class JaxTensor(*args, **kwargs)[source]

Bases: Tensor[Array, AxisType], ConvertibleToAxes[AxisType], Protocol[AxisType]

Interface representing a Tensor with a JAX array.

A JaxTensor is always “convertible to axes” (see jax_utils.pytree.ConvertibleToAxes interface) using jax_utils.pytree.ConvertibleToAxes.convert_to_axes() method. When a Tensor is “converted to axes”, the associated array attribute is no longer a JAX array but rather an integer indicating the dimension index of the specified axis (argument of method convert_to_axes). If axis is flattened or not present, self.array is set to None when converted to axes.

Converting a jax_utils.jax_tensor.JaxTensor to axes is useful to determine the in_axes and out_axes arguments of jax.vmap and jax.pmap.

Parameters:

array (jnp.ndarray) – a JAX array (or an optional integer when “converted to axes”)

array: Array
is_broadcastable_with(other_tensor_or_shape: Tensor[Array, AxisType] | Tuple[int, ...]) bool[source]

Whether this tensor can be broadcasted with other_tensor_or_shape.

Parameters:

other_tensor_or_shape (Union[Tensor[ArrayType, AxisType], Tuple[int, ...]]) – another tensor

Returns:

True if and only if this tensor can be broadcasted with other_tensor_or_shape,

False otherwise.

Return type:

bool

classmethod from_flattened_axes(array: Array, flattened_axes: AbstractSet[AxisType], **kwargs) Self[source]

Constructor when some axes are “flattened” in array.

This method simply adds missing dimensions accordingly.

Parameters:
  • array (jnp.ndarray) – a JAX array

  • flattened_axes (AbstractSet[AxisType]) – the named axes that are “flattened”

Returns:

an instance of class cls

Return type:

Self

convert_to_axes(axis: AxisType | None) Self[source]

Convert the jax_utils.jax_tensor.JaxTensor object so that it can be passed as the in_axes or out_axes arguments of jax.vmap and jax.pmap.

Parameters:

axis (Optional[AxisType]) – a “named” axis

Returns:

same object as self but with array replaced by the index of axis i.e.,

if axis corresponds to the i-th dimension of array, then array is replaced by i, if axis is a flattened axis or simply not present, then array is replaced by None

Return type:

Self

class NonNegativeValues(*args, **kwargs)[source]

Bases: Protocol

Interface to be used in combination with jax_utils.jax_tensor.JaxTensor in order to implement JAX tensors that are constrained to take non-negative values.

Parameters:

array (jnp.ndarray) – a JAX array

array: Array
property values: Array
reverse_values(values: Array) Array[source]
class NonNegativeBudgetedValues(*args, **kwargs)[source]

Bases: NonNegativeValues, Protocol

Interface to be used in combination with jax_utils.jax_tensor.JaxTensor in order to implement JAX tensors that are constrained to both take non-negative values and have their sum over the last axis bounded by a maximal “budget” (max_budget).

Parameters:
  • array (jnp.ndarray) – a JAX array

  • max_budget (Number) – a non-negative number representing the maximum value that the sum over the last axis can take

array: Array
max_budget: Number
property values: Array
reverse_values(values: Array) Array[source]
class RegularizedJaxCost(cost: JaxCostType, regularization: JaxRegularizedCostType, lagrangian_coefficient: Number = 1)[source]

Bases: RegularizedArrayLikeCost[Array], BaseJaxCompilable, Generic[JaxCostType, JaxRegularizedCostType]

Interface for “regularized” JAX costs i.e., costs of the form: cost + lagrangian_coefficient * regularization

Parameters:
  • cost (JaxCostType) – an JAX array of costs

  • regularization (JaxRegularizedCostType) – a JAX array of regularization costs

  • lagrangian_coefficient (jdc.Static[Number]) – a non-negative number quantifying the regularization weight

cost: JaxCostType
regularization: JaxRegularizedCostType
lagrangian_coefficient: Number = 1
mean(*args, **kwargs) Array[source]
Returns:

Should return a scalar array

Return type:

ArrayType_co

jax_utils.markov_decision_process module

High-level abstractions for decision problems (Markov Decision processes, etc…)

class Dynamics(*args, **kwargs)[source]

Bases: Protocol[State, Action_contra, Cost_co, Observation_co]

Interface defining the dynamics of a (Partially Observable) Markov Decision Process.

When an “agent” interacting with the (PO)MDP plays an “action” (a.k.a. “control”) in a given “state”, the (PO)MDP transitions to a new state and the agent observes some signal/feedback in the form of a “cost”/”reward” as well as additional “observations”.

A Dynamics is therefore a callable that maps a state-action pair to a state-cost-observation tuple.

class CostRegularizer(*args, **kwargs)[source]

Bases: Protocol[State_contra, Action_contra, Cost_contra, Observation_contra, Cost_co]

Interface for callables that map any state-action-cost-observation tuple to a new “regularized” cost.

More about regularization.

Example: one may want to penalize action with high norms, etc…

class RegularizedDynamics(*args, **kwargs)[source]

Bases: Dynamics[State, Action, RegularizedCost, Observation], Protocol[State, Action, Cost, RegularizedCost, Observation]

Interface defining a wrapper around class jax_utils.markov_decision_process.Dynamics that allows to add a regularization to the cost.

A RegularizedDynamics is itself a jax_utils.markov_decision_process.Dynamics`.

More about regularization.

Parameters:
  • dynamics (Dynamics[State, Action, Cost, Observation]) – callable defining the dynamics of a (PO)MDP

  • cost_regularizer (CostRegularizer[State, Action, Cost, Observation, RegularizedCost]) – callable

  • transformation (defining a cost)

dynamics: Dynamics[State, Action, Cost, Observation]
cost_regularizer: CostRegularizer[State, Action, Cost, Observation, RegularizedCost]

jax_utils.optim module

Classes and methods to optimize functions involving jax arrays transformations via gradient descent

class OptimizationState(iteration: int, state: State, action: Action, cost: Cost | None = None, optimizer_state: OptimizerState | None = None)[source]

Bases: Generic[State, Action, Cost, OptimizerState]

The current state of an iterative optimization procedure involving the jax_utils.markov_decision_process.Dynamics of a Markov Decision Process (typically a cost minimization).

Parameters:
  • iteration (int) – the current iteration step

  • state (State) – the MDP state

  • action (Action) – the MDP current action

  • cost (Optional[Cost], optional) – the current cost associated to the state-action pair. Defaults to None.

  • optimizer_state (Optional[OptimizerState], optional) – The current optax.OptState. Defaults to None.

iteration: int
state: State
action: Action
cost: Cost | None = None
optimizer_state: OptimizerState | None = None
property scalar_cost: float

Returns the average self.cost

Returns:

average cost

Return type:

float

class OptimStoppingCondition(*args, **kwargs)[source]

Bases: BaseJaxCompilable, Protocol[State, Action, Cost, OptimizerState]

Interface for all stopping conditions of an iterative optimization procedure involving the jax_utils.markov_decision_process.Dynamics of a Markov Decision Process (typically a cost minimization).

The stopping condition depends on an OptimizationState provided as input but may also collect information over steps.

stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

reset() Self[source]

Resets the stopping condition to an initial configuration.

property nb_iterations_upper_bound: int

An upper bound on the maximal number of iterations. Default is 1e20.

class OptimStoppingConditionsCombination(*args, **kwargs)[source]

Bases: OptimStoppingCondition[State, Action, Cost, OptimizerState], Protocol

This interface represents the combination of 2 stopping conditions.

Parameters:
stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
enable_compilation() Self[source]
disable_compilation() Self[source]
class OptimStoppingConditionIntersection(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]

Bases: OptimStoppingConditionsCombination[State, Action, Cost, OptimizerState]

A combination of 2 stopping conditions using an “AND” operation.

Parameters:
stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

reset() OptimStoppingConditionIntersection[State, Action, Cost, OptimizerState][source]

Resets the stopping condition to an initial configuration.

property nb_iterations_upper_bound: int

An upper bound on the maximal number of iterations. Default is 1e20.

class OptimStoppingConditionUnion(stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState], stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]

Bases: OptimStoppingConditionsCombination[State, Action, Cost, OptimizerState]

A combination of 2 stopping conditions using an “OR” operation.

Parameters:
stopping_condition_1: OptimStoppingCondition[State, Action, Cost, OptimizerState]
stopping_condition_2: OptimStoppingCondition[State, Action, Cost, OptimizerState]
stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

reset() OptimStoppingConditionUnion[State, Action, Cost, OptimizerState][source]

Resets the stopping condition to an initial configuration.

property nb_iterations_upper_bound: int

An upper bound on the maximal number of iterations. Default is 1e20.

class MaxIterationsStoppingCondition(max_iterations: int)[source]

Bases: OptimStoppingCondition[State, Action, Cost, OptimizerState]

A stopping condition that stops when the number of iterations exceeds a given threshold.

Parameters:

max_iterations (int) – maximal number of iterations before the stopping condition is raised

max_iterations: int
stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

property nb_iterations_upper_bound: int

An upper bound on the maximal number of iterations. Default is 1e20.

class MinIterationsStoppingCondition(min_iterations: int)[source]

Bases: OptimStoppingCondition[State, Action, Cost, OptimizerState]

A stopping condition that continues as long as the number of iterations is below a given threshold.

Parameters:

min_iterations (int) – minimal number of iterations before the stopping condition is raised

min_iterations: int
stop(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

class MinDeltaActionStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True))[source]

Bases: OptimStoppingCondition[State, JaxTensorType, Cost, OptimizerState]

Stops when the action stops significantly changing (as defined by relative & absolute tolerance). The previous action is saved in memory to compare it to the new action.

The delta between previous and new action is then compared to absolute and relative tolerance i.e, the stopping condition is raised when: all(abs(new_action - previous_action) <= absolute tolerance + relative_tolerance * maximum(abs(new_action), abs(previous_action)))

Parameters:
  • relative_tolerance (jnp.ndarray) – relative tolerance for action variations. Default is 1e-6.

  • absolute_tolerance (jnp.ndarray) – absolute tolerance for action variations. Default is 1e-6.

relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
stop(optimization_state: OptimizationState[State, JaxTensorType, Cost, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

reset() MinDeltaActionStoppingCondition[State, JaxTensorType, Cost, OptimizerState][source]

Resets the stopping condition to an initial configuration.

class MinDeltaCostStoppingCondition(relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True), window_length: int = 2)[source]

Bases: OptimStoppingCondition[State, Action, Array, OptimizerState]

Stops when the cost stops significantly decreasing (as defined by relative & absolute tolerance).

The window_length last cost values are saved in memory and the stopping condition is raised when the delta between the min and max of these values is below a threshold i.e., when: all(maximum(last_window_length_costs) - minimum(last_window_length_costs) < absolute_tolerance + relative_tolerance * minimum(last_window_length_costs))

Parameters:
  • relative_tolerance (jnp.ndarray) – relative tolerance for cost variations. Default is 1e-6.

  • absolute_tolerance (jnp.ndarray) – absolute tolerance for cost variations. Default is 1e-6.

  • window_length (int) – maximal length of the queue storing the past history of costs (costs older than window_length time steps are discarded)

relative_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
absolute_tolerance: Array = Array(1.e-06, dtype=float32, weak_type=True)
window_length: int = 2
stop(optimization_state: OptimizationState[State, Action, Array, OptimizerState]) bool[source]

Returns a boolean inidicating whether the optimization procedure should be stopped. This class should be overriven in every concrete class.

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – _description_

Returns:

True` if the optimization procedure should be stopped, False` otherwise.

Return type:

bool

reset() MinDeltaCostStoppingCondition[State, Action, OptimizerState][source]

Resets the stopping condition to an initial configuration.

class CostHistory(initlist=None)[source]

Bases: UserList[AverageableArrayLike[Array]]

List of costs (typically jax arrays)

scalar_costs() List[float][source]
Returns:

list of cost means

Return type:

List[float]

plot_scalar_costs()[source]

Plots cost history

class GradientDescentOptimizationLoop(gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState], stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState])[source]

Bases: Generic[AxisType, State, Action, Cost, Observation, OptimizerState]

Gradient descent optimization procedure involving the jax_utils.markov_decision_process.Dynamics of a Markov Decision Process. This class is a callable that recursively applies gradient_step until a stopping_condition is met.

Parameters:
  • gradient_step (BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]) – a gradient step

  • stopping_condition (OptimStoppingCondition[State, Action, Cost, OptimizerState]) – a stopping condition

gradient_step: BaseGradientStep[AxisType, State, Action, Cost, OptimizerState]
stopping_condition: OptimStoppingCondition[State, Action, Cost, OptimizerState]
resume(optimization_state: OptimizationState[State, Action, Cost, OptimizerState]) Tuple[OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost, OptimizerState], CostHistory][source]

Resume optimization loop from optimization_state (cold start).

Parameters:

optimization_state (OptimizationState[State, Action, Cost, OptimizerState]) – state from which to resume the optimization procedure.

Returns:

same outputs as GradientDescentOptimizationLoop.__call__() # pylint: disable=line-too-long

Return type:

OptimizationState[State, Action, Cost, OptimizerState], OptimizationState[State, Action, Cost,OptimizerState], CostHistory,]

jax_utils.pytree module

Classes and methods to easily map pytrees to axes (e.g., for vectorization)

class ConvertibleToAxes(*args, **kwargs)[source]

Bases: Protocol[AxisType_contra]

convert_to_axes(axis: AxisType_contra | None) Self[source]

Returns an object that can be used in argument in_axes or out_axes of jax.vmap or jax.pmap

Parameters:

axis (Optional[AxisType_contra]) – a “named” axis over which to apply vectorization.

Returns:

same object as self but with all array-like objects replaced by axes

Return type:

Self

pytree_to_axes(pytree: Any, vectorized_axis: AxisType, default_axis: int | None = None) Any[source]

Transform all the ConvertibleToAxes leafs of a given pytree to axes by applying method ConvertibleToAxes.convert_to_axes(). This is useful for vectorizing functions involving ConvertibleToAxes objects.

Parameters:
  • pytree (Any) – any Python pytree containing ConvertibleToAxes leafs

  • vectorized_axis (AxisType) –

    a “named” axis over which to apply vectorization

  • default_axis (Optional[int], optional) – A default axis value for leafs that are not ConvertibleToAxes. Defaults to None.

Returns:

same pytree as given in input but where all ConvertibleToAxes are converted to axes.

Return type:

Any

jax_utils.tranform module

Jax arrays transformations (scaling, …)

scale_jax_tensor(tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType][source]

Scale the values of a JaxTensor i.e., multiply values by scaling_factor. Note that property values in JaxTensor differ from attribute array.

Parameters:
  • tensor (JaxTensor[AxisType]) – jax tensor to be scaled

  • scaling_factor (jnp.ndarray) – the scaling factors to apply for scaling

Returns:

scaled jax tensor

Return type:

JaxTensor[AxisType]

unscale_jax_tensor(scaled_tensor: JaxTensor[AxisType], scaling_factor: Array) JaxTensor[AxisType][source]

Reverse transformation of function scale_jax_tensor.

Parameters:
  • tensor (JaxTensor[AxisType]) – jax tensor to be unscaled

  • scaling_factor (jnp.ndarray) – the scaling factors applied when scaling the original jax tensor

Returns:

unscaled jax tensor

Return type:

JaxTensor[AxisType]

class JaxScaler(scaling_factors: ~jax.Array = Array(1., dtype=float32, weak_type=True), scaling_axes: ~typing.Set[~jax_utils.tranform.AxisType] = <factory>, tensor_types_to_scale: ~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]] = (<class 'jax_utils.jax_tensor.JaxTensor'>, ), tensor_types_not_to_scale: ~typing.Tuple[~jax_utils.jax_tensor.JaxTensor] = ())[source]

Bases: BaseJaxCompilable, Generic[AxisType]

When applying gradient descent optimization algorithms, it is often helpful to scale all the arrays/tensors involved in the computations.

This class allows to easily scale/unscale jax_utils.jax_tensor.JaxTensor’s (even when the tensors are stored in a nested pytree structure).

Parameters:
  • scaling_factors (jnp.ndarray) – Multiplicative factors for scaling. Default to jnp.array(1.0).

  • scaling_axes (jdc.Static[Set[AxisType]]) – “Named” axes of scaling_factors. Thus, the number of axes should match scaling_factors.ndim. Default to jdc.field(default_factory=set).

  • tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to jax_utils.jax_tensor.JaxTensor.

  • tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to tuple().

scaling_factors: Array = Array(1., dtype=float32, weak_type=True)
scaling_axes: Set[AxisType]
tensor_types_to_scale: Tuple[Type[JaxTensor]] = (<class 'jax_utils.jax_tensor.JaxTensor'>,)
tensor_types_not_to_scale: Tuple[JaxTensor] = ()
classmethod from_tensor(tensor: ~jax_utils.jax_tensor.JaxTensor[~jax_utils.tranform.AxisType], scaling_axes: ~typing.Set[~jax_utils.tranform.AxisType], tensor_types_to_scale: ~typing.Tuple[~typing.Type[~jax_utils.jax_tensor.JaxTensor]] = (<class 'jax_utils.jax_tensor.JaxTensor'>,), tensor_types_not_to_scale: ~typing.Tuple[~jax_utils.jax_tensor.JaxTensor] = (), factor: float = 1.0) JaxScaler[source]

Alternative constructor that automatically computes the scaling_factors based on a jax_utils.jax_tensor.JaxTensor, the scaling_axes and a target factor.

If my_scaler = JaxScaler.from_tensor(tensor=my_tensor, scaling_axes=my_axes, factor=my_factor) then my_scaler.scale(my_tensor) will be a jax_utils.jax_tensor.JaxTensor with my_tensor.mean_over_axes(axes=my_axes) being equal to a “constant” Jax array with all elements equal to my_factor.

Parameters:
  • tensor (JaxTensor[AxisType]) – a jax tensor

  • scaling_axes (Set[AxisType]) – a set of “named” axes present in tensor

  • tensor_types_to_scale (jdc.Static[Tuple[Type[JaxTensor]]]) – Object types that should be scaled in a pytree. Default to jax_utils.jax_tensor.JaxTensor.

  • tensor_types_not_to_scale (jdc.Static[Tuple[JaxTensor]]) – Object types that should not be scaled in a pytree. Default to tuple().

Returns:

a scaler instance

Return type:

JaxScaler

scale(pytree: Any) Any[source]

Scales any pytree containing jax_utils.jax_tensor.JaxTensor’s

Parameters:

pytree (Any) – pytree to be scaled

Returns:

the scaled pytree

Return type:

Any

unscale(scaled_pytree: Any) Any[source]

Unscales any pytree containing jax_utils.jax_tensor.JaxTensor’s (reverse transformation of method scale).

Parameters:

scaled_pytree (Any) – pytree to be unscaled

Returns:

the unscaled pytree

Return type:

Any

expand_scaling_factors(tensor_axes: TensorAxes[AxisType], missing_axes: Set[AxisType]) Array[source]

Expands the dimensions of self.scaling_factors to match the tensor_axes given as inputs.

This method only adds new dimensions of size 1.

Parameters:
  • tensor_axes (TensorAxes[AxisType]) – a set of “named” axes

  • missing_axes (Set[AxisType])

Returns:

same as self.scaling_factors but with (empty) additional dimensions.

The total number of dimensions of the returned array is equal to len(tensor_axes).

Return type:

jnp.ndarray

jax_utils.typing module

Define useful types for package

class DataclassInstance(*args, **kwargs)[source]

Bases: Protocol

Type for Python dataclasses (see dataclasses.dataclass).

This code is copy-pasted from _typeshed (but is simpler to import).

class HashableIndexingOrSlicing(*args, **kwargs)[source]

Bases: DataclassInstance, Hashable, Protocol

Python slices and lists (of indices) are not hashable in Python (slices are hashable for Python version >= 3.12). This is a shared interface for hashable slices and indices.

property values: slice | int | List[int]

Property returning index values in the form of a slice, an int (for singletons) or a list of int.

Returns:

index values

Return type:

Union[slice, int, List[int]]

class HashableSlicing(start: int | None = None, stop: int | None = None, step: int | None = None)[source]

Bases: HashableIndexingOrSlicing

Python slices are not hashable for Python version < 3.12. This class allows to define hashable slices.

Parameters:
  • start (Optional[int]) – initial value of the slice (included)

  • stop (Optional[int]) – last value of the slice (excluded)

  • step (Optional[int]) – step between values

start: int | None = None
stop: int | None = None
step: int | None = None
property values: slice

Property returning index values in the form of a slice, an int (for singletons) or a list of int.

Returns:

index values

Return type:

Union[slice, int, List[int]]

class HashableIndexing(*indices: int)[source]

Bases: HashableIndexingOrSlicing

Python lists are not hashable. This class allows to define hashable list of indices.

Parameters:

indices (Union[int, Tuple[int]]) – index or list of (integer-valued) indices

indices: int | Tuple[int]
property values: int | List[int]

Property returning index values in the form of a slice, an int (for singletons) or a list of int.

Returns:

index values

Return type:

Union[slice, int, List[int]]

jax_utils.vectorization module

Classes and methods to easily vectorize jax tensors

class JaxVectorizableProtocol(*args, **kwargs)[source]

Bases: JaxCompilableProtocol, ConvertibleToAxes[AxisType], Protocol

Interface for classes manipulating JAX arrays. It defines a special attribute vectorized_axis corresponding to the “named” axis over which to apply vectorization.

“Vectorizable” classes are also “compilable” (see interface jax_utils.compilation.JaxCompilableProtocol) and convertible to axes (see jax_utils.pytree.ConvertibleToAxes interface).

Parameters:

vectorized_axis (AxisType) – “named” axis to be vectorized

vectorized_axis: AxisType
vectorize(in_default_axis: int | None = None, out_default_axis: int | None = -1) Callable[[Callable], Callable][source]

Parametrized decorator for methods of classes implementing JaxVectorizableProtocol interface. It allows to vectorize the decorated methods according to a specified axis, namely attribute JaxVectorizableProtocol.vectorized_axis (see JaxVectorizableProtocol).

The pytrees of all inputs and outputs of the decorated method are explored and every JAX array as well as every class implementing the jax_utils.pytree.ConvertibleToAxes interface is considered a leaf of the pytree. When a jax_utils.pytree.ConvertibleToAxes leaf is encountered, method jax_utils.pytree.ConvertibleToAxes.convert_to_axes() is used to determine the dimension to vectorize. When a JAX array leaf or any other leaf is encountered instead, the default value in_default_axis (for input arguments) or out_default_axis (for output arguments) is used to determine the dimension to vectorize.

Parameters:
  • in_default_axis (Optional[int], optional) –

    default index used in in_axes argument of method jax.vmap when a leaf that is not jax_utils.pytree.ConvertibleToAxes is encountered in an input pytree. Defaults to None.

  • out_default_axis (Optional[int], optional) –

    default index used in out_axes argument of method jax.vmap when a leaf that is not jax_utils.pytree.ConvertibleToAxes is encountered in an output pytree. Defaults to -1.

Returns:

decorator with parameters in_default_axis and out_default_axis specified

Return type:

Callable[[Callable], Callable]

class JaxDataclassNestedConvertibleToAxes(*args, **kwargs)[source]

Bases: DataclassInstance, ConvertibleToAxes[AxisType_contra], BaseJaxCompilable, Protocol[AxisType_contra]

Interface for dataclasses that can be jit-compiled and are “convertible to axes”, with a concrete implementation of convert_to_axes method that is specific to dataclasses.

convert_to_axes(axis: AxisType_contra | None) Self[source]

Concrete implementation of convert_to_axes for dataclasses. All the fields of the dataclass are inspected and method convert_to_axes is applied whenever the field is ConvertibleToAxes.

Parameters:

axis (Optional[AxisType_contra]) – a “named” axis

Returns:

same object as self but with array-like fields converted to axes

Return type:

Self

Module contents