From 7e5b01a7599d3ba81b1a666b62eaf0996cfcec64 Mon Sep 17 00:00:00 2001 From: Tapasvi Patel <133996364+tapspatel@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:13:10 -0600 Subject: [PATCH] #341: Added support for step debugger and memory dumps after each op invocation (#1616) 2 new modes in ttrt (debugger + memory) ``` ttrt run --debugger ``` This will start a pdb debugger after every op invocation in runtime. Further support will be added. ``` ttrt run --memory ttrt perf --memory ``` This will dump memory reports after each op invocation in runtime. A memory_report.json file will be dumped, containing all the ops (an entry for each) and the dram/l1 memory usage. This is global view of the board. ``` { "loc": "loc(\"/code/tt-mlir/test/python/golden/test_ttir_ops.py:65:id(0)\")", "debug_str": "%6 = \"ttnn.add\"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #ttnn.buffer_type>, >>, tensor<64x128xf32, #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #ttnn.buffer_type>, >>, tensor<64x128xf32, #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #ttnn.buffer_type>, >>) -> tensor<64x128xf32, #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #ttnn.buffer_type>, >> loc(\"/code/tt-mlir/test/python/golden/test_ttir_ops.py:65:id(0)\")", "dram": { "total_allocatable (bytes) : total_allocatable/bank * num_banks": "12884901504", "total_allocated (bytes) : total_allocated/bank * num_banks": "3268608", "total_free (bytes) : total_allocatable - total_allocated": "12881632896", "blocks": [ { "address (bytes)": "32", "size (bytes)": "90112", "allocated (y/n)": "Y" }, { "address (bytes)": "90144", "size (bytes)": "90112", "allocated (y/n)": "Y" }, { "address (bytes)": "180256", "size (bytes)": "4096", "allocated (y/n)": "N" }, { "address (bytes)": "184352", "size (bytes)": "90112", "allocated (y/n)": "Y" }, { "address (bytes)": "274464", "size (bytes)": "1073465312", "allocated (y/n)": "N" }, { "address (bytes)": "1073739776", "size (bytes)": "2048", "allocated (y/n)": "Y" } ], "total_allocatable (bytes) : per bank": "1073741792", "total_allocated (bytes): per bank": "272384", "total_free (bytes) : per bank": "1073469408", "largest_free_block (bytes) : per bank": "1073465312" }, "l1": { "total_allocatable (bytes) : total_allocatable/bank * num_banks": "87504896", "total_allocated (bytes) : total_allocated/bank * num_banks": "0", "total_free (bytes) : total_allocatable - total_allocated": "87504896", "blocks": [ { "address (bytes)": "99104", "size (bytes)": "1367264", "allocated (y/n)": "N" } ], "total_allocatable (bytes) : per bank": "1367264", "total_allocated (bytes): per bank": "0", "total_free (bytes) : per bank": "1367264", "largest_free_block (bytes) : per bank": "1367264", "largest_contiguous_free_block (bytes) : per bank": "1367264" } }, ``` --- runtime/include/tt/runtime/detail/ttmetal.h | 3 + runtime/include/tt/runtime/detail/ttnn.h | 3 + runtime/include/tt/runtime/runtime.h | 1 + runtime/lib/runtime.cpp | 16 + runtime/lib/ttmetal/runtime.cpp | 10 + runtime/lib/ttnn/runtime.cpp | 8 + runtime/tools/python/ttrt/common/callback.py | 440 +++++++++++++++++++ runtime/tools/python/ttrt/common/golden.py | 196 --------- runtime/tools/python/ttrt/common/perf.py | 20 + runtime/tools/python/ttrt/common/run.py | 151 +++++-- runtime/tools/python/ttrt/runtime/module.cpp | 3 +- 11 files changed, 623 insertions(+), 228 deletions(-) create mode 100644 runtime/tools/python/ttrt/common/callback.py delete mode 100644 runtime/tools/python/ttrt/common/golden.py diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index e532ec05f..deff10755 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -9,6 +9,7 @@ #include "distributed/mesh_device.hpp" #include "impl/buffers/circular_buffer.hpp" #include "impl/event/event.hpp" +#include "tt_metal/detail/reports/memory_reporter.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" @@ -40,6 +41,8 @@ void closeDevice(Device device); void deallocateBuffers(Device device); +void dumpMemoryReport(Device device); + void wait(Event event); void wait(Tensor tensor); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index eac3b0ebb..2310789b6 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -9,6 +9,7 @@ #include "distributed/mesh_device.hpp" #include "host_api.hpp" #include "hostdevcommon/common_values.hpp" +#include "tt_metal/detail/reports/memory_reporter.hpp" #include "ttnn/device.hpp" #include "ttnn/operations/ccl/all_gather/all_gather.hpp" #include "ttnn/operations/conv/conv2d/conv2d.hpp" @@ -90,6 +91,8 @@ void closeDevice(Device device); void deallocateBuffers(Device device); +void dumpMemoryReport(Device device); + void wait(Event event); void wait(Tensor tensor); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index c3b725e0f..2f278ffc1 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -19,6 +19,7 @@ std::pair getCurrentSystemDesc(); namespace detail { void deallocateBuffers(Device device); +void dumpMemoryReport(Device device); } // namespace detail DeviceRuntime getCurrentRuntime(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 1b5b775b0..c25cfed51 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -44,6 +44,22 @@ void deallocateBuffers(Device device) { #endif LOG_FATAL("runtime is not enabled"); } + +void dumpMemoryReport(Device device) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::dumpMemoryReport(device); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::dumpMemoryReport(device); + } +#endif + + LOG_FATAL("runtime is not enabled"); +} } // namespace detail DeviceRuntime getCurrentRuntime() { diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 202965087..68322154d 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -103,6 +103,16 @@ void deallocateBuffers(Device deviceHandle) { } } +void dumpMemoryReport(Device deviceHandle) { + ::tt::tt_metal::distributed::MeshDevice &meshDevice = + deviceHandle.as<::tt::tt_metal::distributed::MeshDevice>( + DeviceRuntime::TTMetal); + + for (::tt::tt_metal::Device *device : meshDevice.get_devices()) { + ::tt::tt_metal::detail::DumpDeviceMemoryState(device); + } +} + void wait(Event event) { Events events = event.as(DeviceRuntime::TTMetal); for (auto e : events) { diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 56a205546..3fd7ba1b9 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -214,6 +214,14 @@ void deallocateBuffers(Device deviceHandle) { } } +void dumpMemoryReport(Device deviceHandle) { + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + for (::ttnn::Device *device : meshDevice.get_devices()) { + ::tt::tt_metal::detail::DumpDeviceMemoryState(device); + } +} + void wait(Event event) { // Nothing to do for ttnn runtime LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py new file mode 100644 index 000000000..3864397b0 --- /dev/null +++ b/runtime/tools/python/ttrt/common/callback.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from functools import partial +import csv +import json + +from ttrt.common.util import * + + +class CallbackRuntimeConfig: + def __init__( + self, + device=None, + artifact_dir="", + pcc=0.99, + atol=1e-08, + rtol=1e-05, + save_golden_tensors=False, + logging=None, + enable_golden=False, + enable_memory=False, + enable_debugger=False, + golden_report={}, + memory_report={}, + ): + self.device = device + self.artifact_dir = artifact_dir + self.pcc = pcc + self.atol = atol + self.rtol = rtol + self.save_golden_tensors = save_golden_tensors + self.logging = logging + self.enable_golden = enable_golden + self.enable_memory = enable_memory + self.enable_debugger = enable_debugger + self.golden_report = golden_report + self.memory_report = memory_report + self.counter = -1 + + def start_new_callback(self, artifact_dir): + self.artifact_dir = artifact_dir + self.counter = -1 + self.golden_report = {} + self.memory_report = {} + + def callback_counter(self): + self.counter = self.counter + 1 + return self.counter + + def save_golden_report(self, golden_report_path): + with open(golden_report_path, "w") as json_file: + json.dump(self.golden_report, json_file, indent=4) + + self.logging.debug(f"Saved golden report to={golden_report_path}") + + def save_memory_report(self, memory_report_path): + with open(memory_report_path, "w") as json_file: + json.dump(self.memory_report, json_file, indent=4) + + self.logging.debug(f"Saved memory report to={memory_report_path}") + + +""" +-----------------------GOLDEN CALLBACK----------------------- +""" + + +def get_atol_rtol_pcc(golden, calculated): + import numpy as np + import torch + + # Calculate atol and rtol + cal_atol = torch.max(torch.abs(golden - calculated)).item() + cal_rtol = torch.max(torch.abs(golden - calculated) / torch.abs(calculated)).item() + + # Calculate PCC + def get_pcc(golden, calculated): + # Both tensors are nan + if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): + logging.debug("Both tensors are 'nan'") + return 1.0 + # Test if either is completely zero + elif torch.any(golden.bool()) != torch.any(calculated.bool()): + return 0.0 + # One tensor is all nan, the other is not + elif torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): + logging.debug("One tensor is all nan, the other is not.") + return 0.0 + else: + # For now, mask all infs and nans so that we check the rest... TODO + golden = golden.clone() + golden[ + torch.logical_or( + torch.isnan(golden), + torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), + ) + ] = 0 + calculated = calculated.clone() + calculated[ + torch.logical_or( + torch.isnan(calculated), + torch.logical_or( + torch.isinf(calculated), torch.isneginf(calculated) + ), + ) + ] = 0 + + if torch.equal(golden, calculated): + return 1.0 + + if golden.dtype == torch.bfloat16: + golden = golden.type(torch.float32) + calculated = calculated.type(torch.float32) + + # Single element case + if golden.numel() == 1: + return float(torch.equal(golden, calculated)) + + # If both tensors are contant + if torch.max(golden) == torch.min(golden) and torch.max( + calculated + ) == torch.min(calculated): + return torch.isclose(torch.max(golden), torch.max(calculated)).item() + + cal_pcc = np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(calculated).detach().numpy() + ).flatten(), + ) + # Remove correlation coefficient with self (typically always 1.0) + mask = np.ones(cal_pcc.shape, dtype=bool) + np.fill_diagonal(mask, 0) + cal_pcc = np.min(cal_pcc[mask]) + + if isinstance(cal_pcc, np.ma.core.MaskedConstant): + return 1.0 + + return cal_pcc + + cal_pcc = get_pcc(golden, calculated) + + return ( + cal_atol, + cal_rtol, + cal_pcc, + f"Max ATOL Delta: {cal_atol}, Max RTOL Delta: {cal_rtol}, PCC: {cal_pcc}", + ) + + +def golden(callback_runtime_config, binary, program_context, op_context): + import torch + import ttrt.runtime + import ttrt.binary + + logging = callback_runtime_config.logging + logging.debug("executing golden comparison") + + loc = ttrt.runtime.get_op_loc_info(op_context) + + op_golden_tensor = binary.get_debug_info_golden(loc) + op_output_tensor = ttrt.runtime.get_op_output_tensor(op_context, program_context) + + if op_golden_tensor is None: + logging.debug("Golden tensor is None - skipping golden comparison") + return + + if len(op_output_tensor) == 0: + logging.debug("Output tensor is empty - skipping golden comparison") + return + + dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype) + + golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten() + + output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten() + + if callback_runtime_config.save_golden_tensors: + torch.save( + golden_tensor_torch, + f"{callback_runtime_config.artifact_dir}/{loc}_golden.pt", + ) + torch.save( + output_tensor_torch, + f"{callback_runtime_config.artifact_dir}/{loc}_device.pt", + ) + + if golden_tensor_torch.shape != output_tensor_torch.shape: + logging.debug( + "Golden and output tensor shapes do not match - skipping golden comparison" + ) + return + + _, _, cal_pcc, output_str = get_atol_rtol_pcc( + golden_tensor_torch, output_tensor_torch + ) + + logging.debug(f"PCC={cal_pcc}") + logging.debug(output_str) + + results = {} + results["expected_pcc"] = callback_runtime_config.pcc + results["actual_pcc"] = cal_pcc + results["atol"] = callback_runtime_config.atol + results["rtol"] = callback_runtime_config.rtol + results["allclose"] = torch.allclose( + golden_tensor_torch, + output_tensor_torch, + atol=callback_runtime_config.atol, + rtol=callback_runtime_config.rtol, + ) + results["max"] = torch.max( + torch.abs(golden_tensor_torch - output_tensor_torch) + ).item() + results["mean_absolute_error"] = torch.mean( + torch.abs(golden_tensor_torch - output_tensor_torch) + ).item() + results["root_mean_square_error"] = torch.sqrt( + torch.mean((golden_tensor_torch - output_tensor_torch) ** 2) + ).item() + results["cosine_similarity"] = torch.nn.functional.cosine_similarity( + golden_tensor_torch.unsqueeze(0), output_tensor_torch.unsqueeze(0) + ).item() + + callback_runtime_config.golden_report[loc] = results + + +""" +-----------------------MEMORY CALLBACK----------------------- +""" + + +def add_key(dram_memory_usage, l1_memory_usage, current_section, key, value): + if current_section == "DRAM": + dram_memory_usage[key] = value + elif current_section == "L1": + l1_memory_usage[key] = value + + +def parse_detailed_memory_usage_file(dram_memory_usage, l1_memory_usage, file_path): + current_section = None + + with open(file_path, "r") as file: + reader = csv.reader(file) + blocks = [] + + for row in reader: + if not any(row): + continue + + if row[1].strip() == "DRAM": + current_section = "DRAM" + elif row[1].strip() == "L1": + current_section = "L1" + elif "Total" in row[1]: + if row[1].strip() == "Total allocatable (B):": + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_allocatable (bytes) : total_allocatable/bank * num_banks", + row[2].strip(), + ) + elif row[1].strip() == "Total allocated (B):": + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_allocated (bytes) : total_allocated/bank * num_banks", + row[2].strip(), + ) + elif row[1].strip() == "Total free (B):": + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_free (bytes) : total_allocatable - total_allocated", + row[2].strip(), + ) + elif "Blocks" in row[2]: + blocks = [] + else: + block = {} + block["address (bytes)"] = row[3].strip() + block["size (bytes)"] = row[4].strip() + block["allocated (y/n)"] = row[5].strip() + + blocks.append(block) + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "blocks", + blocks, + ) + + +def parse_memory_usage_summary_file(dram_memory_usage, l1_memory_usage, file_path): + with open(file_path, "r") as file: + reader = csv.reader(file) + current_section = "DRAM" + + for row in reader: + if not any(row): + continue + + if "Total Allocatable Size" in row[1]: + continue + + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_allocatable (bytes) : per bank", + row[1].strip(), + ) + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_allocated (bytes): per bank", + row[2].strip(), + ) + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "total_free (bytes) : per bank", + row[3].strip(), + ) + add_key( + dram_memory_usage, + l1_memory_usage, + current_section, + "largest_free_block (bytes) : per bank", + row[4].strip(), + ) + + if current_section == "DRAM": + current_section = "L1" + + +def parse_l1_usage_summary_file(dram_memory_usage, l1_memory_usage, file_path): + with open(file_path, "r") as file: + reader = csv.reader(file) + dram_row = True + + for index, row in enumerate(reader): + if index == 2: + add_key( + dram_memory_usage, + l1_memory_usage, + "L1", + "largest_contiguous_free_block (bytes) : per bank", + row[1].strip(), + ) + + +def parse_memory_csv_files( + detailed_memory_usage_file_path, + memory_usage_summary_file_path, + l1_usage_summary_file_path, +): + dram_memory_usage = {} + l1_memory_usage = {} + + parse_detailed_memory_usage_file( + dram_memory_usage, l1_memory_usage, detailed_memory_usage_file_path + ) + parse_memory_usage_summary_file( + dram_memory_usage, l1_memory_usage, memory_usage_summary_file_path + ) + parse_l1_usage_summary_file( + dram_memory_usage, l1_memory_usage, l1_usage_summary_file_path + ) + + return dram_memory_usage, l1_memory_usage + + +def memory(callback_runtime_config, binary, program_context, op_context): + import ttrt.runtime + import ttrt.binary + + device = callback_runtime_config.device + logging = callback_runtime_config.logging + logging.debug("executing memory dump") + loc = ttrt.runtime.get_op_loc_info(op_context) + debug_str = ttrt.runtime.get_op_debug_str(op_context) + + device.dump_memory_report() + memory_dump_dir_path = f"{get_ttrt_metal_home_path()}/generated/reports" + + # read generated memory reports and store in condensed memory_report + dram_memory_usage, l1_memory_usage = parse_memory_csv_files( + f"{memory_dump_dir_path}/detailed_memory_usage.csv", + f"{memory_dump_dir_path}/memory_usage_summary.csv", + f"{memory_dump_dir_path}/l1_usage_summary.csv", + ) + + op_memory_report = {} + op_memory_report["loc"] = loc + op_memory_report["debug_str"] = debug_str + op_memory_report["dram"] = dram_memory_usage + op_memory_report["l1"] = l1_memory_usage + callback_runtime_config.memory_report[ + callback_runtime_config.callback_counter() + ] = op_memory_report + + +""" +-----------------------DEBUGGER CALLBACK----------------------- +""" + + +def debugger(callback_runtime_config, binary, program_context, op_context): + import pdb + + device = callback_runtime_config.device + logging = callback_runtime_config.logging + logging.debug("starting pdb debugger") + pdb.set_trace() + + +def callback(callback_runtime_config, binary, program_context, op_context): + + if callback_runtime_config.enable_golden: + golden(callback_runtime_config, binary, program_context, op_context) + + if callback_runtime_config.enable_memory: + memory(callback_runtime_config, binary, program_context, op_context) + + if callback_runtime_config.enable_debugger: + debugger(callback_runtime_config, binary, program_context, op_context) + + +def get_callback_fn(callback_runtime_config): + return partial(callback, callback_runtime_config) diff --git a/runtime/tools/python/ttrt/common/golden.py b/runtime/tools/python/ttrt/common/golden.py deleted file mode 100644 index 847942615..000000000 --- a/runtime/tools/python/ttrt/common/golden.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -import re -from functools import partial - -from ttrt.common.util import * - - -class GoldenRuntimeConfig: - def __init__( - self, - atol=1e-08, - rtol=1e-05, - pcc=0.99, - artifact_dir="", - save_golden_tensors=False, - ): - self.artifact_dir = artifact_dir - self.pcc = pcc - self.atol = atol - self.rtol = rtol - self.save_golden_tensors = save_golden_tensors - - -def get_atol_rtol_pcc(golden, calculated): - import numpy as np - import torch - - # Calculate atol and rtol - cal_atol = torch.max(torch.abs(golden - calculated)).item() - cal_rtol = torch.max(torch.abs(golden - calculated) / torch.abs(calculated)).item() - - # Calculate PCC - def get_pcc(golden, calculated): - # Both tensors are nan - if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): - print("Both tensors are 'nan'") - return 1.0 - # Test if either is completely zero - elif torch.any(golden.bool()) != torch.any(calculated.bool()): - return 0.0 - # One tensor is all nan, the other is not - elif torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): - print("One tensor is all nan, the other is not.") - return 0.0 - else: - # For now, mask all infs and nans so that we check the rest... TODO - golden = golden.clone() - golden[ - torch.logical_or( - torch.isnan(golden), - torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), - ) - ] = 0 - calculated = calculated.clone() - calculated[ - torch.logical_or( - torch.isnan(calculated), - torch.logical_or( - torch.isinf(calculated), torch.isneginf(calculated) - ), - ) - ] = 0 - - if torch.equal(golden, calculated): - return 1.0 - - if golden.dtype == torch.bfloat16: - golden = golden.type(torch.float32) - calculated = calculated.type(torch.float32) - - # Single element case - if golden.numel() == 1: - return float(torch.equal(golden, calculated)) - - # If both tensors are contant - if torch.max(golden) == torch.min(golden) and torch.max( - calculated - ) == torch.min(calculated): - return torch.isclose(torch.max(golden), torch.max(calculated)).item() - - cal_pcc = np.ma.corrcoef( - np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), - np.ma.masked_invalid( - torch.squeeze(calculated).detach().numpy() - ).flatten(), - ) - # Remove correlation coefficient with self (typically always 1.0) - mask = np.ones(cal_pcc.shape, dtype=bool) - np.fill_diagonal(mask, 0) - cal_pcc = np.min(cal_pcc[mask]) - - if isinstance(cal_pcc, np.ma.core.MaskedConstant): - return 1.0 - - return cal_pcc - - cal_pcc = get_pcc(golden, calculated) - - return ( - cal_atol, - cal_rtol, - cal_pcc, - f"Max ATOL Delta: {cal_atol}, Max RTOL Delta: {cal_rtol}, PCC: {cal_pcc}", - ) - - -def golden_partial_function( - golden_runtime_config, golden_results_data, binary, program_context, op_context -): - import torch - import ttrt.runtime - import ttrt.binary - - print("-----------executing golden comparision-----------") - - try: - loc = ttrt.runtime.get_op_loc_info(op_context) - print(f"found location={loc}") - - op_golden_tensor = binary.get_debug_info_golden(loc) - op_output_tensor = ttrt.runtime.get_op_output_tensor( - op_context, program_context - ) - - if op_golden_tensor is None: - print("Golden tensor is None - skipping golden comparison") - return - - if len(op_output_tensor) == 0: - print("Output tensor is empty - skipping golden comparison") - return - - dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype) - - golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten() - - output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten() - - if golden_runtime_config.save_golden_tensors: - torch.save( - golden_tensor_torch, - f"{golden_runtime_config.artifact_dir}/{loc}_golden.pt", - ) - torch.save( - output_tensor_torch, - f"{golden_runtime_config.artifact_dir}/{loc}_device.pt", - ) - - if golden_tensor_torch.shape != output_tensor_torch.shape: - print( - "Golden and output tensor shapes do not match - skipping golden comparison" - ) - return - - _, _, cal_pcc, output_str = get_atol_rtol_pcc( - golden_tensor_torch, output_tensor_torch - ) - - print(f"PCC={cal_pcc}") - print(output_str) - - results = {} - results["expected_pcc"] = golden_runtime_config.pcc - results["actual_pcc"] = cal_pcc - results["atol"] = golden_runtime_config.atol - results["rtol"] = golden_runtime_config.rtol - results["allclose"] = torch.allclose( - golden_tensor_torch, - output_tensor_torch, - atol=golden_runtime_config.atol, - rtol=golden_runtime_config.rtol, - ) - results["max"] = torch.max( - torch.abs(golden_tensor_torch - output_tensor_torch) - ).item() - results["mean_absolute_error"] = torch.mean( - torch.abs(golden_tensor_torch - output_tensor_torch) - ).item() - results["root_mean_square_error"] = torch.sqrt( - torch.mean((golden_tensor_torch - output_tensor_torch) ** 2) - ).item() - results["cosine_similarity"] = torch.nn.functional.cosine_similarity( - golden_tensor_torch.unsqueeze(0), output_tensor_torch.unsqueeze(0) - ).item() - - golden_results_data[loc] = results - - finally: - print("-----------finished executing golden comparision-----------") - - -def get_golden_fn(golden_runtime_config, golden_results_data): - return partial(golden_partial_function, golden_runtime_config, golden_results_data) diff --git a/runtime/tools/python/ttrt/common/perf.py b/runtime/tools/python/ttrt/common/perf.py index 55ee255f9..032036fa4 100644 --- a/runtime/tools/python/ttrt/common/perf.py +++ b/runtime/tools/python/ttrt/common/perf.py @@ -84,6 +84,20 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Perf.register_arg( + name="--disable-golden", + type=bool, + default=False, + choices=[True, False], + help="disable golden comparison for intermediate and output tensors", + ) + Perf.register_arg( + name="--memory", + type=bool, + default=False, + choices=[True, False], + help="dump memory reports after every op execution", + ) Perf.register_arg( name="binary", type=str, @@ -368,6 +382,12 @@ def get_available_port(): command_options = f"--program-index {self['--program-index']} --loops {self['--loops']} --save-artifacts " + if self["--memory"]: + command_options += " --memory " + + if self["--disable-golden"]: + command_options += " --disable-golden " + ttrt_executable_path = shutil.which("ttrt") test_command = ( f"{ttrt_executable_path} run {bin.file_path} {command_options}" diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index e4cdade85..15c4002b1 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -6,7 +6,7 @@ from ttrt.common.util import * from ttrt.common.query import Query -from ttrt.common.golden import get_golden_fn, GoldenRuntimeConfig +from ttrt.common.callback import get_callback_fn, CallbackRuntimeConfig class Run: @@ -148,11 +148,11 @@ def initialize_api(): help="test file to save results to", ) Run.register_arg( - name="--golden", + name="--disable-golden", type=bool, - default=True, + default=False, choices=[True, False], - help="run golden comparison for intermediate and output tensors", + help="disable golden comparison for intermediate and output tensors", ) Run.register_arg( name="--save-golden-tensors", @@ -161,6 +161,27 @@ def initialize_api(): choices=[True, False], help="save golden and device tensors that are compared during callback runtime", ) + Run.register_arg( + name="--debugger", + type=bool, + default=False, + choices=[True, False], + help="run step debugger after every op execution", + ) + Run.register_arg( + name="--memory", + type=bool, + default=False, + choices=[True, False], + help="dump memory reports after every op execution (use in conjunction with --save-artifacts)", + ) + Run.register_arg( + name="--check-memory-leak", + type=bool, + default=False, + choices=[True, False], + help="check for memory leaks (use in conjunction with --memory)", + ) Run.register_arg( name="binary", type=str, @@ -367,6 +388,23 @@ def _execute(binaries): self.logging.debug(f"opening devices={self.query.device_ids}") device = ttrt.runtime.open_device(self.query.device_ids) + callback_runtime_config = CallbackRuntimeConfig( + device, + "", + self["--pcc"], + self["--atol"], + self["--rtol"], + self["--save-golden-tensors"], + self.logging, + not self["--disable-golden"], + self["--memory"], + self["--debugger"], + ) + + callback_env = ttrt.runtime.DebugHooks.get( + get_callback_fn(callback_runtime_config) + ) + try: for bin in binaries: try: @@ -386,13 +424,20 @@ def _execute(binaries): f"evaluating program={program_index} for binary={bin.file_path}" ) + callback_runtime_config.start_new_callback( + f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}" + ) + program = bin.get_program(program_index) golden_inputs = [] for i in range(len(program.program["inputs"])): - golden_tensor = bin.fbb.get_debug_info_golden( - f"input_{i}" - ) + golden_tensor = None + + if not self["--disable-golden"]: + golden_tensor = bin.fbb.get_debug_info_golden( + f"input_{i}" + ) if golden_tensor is not None: @@ -448,20 +493,7 @@ def _execute(binaries): total_outputs.append(outputs) event = None - golden_results_data = {} - if self["--golden"]: - callback_env = ttrt.runtime.DebugHooks.get( - get_golden_fn( - GoldenRuntimeConfig( - self["--atol"], - self["--rtol"], - self["--pcc"], - f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}", - self["--save-golden-tensors"], - ), - golden_results_data, - ) - ) + for loop in range(self["--loops"]): self.logging.debug( f"starting loop={loop+1}/{self['--loops']} for binary={bin.file_path}" @@ -543,18 +575,16 @@ def _execute(binaries): device.deallocate_buffers() # if golden comparison is enabled, check golden results json file to see if test passed - if self["--golden"]: + if not self["--disable-golden"]: if self["--save-artifacts"]: - golden_results_file_path = f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}/golden_results.json" - - with open( - golden_results_file_path, "w" - ) as json_file: - json.dump( - golden_results_data, json_file, indent=4 - ) + callback_runtime_config.save_golden_report( + f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}/golden_results.json" + ) - for loc, golden_data in golden_results_data.items(): + for ( + loc, + golden_data, + ) in callback_runtime_config.golden_report.items(): if ( golden_data["actual_pcc"] < golden_data["expected_pcc"] @@ -563,6 +593,65 @@ def _execute(binaries): f"Failed: golden comparison failed for program={program_index}, actual_pcc={golden_data['actual_pcc']} < expected_pcc={golden_data['expected_pcc']}" ) + if self["--memory"]: + if self["--save-artifacts"]: + callback_runtime_config.save_memory_report( + f"{self.artifacts.get_binary_folder_path(bin)}/run/program_{program_index}/memory_results.json" + ) + + if self["--check-memory-leak"]: + num_items = 0 + for ( + key, + value, + ) in callback_runtime_config.memory_report.items(): + num_items += 1 + + if num_items == 0: + self.logging.warning(f"No memory data found") + else: + # query initial memory usage + dram_initial_size = callback_runtime_config.memory_report[ + 0 + ][ + "dram" + ][ + "total_allocated (bytes) : total_allocated/bank * num_banks" + ] + l1_initlal_size = callback_runtime_config.memory_report[ + 0 + ][ + "l1" + ][ + "total_allocated (bytes) : total_allocated/bank * num_banks" + ] + + # query final memory usage and ensure no memory leaks + dram_final_size = callback_runtime_config.memory_report[ + num_items - 1 + ][ + "dram" + ][ + "total_allocated (bytes) : total_allocated/bank * num_banks" + ] + l1_final_size = callback_runtime_config.memory_report[ + num_items - 1 + ][ + "l1" + ][ + "total_allocated (bytes) : total_allocated/bank * num_banks" + ] + + if dram_final_size > dram_initial_size: + raise Exception( + "Memory leak detected in DRAM" + ) + + if l1_final_size > l1_initlal_size: + raise Exception( + "Memory leak detected in L1 cache" + ) + except Exception as e: test_result = { "file_path": bin.file_path, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index bf9d1840b..4c3eb8c69 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -22,7 +22,8 @@ PYBIND11_MODULE(_C, m) { m.doc() = "ttrt.runtime python extension for interacting with the " "Tenstorrent devices"; py::class_(m, "Device") - .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); + .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers) + .def("dump_memory_report", &tt::runtime::detail::dumpMemoryReport); py::class_(m, "Event"); py::class_(m, "Tensor"); py::class_(m, "Layout");