Skip to content

Commit

Permalink
[CCL] Fix ccl sweeps failure (#15304)
Browse files Browse the repository at this point in the history
### Ticket
#15287 

### Problem description
Sweeps tests generation failing in main

### What's changed
Fix test generation

### Checklist
- [ ] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
Aswinmcw authored Nov 22, 2024
1 parent 06c2663 commit 63ae1d9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
7 changes: 6 additions & 1 deletion tests/sweep_framework/sweeps/ccl/all_gather_n300.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
],
"enable_async": [True],
"num_iters": [1],
"tile": [(32, 32)],
},
}

Expand All @@ -58,6 +59,7 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
test_vector["num_links"],
test_vector["input_dtype"],
test_vector["layout"],
test_vector["tile"],
)
if is_known_failure:
return True, f"Skipping unsupported case {message}."
Expand Down Expand Up @@ -90,6 +92,7 @@ def run(
mem_config,
enable_async,
num_iters,
tile,
*,
device,
) -> list:
Expand All @@ -108,7 +111,9 @@ def run(
input_tensors = torch.chunk(input_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(all_devices[i], mem_config))
tt_input_tensors.append(
ttnn.Tensor(t, input_dtype, {}, ttnn.Tile(tile)).to(layout).to(all_devices[i], mem_config)
)

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
for i in range(num_iters):
Expand Down
5 changes: 4 additions & 1 deletion tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
],
"enable_async": [True],
"num_iters": [1],
"tile": [(32, 32)],
},
"all_gather_n300_focused_large": {
"num_devices": [2],
Expand All @@ -134,6 +135,7 @@
],
"enable_async": [True],
"num_iters": [1],
"tile": [(32, 32)],
},
}

Expand Down Expand Up @@ -179,6 +181,7 @@ def run(
mem_config,
enable_async,
num_iters,
tile,
*,
device,
) -> list:
Expand All @@ -204,7 +207,7 @@ def run(
input_tensors = torch.chunk(input_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
t = ttnn.from_torch(t, input_dtype, layout=ttnn.Layout.TILE)
t = ttnn.from_torch(t, input_dtype, tile=ttnn.Tile(tile), layout=ttnn.Layout.TILE)
t = t.to(all_devices[i], mem_config)
tt_input_tensors.append(t)

Expand Down
7 changes: 6 additions & 1 deletion tests/sweep_framework/sweeps/ccl/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"mem_config": [ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM)],
"enable_async": [True, False],
"num_iters": [1],
"tile": [(32, 32)],
},
}

Expand All @@ -52,6 +53,7 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
test_vector["num_links"],
test_vector["input_dtype"],
test_vector["layout"],
test_vector["tile"],
)
if is_known_failure:
return True, f"Skipping unsupported case {message}."
Expand Down Expand Up @@ -89,6 +91,7 @@ def run(
mem_config,
enable_async,
num_iters,
tile,
*,
device,
) -> list:
Expand All @@ -100,7 +103,9 @@ def run(

input_tensor = torch.rand(input_shape).bfloat16()

ttnn_tensor = ttnn.from_torch(input_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim))
ttnn_tensor = ttnn.from_torch(
input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim)
)
input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device)

for i in range(num_iters):
Expand Down

0 comments on commit 63ae1d9

Please sign in to comment.