Skip to content

Commit

Permalink
fix: Trainer usable in interactive mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Nov 18, 2023
1 parent af5ac42 commit ea8e01f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
4 changes: 2 additions & 2 deletions dmlcloud/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def setup_all(self, use_checkpointing=True, use_wandb=True, print_diagnostics=Tr

def print_diagnositcs(self):
log_delimiter()
logging.info(f'Script path: {script_path()}')
logging.info(f'Project dir: {project_dir()}')
logging.info(f'Script path: {script_path() or "N/A"}')
logging.info(f'Project dir: {project_dir() or "N/A"}')
log_git()
log_diagnostics(self.device)

Expand Down
4 changes: 2 additions & 2 deletions dmlcloud/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def log_config(config):


def log_git():
msg = f'Git Hash: {git_hash()}\n'
msg += f'Git Diff:\n{git_diff()}\n'
msg = f'Git Hash: {git_hash() or "N/A"}\n'
msg += f'Git Diff:\n{git_diff() or "N/A"}\n'
msg += delimiter()
logging.info(msg)

Expand Down
6 changes: 5 additions & 1 deletion dmlcloud/util/git.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .project import run_in_project
from .project import run_in_project, script_path


def git_hash(short=False):
if script_path() is None:
return None
if short:
process = run_in_project(['git', 'rev-parse', '--short', 'HEAD'])
else:
Expand All @@ -10,5 +12,7 @@ def git_hash(short=False):


def git_diff():
if script_path() is None:
return None
process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD'])
return process.stdout.decode('utf-8').strip()
28 changes: 22 additions & 6 deletions dmlcloud/util/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def is_setuptools_cli_script(module):
def script_path():
"""
Returns the path to the script or module that was executed.
If python runs in interactive mode, or if "-c" command line option was used, raises a RuntimeError.
If python runs in interactive mode, or if "-c" command line option was used, returns None.
"""
main = sys.modules['__main__']
if not hasattr(main, '__file__'):
raise RuntimeError('script_path() is not supported in interactive mode')
return None

if is_setuptools_cli_script(main):
stack = traceback.extract_stack()
Expand All @@ -50,20 +50,33 @@ def script_path():
return Path(main.__file__).resolve()


def script_path_available():
try:
script_path()
return True
except RuntimeError:
return False


def script_dir():
"""
Returns the directory containing the script or module that was executed.
If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError.
If python runs in interactive mode, or if "-c" command line option was used, returns None.
"""
return script_path().parent

path = script_path()
if path is None:
return None
else:
return path.parent

def project_dir():
"""
Returns the top-level directory containing the script or module that was executed.
If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError.
If python runs in interactive mode, or if "-c" command line option was used, returns None.
"""
cur_dir = script_dir()
if cur_dir is None:
return None
while (cur_dir / '__init__.py').exists():
cur_dir = cur_dir.parent
return cur_dir
Expand All @@ -72,6 +85,9 @@ def project_dir():
def run_in_project(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs):
"""
Runs a command in the project directory and returns the output.
If python runs in interactive mode, or if "-c" command line option was used, raises RuntimeError.
"""
cwd = project_dir()
if cwd is None:
raise RuntimeError("Cannot run in project directory: script path not available")
return subprocess.run(cmd, cwd=cwd, stdout=stdout, stderr=stderr, **kwargs)

0 comments on commit ea8e01f

Please sign in to comment.