Skip to content

Commit

Permalink
#5592: Add sharded softmax support for attn mask of size (1, 1, seq_l…
Browse files Browse the repository at this point in the history
…en, seq_len). Provide support for bcast_hw op inplace, in0 and output height sharding. Height shard attention sequence from falcon7b prefill.
  • Loading branch information
ppopovic authored and pavlepopovic committed Apr 10, 2024
1 parent c419751 commit 702adca
Show file tree
Hide file tree
Showing 18 changed files with 1,216 additions and 314 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,95 @@ def test_block_sharded_partial_op(
assert passing


@pytest.mark.parametrize("num_cores", [64, 1], ids=["multi_core", "single_core"])
@pytest.mark.parametrize("in0_height_sharded", [True, False], ids=["in0_height_sharded", "in0_dram_interleaved"])
@pytest.mark.parametrize("out_height_sharded", [True, False], ids=["out_height_sharded", "out_dram_interleaved"])
@pytest.mark.parametrize("in_place", [True, False], ids=["in_place", "not_in_place"])
def test_bcast_hw(device, num_cores, in0_height_sharded, out_height_sharded, in_place):
compute_grid_size = device.compute_with_storage_grid_size()
if num_cores > (compute_grid_size.x * compute_grid_size.y):
pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}")

if in0_height_sharded != out_height_sharded:
pytest.skip(f"Currently bcast hw op supports sharding if both inputs and outputs are sharded")

scalar_shape = [1, 1, 32, 32]
in0_shape = [1, 1, num_cores * 32, 128]
height_shard_spec = [32, 128]

torch_scalar = torch.randn(scalar_shape).bfloat16().float()
torch_in0 = torch.randn(in0_shape).bfloat16().float()

dram_interleaved_memory_config = ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
)

height_sharded_memory_config = ttl.tensor.MemoryConfig(
memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttl.tensor.BufferType.L1
)

tt_scalar_dram = torch2tt_tensor(
torch_scalar, device, tt_memory_config=dram_interleaved_memory_config, tt_dtype=ttl.tensor.DataType.BFLOAT16
)

tt_in0_dram = torch2tt_tensor(
torch_in0, device, tt_memory_config=dram_interleaved_memory_config, tt_dtype=ttl.tensor.DataType.BFLOAT16
)

if out_height_sharded:
out_mem_config = height_sharded_memory_config
else:
out_mem_config = dram_interleaved_memory_config

if in0_height_sharded:
tt_in0_height_sharded = ttl.tensor.interleaved_to_sharded(
tt_in0_dram,
device.compute_with_storage_grid_size(),
height_shard_spec,
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
ttl.tensor.ShardOrientation.ROW_MAJOR,
)
tt_out = ttl.operations.primary.bcast(
tt_in0_height_sharded,
tt_scalar_dram,
ttl.tensor.BcastOpMath.MUL,
ttl.tensor.BcastOpDim.HW,
output_mem_config=out_mem_config,
in_place=in_place,
)
tt_in0_height_sharded.deallocate()
else:
tt_out = ttl.operations.primary.bcast(
tt_in0_dram,
tt_scalar_dram,
ttl.tensor.BcastOpMath.MUL,
ttl.tensor.BcastOpDim.HW,
output_mem_config=out_mem_config,
in_place=in_place,
)

if out_height_sharded:
tt_out = ttl.tensor.sharded_to_interleaved(tt_out, output_mem_config=dram_interleaved_memory_config)

# Reference is out and input dram interleaved
tt_out_ref = ttl.operations.primary.bcast(
tt_in0_dram,
tt_scalar_dram,
ttl.tensor.BcastOpMath.MUL,
ttl.tensor.BcastOpDim.HW,
output_mem_config=dram_interleaved_memory_config,
in_place=in_place,
)

tt_out_torch = tt2torch_tensor(tt_out)
tt_ref_torch = tt2torch_tensor(tt_out_ref)

passing, output = comp_pcc(tt_out_torch, tt_ref_torch)
logger.info(output)
assert passing


@pytest.mark.parametrize("H, W, num_cores, num_slices", [[4 * 32, 32 * 32, 64, 2]])
@pytest.mark.parametrize(
"activations_dtype",
Expand Down
42 changes: 38 additions & 4 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_dnn/op_library/bcast/bcast_op.hpp"
#include "common/assert.hpp"
#include "impl/buffers/buffer.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"

#include "tensor/tensor.hpp"
Expand Down Expand Up @@ -89,7 +91,24 @@ void EltwiseBinaryBroadcast::validate(const std::vector<Tensor> &input_tensors)
TT_FATAL(input_tensor_b.get_layout() == Layout::TILE);
TT_FATAL(input_tensor_a.get_dtype() == input_tensor_b.get_dtype());
TT_FATAL(input_tensor_a.get_dtype() == tt::tt_metal::DataType::BFLOAT16 || input_tensor_a.get_dtype() == tt::tt_metal::DataType::BFLOAT8_B, "Unsupported data format");
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED && input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED && this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Bcast does not currently support sharding");
if (this->in_place) {
TT_FATAL(input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout);
TT_FATAL(input_tensor_a.memory_config().buffer_type == this->output_mem_config.buffer_type);
}
if (this->dim != BcastOpDim::HW) {
TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED &&
this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"Bcast does not currently support input0 sharding, except if dim is HW");
} else {
TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED ||
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED,
"HW bcast in0 supports Height Sharding or Interleaving");
TT_FATAL(
input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout,
"Input and output mem layouts must be the same for bcast HW op!");
}

auto batch_size_a = input_shape_a[0];
auto num_channels_a = input_shape_a[1];
Expand Down Expand Up @@ -119,14 +138,28 @@ std::vector<Shape> EltwiseBinaryBroadcast::compute_output_shapes(const std::vect


std::vector<Tensor> EltwiseBinaryBroadcast::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
if (this->in_place) {
return {};
}
const auto& input_tensor = input_tensors.at(0);
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
if (this->output_mem_config.is_sharded()) {
ShardSpec shard_spec{CoreRangeSet({}), {0, 0}};
if (input_tensor.memory_config().is_sharded()) {
// Derive output shard_spec based on input
shard_spec = input_tensor.shard_spec().value();
}
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config)};
} else {
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
}
}

operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
auto& output_tensor = output_tensors.at(0);
const auto& output_tensor = this->in_place ? input_tensor_a : output_tensors.at(0);

auto parallelization_strategy = this->get_parallelization_strategy(input_tensors);

Expand Down Expand Up @@ -154,7 +187,8 @@ const operation::Hash EltwiseBinaryBroadcast::compute_program_hash(
input_tensors.at(0).get_dtype(),
input_tensors.at(1).memory_config(),
input_tensors.at(1).get_dtype(),
bcast_scalar);
bcast_scalar,
this->in_place);
}

BcastOpParallelizationStrategy EltwiseBinaryBroadcast::get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const {
Expand Down
28 changes: 17 additions & 11 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/run_operation.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::tt_metal;

Expand All @@ -25,32 +24,33 @@ enum class BcastOpParallelizationStrategy { MULTI_CORE_H = 0, MULTI_CORE_W = 1,
operation::ProgramWithCallbacks bcast_single_core(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
Tensor &output_tensor,
const Tensor &output_tensor,
BcastOpMath bcast_op,
BcastOpDim bcast_dim);
operation::ProgramWithCallbacks bcast_multi_core_h(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
Tensor &output_tensor,
const Tensor &output_tensor,
BcastOpMath bcast_op,
BcastOpDim bcast_dim);
operation::ProgramWithCallbacks bcast_multi_core_w(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
Tensor &output_tensor,
const Tensor &output_tensor,
BcastOpMath bcast_op,
BcastOpDim bcast_dim);
operation::ProgramWithCallbacks bcast_multi_core_hw(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
Tensor &output_tensor,
const Tensor &output_tensor,
BcastOpMath bcast_op,
BcastOpDim bcast_dim);

struct EltwiseBinaryBroadcast {
const BcastOpMath math_op;
const BcastOpDim dim;
const MemoryConfig output_mem_config;
const bool in_place;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<Shape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
Expand All @@ -59,9 +59,10 @@ struct EltwiseBinaryBroadcast {
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
BcastOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const;

static constexpr auto attribute_names = std::make_tuple("math_op", "dim", "output_mem_config");
static constexpr auto attribute_names =
std::make_tuple("math_op", "dim", "output_mem_config", "in_place");
const auto attribute_values() const {
return std::make_tuple(std::cref(this->math_op), std::cref(this->dim), std::cref(this->output_mem_config));
return std::make_tuple(std::cref(this->math_op), std::cref(this->dim), std::cref(this->output_mem_config), std::cref(this->in_place));
}

const operation::Hash compute_program_hash(const std::vector<Tensor> &input_tensors) const;
Expand Down Expand Up @@ -107,7 +108,7 @@ inline Tensor bcast(
}
}
return operation::run_with_autoformat(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b})
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config, false}, {input_tensor_a, input_tensor_b})
.at(0);
}

Expand All @@ -122,9 +123,14 @@ inline Tensor bcast(
const Tensor &input_tensor_b,
BcastOpMath bcast_op,
BcastOpDim bcast_dim,
const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
return operation::run(EltwiseBinaryBroadcast{bcast_op, bcast_dim, mem_config}, {input_tensor_a, input_tensor_b})
.at(0);
const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
bool in_place = false) {
vector<Tensor> output = operation::run(EltwiseBinaryBroadcast{bcast_op, bcast_dim, mem_config, in_place}, {input_tensor_a, input_tensor_b});
if (in_place) {
return input_tensor_a;
} else {
return output.at(0);
}
}

} // namespace primary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ void kernel_main() {
uint32_t Wt = get_arg_val<uint32_t>(11);
uint32_t nc1 = get_arg_val<uint32_t>(12); // if 1 we expect the bcast tensor to have NC=1 and wrap around in NC

#ifndef IN0_SHARDED
constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
#endif

constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1;

constexpr uint32_t cb_id_in0 = 0;
Expand All @@ -33,14 +36,19 @@ void kernel_main() {
uint32_t l1_write_addr_in1;

uint32_t num_tiles = src0_num_tiles;
uint32_t i = 0;
uint32_t i1 = 0;

#ifndef IN0_SHARDED
uint32_t i = 0;
const InterleavedAddrGenFast<src0_is_dram> s0 = {
.bank_base_address = src0_addr,
.page_size = in0_tile_bytes,
.data_format = in0_data_format
};
#else
cb_reserve_back(cb_id_in0, num_tiles);
cb_push_back(cb_id_in0, num_tiles);
#endif

const InterleavedAddrGenFast<src1_is_dram> s1 = {
.bank_base_address = src1_addr,
Expand All @@ -59,12 +67,14 @@ void kernel_main() {
for (uint32_t nc = 0; nc < NC; nc++) {
for (uint32_t ht = 0; ht < Ht; ht++) {
for (uint32_t wt = 0; wt < Wt; wt++) {
#ifndef IN0_SHARDED
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i, s0, l1_write_addr_in0);
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);
i++; // input tile iterates over NC Ht Wt
#endif

#ifndef BCAST_SCALAR
// for each H,W-tile of the first tensor we push one tile from the second arg tile list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ void kernel_main() {
uint32_t curr_id_from_base = get_arg_val<uint32_t>(5);
uint32_t bcast_id = get_arg_val<uint32_t>(6);


#ifndef IN0_SHARDED
constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
#endif

constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1;

constexpr uint32_t cb_id_in0 = 0;
Expand All @@ -31,11 +33,16 @@ void kernel_main() {
uint32_t l1_write_addr_in0;
uint32_t l1_write_addr_in1;

#ifndef IN0_SHARDED
const InterleavedAddrGenFast<src0_is_dram> s0 = {
.bank_base_address = src0_addr,
.page_size = in0_tile_bytes,
.data_format = in0_data_format
};
#else
cb_reserve_back(cb_id_in0, num_tiles);
cb_push_back(cb_id_in0, num_tiles);
#endif

const InterleavedAddrGenFast<src1_is_dram> s1 = {
.bank_base_address = src1_addr,
Expand All @@ -53,11 +60,15 @@ void kernel_main() {

for (uint32_t i = 0; i < num_tiles; i++) {
uint32_t curr_id = base_start_id_HtWt + curr_id_from_base;

#ifndef IN0_SHARDED
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(curr_id, s0, l1_write_addr_in0);
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);
#endif

curr_id_from_base++;

#ifndef BCAST_SCALAR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace tt {

namespace tt_metal {

operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor &b, Tensor& output, BcastOpMath bcast_math, BcastOpDim bcast_dim) {
operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math, BcastOpDim bcast_dim) {
TT_ASSERT(bcast_dim == BcastOpDim::H);

const auto ashape = a.get_legacy_shape();
Expand Down
Loading

0 comments on commit 702adca

Please sign in to comment.