Skip to content

Commit

Permalink
#17532: Add relevant tests to matmul trace sweeps (#15850)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #15732

### Problem description
Need tests for various shapes to verify that they pass after making
changes

### What's changed
Create the tests based on a trace of a run of the model

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12247713568
- [ ] Blackhole Post commit (if applicable) N/A
- [ ] Model regression CI testing passes (if applicable) N/A
- [ ] Device performance regression CI testing passes (if applicable)
N/A
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes N/A
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
bbradelTT authored Dec 10, 2024
1 parent 4bcc79b commit 969c680
Showing 1 changed file with 253 additions and 5 deletions.
258 changes: 253 additions & 5 deletions tests/sweep_framework/sweeps/matmul/short/matmul_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

TIMEOUT = 70

# params contains the shape of the first tensor followed by the second tensor
# Note: the shape of the second tensor starts at int(count / 2). It's easiest
# to reason about if both tensors are the same rank, although some other
# combinations may be valid.
parameters = {
"default": {
"params": [
Expand Down Expand Up @@ -111,25 +115,269 @@
(9, 768, 768, 640),
(920, 256, 256, 256),
],
}
"core_grid": [False],
},
"gpt": {
"params": [
(1, 1, 1, 1, 1, 1, 1, 1),
(1, 1, 1, 1, 1, 1, 1, 2304),
(1, 1, 1, 1, 1, 1, 1, 3072),
(1, 1, 1, 1, 1, 1, 1, 65536),
(1, 1, 1, 1, 1, 1, 1, 768),
(1, 1, 1, 1, 1, 1, 1, 96),
(1, 1, 1, 2304, 1, 1, 2304, 1),
(1, 1, 1, 2304, 1, 1, 2304, 65536),
(1, 1, 1, 2304, 1, 1, 2304, 768),
(1, 1, 1, 3072, 1, 1, 3072, 1),
(1, 1, 1, 3072, 1, 1, 3072, 65536),
(1, 1, 1, 3072, 1, 1, 3072, 768),
(1, 1, 1, 65536, 1, 1, 65536, 2304),
(1, 1, 1, 65536, 1, 1, 65536, 3072),
(1, 1, 1, 65536, 1, 1, 65536, 768),
(1, 1, 1, 65536, 1, 1, 65536, 96),
(1, 1, 1, 768, 1, 1, 768, 1),
(1, 1, 1, 768, 1, 1, 768, 1024),
(1, 1, 1, 768, 1, 1, 768, 2304),
(1, 1, 1, 768, 1, 1, 768, 3072),
(1, 1, 1, 768, 1, 1, 768, 65536),
(1, 1, 1, 768, 1, 1, 768, 768),
(1, 1, 1, 768, 1, 1, 768, 96),
(1, 1, 1, 96, 1, 1, 96, 1),
(1, 1, 1, 96, 1, 1, 96, 65536),
(1, 1, 1, 96, 1, 1, 96, 768),
(1, 1, 1024, 768, 1, 1, 768, 1),
(1, 1, 1024, 768, 1, 1, 768, 1024),
(1, 1, 1024, 768, 1, 1, 768, 2304),
(1, 1, 1024, 768, 1, 1, 768, 3072),
(1, 1, 1024, 768, 1, 1, 768, 65536),
(1, 1, 1024, 768, 1, 1, 768, 768),
(1, 1, 1024, 768, 1, 1, 768, 96),
(1, 1, 2304, 1, 1, 1, 1, 1),
(1, 1, 2304, 1, 1, 1, 1, 2304),
(1, 1, 2304, 1, 1, 1, 1, 3072),
(1, 1, 2304, 1, 1, 1, 1, 65536),
(1, 1, 2304, 1, 1, 1, 1, 768),
(1, 1, 2304, 1, 1, 1, 1, 96),
(1, 1, 2304, 65536, 1, 1, 65536, 1),
(1, 1, 2304, 65536, 1, 1, 65536, 2304),
(1, 1, 2304, 65536, 1, 1, 65536, 3072),
(1, 1, 2304, 65536, 1, 1, 65536, 768),
(1, 1, 2304, 65536, 1, 1, 65536, 96),
(1, 1, 2304, 768, 1, 1, 768, 1),
(1, 1, 2304, 768, 1, 1, 768, 1024),
(1, 1, 2304, 768, 1, 1, 768, 2304),
(1, 1, 2304, 768, 1, 1, 768, 3072),
(1, 1, 2304, 768, 1, 1, 768, 65536),
(1, 1, 2304, 768, 1, 1, 768, 768),
(1, 1, 2304, 768, 1, 1, 768, 96),
(1, 1, 3072, 1, 1, 1, 1, 1),
(1, 1, 3072, 1, 1, 1, 1, 2304),
(1, 1, 3072, 1, 1, 1, 1, 3072),
(1, 1, 3072, 1, 1, 1, 1, 65536),
(1, 1, 3072, 1, 1, 1, 1, 768),
(1, 1, 3072, 1, 1, 1, 1, 96),
(1, 1, 3072, 65536, 1, 1, 65536, 1),
(1, 1, 3072, 65536, 1, 1, 65536, 2304),
(1, 1, 3072, 65536, 1, 1, 65536, 3072),
(1, 1, 3072, 65536, 1, 1, 65536, 768),
(1, 1, 3072, 65536, 1, 1, 65536, 96),
(1, 1, 3072, 768, 1, 1, 768, 1),
(1, 1, 3072, 768, 1, 1, 768, 1024),
(1, 1, 3072, 768, 1, 1, 768, 2304),
(1, 1, 3072, 768, 1, 1, 768, 3072),
(1, 1, 3072, 768, 1, 1, 768, 65536),
(1, 1, 3072, 768, 1, 1, 768, 768),
(1, 1, 3072, 768, 1, 1, 768, 96),
(1, 1, 65536, 1, 1, 1, 1, 1),
(1, 1, 65536, 1, 1, 1, 1, 2304),
(1, 1, 65536, 1, 1, 1, 1, 3072),
(1, 1, 65536, 1, 1, 1, 1, 65536),
(1, 1, 65536, 1, 1, 1, 1, 768),
(1, 1, 65536, 1, 1, 1, 1, 96),
(1, 1, 65536, 2304, 1, 1, 2304, 1),
(1, 1, 65536, 2304, 1, 1, 2304, 65536),
(1, 1, 65536, 2304, 1, 1, 2304, 768),
(1, 1, 65536, 3072, 1, 1, 3072, 1),
(1, 1, 65536, 3072, 1, 1, 3072, 65536),
(1, 1, 65536, 3072, 1, 1, 3072, 768),
(1, 1, 65536, 768, 1, 1, 768, 1),
(1, 1, 65536, 768, 1, 1, 768, 1024),
(1, 1, 65536, 768, 1, 1, 768, 2304),
(1, 1, 65536, 768, 1, 1, 768, 3072),
(1, 1, 65536, 768, 1, 1, 768, 65536),
(1, 1, 65536, 768, 1, 1, 768, 768),
(1, 1, 65536, 768, 1, 1, 768, 96),
(1, 1, 65536, 96, 1, 1, 96, 65536),
(1, 1, 65536, 96, 1, 1, 96, 768),
(1, 1, 768, 1, 1, 1, 1, 1),
(1, 1, 768, 1, 1, 1, 1, 2304),
(1, 1, 768, 1, 1, 1, 1, 3072),
(1, 1, 768, 1, 1, 1, 1, 65536),
(1, 1, 768, 1, 1, 1, 1, 768),
(1, 1, 768, 1, 1, 1, 1, 96),
(1, 1, 768, 1024, 1, 1, 1024, 768),
(1, 1, 768, 2304, 1, 1, 2304, 1),
(1, 1, 768, 2304, 1, 1, 2304, 65536),
(1, 1, 768, 2304, 1, 1, 2304, 768),
(1, 1, 768, 3072, 1, 1, 3072, 1),
(1, 1, 768, 3072, 1, 1, 3072, 65536),
(1, 1, 768, 3072, 1, 1, 3072, 768),
(1, 1, 768, 65536, 1, 1, 65536, 1),
(1, 1, 768, 65536, 1, 1, 65536, 2304),
(1, 1, 768, 65536, 1, 1, 65536, 3072),
(1, 1, 768, 65536, 1, 1, 65536, 768),
(1, 1, 768, 65536, 1, 1, 65536, 96),
(1, 1, 768, 768, 1, 1, 768, 1),
(1, 1, 768, 768, 1, 1, 768, 1024),
(1, 1, 768, 768, 1, 1, 768, 2304),
(1, 1, 768, 768, 1, 1, 768, 3072),
(1, 1, 768, 768, 1, 1, 768, 65536),
(1, 1, 768, 768, 1, 1, 768, 768),
(1, 1, 768, 768, 1, 1, 768, 96),
(1, 1, 768, 96, 1, 1, 96, 1),
(1, 1, 768, 96, 1, 1, 96, 65536),
(1, 1, 768, 96, 1, 1, 96, 768),
(1, 1, 96, 1, 1, 1, 1, 1),
(1, 1, 96, 1, 1, 1, 1, 2304),
(1, 1, 96, 1, 1, 1, 1, 3072),
(1, 1, 96, 1, 1, 1, 1, 65536),
(1, 1, 96, 1, 1, 1, 1, 768),
(1, 1, 96, 1, 1, 1, 1, 96),
(1, 1, 96, 65536, 1, 1, 65536, 1),
(1, 1, 96, 65536, 1, 1, 65536, 2304),
(1, 1, 96, 65536, 1, 1, 65536, 3072),
(1, 1, 96, 65536, 1, 1, 65536, 768),
(1, 1, 96, 65536, 1, 1, 65536, 96),
(1, 1, 96, 768, 1, 1, 768, 1),
(1, 1, 96, 768, 1, 1, 768, 1024),
(1, 1, 96, 768, 1, 1, 768, 2304),
(1, 1, 96, 768, 1, 1, 768, 3072),
(1, 1, 96, 768, 1, 1, 768, 65536),
(1, 1, 96, 768, 1, 1, 768, 768),
(1, 1, 96, 768, 1, 1, 768, 96),
(1, 64, 1024, 768, 1, 1, 768, 1),
(1, 64, 1024, 768, 1, 1, 768, 2304),
(1, 64, 1024, 768, 1, 1, 768, 3072),
(1, 64, 1024, 768, 1, 1, 768, 65536),
(1, 64, 1024, 768, 1, 1, 768, 768),
(1, 64, 1024, 768, 1, 1, 768, 96),
(1, 64, 768, 1024, 1, 1, 1024, 768),
(1, 64, 768, 1024, 1, 64, 1024, 768),
(64, 1, 1, 1024, 1, 1, 1024, 768),
(64, 1, 1, 1024, 64, 1, 1024, 1),
(64, 1, 1, 1024, 64, 1, 1024, 2304),
(64, 1, 1, 1024, 64, 1, 1024, 3072),
(64, 1, 1, 1024, 64, 1, 1024, 768),
(64, 1, 1, 1024, 64, 1, 1024, 96),
(64, 1, 1, 768, 1, 1, 768, 1),
(64, 1, 1, 768, 1, 1, 768, 1024),
(64, 1, 1, 768, 1, 1, 768, 2304),
(64, 1, 1, 768, 1, 1, 768, 3072),
(64, 1, 1, 768, 1, 1, 768, 65536),
(64, 1, 1, 768, 1, 1, 768, 768),
(64, 1, 1, 768, 1, 1, 768, 96),
(64, 1, 1, 768, 64, 1, 768, 1),
(64, 1, 1, 768, 64, 1, 768, 1024),
(64, 1, 1024, 1, 1, 1, 1, 2304),
(64, 1, 1024, 1, 1, 1, 1, 3072),
(64, 1, 1024, 1, 1, 1, 1, 768),
(64, 1, 1024, 1, 1, 1, 1, 96),
(64, 1, 1024, 1, 64, 1, 1, 1024),
(64, 1, 1024, 1, 64, 1, 1, 768),
(64, 1, 1024, 2304, 1, 1, 2304, 65536),
(64, 1, 1024, 2304, 1, 1, 2304, 768),
(64, 1, 1024, 2304, 64, 1, 2304, 1024),
(64, 1, 1024, 3072, 1, 1, 3072, 1),
(64, 1, 1024, 3072, 1, 1, 3072, 65536),
(64, 1, 1024, 3072, 1, 1, 3072, 768),
(64, 1, 1024, 768, 1, 1, 768, 1),
(64, 1, 1024, 768, 1, 1, 768, 1024),
(64, 1, 1024, 768, 1, 1, 768, 2304),
(64, 1, 1024, 768, 1, 1, 768, 3072),
(64, 1, 1024, 768, 1, 1, 768, 65536),
(64, 1, 1024, 768, 1, 1, 768, 768),
(64, 1, 1024, 768, 1, 1, 768, 96),
(64, 1, 1024, 768, 64, 1, 768, 1024),
(64, 1, 1024, 96, 1, 1, 96, 65536),
(64, 1, 1024, 96, 1, 1, 96, 768),
(64, 1, 1024, 96, 64, 1, 96, 1024),
(64, 1, 2304, 1024, 1, 1, 1024, 768),
(64, 1, 2304, 1024, 64, 1, 1024, 1),
(64, 1, 2304, 1024, 64, 1, 1024, 2304),
(64, 1, 2304, 1024, 64, 1, 1024, 3072),
(64, 1, 2304, 1024, 64, 1, 1024, 768),
(64, 1, 2304, 1024, 64, 1, 1024, 96),
(64, 1, 3072, 1024, 1, 1, 1024, 768),
(64, 1, 3072, 1024, 64, 1, 1024, 1),
(64, 1, 3072, 1024, 64, 1, 1024, 2304),
(64, 1, 3072, 1024, 64, 1, 1024, 3072),
(64, 1, 3072, 1024, 64, 1, 1024, 768),
(64, 1, 3072, 1024, 64, 1, 1024, 96),
(64, 1, 768, 1, 1, 1, 1, 2304),
(64, 1, 768, 1, 1, 1, 1, 3072),
(64, 1, 768, 1, 1, 1, 1, 768),
(64, 1, 768, 1, 1, 1, 1, 96),
(64, 1, 768, 1, 64, 1, 1, 768),
(64, 1, 768, 1024, 1, 1, 1024, 768),
(64, 1, 768, 1024, 64, 1, 1024, 1),
(64, 1, 768, 1024, 64, 1, 1024, 2304),
(64, 1, 768, 1024, 64, 1, 1024, 3072),
(64, 1, 768, 1024, 64, 1, 1024, 768),
(64, 1, 768, 1024, 64, 1, 1024, 96),
(64, 1, 96, 1024, 1, 1, 1024, 768),
(64, 1, 96, 1024, 64, 1, 1024, 1),
(64, 1, 96, 1024, 64, 1, 1024, 2304),
(64, 1, 96, 1024, 64, 1, 1024, 3072),
(64, 1, 96, 1024, 64, 1, 1024, 768),
(64, 1, 96, 1024, 64, 1, 1024, 96),
(64, 12, 1, 1024, 1, 1, 1024, 768),
(64, 12, 1, 1024, 64, 12, 1024, 1),
(64, 12, 1, 1024, 64, 12, 1024, 1024),
(64, 12, 1, 1024, 64, 12, 1024, 64),
(64, 12, 1024, 1, 1, 1, 1, 1),
(64, 12, 1024, 1, 1, 1, 1, 2304),
(64, 12, 1024, 1, 1, 1, 1, 3072),
(64, 12, 1024, 1, 1, 1, 1, 768),
(64, 12, 1024, 1, 1, 1, 1, 96),
(64, 12, 1024, 1, 64, 12, 1, 1024),
(64, 12, 1024, 1024, 1, 1, 1024, 768),
(64, 12, 1024, 1024, 64, 12, 1024, 1),
(64, 12, 1024, 1024, 64, 12, 1024, 1024),
(64, 12, 1024, 1024, 64, 12, 1024, 64),
(64, 12, 1024, 64, 64, 12, 64, 1024),
(64, 12, 64, 1024, 1, 1, 1024, 768),
(64, 12, 64, 1024, 64, 12, 1024, 1),
(64, 12, 64, 1024, 64, 12, 1024, 1024),
(64, 12, 64, 1024, 64, 12, 1024, 64),
],
"core_grid": [True, False],
},
}


def run(
params,
core_grid,
*,
device,
) -> list:
[in0_h, in0_w, in1_h, in1_w] = params
torch_input_tensor0 = torch.rand([in0_h, in0_w], dtype=torch.float32)
torch_input_tensor1 = torch.rand([in1_h, in1_w], dtype=torch.float32)
if core_grid == False:
grid = None
else:
grid = device.core_grid
count = len(params)
half = int(count / 2)
shape0 = params[0:half]
shape1 = params[half:count]
torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32)
torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32)
torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1)

input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

start_time = start_measuring_time()
output_tensor = ttnn.matmul(input_tensor0, input_tensor1)
output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid)
output_tensor = ttnn.to_torch(output_tensor)
e2e_perf = stop_measuring_time(start_time)
expected_pcc = 0.99
Expand Down

0 comments on commit 969c680

Please sign in to comment.