From 969c680cf9e8b63d95ff2aac89b2ced5a872905e Mon Sep 17 00:00:00 2001 From: Borys Bradel <164946524+bbradelTT@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:26:45 -0500 Subject: [PATCH] #17532: Add relevant tests to matmul trace sweeps (#15850) ### Ticket Link to Github Issue #15732 ### Problem description Need tests for various shapes to verify that they pass after making changes ### What's changed Create the tests based on a trace of a run of the model ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12247713568 - [ ] Blackhole Post commit (if applicable) N/A - [ ] Model regression CI testing passes (if applicable) N/A - [ ] Device performance regression CI testing passes (if applicable) N/A - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes N/A - [x] New/Existing tests provide coverage for changes --- .../sweeps/matmul/short/matmul_traces.py | 258 +++++++++++++++++- 1 file changed, 253 insertions(+), 5 deletions(-) diff --git a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py index c66e5548834..0b6c9031553 100644 --- a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py +++ b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py @@ -13,6 +13,10 @@ TIMEOUT = 70 +# params contains the shape of the first tensor followed by the second tensor +# Note: the shape of the second tensor starts at int(count / 2). It's easiest +# to reason about if both tensors are the same rank, although some other +# combinations may be valid. parameters = { "default": { "params": [ @@ -111,25 +115,269 @@ (9, 768, 768, 640), (920, 256, 256, 256), ], - } + "core_grid": [False], + }, + "gpt": { + "params": [ + (1, 1, 1, 1, 1, 1, 1, 1), + (1, 1, 1, 1, 1, 1, 1, 2304), + (1, 1, 1, 1, 1, 1, 1, 3072), + (1, 1, 1, 1, 1, 1, 1, 65536), + (1, 1, 1, 1, 1, 1, 1, 768), + (1, 1, 1, 1, 1, 1, 1, 96), + (1, 1, 1, 2304, 1, 1, 2304, 1), + (1, 1, 1, 2304, 1, 1, 2304, 65536), + (1, 1, 1, 2304, 1, 1, 2304, 768), + (1, 1, 1, 3072, 1, 1, 3072, 1), + (1, 1, 1, 3072, 1, 1, 3072, 65536), + (1, 1, 1, 3072, 1, 1, 3072, 768), + (1, 1, 1, 65536, 1, 1, 65536, 2304), + (1, 1, 1, 65536, 1, 1, 65536, 3072), + (1, 1, 1, 65536, 1, 1, 65536, 768), + (1, 1, 1, 65536, 1, 1, 65536, 96), + (1, 1, 1, 768, 1, 1, 768, 1), + (1, 1, 1, 768, 1, 1, 768, 1024), + (1, 1, 1, 768, 1, 1, 768, 2304), + (1, 1, 1, 768, 1, 1, 768, 3072), + (1, 1, 1, 768, 1, 1, 768, 65536), + (1, 1, 1, 768, 1, 1, 768, 768), + (1, 1, 1, 768, 1, 1, 768, 96), + (1, 1, 1, 96, 1, 1, 96, 1), + (1, 1, 1, 96, 1, 1, 96, 65536), + (1, 1, 1, 96, 1, 1, 96, 768), + (1, 1, 1024, 768, 1, 1, 768, 1), + (1, 1, 1024, 768, 1, 1, 768, 1024), + (1, 1, 1024, 768, 1, 1, 768, 2304), + (1, 1, 1024, 768, 1, 1, 768, 3072), + (1, 1, 1024, 768, 1, 1, 768, 65536), + (1, 1, 1024, 768, 1, 1, 768, 768), + (1, 1, 1024, 768, 1, 1, 768, 96), + (1, 1, 2304, 1, 1, 1, 1, 1), + (1, 1, 2304, 1, 1, 1, 1, 2304), + (1, 1, 2304, 1, 1, 1, 1, 3072), + (1, 1, 2304, 1, 1, 1, 1, 65536), + (1, 1, 2304, 1, 1, 1, 1, 768), + (1, 1, 2304, 1, 1, 1, 1, 96), + (1, 1, 2304, 65536, 1, 1, 65536, 1), + (1, 1, 2304, 65536, 1, 1, 65536, 2304), + (1, 1, 2304, 65536, 1, 1, 65536, 3072), + (1, 1, 2304, 65536, 1, 1, 65536, 768), + (1, 1, 2304, 65536, 1, 1, 65536, 96), + (1, 1, 2304, 768, 1, 1, 768, 1), + (1, 1, 2304, 768, 1, 1, 768, 1024), + (1, 1, 2304, 768, 1, 1, 768, 2304), + (1, 1, 2304, 768, 1, 1, 768, 3072), + (1, 1, 2304, 768, 1, 1, 768, 65536), + (1, 1, 2304, 768, 1, 1, 768, 768), + (1, 1, 2304, 768, 1, 1, 768, 96), + (1, 1, 3072, 1, 1, 1, 1, 1), + (1, 1, 3072, 1, 1, 1, 1, 2304), + (1, 1, 3072, 1, 1, 1, 1, 3072), + (1, 1, 3072, 1, 1, 1, 1, 65536), + (1, 1, 3072, 1, 1, 1, 1, 768), + (1, 1, 3072, 1, 1, 1, 1, 96), + (1, 1, 3072, 65536, 1, 1, 65536, 1), + (1, 1, 3072, 65536, 1, 1, 65536, 2304), + (1, 1, 3072, 65536, 1, 1, 65536, 3072), + (1, 1, 3072, 65536, 1, 1, 65536, 768), + (1, 1, 3072, 65536, 1, 1, 65536, 96), + (1, 1, 3072, 768, 1, 1, 768, 1), + (1, 1, 3072, 768, 1, 1, 768, 1024), + (1, 1, 3072, 768, 1, 1, 768, 2304), + (1, 1, 3072, 768, 1, 1, 768, 3072), + (1, 1, 3072, 768, 1, 1, 768, 65536), + (1, 1, 3072, 768, 1, 1, 768, 768), + (1, 1, 3072, 768, 1, 1, 768, 96), + (1, 1, 65536, 1, 1, 1, 1, 1), + (1, 1, 65536, 1, 1, 1, 1, 2304), + (1, 1, 65536, 1, 1, 1, 1, 3072), + (1, 1, 65536, 1, 1, 1, 1, 65536), + (1, 1, 65536, 1, 1, 1, 1, 768), + (1, 1, 65536, 1, 1, 1, 1, 96), + (1, 1, 65536, 2304, 1, 1, 2304, 1), + (1, 1, 65536, 2304, 1, 1, 2304, 65536), + (1, 1, 65536, 2304, 1, 1, 2304, 768), + (1, 1, 65536, 3072, 1, 1, 3072, 1), + (1, 1, 65536, 3072, 1, 1, 3072, 65536), + (1, 1, 65536, 3072, 1, 1, 3072, 768), + (1, 1, 65536, 768, 1, 1, 768, 1), + (1, 1, 65536, 768, 1, 1, 768, 1024), + (1, 1, 65536, 768, 1, 1, 768, 2304), + (1, 1, 65536, 768, 1, 1, 768, 3072), + (1, 1, 65536, 768, 1, 1, 768, 65536), + (1, 1, 65536, 768, 1, 1, 768, 768), + (1, 1, 65536, 768, 1, 1, 768, 96), + (1, 1, 65536, 96, 1, 1, 96, 65536), + (1, 1, 65536, 96, 1, 1, 96, 768), + (1, 1, 768, 1, 1, 1, 1, 1), + (1, 1, 768, 1, 1, 1, 1, 2304), + (1, 1, 768, 1, 1, 1, 1, 3072), + (1, 1, 768, 1, 1, 1, 1, 65536), + (1, 1, 768, 1, 1, 1, 1, 768), + (1, 1, 768, 1, 1, 1, 1, 96), + (1, 1, 768, 1024, 1, 1, 1024, 768), + (1, 1, 768, 2304, 1, 1, 2304, 1), + (1, 1, 768, 2304, 1, 1, 2304, 65536), + (1, 1, 768, 2304, 1, 1, 2304, 768), + (1, 1, 768, 3072, 1, 1, 3072, 1), + (1, 1, 768, 3072, 1, 1, 3072, 65536), + (1, 1, 768, 3072, 1, 1, 3072, 768), + (1, 1, 768, 65536, 1, 1, 65536, 1), + (1, 1, 768, 65536, 1, 1, 65536, 2304), + (1, 1, 768, 65536, 1, 1, 65536, 3072), + (1, 1, 768, 65536, 1, 1, 65536, 768), + (1, 1, 768, 65536, 1, 1, 65536, 96), + (1, 1, 768, 768, 1, 1, 768, 1), + (1, 1, 768, 768, 1, 1, 768, 1024), + (1, 1, 768, 768, 1, 1, 768, 2304), + (1, 1, 768, 768, 1, 1, 768, 3072), + (1, 1, 768, 768, 1, 1, 768, 65536), + (1, 1, 768, 768, 1, 1, 768, 768), + (1, 1, 768, 768, 1, 1, 768, 96), + (1, 1, 768, 96, 1, 1, 96, 1), + (1, 1, 768, 96, 1, 1, 96, 65536), + (1, 1, 768, 96, 1, 1, 96, 768), + (1, 1, 96, 1, 1, 1, 1, 1), + (1, 1, 96, 1, 1, 1, 1, 2304), + (1, 1, 96, 1, 1, 1, 1, 3072), + (1, 1, 96, 1, 1, 1, 1, 65536), + (1, 1, 96, 1, 1, 1, 1, 768), + (1, 1, 96, 1, 1, 1, 1, 96), + (1, 1, 96, 65536, 1, 1, 65536, 1), + (1, 1, 96, 65536, 1, 1, 65536, 2304), + (1, 1, 96, 65536, 1, 1, 65536, 3072), + (1, 1, 96, 65536, 1, 1, 65536, 768), + (1, 1, 96, 65536, 1, 1, 65536, 96), + (1, 1, 96, 768, 1, 1, 768, 1), + (1, 1, 96, 768, 1, 1, 768, 1024), + (1, 1, 96, 768, 1, 1, 768, 2304), + (1, 1, 96, 768, 1, 1, 768, 3072), + (1, 1, 96, 768, 1, 1, 768, 65536), + (1, 1, 96, 768, 1, 1, 768, 768), + (1, 1, 96, 768, 1, 1, 768, 96), + (1, 64, 1024, 768, 1, 1, 768, 1), + (1, 64, 1024, 768, 1, 1, 768, 2304), + (1, 64, 1024, 768, 1, 1, 768, 3072), + (1, 64, 1024, 768, 1, 1, 768, 65536), + (1, 64, 1024, 768, 1, 1, 768, 768), + (1, 64, 1024, 768, 1, 1, 768, 96), + (1, 64, 768, 1024, 1, 1, 1024, 768), + (1, 64, 768, 1024, 1, 64, 1024, 768), + (64, 1, 1, 1024, 1, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 1), + (64, 1, 1, 1024, 64, 1, 1024, 2304), + (64, 1, 1, 1024, 64, 1, 1024, 3072), + (64, 1, 1, 1024, 64, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 96), + (64, 1, 1, 768, 1, 1, 768, 1), + (64, 1, 1, 768, 1, 1, 768, 1024), + (64, 1, 1, 768, 1, 1, 768, 2304), + (64, 1, 1, 768, 1, 1, 768, 3072), + (64, 1, 1, 768, 1, 1, 768, 65536), + (64, 1, 1, 768, 1, 1, 768, 768), + (64, 1, 1, 768, 1, 1, 768, 96), + (64, 1, 1, 768, 64, 1, 768, 1), + (64, 1, 1, 768, 64, 1, 768, 1024), + (64, 1, 1024, 1, 1, 1, 1, 2304), + (64, 1, 1024, 1, 1, 1, 1, 3072), + (64, 1, 1024, 1, 1, 1, 1, 768), + (64, 1, 1024, 1, 1, 1, 1, 96), + (64, 1, 1024, 1, 64, 1, 1, 1024), + (64, 1, 1024, 1, 64, 1, 1, 768), + (64, 1, 1024, 2304, 1, 1, 2304, 65536), + (64, 1, 1024, 2304, 1, 1, 2304, 768), + (64, 1, 1024, 2304, 64, 1, 2304, 1024), + (64, 1, 1024, 3072, 1, 1, 3072, 1), + (64, 1, 1024, 3072, 1, 1, 3072, 65536), + (64, 1, 1024, 3072, 1, 1, 3072, 768), + (64, 1, 1024, 768, 1, 1, 768, 1), + (64, 1, 1024, 768, 1, 1, 768, 1024), + (64, 1, 1024, 768, 1, 1, 768, 2304), + (64, 1, 1024, 768, 1, 1, 768, 3072), + (64, 1, 1024, 768, 1, 1, 768, 65536), + (64, 1, 1024, 768, 1, 1, 768, 768), + (64, 1, 1024, 768, 1, 1, 768, 96), + (64, 1, 1024, 768, 64, 1, 768, 1024), + (64, 1, 1024, 96, 1, 1, 96, 65536), + (64, 1, 1024, 96, 1, 1, 96, 768), + (64, 1, 1024, 96, 64, 1, 96, 1024), + (64, 1, 2304, 1024, 1, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 1), + (64, 1, 2304, 1024, 64, 1, 1024, 2304), + (64, 1, 2304, 1024, 64, 1, 1024, 3072), + (64, 1, 2304, 1024, 64, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 96), + (64, 1, 3072, 1024, 1, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 1), + (64, 1, 3072, 1024, 64, 1, 1024, 2304), + (64, 1, 3072, 1024, 64, 1, 1024, 3072), + (64, 1, 3072, 1024, 64, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 96), + (64, 1, 768, 1, 1, 1, 1, 2304), + (64, 1, 768, 1, 1, 1, 1, 3072), + (64, 1, 768, 1, 1, 1, 1, 768), + (64, 1, 768, 1, 1, 1, 1, 96), + (64, 1, 768, 1, 64, 1, 1, 768), + (64, 1, 768, 1024, 1, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 1), + (64, 1, 768, 1024, 64, 1, 1024, 2304), + (64, 1, 768, 1024, 64, 1, 1024, 3072), + (64, 1, 768, 1024, 64, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 96), + (64, 1, 96, 1024, 1, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 1), + (64, 1, 96, 1024, 64, 1, 1024, 2304), + (64, 1, 96, 1024, 64, 1, 1024, 3072), + (64, 1, 96, 1024, 64, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 96), + (64, 12, 1, 1024, 1, 1, 1024, 768), + (64, 12, 1, 1024, 64, 12, 1024, 1), + (64, 12, 1, 1024, 64, 12, 1024, 1024), + (64, 12, 1, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 1, 1, 1, 1, 1), + (64, 12, 1024, 1, 1, 1, 1, 2304), + (64, 12, 1024, 1, 1, 1, 1, 3072), + (64, 12, 1024, 1, 1, 1, 1, 768), + (64, 12, 1024, 1, 1, 1, 1, 96), + (64, 12, 1024, 1, 64, 12, 1, 1024), + (64, 12, 1024, 1024, 1, 1, 1024, 768), + (64, 12, 1024, 1024, 64, 12, 1024, 1), + (64, 12, 1024, 1024, 64, 12, 1024, 1024), + (64, 12, 1024, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 64, 64, 12, 64, 1024), + (64, 12, 64, 1024, 1, 1, 1024, 768), + (64, 12, 64, 1024, 64, 12, 1024, 1), + (64, 12, 64, 1024, 64, 12, 1024, 1024), + (64, 12, 64, 1024, 64, 12, 1024, 64), + ], + "core_grid": [True, False], + }, } def run( params, + core_grid, *, device, ) -> list: - [in0_h, in0_w, in1_h, in1_w] = params - torch_input_tensor0 = torch.rand([in0_h, in0_w], dtype=torch.float32) - torch_input_tensor1 = torch.rand([in1_h, in1_w], dtype=torch.float32) + if core_grid == False: + grid = None + else: + grid = device.core_grid + count = len(params) + half = int(count / 2) + shape0 = params[0:half] + shape1 = params[half:count] + torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32) + torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32) torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1) input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) start_time = start_measuring_time() - output_tensor = ttnn.matmul(input_tensor0, input_tensor1) + output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid) output_tensor = ttnn.to_torch(output_tensor) e2e_perf = stop_measuring_time(start_time) expected_pcc = 0.99