Skip to content

Commit

Permalink
Added --artifact-dir flag to TTRT CLI (#629)
Browse files Browse the repository at this point in the history
* Added --artifact-dir flag to TTRT CLI

* Small bug fix when calling ttrt run upstream

* Hotfix for Upstream TTRT Run call
  • Loading branch information
vprajapati-tt authored Sep 12, 2024
1 parent 235d441 commit 82eae0c
Showing 1 changed file with 59 additions and 6 deletions.
65 changes: 59 additions & 6 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def initialize_apis():
choices=None,
help="log file to dump ttrt output to",
)
API.Query.register_arg(
name="--artifact-dir",
type=str,
default="",
choices=None,
help="--save-artifacts flag must be set, provides a directory path to save artifacts to",
)

# register all read arguments
API.Read.register_arg(
Expand Down Expand Up @@ -96,6 +103,13 @@ def initialize_apis():
choices=None,
help="log file to dump ttrt output to",
)
API.Read.register_arg(
name="--artifact-dir",
type=str,
default="",
choices=None,
help="--save-artifacts flag must be set, provides a directory path to save artifacts to",
)

# register all run arguments
API.Run.register_arg(
Expand Down Expand Up @@ -182,6 +196,14 @@ def initialize_apis():
choices=None,
help="log file to dump ttrt output to",
)
API.Run.register_arg(
name="--artifact-dir",
type=str,
default="",
choices=None,
help="--save-artifacts flag must be set, provides a directory path to save artifacts to",
api_only=False,
)

# register all perf arguments
API.Perf.register_arg(
Expand Down Expand Up @@ -265,6 +287,13 @@ def initialize_apis():
choices=None,
help="system desc to check against",
)
API.Check.register_arg(
name="--artifact-dir",
type=str,
default="",
choices=None,
help="--save-artifacts flag must be set, provides a directory path to save artifacts to",
)

# register apis
API.register_api(API.Query)
Expand Down Expand Up @@ -310,7 +339,11 @@ def __init__(self, args={}, logging=None, artifacts=None):
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
else Artifacts(
self.logger,
self.file_manager,
artifacts_folder_path=self["artifact_dir"],
)
)
self.system_desc = None
self.device_ids = None
Expand Down Expand Up @@ -465,7 +498,11 @@ def __init__(self, args={}, logging=None, artifacts=None):
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
else Artifacts(
self.logger,
self.file_manager,
artifacts_folder_path=self["artifact_dir"],
)
)
self.read_action_functions = {}
self.ttnn_binaries = []
Expand Down Expand Up @@ -714,7 +751,11 @@ def __init__(self, args={}, logging=None, artifacts=None):
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
else Artifacts(
self.logger,
self.file_manager,
artifacts_folder_path=self["artifact_dir"],
)
)
self.query = API.Query({}, self.logger, self.artifacts)
self.ttnn_binaries = []
Expand Down Expand Up @@ -1063,7 +1104,11 @@ def __init__(self, args={}, logging=None, artifacts=None):
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
else Artifacts(
self.logger,
self.file_manager,
artifacts_folder_path=self["artifact_dir"],
)
)
self.query = API.Query({}, self.logger, self.artifacts)
self.ttnn_binaries = []
Expand Down Expand Up @@ -1311,7 +1356,11 @@ def _execute(binaries):
if self[name]:
command_options += f" {api['name']} "
else:
command_options += f" {api['name']} {self[name]} "
command_options += f" {api['name']} "
if isinstance(self[name], str) and not self[name]:
command_options += f'"{self[name]}" '
else:
command_options += f"{self[name]} "

library_link_path = self.globals.get_ld_path(
f"{self.globals.get_ttmetal_home_path()}"
Expand Down Expand Up @@ -1466,7 +1515,11 @@ def __init__(self, args={}, logging=None, artifacts=None):
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
else Artifacts(
self.logger,
self.file_manager,
artifacts_folder_dir=self["artifact_dir"],
)
)
self.query = API.Query({}, self.logger, self.artifacts)
self.ttnn_binaries = []
Expand Down

0 comments on commit 82eae0c

Please sign in to comment.