Skip to content

Commit

Permalink
#0: Revert bcast changes for breaking post-commit main (#9508)
Browse files Browse the repository at this point in the history
* #0: Revert "#0: Fix ttl.add to tnn.add conversion for falcon40b"

This reverts commit 0d24864.

* #0: Revert "#9472: Optimize sharded bcast op"

This reverts commit 2c7f442.
  • Loading branch information
tt-rkim authored Jun 18, 2024
1 parent 0d24864 commit 3fe8e14
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,13 @@ def test_falcon7b_attnention_sliced(
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)

mm_slice = ttnn.add(
mm_slice = ttnn.experimental.operations.primary.add(
mm_slice,
attn_mask_slice,
fused_activations=None,
memory_config=height_sharded_memory_config,
output_mem_config=height_sharded_memory_config,
output_dtype=ttnn.experimental.tensor.DataType.BFLOAT16,
output_tensor=mm_slice,
in_place=True,
)

attn_mask_slice.deallocate()
Expand Down
24 changes: 12 additions & 12 deletions models/demos/t3000/falcon40b/tt/falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ def fwd_prefill(
# Note that this is only correct in inference when dropout is disabled
for i in range(len(residual)):
output.append(
ttnn.add(
ttnn.experimental.operations.primary.add(
residual[i],
attention_output[i],
memory_config=self.model_config["PARALLEL_ATTN_ADD_OUTPUT_MEMCFG"],
output_tensor=residual[i],
output_mem_config=self.model_config["PARALLEL_ATTN_ADD_OUTPUT_MEMCFG"],
in_place=True,
)
)
attention_output[i].deallocate(True)
Expand All @@ -320,11 +320,11 @@ def fwd_prefill(
# dropout_add
# For inference, this is just add
for i in range(len(output)):
output[i] = ttnn.add(
output[i] = ttnn.experimental.operations.primary.add(
output[i],
mlp_output[i],
memory_config=self.model_config["DROPOUT_ADD_OUTPUT_MEMCFG"],
output_tensor=output[i],
output_mem_config=self.model_config["DROPOUT_ADD_OUTPUT_MEMCFG"],
in_place=True,
)

mlp_output[i].deallocate(True)
Expand Down Expand Up @@ -421,11 +421,11 @@ def fwd_decode(
# Note that this is only correct in inference when dropout is disabled
for i in range(len(residual)):
output.append(
ttnn.add(
ttnn.experimental.operations.primary.add(
residual[i],
attention_output[i],
memory_config=self.model_config["PARALLEL_ATTN_ADD_OUTPUT_MEMCFG"],
output_tensor=residual[i],
output_mem_config=self.model_config["PARALLEL_ATTN_ADD_OUTPUT_MEMCFG"],
in_place=True,
)
)
attention_output[i].deallocate(True)
Expand All @@ -437,11 +437,11 @@ def fwd_decode(
# dropout_add
# For inference, this is just add
for i in range(len(output)):
output[i] = ttnn.add(
output[i] = ttnn.experimental.operations.primary.add(
output[i],
mlp_output[i],
memory_config=self.model_config["DROPOUT_ADD_OUTPUT_MEMCFG"],
output_tensor=output[i],
output_mem_config=self.model_config["DROPOUT_ADD_OUTPUT_MEMCFG"],
in_place=True,
)

mlp_output[i].deallocate(True)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
#include "dataflow_api.h"

void kernel_main() {
uint32_t src1_addr = get_arg_val<uint32_t>(0);
uint32_t Ht = get_arg_val<uint32_t>(1);
uint32_t Wt = get_arg_val<uint32_t>(2);
uint32_t offset = get_arg_val<uint32_t>(3);
uint32_t NC = get_arg_val<uint32_t>(4);
uint32_t batch_offset = get_arg_val<uint32_t>(5); //if weight has multiple batches
uint32_t w_blk = get_arg_val<uint32_t>(6);
uint32_t src1_addr = get_arg_val<uint32_t>(0);
uint32_t Ht = get_arg_val<uint32_t>(1);
uint32_t Wt = get_arg_val<uint32_t>(2);
uint32_t offset = get_arg_val<uint32_t>(3);
uint32_t NC = get_arg_val<uint32_t>(4);
uint32_t batch_offset= get_arg_val<uint32_t>(5); //if weight has multiple batches

//constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1;
Expand All @@ -36,17 +35,26 @@ void kernel_main() {
uint32_t l1_write_addr_in0;
uint32_t l1_write_addr_in1;

// TODO: do we really need in1 NC != 1 support?! if not supported here need to validate in1 correctly!

uint32_t i = 0;
cb_push_back(cb_id_in0, Ht * Wt);
for (uint32_t wt = 0; wt < Wt; wt += w_blk) {
cb_reserve_back(cb_id_in1, w_blk);
l1_write_addr_in1 = get_write_ptr(cb_id_in1);
for (uint32_t r = 0; r<w_blk; r++) {
noc_async_read_tile(offset + wt + r, s1, l1_write_addr_in1);
l1_write_addr_in1 += tile_bytes;
for (uint32_t ht = 0; ht < Ht; ht++) {
for (uint32_t wt = 0; wt < Wt; wt++) {
// for each W-tile of the first tensor we push one tile from the second arg tile list
// but we loop the second list around
cb_reserve_back(cb_id_in1, onetile);
l1_write_addr_in1 = get_write_ptr(cb_id_in1);
noc_async_read_tile(offset, s1, l1_write_addr_in1);
noc_async_read_barrier();
cb_push_back(cb_id_in1, onetile);
offset ++;
}


// bcast tensor should be NC1W (actually NC32W padded with 0s in H)
// wrap W around for each h (broadcast)
offset -= Wt;
if(ht % NC == (NC -1)){
offset += batch_offset; //switching to next batch
}
}
noc_async_read_barrier();
cb_push_back(cb_id_in1, w_blk);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor
std::map<std::string, std::string> bcast_defines = bcast_op_utils::get_defines(BcastOpDim::H, bcast_math);
auto bcast_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h_interleaved.cpp",
"tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h.cpp",
all_device_cores,
tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
.set_globally_allocated_address(*output.buffer());
auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);

uint32_t h_blk = std::min(Ht, 8u);
uint32_t w_blk = std::min(Wt, 8u);

uint32_t num_input_tiles = w_blk;
uint32_t num_input_tiles = (b.get_legacy_shape()[-1] * output.element_size() + TILE_HW - 1)/ TILE_HW;
uint32_t src1_cb_index = CB::c_in1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * aligned_input_tile_nbytes, {{src1_cb_index, act_df}})
.set_page_size(src1_cb_index, aligned_input_tile_nbytes);
Expand All @@ -110,13 +107,13 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
//const char* compute_name = bcast_op_utils::get_compute_name(BcastOpDim::H));
auto bcast_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h_sharded.cpp",
"tt_eager/tt_dnn/op_library/bcast/kernels/compute/bcast_h.cpp",
all_cores,
tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines}
);

uint32_t ncores_y = ncores / ncores_x;
log_debug("ncores {}, ncores_x {}, Wt {}, Ht {}, h_blk {}, w_blk {}, src0_cb_index {}, src1_cb_index {}, output_cb_index {}, src1_is_dram {}, dst_is_dram {}", ncores, ncores_x, Wt, Ht, h_blk, w_blk, src0_cb_index, src1_cb_index, output_cb_index, src1_is_dram, dst_is_dram);
log_debug("ncores {}, ncores_x {}, Wt {}, Ht {}, src0_cb_index {}, src1_cb_index {}, output_cb_index {}, src1_is_dram {}, dst_is_dram {}", ncores, ncores_x, Wt, Ht, src0_cb_index, src1_cb_index, output_cb_index, src1_is_dram, dst_is_dram);
for (uint32_t i = 0; i < ncores; i++){
CoreCoord core;
uint32_t offset = 0;
Expand Down Expand Up @@ -152,8 +149,7 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
Wt, // 2
offset, // 3
Ht_per_core, // 4
tile_offset, // 5
w_blk, // 6
tile_offset, //5
}
);

Expand All @@ -165,7 +161,6 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
NC, // B
Ht, // Hbatch for block shardeshardedt
Wt, // Wt
h_blk, // h block size
}
);
}
Expand Down Expand Up @@ -226,9 +221,6 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
}
uint32_t tile_offset = Wt * ncores;

uint32_t h_blk = std::min(Ht, 8u);
uint32_t w_blk = std::min(Wt, 8u);

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
Expand All @@ -240,7 +232,6 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
offset, // 3
Ht_per_core, // 4
tile_offset, //5
w_blk, // 6
}
);

Expand All @@ -252,7 +243,6 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
NC, // B
Ht, // Ht
Wt, // Wt
h_blk, // h block size
}
);
}
Expand Down

0 comments on commit 3fe8e14

Please sign in to comment.