diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py index 3af7ad9..84e6e19 100644 --- a/pathwaysutils/__init__.py +++ b/pathwaysutils/__init__.py @@ -60,6 +60,11 @@ def _is_persistence_enabled(): pathways_orbax_handler.register_pathways_handlers( datetime.timedelta(minutes=10) ) + + # Turn off JAX compilation cache because Pathways handles its own compilation + # cache. + jax.config.update("jax_enable_compilation_cache", False) + try: cloud_logging.setup() except OSError as e: