Source code for jax_utils.compilation

"""Classes and methods to to jit-compile functions involving jax arrays transformations"""

from functools import wraps
from typing import Callable, Protocol

from jax import jit
from typing_extensions import Self


[docs] class JaxCompilableProtocol(Protocol): """All classes implementing this interface should implement property ``is_compilation_enabled`` indicating whether the methods involving JAX arrays should be `jit-compiled <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_. """ # pylint: disable=C0116 @property def is_compilation_enabled(self) -> bool: pass
[docs] class BaseJaxCompilable(JaxCompilableProtocol, Protocol): """Subclassing ``BaseCompilableJax`` allows to easily enable/disable `jit-compilation <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_ of methods involving JAX arrays. Use ``with_optional_jax_jit`` decorator to compile a method only when ``is_compilation_enabled`` is ``True`` (``False`` by default). To enable (resp. disable) jit-compilation, one only needs to call method ``enable_compilation`` (resp. ``disable_compilation``). By default, jit-compilation is disabled. """ @property def is_compilation_enabled(self) -> bool: return hasattr(self, "_is_compilation_enabled") and getattr(self, "_is_compilation_enabled") # pylint: disable=C0116
[docs] def enable_compilation(self) -> Self: object.__setattr__(self, "_is_compilation_enabled", True) return self
# pylint: disable=C0116
[docs] def disable_compilation(self) -> Self: object.__setattr__(self, "_is_compilation_enabled", False) return self
[docs] def jit_when_compilation_enabled(**jax_jit_args) -> Callable[[Callable], Callable]: """Parametrized decorator for methods of classes implementing ``CompilableJaxProtocol`` interface. Allows to `jit-compile <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_ some methods only when compilation is enabled. Returns: Callable[[Callable], Callable]: decorator with parameters ``jax_jit_args`` specified """ def decorator( func: Callable, ) -> Callable: @wraps(func) def wrapper(self: JaxCompilableProtocol, *args, **kwargs): if self.is_compilation_enabled: return jit(func, **jax_jit_args)(self, *args, **kwargs) return func(self, *args, **kwargs) return wrapper return decorator