From 230d53b9d99c82bb301882d1f20c667b67cd00b1 Mon Sep 17 00:00:00 2001 From: Atul Krishnadas Date: Fri, 22 Nov 2024 14:29:41 -0800 Subject: [PATCH] #15320: sweep expand (#15343) ### Ticket [#15320](https://github.com/tenstorrent/tt-metal/issues/15320) ### Problem description Need to add sweep for expand ### What's changed Added the run function to compare torch and ttnn expand ### Checklist - [ ] Post commit CI passes: --- .../data_movement/expand/expand_pytorch2.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/sweep_framework/sweeps/data_movement/expand/expand_pytorch2.py b/tests/sweep_framework/sweeps/data_movement/expand/expand_pytorch2.py index 22837badf2f..9eef68842af 100644 --- a/tests/sweep_framework/sweeps/data_movement/expand/expand_pytorch2.py +++ b/tests/sweep_framework/sweeps/data_movement/expand/expand_pytorch2.py @@ -309,4 +309,18 @@ def run( *, device, ): - raise Exception("Expand is not supported, TODO: implement via recursive concat with itself") + torch_tensor = torch_random(expand_specs["shape"], -10, 10, dtype=torch.bfloat16) + expanded_tensor = torch_tensor.expand(expand_specs["size"]) + + ttnn_tensor = ttnn.from_torch(torch_tensor, device=device, layout=layout, dtype=dtype) + + start_time = start_measuring_time() + expanded_ttnn_tensor = ttnn.expand(ttnn_tensor, expand_specs["size"]) + e2e_perf = stop_measuring_time(start_time) + + ttnn_output_tensor = ttnn.to_torch(expanded_ttnn_tensor) + + result = check_with_pcc(expanded_tensor, ttnn_output_tensor, 0.999) + + return [result, e2e_perf] + # raise Exception("Expand is not supported, TODO: implement via recursive concat with itself")