Skip to content

Commit

Permalink
#0: fix corerage
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar committed Jan 7, 2025
1 parent ebbdece commit 66ae1a9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ void RotaryEmbeddingLlamaFusedQK::validate(const std::vector<Tensor>& input_tens
TT_FATAL(
q_batch_size <= 32,
"Q and K must have batch size less than or equal to 32, due to parallelization over core-grid of 64");
uint32_t q_num_cores = q_input_tensor.shard_spec()->grid.bounding_box().grid_size().x *
q_input_tensor.shard_spec()->grid.bounding_box().grid_size().y;
uint32_t k_num_cores = k_input_tensor.shard_spec()->grid.bounding_box().grid_size().x *
k_input_tensor.shard_spec()->grid.bounding_box().grid_size().y;
uint32_t q_num_cores = q_input_tensor.shard_spec()->grid.num_cores();
uint32_t k_num_cores = k_input_tensor.shard_spec()->grid.num_cores();
TT_FATAL(q_num_cores + k_num_cores <= 64, "Q and K must not exceed max core grid size of 64");

bool is_overlap = q_input_tensor.shard_spec()->grid.intersects(k_input_tensor.shard_spec()->grid);
Expand All @@ -84,8 +82,7 @@ void RotaryEmbeddingLlamaFusedQK::validate(const std::vector<Tensor>& input_tens
"sizes");

// Checks for transformation matrix
uint32_t trans_mat_num_cores = trans_mat.shard_spec()->grid.bounding_box().grid_size().x *
trans_mat.shard_spec()->grid.bounding_box().grid_size().y;
uint32_t trans_mat_num_cores = trans_mat.shard_spec()->grid.num_cores();
TT_FATAL(
trans_mat_num_cores >= (q_num_cores + k_num_cores),
"Transformation matrix is repeated for Q and K must be sharded over core grid of Q and K");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,11 @@ operation::ProgramWithCallbacks rotary_embedding_llama_fused_qk_multi_core_shard
auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] =
get_compute_kernel_config_args(device->arch(), compute_kernel_config);

CoreRange q_cores = q_shard_spec->grid.bounding_box();
uint32_t q_num_cores_x = q_cores.grid_size().x;
uint32_t q_num_cores_y = q_cores.grid_size().y;
CoreRangeSet q_cores = q_shard_spec->grid;

CoreRange k_cores = k_shard_spec->grid.bounding_box();
uint32_t k_num_cores_x = k_cores.grid_size().x;
uint32_t k_num_cores_y = k_cores.grid_size().y;
CoreRangeSet k_cores = k_shard_spec->grid;

CoreRange all_cores = cos_sin_shard_spec->grid.bounding_box();
CoreRangeSet all_cores = cos_sin_shard_spec->grid;

const uint32_t num_q_input_tiles = q_n_heads_t * head_dim_t;
const uint32_t num_q_output_tiles = num_q_input_tiles;
Expand Down

0 comments on commit 66ae1a9

Please sign in to comment.