Skip to content

Commit

Permalink
#563: Enable mixed ttnn and ttm runtime in ttrt (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
tapspatel authored Aug 30, 2024
1 parent 8cc0f05 commit 19c5407
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 86 deletions.
2 changes: 1 addition & 1 deletion runtime/tools/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_custom_target(ttrt-copy-files
)

add_custom_target(ttrt
COMMAND rm -f *.whl
COMMAND rm -f build/*.whl
COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME}
TT_RUNTIME_ENABLE_TTNN=${TT_RUNTIME_ENABLE_TTNN}
TT_RUNTIME_ENABLE_TTMETAL=${TT_RUNTIME_ENABLE_TTMETAL}
Expand Down
173 changes: 88 additions & 85 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,112 +793,115 @@ def _execute(binaries):
self.logging.debug(f"setting torch manual seed={self['seed']}")
torch.manual_seed(self["seed"])
ttrt.runtime.set_compatible_runtime(binaries[0].fbb)

self.logging.debug(f"opening device id={self.query.device_ids[0]}")
device = ttrt.runtime.open_device([self.query.device_ids[0]])
atexit.register(lambda: ttrt.runtime.close_device(device))

for bin in binaries:
self.logging.info(f"evaluating binary={bin.file_path}")

program_indices = []
if self["program_index"] == "all":
program_indices.extend(range(bin.get_num_programs()))
else:
program_indices.append(int(self["program_index"]))

for program_index in program_indices:
self.logging.debug(
f"evaluating program={program_index} for binary={bin.file_path}"
)
try:
for bin in binaries:
self.logging.info(f"evaluating binary={bin.file_path}")

program = bin.get_program(program_index)
program.populate_inputs(
API.Run.TorchInitilizer.get_initilizer(self["init"])
)
program.populate_outputs(
API.Run.TorchInitilizer.get_initilizer("zeros")
)
program_indices = []
if self["program_index"] == "all":
program_indices.extend(range(bin.get_num_programs()))
else:
program_indices.append(int(self["program_index"]))

total_inputs = []
total_outputs = []
for loop in range(self["loops"]):
for program_index in program_indices:
self.logging.debug(
f"generating inputs/outputs for loop={loop+1}/{self['loops']} for binary={bin.file_path}"
f"evaluating program={program_index} for binary={bin.file_path}"
)

inputs = []
outputs = []
for i in program.input_tensors:
inputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
Binary.Program.to_data_type(i.dtype),
)
program = bin.get_program(program_index)
program.populate_inputs(
API.Run.TorchInitilizer.get_initilizer(self["init"])
)
program.populate_outputs(
API.Run.TorchInitilizer.get_initilizer("zeros")
)

total_inputs = []
total_outputs = []
for loop in range(self["loops"]):
self.logging.debug(
f"generating inputs/outputs for loop={loop+1}/{self['loops']} for binary={bin.file_path}"
)

for i in program.output_tensors:
outputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
Binary.Program.to_data_type(i.dtype),
inputs = []
outputs = []
for i in program.input_tensors:
inputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
Binary.Program.to_data_type(i.dtype),
)
)
)

total_inputs.append(inputs)
total_outputs.append(outputs)
for i in program.output_tensors:
outputs.append(
ttrt.runtime.create_tensor(
i.data_ptr(),
list(i.shape),
list(i.stride()),
i.element_size(),
Binary.Program.to_data_type(i.dtype),
)
)

event = None
for loop in range(self["loops"]):
self.logging.debug(
f"starting loop={loop+1}/{self['loops']} for binary={bin.file_path}"
)
total_inputs.append(inputs)
total_outputs.append(outputs)

event = ttrt.runtime.submit(
device,
bin.fbb,
program_index,
total_inputs[loop],
total_outputs[loop],
)
event = None
for loop in range(self["loops"]):
self.logging.debug(
f"starting loop={loop+1}/{self['loops']} for binary={bin.file_path}"
)

self.logging.debug(
f"finished loop={loop+1}/{self['loops']} for binary={bin.file_path}"
)
event = ttrt.runtime.submit(
device,
bin.fbb,
program_index,
total_inputs[loop],
total_outputs[loop],
)

ttrt.runtime.wait(event)
self.logging.debug(
f"finished loop={loop+1}/{self['loops']} for binary={bin.file_path}"
)

if self["identity"]:
self.logging.debug(
f"checking identity with rtol={self['rtol']} and atol={self['atol']}"
)
ttrt.runtime.wait(event)

for i, o in zip(
program.input_tensors, program.output_tensors
):
if not torch.allclose(
i, o, rtol=self["rtol"], atol=self["atol"]
if self["identity"]:
self.logging.debug(
f"checking identity with rtol={self['rtol']} and atol={self['atol']}"
)

for i, o in zip(
program.input_tensors, program.output_tensors
):
self.logging.error(
f"Failed: inputs and outputs do not match in binary"
)
self.logging.error(i - o)
if not torch.allclose(
i, o, rtol=self["rtol"], atol=self["atol"]
):
self.logging.error(
f"Failed: inputs and outputs do not match in binary"
)
self.logging.error(i - o)

self.logging.debug(f"input tensors for program={program_index}")
for tensor in program.input_tensors:
self.logging.debug(f"{tensor}\n")
self.logging.debug(
f"input tensors for program={program_index}"
)
for tensor in program.input_tensors:
self.logging.debug(f"{tensor}\n")

self.logging.debug(
f"output tensors for program={program_index}"
)
for tensor in program.output_tensors:
self.logging.debug(f"{tensor}\n")
self.logging.debug(
f"output tensors for program={program_index}"
)
for tensor in program.output_tensors:
self.logging.debug(f"{tensor}\n")
finally:
ttrt.runtime.close_device(device)

self.logging.debug(f"executing ttnn binaries")
_execute(self.ttnn_binaries)
Expand Down

0 comments on commit 19c5407

Please sign in to comment.