diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py index 84e6e19..9855df7 100644 --- a/pathwaysutils/__init__.py +++ b/pathwaysutils/__init__.py @@ -14,9 +14,9 @@ """Package of Pathways-on-Cloud utilities.""" import datetime +import logging import os -from absl import logging import jax from pathwaysutils import cloud_logging from pathwaysutils import profiling @@ -24,6 +24,8 @@ from pathwaysutils.persistence import pathways_orbax_handler +logger = logging.getLogger(__name__) + # A new PyPI release will be pushed every time `__version__` is increased. # When changing this, also update the CHANGELOG.md. __version__ = "v0.0.7" @@ -50,7 +52,7 @@ def _is_persistence_enabled(): if _is_pathways_used(): - logging.debug( + logger.debug( "pathwaysutils: Detected Pathways-on-Cloud backend. Applying changes." ) proxy_backend.register_backend_factory() @@ -68,9 +70,9 @@ def _is_persistence_enabled(): try: cloud_logging.setup() except OSError as e: - logging.debug("pathwaysutils: Failed to set up cloud logging.") + logger.debug("pathwaysutils: Failed to set up cloud logging.") else: - logging.debug( + logger.debug( "pathwaysutils: Did not detect Pathways-on-Cloud backend. No changes" " applied." ) diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index 43b20db..cb84e48 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -16,15 +16,18 @@ import collections import datetime import functools +import logging import typing from typing import Optional, Sequence -from absl import logging import jax from orbax.checkpoint import future from orbax.checkpoint import type_handlers from pathwaysutils.persistence import helper + +logger = logging.getLogger(__name__) + ParamInfo = type_handlers.ParamInfo SaveArgs = type_handlers.SaveArgs RestoreArgs = type_handlers.RestoreArgs @@ -121,7 +124,7 @@ async def deserialize( mesh_axes.append(sharding.spec) shardings.append(sharding) if arg.global_shape is None or arg.dtype is None: - logging.warning( + logger.warning( 'Shape or dtype not provided for restoration. Provide these' ' properties for improved performance.' ) @@ -180,7 +183,7 @@ def register_pathways_handlers( read_timeout: Optional[datetime.timedelta] = None, ): """Function that must be called before saving or restoring with Pathways.""" - logging.debug( + logger.debug( 'Registering CloudPathwaysArrayHandler (Pathways Persistence API).' ) type_handlers.register_type_handler( diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 311995a..e8ccbd7 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -14,16 +14,18 @@ """Profiling utilites.""" import dataclasses +import logging import threading import time -from absl import logging import fastapi import jax from jax import numpy as jnp from pathwaysutils import plugin_executable import uvicorn +logger = logging.getLogger(__name__) + class _ProfileState: def __init__(self): @@ -100,7 +102,7 @@ def start_server(port: int): port : The port to start the server on. """ def server_loop(port: int): - logging.debug("Starting JAX profiler server on port %s", port) + logger.debug("Starting JAX profiler server on port %s", port) app = fastapi.FastAPI() @dataclasses.dataclass @@ -110,8 +112,8 @@ class ProfilingConfig: @app.post("/profiling") async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable - logging.debug("Capturing profiling data for %s ms", pc.duration_ms) - logging.debug("Writing profiling data to %s", pc.repository_path) + logger.debug("Capturing profiling data for %s ms", pc.duration_ms) + logger.debug("Writing profiling data to %s", pc.repository_path) jax.profiler.start_trace(pc.repository_path) time.sleep(pc.duration_ms / 1e3) jax.profiler.stop_trace() @@ -156,27 +158,25 @@ def start_trace_patch( create_perfetto_link: bool = False, # pylint: disable=unused-argument create_perfetto_trace: bool = False, # pylint: disable=unused-argument ) -> None: - logging.debug("jax.profile.start_trace patched with pathways' start_trace") + logger.debug("jax.profile.start_trace patched with pathways' start_trace") return start_trace(log_dir) jax.profiler.start_trace = start_trace_patch def stop_trace_patch() -> None: - logging.debug("jax.profile.stop_trace patched with pathways' stop_trace") + logger.debug("jax.profile.stop_trace patched with pathways' stop_trace") return stop_trace() jax.profiler.stop_trace = stop_trace_patch def start_server_patch(port: int): - logging.debug( - "jax.profile.start_server patched with pathways' start_server" - ) + logger.debug("jax.profile.start_server patched with pathways' start_server") return start_server(port) jax.profiler.start_server = start_server_patch def stop_server_patch(): - logging.debug("jax.profile.stop_server patched with pathways' stop_server") + logger.debug("jax.profile.stop_server patched with pathways' stop_server") return stop_server() jax.profiler.stop_server = stop_server_patch