diff --git a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py index c66e5548834..0b6c9031553 100644 --- a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py +++ b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py @@ -13,6 +13,10 @@ TIMEOUT = 70 +# params contains the shape of the first tensor followed by the second tensor +# Note: the shape of the second tensor starts at int(count / 2). It's easiest +# to reason about if both tensors are the same rank, although some other +# combinations may be valid. parameters = { "default": { "params": [ @@ -111,25 +115,269 @@ (9, 768, 768, 640), (920, 256, 256, 256), ], - } + "core_grid": [False], + }, + "gpt": { + "params": [ + (1, 1, 1, 1, 1, 1, 1, 1), + (1, 1, 1, 1, 1, 1, 1, 2304), + (1, 1, 1, 1, 1, 1, 1, 3072), + (1, 1, 1, 1, 1, 1, 1, 65536), + (1, 1, 1, 1, 1, 1, 1, 768), + (1, 1, 1, 1, 1, 1, 1, 96), + (1, 1, 1, 2304, 1, 1, 2304, 1), + (1, 1, 1, 2304, 1, 1, 2304, 65536), + (1, 1, 1, 2304, 1, 1, 2304, 768), + (1, 1, 1, 3072, 1, 1, 3072, 1), + (1, 1, 1, 3072, 1, 1, 3072, 65536), + (1, 1, 1, 3072, 1, 1, 3072, 768), + (1, 1, 1, 65536, 1, 1, 65536, 2304), + (1, 1, 1, 65536, 1, 1, 65536, 3072), + (1, 1, 1, 65536, 1, 1, 65536, 768), + (1, 1, 1, 65536, 1, 1, 65536, 96), + (1, 1, 1, 768, 1, 1, 768, 1), + (1, 1, 1, 768, 1, 1, 768, 1024), + (1, 1, 1, 768, 1, 1, 768, 2304), + (1, 1, 1, 768, 1, 1, 768, 3072), + (1, 1, 1, 768, 1, 1, 768, 65536), + (1, 1, 1, 768, 1, 1, 768, 768), + (1, 1, 1, 768, 1, 1, 768, 96), + (1, 1, 1, 96, 1, 1, 96, 1), + (1, 1, 1, 96, 1, 1, 96, 65536), + (1, 1, 1, 96, 1, 1, 96, 768), + (1, 1, 1024, 768, 1, 1, 768, 1), + (1, 1, 1024, 768, 1, 1, 768, 1024), + (1, 1, 1024, 768, 1, 1, 768, 2304), + (1, 1, 1024, 768, 1, 1, 768, 3072), + (1, 1, 1024, 768, 1, 1, 768, 65536), + (1, 1, 1024, 768, 1, 1, 768, 768), + (1, 1, 1024, 768, 1, 1, 768, 96), + (1, 1, 2304, 1, 1, 1, 1, 1), + (1, 1, 2304, 1, 1, 1, 1, 2304), + (1, 1, 2304, 1, 1, 1, 1, 3072), + (1, 1, 2304, 1, 1, 1, 1, 65536), + (1, 1, 2304, 1, 1, 1, 1, 768), + (1, 1, 2304, 1, 1, 1, 1, 96), + (1, 1, 2304, 65536, 1, 1, 65536, 1), + (1, 1, 2304, 65536, 1, 1, 65536, 2304), + (1, 1, 2304, 65536, 1, 1, 65536, 3072), + (1, 1, 2304, 65536, 1, 1, 65536, 768), + (1, 1, 2304, 65536, 1, 1, 65536, 96), + (1, 1, 2304, 768, 1, 1, 768, 1), + (1, 1, 2304, 768, 1, 1, 768, 1024), + (1, 1, 2304, 768, 1, 1, 768, 2304), + (1, 1, 2304, 768, 1, 1, 768, 3072), + (1, 1, 2304, 768, 1, 1, 768, 65536), + (1, 1, 2304, 768, 1, 1, 768, 768), + (1, 1, 2304, 768, 1, 1, 768, 96), + (1, 1, 3072, 1, 1, 1, 1, 1), + (1, 1, 3072, 1, 1, 1, 1, 2304), + (1, 1, 3072, 1, 1, 1, 1, 3072), + (1, 1, 3072, 1, 1, 1, 1, 65536), + (1, 1, 3072, 1, 1, 1, 1, 768), + (1, 1, 3072, 1, 1, 1, 1, 96), + (1, 1, 3072, 65536, 1, 1, 65536, 1), + (1, 1, 3072, 65536, 1, 1, 65536, 2304), + (1, 1, 3072, 65536, 1, 1, 65536, 3072), + (1, 1, 3072, 65536, 1, 1, 65536, 768), + (1, 1, 3072, 65536, 1, 1, 65536, 96), + (1, 1, 3072, 768, 1, 1, 768, 1), + (1, 1, 3072, 768, 1, 1, 768, 1024), + (1, 1, 3072, 768, 1, 1, 768, 2304), + (1, 1, 3072, 768, 1, 1, 768, 3072), + (1, 1, 3072, 768, 1, 1, 768, 65536), + (1, 1, 3072, 768, 1, 1, 768, 768), + (1, 1, 3072, 768, 1, 1, 768, 96), + (1, 1, 65536, 1, 1, 1, 1, 1), + (1, 1, 65536, 1, 1, 1, 1, 2304), + (1, 1, 65536, 1, 1, 1, 1, 3072), + (1, 1, 65536, 1, 1, 1, 1, 65536), + (1, 1, 65536, 1, 1, 1, 1, 768), + (1, 1, 65536, 1, 1, 1, 1, 96), + (1, 1, 65536, 2304, 1, 1, 2304, 1), + (1, 1, 65536, 2304, 1, 1, 2304, 65536), + (1, 1, 65536, 2304, 1, 1, 2304, 768), + (1, 1, 65536, 3072, 1, 1, 3072, 1), + (1, 1, 65536, 3072, 1, 1, 3072, 65536), + (1, 1, 65536, 3072, 1, 1, 3072, 768), + (1, 1, 65536, 768, 1, 1, 768, 1), + (1, 1, 65536, 768, 1, 1, 768, 1024), + (1, 1, 65536, 768, 1, 1, 768, 2304), + (1, 1, 65536, 768, 1, 1, 768, 3072), + (1, 1, 65536, 768, 1, 1, 768, 65536), + (1, 1, 65536, 768, 1, 1, 768, 768), + (1, 1, 65536, 768, 1, 1, 768, 96), + (1, 1, 65536, 96, 1, 1, 96, 65536), + (1, 1, 65536, 96, 1, 1, 96, 768), + (1, 1, 768, 1, 1, 1, 1, 1), + (1, 1, 768, 1, 1, 1, 1, 2304), + (1, 1, 768, 1, 1, 1, 1, 3072), + (1, 1, 768, 1, 1, 1, 1, 65536), + (1, 1, 768, 1, 1, 1, 1, 768), + (1, 1, 768, 1, 1, 1, 1, 96), + (1, 1, 768, 1024, 1, 1, 1024, 768), + (1, 1, 768, 2304, 1, 1, 2304, 1), + (1, 1, 768, 2304, 1, 1, 2304, 65536), + (1, 1, 768, 2304, 1, 1, 2304, 768), + (1, 1, 768, 3072, 1, 1, 3072, 1), + (1, 1, 768, 3072, 1, 1, 3072, 65536), + (1, 1, 768, 3072, 1, 1, 3072, 768), + (1, 1, 768, 65536, 1, 1, 65536, 1), + (1, 1, 768, 65536, 1, 1, 65536, 2304), + (1, 1, 768, 65536, 1, 1, 65536, 3072), + (1, 1, 768, 65536, 1, 1, 65536, 768), + (1, 1, 768, 65536, 1, 1, 65536, 96), + (1, 1, 768, 768, 1, 1, 768, 1), + (1, 1, 768, 768, 1, 1, 768, 1024), + (1, 1, 768, 768, 1, 1, 768, 2304), + (1, 1, 768, 768, 1, 1, 768, 3072), + (1, 1, 768, 768, 1, 1, 768, 65536), + (1, 1, 768, 768, 1, 1, 768, 768), + (1, 1, 768, 768, 1, 1, 768, 96), + (1, 1, 768, 96, 1, 1, 96, 1), + (1, 1, 768, 96, 1, 1, 96, 65536), + (1, 1, 768, 96, 1, 1, 96, 768), + (1, 1, 96, 1, 1, 1, 1, 1), + (1, 1, 96, 1, 1, 1, 1, 2304), + (1, 1, 96, 1, 1, 1, 1, 3072), + (1, 1, 96, 1, 1, 1, 1, 65536), + (1, 1, 96, 1, 1, 1, 1, 768), + (1, 1, 96, 1, 1, 1, 1, 96), + (1, 1, 96, 65536, 1, 1, 65536, 1), + (1, 1, 96, 65536, 1, 1, 65536, 2304), + (1, 1, 96, 65536, 1, 1, 65536, 3072), + (1, 1, 96, 65536, 1, 1, 65536, 768), + (1, 1, 96, 65536, 1, 1, 65536, 96), + (1, 1, 96, 768, 1, 1, 768, 1), + (1, 1, 96, 768, 1, 1, 768, 1024), + (1, 1, 96, 768, 1, 1, 768, 2304), + (1, 1, 96, 768, 1, 1, 768, 3072), + (1, 1, 96, 768, 1, 1, 768, 65536), + (1, 1, 96, 768, 1, 1, 768, 768), + (1, 1, 96, 768, 1, 1, 768, 96), + (1, 64, 1024, 768, 1, 1, 768, 1), + (1, 64, 1024, 768, 1, 1, 768, 2304), + (1, 64, 1024, 768, 1, 1, 768, 3072), + (1, 64, 1024, 768, 1, 1, 768, 65536), + (1, 64, 1024, 768, 1, 1, 768, 768), + (1, 64, 1024, 768, 1, 1, 768, 96), + (1, 64, 768, 1024, 1, 1, 1024, 768), + (1, 64, 768, 1024, 1, 64, 1024, 768), + (64, 1, 1, 1024, 1, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 1), + (64, 1, 1, 1024, 64, 1, 1024, 2304), + (64, 1, 1, 1024, 64, 1, 1024, 3072), + (64, 1, 1, 1024, 64, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 96), + (64, 1, 1, 768, 1, 1, 768, 1), + (64, 1, 1, 768, 1, 1, 768, 1024), + (64, 1, 1, 768, 1, 1, 768, 2304), + (64, 1, 1, 768, 1, 1, 768, 3072), + (64, 1, 1, 768, 1, 1, 768, 65536), + (64, 1, 1, 768, 1, 1, 768, 768), + (64, 1, 1, 768, 1, 1, 768, 96), + (64, 1, 1, 768, 64, 1, 768, 1), + (64, 1, 1, 768, 64, 1, 768, 1024), + (64, 1, 1024, 1, 1, 1, 1, 2304), + (64, 1, 1024, 1, 1, 1, 1, 3072), + (64, 1, 1024, 1, 1, 1, 1, 768), + (64, 1, 1024, 1, 1, 1, 1, 96), + (64, 1, 1024, 1, 64, 1, 1, 1024), + (64, 1, 1024, 1, 64, 1, 1, 768), + (64, 1, 1024, 2304, 1, 1, 2304, 65536), + (64, 1, 1024, 2304, 1, 1, 2304, 768), + (64, 1, 1024, 2304, 64, 1, 2304, 1024), + (64, 1, 1024, 3072, 1, 1, 3072, 1), + (64, 1, 1024, 3072, 1, 1, 3072, 65536), + (64, 1, 1024, 3072, 1, 1, 3072, 768), + (64, 1, 1024, 768, 1, 1, 768, 1), + (64, 1, 1024, 768, 1, 1, 768, 1024), + (64, 1, 1024, 768, 1, 1, 768, 2304), + (64, 1, 1024, 768, 1, 1, 768, 3072), + (64, 1, 1024, 768, 1, 1, 768, 65536), + (64, 1, 1024, 768, 1, 1, 768, 768), + (64, 1, 1024, 768, 1, 1, 768, 96), + (64, 1, 1024, 768, 64, 1, 768, 1024), + (64, 1, 1024, 96, 1, 1, 96, 65536), + (64, 1, 1024, 96, 1, 1, 96, 768), + (64, 1, 1024, 96, 64, 1, 96, 1024), + (64, 1, 2304, 1024, 1, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 1), + (64, 1, 2304, 1024, 64, 1, 1024, 2304), + (64, 1, 2304, 1024, 64, 1, 1024, 3072), + (64, 1, 2304, 1024, 64, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 96), + (64, 1, 3072, 1024, 1, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 1), + (64, 1, 3072, 1024, 64, 1, 1024, 2304), + (64, 1, 3072, 1024, 64, 1, 1024, 3072), + (64, 1, 3072, 1024, 64, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 96), + (64, 1, 768, 1, 1, 1, 1, 2304), + (64, 1, 768, 1, 1, 1, 1, 3072), + (64, 1, 768, 1, 1, 1, 1, 768), + (64, 1, 768, 1, 1, 1, 1, 96), + (64, 1, 768, 1, 64, 1, 1, 768), + (64, 1, 768, 1024, 1, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 1), + (64, 1, 768, 1024, 64, 1, 1024, 2304), + (64, 1, 768, 1024, 64, 1, 1024, 3072), + (64, 1, 768, 1024, 64, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 96), + (64, 1, 96, 1024, 1, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 1), + (64, 1, 96, 1024, 64, 1, 1024, 2304), + (64, 1, 96, 1024, 64, 1, 1024, 3072), + (64, 1, 96, 1024, 64, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 96), + (64, 12, 1, 1024, 1, 1, 1024, 768), + (64, 12, 1, 1024, 64, 12, 1024, 1), + (64, 12, 1, 1024, 64, 12, 1024, 1024), + (64, 12, 1, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 1, 1, 1, 1, 1), + (64, 12, 1024, 1, 1, 1, 1, 2304), + (64, 12, 1024, 1, 1, 1, 1, 3072), + (64, 12, 1024, 1, 1, 1, 1, 768), + (64, 12, 1024, 1, 1, 1, 1, 96), + (64, 12, 1024, 1, 64, 12, 1, 1024), + (64, 12, 1024, 1024, 1, 1, 1024, 768), + (64, 12, 1024, 1024, 64, 12, 1024, 1), + (64, 12, 1024, 1024, 64, 12, 1024, 1024), + (64, 12, 1024, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 64, 64, 12, 64, 1024), + (64, 12, 64, 1024, 1, 1, 1024, 768), + (64, 12, 64, 1024, 64, 12, 1024, 1), + (64, 12, 64, 1024, 64, 12, 1024, 1024), + (64, 12, 64, 1024, 64, 12, 1024, 64), + ], + "core_grid": [True, False], + }, } def run( params, + core_grid, *, device, ) -> list: - [in0_h, in0_w, in1_h, in1_w] = params - torch_input_tensor0 = torch.rand([in0_h, in0_w], dtype=torch.float32) - torch_input_tensor1 = torch.rand([in1_h, in1_w], dtype=torch.float32) + if core_grid == False: + grid = None + else: + grid = device.core_grid + count = len(params) + half = int(count / 2) + shape0 = params[0:half] + shape1 = params[half:count] + torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32) + torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32) torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1) input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) start_time = start_measuring_time() - output_tensor = ttnn.matmul(input_tensor0, input_tensor1) + output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid) output_tensor = ttnn.to_torch(output_tensor) e2e_perf = stop_measuring_time(start_time) expected_pcc = 0.99