Skip to content

Commit

Permalink
#15877: added support for subcoregrid in sdpa decode (#15927)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#15877)

### Problem description
SDPA decode can now run on subcoregrids if sub_core_grids is passed in
SDPA Program Config
  • Loading branch information
kpaigwar authored Dec 12, 2024
1 parent 279ef8b commit ee62a86
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,21 @@ def run_test_sdpa_decode_single_iter(
sharded_out=False,
start_indices=None,
causal=True,
start_core=ttnn.CoreCoord(0, 0),
sub_core_grids=None,
):
compute_grid_size = device.compute_with_storage_grid_size()
if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}")

if sub_core_grids is None:
if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}")
else:
unharvested_grid_size = (7, 10)
if compute_grid_size.x > unharvested_grid_size[0] or compute_grid_size.y > unharvested_grid_size[1]:
pytest.skip(f"Need {unharvested_grid_size} grid size to run this test but core grid is {compute_grid_size}")
if grid_size[0] * grid_size[1] > sub_core_grids.num_cores():
pytest.skip(
f"Need {grid_size[0]*grid_size[1]} grid size to run this test but core grid is {sub_core_grids.num_cores()}"
)
padded_num_heads = nearest_pow_2(nearest_n(nh, n=32))
torch.manual_seed(1234)

Expand All @@ -346,7 +356,14 @@ def run_test_sdpa_decode_single_iter(
)
dram_memcfg = ttnn.DRAM_MEMORY_CONFIG

shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)})
if sub_core_grids is None:
shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)})
compute_sub_core_grids = None
else:
shard_grid = ttnn.num_cores_to_corerangeset_in_subcoregrids(start_core, b, sub_core_grids, row_wise=True)
compute_sub_core_grids = ttnn.num_cores_to_corerangeset_in_subcoregrids(
start_core, grid_size[0] * grid_size[1], sub_core_grids, row_wise=True
)
shard_spec = ttnn.ShardSpec(shard_grid, (padded_num_heads, d), ttnn.ShardOrientation.ROW_MAJOR, False)

height_sharded_memcfg = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec)
Expand All @@ -364,6 +381,7 @@ def run_test_sdpa_decode_single_iter(
k_chunk_size = get_chunk_size(max_start_idx + 1, s)
program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=grid_size,
sub_core_grids=compute_sub_core_grids,
q_chunk_size=padded_num_heads,
k_chunk_size=k_chunk_size,
exp_approx_mode=False,
Expand Down Expand Up @@ -904,6 +922,75 @@ def test_sdpa_decode_sharded(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype
)


@skip_for_blackhole("Unsupported on BH, see #12349")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], indirect=True)
@pytest.mark.parametrize(
"dtype, q_dtype",
[
[ttnn.bfloat8_b, ttnn.bfloat16],
],
ids=[
"bfp8_cache_bf16_act",
],
)
@pytest.mark.parametrize(
"b, nh, nkv, s, d, grid_size",
(
[8, 8, 1, 2048, 128, (8, 4)],
[8, 8, 1, 256, 128, (8, 4)],
), # Llama2-70B
)
@pytest.mark.parametrize(
"start_core, sub_core_grids",
[
(
ttnn.CoreCoord(1, 0),
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
),
),
],
)
def test_sdpa_decode_sharded_on_subcoregrids(
device, use_program_cache, b, nh, nkv, s, d, dtype, grid_size, q_dtype, start_core, sub_core_grids
):
run_test_sdpa_decode_single_iter(
device,
b,
nh,
nkv,
s,
d,
dtype,
grid_size,
q_dtype,
sharded_in=True,
sharded_out=True,
start_core=start_core,
sub_core_grids=sub_core_grids,
)
run_test_sdpa_decode_single_iter(
device,
b,
nh,
nkv,
s,
d,
dtype,
grid_size,
q_dtype,
sharded_in=True,
sharded_out=True,
start_core=start_core,
sub_core_grids=sub_core_grids,
)
assert device.num_program_cache_entries() == 1


@skip_for_blackhole("Unsupported on BH, see #12349")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping Perf Test in CI")
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace ttnn::operations::transformer {

struct SDPAProgramConfig {
CoreCoord compute_with_storage_grid_size;
std::optional<CoreRangeSet> sub_core_grids;
std::size_t q_chunk_size;
std::size_t k_chunk_size;
std::optional<bool> exp_approx_mode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,20 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
CoreCoord grid_size = program_config.has_value() ? program_config->compute_with_storage_grid_size
: device->compute_with_storage_grid_size();

auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1});
uint32_t num_cores_available = grid_size.x * grid_size.y;

CoreRangeSet core_grid;
bool on_subcoregrid = false;
if (program_config.has_value() && program_config->sub_core_grids.has_value()) {
core_grid = program_config->sub_core_grids.value();
TT_FATAL(
core_grid.num_cores() == num_cores_available,
"Number of cores in sub_core_grids must match the number of cores available");
on_subcoregrid = true;
} else {
core_grid = CoreRangeSet(std::vector{CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1})});
}

uint32_t num_cores_in_grid = device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y;
TT_FATAL(num_cores_available <= num_cores_in_grid, "Expected number of cores available to be less than or equal to the number of cores in the grid, got {} and {}", num_cores_available, num_cores_in_grid);
TT_FATAL(num_cores_available >= B, "Expect number of cores available to be greater or equal to batch size, got {} and {}", num_cores_available, B);
Expand Down Expand Up @@ -154,32 +165,53 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
// h_worker2) head_reducer2 to head_reducerk then send the result to head_reducer1, which is also the batch_output1
std::vector<CoreCoord> core_group;
std::vector<CoreCoord> core_group_idle;
if (is_q_sharded || is_output_sharded) {
int reducer_idx = 0;
int worker_idx = num_output_cores;

for (int i = 0; i < num_cores_available; ++i) {
CoreCoord core;
if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) {
core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x};
reducer_idx++;
} else {
core = {worker_idx % grid_size.x, worker_idx / grid_size.x};
worker_idx++;
}
if (i < num_active_cores) {
core_group.push_back(core);
} else {
core_group_idle.push_back(core);
if (on_subcoregrid) {
if (is_q_sharded || is_output_sharded) {
auto cores_vec = corerange_to_cores(core_grid, num_cores_available, true);
int reducer_idx = 0;
int worker_idx = num_output_cores;
for (int i = 0; i < num_cores_available; ++i) {
if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) {
i < num_active_cores ? core_group.push_back(cores_vec[reducer_idx])
: core_group_idle.push_back(cores_vec[reducer_idx]);
reducer_idx++;
} else {
i < num_active_cores ? core_group.push_back(cores_vec[worker_idx])
: core_group_idle.push_back(cores_vec[worker_idx]);
worker_idx++;
}
}
} else {
TT_FATAL(false, "We only support SDPA on subcoregrids with sharded Q and sharded output");
}
} else {
for (int i = 0; i < num_cores_available; ++i) {
CoreCoord core = {i % grid_size.x, i / grid_size.x};
if (i < num_active_cores) {
core_group.push_back(core);
} else {
core_group_idle.push_back(core);
if (is_q_sharded || is_output_sharded) {
int reducer_idx = 0;
int worker_idx = num_output_cores;

for (int i = 0; i < num_cores_available; ++i) {
CoreCoord core;
if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) {
core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x};
reducer_idx++;
} else {
core = {worker_idx % grid_size.x, worker_idx / grid_size.x};
worker_idx++;
}
if (i < num_active_cores) {
core_group.push_back(core);
} else {
core_group_idle.push_back(core);
}
}
} else {
for (int i = 0; i < num_cores_available; ++i) {
CoreCoord core = {i % grid_size.x, i / grid_size.x};
if (i < num_active_cores) {
core_group.push_back(core);
} else {
core_group_idle.push_back(core);
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ namespace py = pybind11;
void py_module(py::module& module) {
py::class_<SDPAProgramConfig>(module, "SDPAProgramConfig")
.def(
py::init<CoreCoord, std::size_t, std::size_t, std::optional<bool>>(),
py::init<CoreCoord, std::optional<CoreRangeSet>, std::size_t, std::size_t, std::optional<bool>>(),
py::kw_only(),
py::arg("compute_with_storage_grid_size"),
py::arg("sub_core_grids") = std::nullopt,
py::arg("q_chunk_size").noconvert(),
py::arg("k_chunk_size").noconvert(),
py::arg("exp_approx_mode") = std::nullopt)
.def_readwrite("compute_with_storage_grid_size", &SDPAProgramConfig::compute_with_storage_grid_size)
.def_readwrite("sub_core_grids", &SDPAProgramConfig::sub_core_grids)
.def_readwrite("q_chunk_size", &SDPAProgramConfig::q_chunk_size)
.def_readwrite("k_chunk_size", &SDPAProgramConfig::k_chunk_size)
.def_readwrite("exp_approx_mode", &SDPAProgramConfig::exp_approx_mode);
Expand Down

0 comments on commit ee62a86

Please sign in to comment.