jax_utils
- jax_utils package
- Submodules
- jax_utils.common_tensor module
check_ndim_in()
is_broadcastable()
TensorAxes
expand_dims_axis()
AverageableArrayLike
Tensor
Tensor.array
Tensor.check_array()
Tensor.getitem_from_axes()
Tensor.values
Tensor.reverse_values()
Tensor.dtype
Tensor.shape
Tensor.ndim
Tensor.axes
Tensor.actual_axes
Tensor.mean()
Tensor.mean_over_axes()
Tensor.sum()
Tensor.sum_over_axes()
Tensor.has()
Tensor.has_actual()
Tensor.index()
Tensor.reverse_index()
Tensor.size()
Tensor.is_broadcastable_with()
RegularizedArrayLikeCost
- jax_utils.compilation module
- jax_utils.dynamics module
- jax_utils.gradient module
- jax_utils.jax_tensor module
- jax_utils.markov_decision_process module
- jax_utils.optim module
OptimizationState
OptimStoppingCondition
OptimStoppingConditionsCombination
OptimStoppingConditionIntersection
OptimStoppingConditionUnion
MaxIterationsStoppingCondition
MinIterationsStoppingCondition
MinDeltaActionStoppingCondition
MinDeltaCostStoppingCondition
CostHistory
GradientDescentOptimizationLoop
- jax_utils.pytree module
- jax_utils.tranform module
- jax_utils.typing module
- jax_utils.vectorization module
- Module contents