Skip to content

Commit

Permalink
#393: Added check API in ttrt to check against flatbuffers and system…
Browse files Browse the repository at this point in the history
… descriptors (#431)
  • Loading branch information
tapspatel authored Aug 18, 2024
1 parent 21afdcc commit 150e466
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 4 deletions.
19 changes: 17 additions & 2 deletions docs/src/ttrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ ttrt --help
ttrt read
ttrt run
ttrt query
ttrt perf (coming soon)
ttrt check (coming soon)
ttrt perf
ttrt check
```

## Command Line
Expand Down Expand Up @@ -77,6 +77,7 @@ Note: It's required to be on a system with silicon and to have a runtime enabled
ttrt query --help
ttrt query --save-artifacts
ttrt query --clean-artifacts
ttrt query /dir/of/flatbuffers
ttrt query --save-artifacts --log-file ttrt.log
```

Expand All @@ -98,6 +99,20 @@ ttrt perf /dir/of/flatbuffers --loops 10
ttrt perf /dir/of/flatbuffers --log-file ttrt.log
```

### check
Check a binary file or a directory of binary files against a system desc (by default, uses the host machine)
Note: It's required to be on a system with silicon and to have a runtime enabled build `-DTTMLIR_ENABLE_RUNTIME=ON`.

```bash
ttrt check --help
ttrt check out.ttnn
ttrt check out.ttnn --system-desc /path/to/system_desc.ttsys
ttrt check out.ttnn --clean-artifacts
ttrt check out.ttnn --save-artifacts
ttrt check out.ttnn --log-file ttrt.log
ttrt check /dir/of/flatbuffers --system-desc /dir/of/system_desc
```

## ttrt as a python package

The other way to use the APIs under ttrt is importing it as a library. This allows the user to use it in custom scripts.
Expand Down
267 changes: 265 additions & 2 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,49 @@ def initialize_apis():
help=api["help"],
)

# register all check arguments
API.Check.register_arg(
name="--clean-artifacts",
type=bool,
default=False,
choices=[True, False],
help="clean all artifacts from previous runs",
)
API.Check.register_arg(
name="--save-artifacts",
type=bool,
default=False,
choices=[True, False],
help="save all artifacts during run",
)
API.Check.register_arg(
name="binary",
type=str,
default="",
choices=None,
help="flatbuffer binary file",
)
API.Check.register_arg(
name="--log-file",
type=str,
default="",
choices=None,
help="log file to dump ttrt output to",
)
API.Check.register_arg(
name="--system-desc",
type=str,
default="",
choices=None,
help="system desc to check against",
)

# register apis
API.register_api(API.Query)
API.register_api(API.Read)
API.register_api(API.Run)
API.register_api(API.Perf)
API.register_api(API.Check)

@staticmethod
def register_api(api_class):
Expand Down Expand Up @@ -703,6 +741,9 @@ def check_constraints(self):
self["binary"]
)

self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}")
self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}")

for path in ttnn_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
Expand All @@ -720,8 +761,6 @@ def check_constraints(self):

self.ttnn_binaries.append(bin)

self.logging.debug(f"finished checking constraints for run API")

for path in ttmetal_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
Expand All @@ -739,6 +778,8 @@ def check_constraints(self):

self.ttmetal_binaries.append(bin)

self.logging.debug(f"finished checking constraints for run API")

def execute(self):
self.logging.debug(f"executing run API")

Expand Down Expand Up @@ -1047,11 +1088,15 @@ def check_constraints(self):
assert self.file_manager.check_file_exists(
self.tracy_csvexport_tool_path
), f"perf tool={self.tracy_csvexport_tool_path} does not exist - rebuild using perf mode"

ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths(self["binary"])
ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths(
self["binary"]
)

self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}")
self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}")

for path in ttnn_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
Expand Down Expand Up @@ -1344,3 +1389,221 @@ def generate_subparser(subparsers):
)

return perf_parser

class Check:
registered_args = {}
api_only_arg = []

def __init__(self, args={}, logging=None, artifacts=None):
for name, attributes in API.Check.registered_args.items():
name = name if not name.startswith("-") else name.lstrip("-")
name = name.replace("-", "_")

if type(args) == dict:
if name in args.keys():
self[name] = args[name]
else:
self[name] = attributes["default"]
else:
self[name] = getattr(args, name)

self.logger = logging if logging != None else Logger(self["log_file"])
self.logging = self.logger.get_logger()
self.globals = Globals(self.logger)
self.file_manager = FileManager(self.logger)
self.artifacts = (
artifacts
if artifacts != None
else Artifacts(self.logger, self.file_manager)
)
self.query = API.Query({}, self.logger, self.artifacts)
self.ttnn_binaries = []
self.ttmetal_binaries = []
self.system_desc_binaries = []

def preprocess(self):
self.logging.debug(f"preprocessing check API")

if self["clean_artifacts"]:
self.artifacts.clean_artifacts()

if self["save_artifacts"]:
self.artifacts.create_artifacts()

self.logging.debug(f"finished preprocessing check API")

def check_constraints(self):
self.logging.debug(f"checking constraints for check API")

ttsys_binary_paths = self.file_manager.find_ttsys_binary_paths(
self["system_desc"]
)
ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths(self["binary"])
ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths(
self["binary"]
)

self.logging.debug(f"ttsys_binary_paths={ttsys_binary_paths}")
self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}")
self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}")

for path in ttsys_binary_paths:
bin = SystemDesc(self.logger, self.file_manager, path)
if bin.check_version():
self.system_desc_binaries.append(bin)

for path in ttnn_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
continue

self.ttnn_binaries.append(bin)

for path in ttmetal_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
continue

self.ttmetal_binaries.append(bin)

self.logging.debug(f"finished checking constraints for check API")

def execute(self):
self.logging.debug(f"executing check API")

def _execute(binaries, system_desc_to_check):
if len(binaries) == 0:
self.logging.warning(f"no binaries found to run - returning early")
return

for bin in binaries:
if system_desc_to_check != None:
if (
bin.fbb_dict["system_desc"]
!= system_desc_to_check["system_desc"]
):
self.logging.info(
f"system desc for device did not match flatbuffer: {bin.file_path}"
)
else:
self.logging.info(
f"system desc for device matched flatbuffer: {bin.file_path}"
)
else:
for desc in self.system_desc_binaries:
if (
bin.fbb_dict["system_desc"]
!= desc.fbb_dict["system_desc"]
):
self.logging.info(
f"system desc for: {desc.file_path} did not match flatbuffer: {bin.file_path}"
)
else:
self.logging.info(
f"system desc for: {desc.file_path} matched flatbuffer: {bin.file_path}"
)

system_desc_to_check = None
if self["system_desc"] == "" or len(self.system_desc_binaries) == 0:
self.logging.warning(
"no system descriptor file provided - querying from host machine"
)
self.query()
system_desc_to_check = self.query.get_system_desc_as_dict()

self.logging.debug(f"executing ttnn binaries")
_execute(self.ttnn_binaries, system_desc_to_check)
self.logging.debug(f"finished executing ttnn binaries")

self.logging.debug(f"executing ttmetal binaries")
_execute(self.ttmetal_binaries, system_desc_to_check)
self.logging.debug(f"finished executing ttmetal binaries")

self.logging.debug(f"finished executing check API")

def postprocess(self):
self.logging.debug(f"postprocessing check API")

if self["save_artifacts"]:
for bin in self.ttnn_binaries:
self.artifacts.save_binary(bin)

for bin in self.ttmetal_binaries:
self.artifacts.save_binary(bin)

self.logging.debug(f"finished postprocessing check API")

def __str__(self):
pass

def __getitem__(self, key):
return getattr(self, key)

def __setitem__(self, key, value):
setattr(self, key, value)

def __call__(self):
self.logging.debug(f"starting check API")

self.preprocess()
self.check_constraints()
self.execute()
self.postprocess()

self.logging.debug(f"finished check API")

@staticmethod
def register_arg(name, type, default, choices, help, api_only=True):
API.Check.registered_args[name] = {
"type": type,
"default": default,
"choices": choices,
"help": help,
}

if api_only:
API.Check.api_only_arg.append(name)

@staticmethod
def get_upstream_apis():
upstream_apis = []
for arg_name, arg_value in API.Check.registered_args.items():
if arg_name not in API.Check.api_only_arg:
upstream_apis.append(
{
"name": arg_name,
"type": arg_value["type"],
"default": arg_value["default"],
"choices": arg_value["choices"],
"help": arg_value["help"],
}
)

return upstream_apis

@staticmethod
def generate_subparser(subparsers):
check_parser = subparsers.add_parser(
"check", help="check a flatbuffer binary against a system desc file"
)
check_parser.set_defaults(api=API.Check)

for name, attributes in API.Check.registered_args.items():
if name == "binary":
check_parser.add_argument(f"{name}", help=attributes["help"])
elif attributes["type"] == bool:
check_parser.add_argument(
f"{name}",
action="store_true",
help=attributes["help"],
)
else:
check_parser.add_argument(
f"{name}",
type=attributes["type"],
default=attributes["default"],
choices=attributes["choices"],
help=attributes["help"],
)

return check_parser

0 comments on commit 150e466

Please sign in to comment.