From 19c54072584ae3ba0c5be20c536b3a82c1e42cab Mon Sep 17 00:00:00 2001 From: Tapasvi Patel <133996364+tapspatel@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:13:46 -0500 Subject: [PATCH] #563: Enable mixed ttnn and ttm runtime in ttrt (#565) --- runtime/tools/python/CMakeLists.txt | 2 +- runtime/tools/python/ttrt/common/api.py | 173 ++++++++++++------------ 2 files changed, 89 insertions(+), 86 deletions(-) diff --git a/runtime/tools/python/CMakeLists.txt b/runtime/tools/python/CMakeLists.txt index 791541cbb..84810a4cf 100644 --- a/runtime/tools/python/CMakeLists.txt +++ b/runtime/tools/python/CMakeLists.txt @@ -4,7 +4,7 @@ add_custom_target(ttrt-copy-files ) add_custom_target(ttrt - COMMAND rm -f *.whl + COMMAND rm -f build/*.whl COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME} TT_RUNTIME_ENABLE_TTNN=${TT_RUNTIME_ENABLE_TTNN} TT_RUNTIME_ENABLE_TTMETAL=${TT_RUNTIME_ENABLE_TTMETAL} diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index f2c63a80a..4b2d3e418 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -793,112 +793,115 @@ def _execute(binaries): self.logging.debug(f"setting torch manual seed={self['seed']}") torch.manual_seed(self["seed"]) ttrt.runtime.set_compatible_runtime(binaries[0].fbb) - self.logging.debug(f"opening device id={self.query.device_ids[0]}") device = ttrt.runtime.open_device([self.query.device_ids[0]]) - atexit.register(lambda: ttrt.runtime.close_device(device)) - - for bin in binaries: - self.logging.info(f"evaluating binary={bin.file_path}") - - program_indices = [] - if self["program_index"] == "all": - program_indices.extend(range(bin.get_num_programs())) - else: - program_indices.append(int(self["program_index"])) - for program_index in program_indices: - self.logging.debug( - f"evaluating program={program_index} for binary={bin.file_path}" - ) + try: + for bin in binaries: + self.logging.info(f"evaluating binary={bin.file_path}") - program = bin.get_program(program_index) - program.populate_inputs( - API.Run.TorchInitilizer.get_initilizer(self["init"]) - ) - program.populate_outputs( - API.Run.TorchInitilizer.get_initilizer("zeros") - ) + program_indices = [] + if self["program_index"] == "all": + program_indices.extend(range(bin.get_num_programs())) + else: + program_indices.append(int(self["program_index"])) - total_inputs = [] - total_outputs = [] - for loop in range(self["loops"]): + for program_index in program_indices: self.logging.debug( - f"generating inputs/outputs for loop={loop+1}/{self['loops']} for binary={bin.file_path}" + f"evaluating program={program_index} for binary={bin.file_path}" ) - inputs = [] - outputs = [] - for i in program.input_tensors: - inputs.append( - ttrt.runtime.create_tensor( - i.data_ptr(), - list(i.shape), - list(i.stride()), - i.element_size(), - Binary.Program.to_data_type(i.dtype), - ) + program = bin.get_program(program_index) + program.populate_inputs( + API.Run.TorchInitilizer.get_initilizer(self["init"]) + ) + program.populate_outputs( + API.Run.TorchInitilizer.get_initilizer("zeros") + ) + + total_inputs = [] + total_outputs = [] + for loop in range(self["loops"]): + self.logging.debug( + f"generating inputs/outputs for loop={loop+1}/{self['loops']} for binary={bin.file_path}" ) - for i in program.output_tensors: - outputs.append( - ttrt.runtime.create_tensor( - i.data_ptr(), - list(i.shape), - list(i.stride()), - i.element_size(), - Binary.Program.to_data_type(i.dtype), + inputs = [] + outputs = [] + for i in program.input_tensors: + inputs.append( + ttrt.runtime.create_tensor( + i.data_ptr(), + list(i.shape), + list(i.stride()), + i.element_size(), + Binary.Program.to_data_type(i.dtype), + ) ) - ) - total_inputs.append(inputs) - total_outputs.append(outputs) + for i in program.output_tensors: + outputs.append( + ttrt.runtime.create_tensor( + i.data_ptr(), + list(i.shape), + list(i.stride()), + i.element_size(), + Binary.Program.to_data_type(i.dtype), + ) + ) - event = None - for loop in range(self["loops"]): - self.logging.debug( - f"starting loop={loop+1}/{self['loops']} for binary={bin.file_path}" - ) + total_inputs.append(inputs) + total_outputs.append(outputs) - event = ttrt.runtime.submit( - device, - bin.fbb, - program_index, - total_inputs[loop], - total_outputs[loop], - ) + event = None + for loop in range(self["loops"]): + self.logging.debug( + f"starting loop={loop+1}/{self['loops']} for binary={bin.file_path}" + ) - self.logging.debug( - f"finished loop={loop+1}/{self['loops']} for binary={bin.file_path}" - ) + event = ttrt.runtime.submit( + device, + bin.fbb, + program_index, + total_inputs[loop], + total_outputs[loop], + ) - ttrt.runtime.wait(event) + self.logging.debug( + f"finished loop={loop+1}/{self['loops']} for binary={bin.file_path}" + ) - if self["identity"]: - self.logging.debug( - f"checking identity with rtol={self['rtol']} and atol={self['atol']}" - ) + ttrt.runtime.wait(event) - for i, o in zip( - program.input_tensors, program.output_tensors - ): - if not torch.allclose( - i, o, rtol=self["rtol"], atol=self["atol"] + if self["identity"]: + self.logging.debug( + f"checking identity with rtol={self['rtol']} and atol={self['atol']}" + ) + + for i, o in zip( + program.input_tensors, program.output_tensors ): - self.logging.error( - f"Failed: inputs and outputs do not match in binary" - ) - self.logging.error(i - o) + if not torch.allclose( + i, o, rtol=self["rtol"], atol=self["atol"] + ): + self.logging.error( + f"Failed: inputs and outputs do not match in binary" + ) + self.logging.error(i - o) - self.logging.debug(f"input tensors for program={program_index}") - for tensor in program.input_tensors: - self.logging.debug(f"{tensor}\n") + self.logging.debug( + f"input tensors for program={program_index}" + ) + for tensor in program.input_tensors: + self.logging.debug(f"{tensor}\n") - self.logging.debug( - f"output tensors for program={program_index}" - ) - for tensor in program.output_tensors: - self.logging.debug(f"{tensor}\n") + self.logging.debug( + f"output tensors for program={program_index}" + ) + for tensor in program.output_tensors: + self.logging.debug(f"{tensor}\n") + finally: + ttrt.runtime.close_device(device) self.logging.debug(f"executing ttnn binaries") _execute(self.ttnn_binaries)