From ea8e01f80b7e6c563709ac0d0e50ce4ddb862816 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Sat, 18 Nov 2023 14:46:25 +0100 Subject: [PATCH] fix: Trainer usable in interactive mode --- dmlcloud/training/trainer.py | 4 ++-- dmlcloud/training/util.py | 4 ++-- dmlcloud/util/git.py | 6 +++++- dmlcloud/util/project.py | 28 ++++++++++++++++++++++------ 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/dmlcloud/training/trainer.py b/dmlcloud/training/trainer.py index 03c2070..3a3064d 100644 --- a/dmlcloud/training/trainer.py +++ b/dmlcloud/training/trainer.py @@ -176,8 +176,8 @@ def setup_all(self, use_checkpointing=True, use_wandb=True, print_diagnostics=Tr def print_diagnositcs(self): log_delimiter() - logging.info(f'Script path: {script_path()}') - logging.info(f'Project dir: {project_dir()}') + logging.info(f'Script path: {script_path() or "N/A"}') + logging.info(f'Project dir: {project_dir() or "N/A"}') log_git() log_diagnostics(self.device) diff --git a/dmlcloud/training/util.py b/dmlcloud/training/util.py index f049875..f4e36c2 100644 --- a/dmlcloud/training/util.py +++ b/dmlcloud/training/util.py @@ -75,8 +75,8 @@ def log_config(config): def log_git(): - msg = f'Git Hash: {git_hash()}\n' - msg += f'Git Diff:\n{git_diff()}\n' + msg = f'Git Hash: {git_hash() or "N/A"}\n' + msg += f'Git Diff:\n{git_diff() or "N/A"}\n' msg += delimiter() logging.info(msg) diff --git a/dmlcloud/util/git.py b/dmlcloud/util/git.py index 1e499af..4b38dcd 100644 --- a/dmlcloud/util/git.py +++ b/dmlcloud/util/git.py @@ -1,7 +1,9 @@ -from .project import run_in_project +from .project import run_in_project, script_path def git_hash(short=False): + if script_path() is None: + return None if short: process = run_in_project(['git', 'rev-parse', '--short', 'HEAD']) else: @@ -10,5 +12,7 @@ def git_hash(short=False): def git_diff(): + if script_path() is None: + return None process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD']) return process.stdout.decode('utf-8').strip() diff --git a/dmlcloud/util/project.py b/dmlcloud/util/project.py index fb559f0..87f2a26 100644 --- a/dmlcloud/util/project.py +++ b/dmlcloud/util/project.py @@ -34,11 +34,11 @@ def is_setuptools_cli_script(module): def script_path(): """ Returns the path to the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, raises a RuntimeError. + If python runs in interactive mode, or if "-c" command line option was used, returns None. """ main = sys.modules['__main__'] if not hasattr(main, '__file__'): - raise RuntimeError('script_path() is not supported in interactive mode') + return None if is_setuptools_cli_script(main): stack = traceback.extract_stack() @@ -50,20 +50,33 @@ def script_path(): return Path(main.__file__).resolve() +def script_path_available(): + try: + script_path() + return True + except RuntimeError: + return False + + def script_dir(): """ Returns the directory containing the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError. + If python runs in interactive mode, or if "-c" command line option was used, returns None. """ - return script_path().parent - + path = script_path() + if path is None: + return None + else: + return path.parent def project_dir(): """ Returns the top-level directory containing the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError. + If python runs in interactive mode, or if "-c" command line option was used, returns None. """ cur_dir = script_dir() + if cur_dir is None: + return None while (cur_dir / '__init__.py').exists(): cur_dir = cur_dir.parent return cur_dir @@ -72,6 +85,9 @@ def project_dir(): def run_in_project(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs): """ Runs a command in the project directory and returns the output. + If python runs in interactive mode, or if "-c" command line option was used, raises RuntimeError. """ cwd = project_dir() + if cwd is None: + raise RuntimeError("Cannot run in project directory: script path not available") return subprocess.run(cmd, cwd=cwd, stdout=stdout, stderr=stderr, **kwargs)