Skip to content

Commit

Permalink
feat: only report versions of imported packages => faster startup, mo…
Browse files Browse the repository at this point in the history
…re relevant diagnostics
  • Loading branch information
sehoffmann committed Apr 3, 2024
1 parent 67fe555 commit daa9d08
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
22 changes: 4 additions & 18 deletions dmlcloud/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dmlcloud
from . import slurm
from .git import git_hash
from .thirdparty import try_get_version
from .thirdparty import is_imported, ML_MODULES, try_get_version


class IORedirector:
Expand Down Expand Up @@ -138,23 +138,9 @@ def general_diagnostics() -> str:
except (FileNotFoundError, IndexError):
pass

msg += f' - torch: {torch.__version__}\n'
if try_get_version('torchvision'):
msg += f' - torchvision: {try_get_version("torchvision")}\n'
if try_get_version('torchtext'):
msg += f' - torchtext: {try_get_version("torchtext")}\n'
if try_get_version('torchaudio'):
msg += f' - torchaudio: {try_get_version("torchaudio")}\n'
if try_get_version('einops'):
msg += f' - einops: {try_get_version("einops")}\n'
if try_get_version('numpy'):
msg += f' - numpy: {try_get_version("numpy")}\n'
if try_get_version('pandas'):
msg += f' - pandas: {try_get_version("pandas")}\n'
if try_get_version('xarray'):
msg += f' - xarray: {try_get_version("xarray")}\n'
if try_get_version('sklearn'):
msg += f' - sklearn: {try_get_version("sklearn")}\n'
for module_name in ML_MODULES:
if is_imported(module_name):
msg += f' - {module_name}: {try_get_version(module_name)}\n'

if 'SLURM_JOB_ID' in os.environ:
msg += '* SLURM:\n'
Expand Down
18 changes: 18 additions & 0 deletions dmlcloud/util/thirdparty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import importlib
import sys
from types import ModuleType
from typing import Optional


ML_MODULES = [
'torch',
'torchvision',
'torchtext',
'torchaudio',
'einops',
'numpy',
'pandas',
'xarray',
'sklearn',
]


def is_imported(name: str) -> bool:
return name in sys.modules


def try_import(name: str) -> Optional[ModuleType]:
try:
return importlib.import_module(name)
Expand Down

0 comments on commit daa9d08

Please sign in to comment.