0.1.0a1
Contents:
jax_utils
AI Helpers Jax Utils
Index
Index
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
|
W
A
absolute_tolerance (MinDeltaActionStoppingCondition attribute)
(MinDeltaCostStoppingCondition attribute)
action (OptimizationState attribute)
actual_axes (Tensor property)
array (JaxTensor attribute)
(NonNegativeBudgetedValues attribute)
(NonNegativeValues attribute)
(Tensor attribute)
AverageableArrayLike (class in jax_utils.common_tensor)
axes (Tensor property)
B
BaseGradientStep (class in jax_utils.gradient)
BaseJaxCompilable (class in jax_utils.compilation)
C
check_array() (Tensor method)
check_ndim_in() (in module jax_utils.common_tensor)
compute_gradient() (BaseGradientStep method)
(GradientStep method)
(JaxDynamics method)
(VectorizedGradientStep method)
convert_to_axes() (ConvertibleToAxes method)
(JaxDataclassNestedConvertibleToAxes method)
(JaxTensor method)
ConvertibleToAxes (class in jax_utils.pytree)
cost (OptimizationState attribute)
(RegularizedArrayLikeCost attribute)
(RegularizedJaxCost attribute)
cost_regularizer (RegularizedDynamics attribute)
(RegularizedJaxDynamics attribute)
CostHistory (class in jax_utils.optim)
CostRegularizer (class in jax_utils.markov_decision_process)
D
DataclassInstance (class in jax_utils.typing)
disable_compilation() (BaseJaxCompilable method)
(OptimStoppingConditionsCombination method)
dtype (Tensor property)
Dynamics (class in jax_utils.markov_decision_process)
dynamics (GradientStep attribute)
(RegularizedDynamics attribute)
(RegularizedJaxDynamics attribute)
(VectorizedJaxDynamics attribute)
E
enable_compilation() (BaseJaxCompilable method)
(OptimStoppingConditionsCombination method)
expand_dims_axis() (in module jax_utils.common_tensor)
expand_scaling_factors() (JaxScaler method)
F
from_flattened_axes() (JaxTensor class method)
from_tensor() (JaxScaler class method)
G
getitem_from_axes() (Tensor method)
gradient_step (GradientDescentOptimizationLoop attribute)
(VectorizedGradientStep attribute)
GradientDescentOptimizationLoop (class in jax_utils.optim)
GradientStep (class in jax_utils.gradient)
H
has() (Tensor method)
has_actual() (Tensor method)
HashableIndexing (class in jax_utils.typing)
HashableIndexingOrSlicing (class in jax_utils.typing)
HashableSlicing (class in jax_utils.typing)
I
index() (Tensor method)
indices (HashableIndexing attribute)
init_optimizer() (BaseGradientStep method)
(GradientStep method)
(VectorizedGradientStep method)
is_broadcastable() (in module jax_utils.common_tensor)
is_broadcastable_with() (JaxTensor method)
(Tensor method)
is_compilation_enabled (BaseJaxCompilable property)
(JaxCompilableProtocol property)
iteration (OptimizationState attribute)
J
jax_utils
module
jax_utils.common_tensor
module
jax_utils.compilation
module
jax_utils.dynamics
module
jax_utils.gradient
module
jax_utils.jax_tensor
module
jax_utils.markov_decision_process
module
jax_utils.optim
module
jax_utils.pytree
module
jax_utils.tranform
module
jax_utils.typing
module
jax_utils.vectorization
module
JaxCompilableProtocol (class in jax_utils.compilation)
JaxDataclassNestedConvertibleToAxes (class in jax_utils.vectorization)
JaxDynamics (class in jax_utils.dynamics)
JaxScaler (class in jax_utils.tranform)
JaxTensor (class in jax_utils.jax_tensor)
JaxVectorizableProtocol (class in jax_utils.vectorization)
jit_when_compilation_enabled() (in module jax_utils.compilation)
L
lagrangian_coefficient (RegularizedArrayLikeCost attribute)
(RegularizedJaxCost attribute)
M
mandatory (TensorAxes property)
max_budget (NonNegativeBudgetedValues attribute)
max_iterations (MaxIterationsStoppingCondition attribute)
MaxIterationsStoppingCondition (class in jax_utils.optim)
mean() (AverageableArrayLike method)
(RegularizedArrayLikeCost method)
(RegularizedJaxCost method)
(Tensor method)
mean_over_axes() (Tensor method)
min_iterations (MinIterationsStoppingCondition attribute)
MinDeltaActionStoppingCondition (class in jax_utils.optim)
MinDeltaCostStoppingCondition (class in jax_utils.optim)
MinIterationsStoppingCondition (class in jax_utils.optim)
module
jax_utils
jax_utils.common_tensor
jax_utils.compilation
jax_utils.dynamics
jax_utils.gradient
jax_utils.jax_tensor
jax_utils.markov_decision_process
jax_utils.optim
jax_utils.pytree
jax_utils.tranform
jax_utils.typing
jax_utils.vectorization
N
nb_iterations_upper_bound (MaxIterationsStoppingCondition property)
(OptimStoppingCondition property)
(OptimStoppingConditionIntersection property)
(OptimStoppingConditionUnion property)
ndim (Tensor property)
NonNegativeBudgetedValues (class in jax_utils.jax_tensor)
NonNegativeValues (class in jax_utils.jax_tensor)
O
OptimizationState (class in jax_utils.optim)
optimizer (GradientStep attribute)
optimizer_state (OptimizationState attribute)
OptimStoppingCondition (class in jax_utils.optim)
OptimStoppingConditionIntersection (class in jax_utils.optim)
OptimStoppingConditionsCombination (class in jax_utils.optim)
OptimStoppingConditionUnion (class in jax_utils.optim)
optional (TensorAxes property)
P
plot_scalar_costs() (CostHistory method)
pytree_to_axes() (in module jax_utils.pytree)
R
regularization (RegularizedArrayLikeCost attribute)
(RegularizedJaxCost attribute)
RegularizedArrayLikeCost (class in jax_utils.common_tensor)
RegularizedDynamics (class in jax_utils.markov_decision_process)
RegularizedJaxCost (class in jax_utils.jax_tensor)
RegularizedJaxDynamics (class in jax_utils.dynamics)
relative_tolerance (MinDeltaActionStoppingCondition attribute)
(MinDeltaCostStoppingCondition attribute)
reset() (MinDeltaActionStoppingCondition method)
(MinDeltaCostStoppingCondition method)
(OptimStoppingCondition method)
(OptimStoppingConditionIntersection method)
(OptimStoppingConditionUnion method)
resume() (GradientDescentOptimizationLoop method)
reverse_index() (Tensor method)
(TensorAxes method)
reverse_values() (NonNegativeBudgetedValues method)
(NonNegativeValues method)
(Tensor method)
S
scalar_cost (OptimizationState property)
scalar_cost() (JaxDynamics method)
scalar_costs() (CostHistory method)
scale() (JaxScaler method)
scale_jax_tensor() (in module jax_utils.tranform)
scaling_axes (JaxScaler attribute)
scaling_factors (JaxScaler attribute)
shape (Tensor property)
size() (Tensor method)
start (HashableSlicing attribute)
state (OptimizationState attribute)
step (HashableSlicing attribute)
stop (HashableSlicing attribute)
stop() (MaxIterationsStoppingCondition method)
(MinDeltaActionStoppingCondition method)
(MinDeltaCostStoppingCondition method)
(MinIterationsStoppingCondition method)
(OptimStoppingCondition method)
(OptimStoppingConditionIntersection method)
(OptimStoppingConditionUnion method)
stopping_condition (GradientDescentOptimizationLoop attribute)
stopping_condition_1 (OptimStoppingConditionIntersection attribute)
(OptimStoppingConditionsCombination attribute)
(OptimStoppingConditionUnion attribute)
stopping_condition_2 (OptimStoppingConditionIntersection attribute)
(OptimStoppingConditionsCombination attribute)
(OptimStoppingConditionUnion attribute)
sum() (Tensor method)
sum_over_axes() (Tensor method)
T
Tensor (class in jax_utils.common_tensor)
tensor_types_not_to_scale (JaxScaler attribute)
tensor_types_to_scale (JaxScaler attribute)
TensorAxes (class in jax_utils.common_tensor)
U
unscale() (JaxScaler method)
unscale_jax_tensor() (in module jax_utils.tranform)
update() (BaseGradientStep method)
(GradientStep method)
(VectorizedGradientStep method)
V
values (HashableIndexing property)
(HashableIndexingOrSlicing property)
(HashableSlicing property)
(NonNegativeBudgetedValues property)
(NonNegativeValues property)
(Tensor property)
vectorize() (in module jax_utils.vectorization)
vectorized_axis (JaxVectorizableProtocol attribute)
(VectorizedGradientStep attribute)
(VectorizedJaxDynamics attribute)
VectorizedGradientStep (class in jax_utils.gradient)
VectorizedJaxDynamics (class in jax_utils.dynamics)
W
window_length (MinDeltaCostStoppingCondition attribute)