Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create lazy-style import decorator #1961

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import yaml

import accelerate.utils.imports as imports

from .logging import get_logger
from .state import PartialState
from .utils import (
Expand All @@ -33,6 +35,7 @@
is_tensorboard_available,
is_wandb_available,
listify,
require_import,
)


Expand All @@ -42,23 +45,15 @@
_available_trackers.append(LoggerType.TENSORBOARD)

if is_wandb_available():
import wandb

_available_trackers.append(LoggerType.WANDB)

if is_comet_ml_available():
from comet_ml import Experiment

_available_trackers.append(LoggerType.COMETML)

if is_aim_available():
from aim import Run

_available_trackers.append(LoggerType.AIM)

if is_mlflow_available():
import mlflow

_available_trackers.append(LoggerType.MLFLOW)

logger = get_logger(__name__)
Expand Down Expand Up @@ -179,13 +174,12 @@ class TensorBoardTracker(GeneralTracker):
requires_logging_directory = True

@on_main_process
@require_import("import torch.utils.tensorboard as tensorboard", "import tensorboardX as tensorboard")
def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
if is_tensorboard_available():
try:
from torch.utils import tensorboard
except ModuleNotFoundError:
import tensorboardX as tensorboard
super().__init__()
global tensorboard
tensorboard = imports.tensorboard

self.run_name = run_name
self.logging_dir = os.path.join(logging_dir, run_name)
self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
Expand Down Expand Up @@ -290,8 +284,12 @@ class WandBTracker(GeneralTracker):
main_process_only = False

@on_main_process
@require_import("wandb")
def __init__(self, run_name: str, **kwargs):
super().__init__()
global wandb
wandb = imports.wandb

self.run_name = run_name
self.run = wandb.init(project=self.run_name, **kwargs)
logger.debug(f"Initialized WandB project {self.run_name}")
Expand Down Expand Up @@ -376,7 +374,6 @@ def log_table(
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
"""

values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
self.log(values, step=step, **kwargs)

Expand Down Expand Up @@ -406,8 +403,12 @@ class CometMLTracker(GeneralTracker):
requires_logging_directory = False

@on_main_process
@require_import("from comet_ml import Experiment")
def __init__(self, run_name: str, **kwargs):
super().__init__()
global Experiment
Experiment = imports.Experiment

self.run_name = run_name
self.writer = Experiment(project_name=run_name, **kwargs)
logger.debug(f"Initialized CometML project {self.run_name}")
Expand Down Expand Up @@ -482,7 +483,11 @@ class AimTracker(GeneralTracker):
requires_logging_directory = True

@on_main_process
@require_import("from aim import Run")
def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
global Run
Run = imports.Run

self.run_name = run_name
self.writer = Run(repo=logging_dir, **kwargs)
self.writer.name = self.run_name
Expand Down Expand Up @@ -563,6 +568,7 @@ class MLflowTracker(GeneralTracker):
requires_logging_directory = False

@on_main_process
@require_import("mlflow")
def __init__(
self,
experiment_name: str = None,
Expand All @@ -573,6 +579,9 @@ def __init__(
run_name: Optional[str] = None,
description: Optional[str] = None,
):
global mlflow
mlflow = imports.mlflow

experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", experiment_name)
run_id = os.getenv("MLFLOW_RUN_ID", run_id)
tags = os.getenv("MLFLOW_TAGS", tags)
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
is_transformers_available,
is_wandb_available,
is_xpu_available,
require_import,
)
from .modeling import (
calculate_maximum_sizes,
Expand Down
53 changes: 52 additions & 1 deletion src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import importlib
import importlib.metadata
import os
import warnings
from distutils.util import strtobool
from functools import lru_cache
from functools import lru_cache, wraps

import torch
from packaging import version
Expand Down Expand Up @@ -296,3 +297,53 @@ def is_xpu_available(check_device=False):
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


def require_import(import_str: str, secondary_import_str: str = None):
"""
Decorator which checks that the module in `import_str` is available, and then imports it, making it available on
the global scope.

Args:
import_str (`str`):
The import statement to check and execute.
secondary_import_str (`str`, *optional*):
A secondary import statement to check and execute if `import_str` fails via `ModuleNotFoundError`.
"""

def _import_module(parsed_import):
name = parsed_import.names[0]
# First check `import x` syntax
if isinstance(parsed_import, ast.Import):
if name.asname is None:
globals()[name.name] = importlib.import_module(name.name)
else:
globals()[name.asname] = importlib.import_module(name.name)

# Then check for `from x import y` syntax
elif isinstance(parsed_import, ast.ImportFrom):
globals()[name.name] = importlib.import_module(parsed_import.module)

def decorator(function):
@wraps(function)
def inner(*args, **kwargs):
nonlocal import_str, secondary_import_str
if len(import_str.split(" ")) == 1:
import_str = f"import {import_str}"
if secondary_import_str is not None and len(secondary_import_str.split(" ")) == 1:
secondary_import_str = f"import {secondary_import_str}"
try:
parsed_import = ast.parse(import_str).body[0]
_import_module(parsed_import)
except ModuleNotFoundError:
if secondary_import_str is not None:
parsed_import = ast.parse(secondary_import_str).body[0]
_import_module(parsed_import)
else:
raise

function(*args, **kwargs)

return inner

return decorator
Loading