diff --git a/tt_eager/tt_dnn/op_library/softmax/multi_core/softmax_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/softmax/multi_core/softmax_op_multi_core.cpp index 295ecb9e6c0..7aebae2f470 100644 --- a/tt_eager/tt_dnn/op_library/softmax/multi_core/softmax_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/softmax/multi_core/softmax_op_multi_core.cpp @@ -455,7 +455,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( tt::DataFormat out0_cb_data_format = tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); tt::DataFormat im_cb_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b; - tt::DataFormat mask_cb_data_format = mask.has_value() ? tt_metal::datatype_to_dataformat_converter(mask.value().get_dtype()) : tt::DataFormat::Float16_b; + tt::DataFormat mask_cb_data_format = mask.has_value() ? tt_metal::datatype_to_dataformat_converter(mask->get_dtype()) : tt::DataFormat::Float16_b; tt::DataFormat scale_cb_data_format = tt::DataFormat::Float16_b; tt::DataFormat scalar_cb_data_format = tt::DataFormat::Float16_b; @@ -480,7 +480,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( uint32_t mask_H = shape[2]; if (mask.has_value()) { - mask_H = mask.value().get_legacy_shape()[2]; + mask_H = mask->get_legacy_shape()[2]; } uint32_t mask_Ht = mask_H/TILE_HEIGHT; // block @@ -552,7 +552,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( // reader compile arg bool is_dram_mask = 0; if (mask.has_value()) { - is_dram_mask = mask.value().buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + is_dram_mask = mask->buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; } std::vector reader_compile_time_args = { (std::uint32_t) block_wt, @@ -560,9 +560,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( }; std::map softmax_defines; // hw_dims_only_causal_mask does not support RM Layout atm - bool use_row_major_kernel = (mask.has_value() and mask.value().get_layout() == Layout::ROW_MAJOR); + bool use_row_major_kernel = (mask.has_value() and mask->get_layout() == Layout::ROW_MAJOR); if (use_row_major_kernel) { - auto mask_stick_size = mask.value().get_legacy_shape()[3] * mask.value().element_size(); + auto mask_stick_size = mask->get_legacy_shape()[3] * mask->element_size(); bool mask_stick_size_is_power_of_two = is_power_of_two_at_least_32(mask_stick_size); reader_compile_time_args.push_back((std::uint32_t) mask_stick_size_is_power_of_two); if (mask_stick_size_is_power_of_two) { @@ -648,8 +648,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( .set_page_size(CB::c_in2, scale_tile_size); cb_in2_id = CreateCircularBuffer(program, all_device_cores, c_in2_config); // in3 attn mask - if (mask.value().is_sharded()) { - auto mask_buffer = mask.value().buffer(); + if (mask->is_sharded()) { + auto mask_buffer = mask->buffer(); auto c_in3_config = CircularBufferConfig(in3_CB_size, {{CB::c_in3, mask_cb_data_format}}) .set_page_size(CB::c_in3, mask_tile_size).set_globally_allocated_address(*mask_buffer); cb_in3_id = CreateCircularBuffer( program, all_device_cores, c_in3_config); @@ -673,7 +673,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( auto cb_intermed1_id = CreateCircularBuffer( program, all_device_cores, c_intermed1_config ); // Runtime Args - uint32_t mask_addr = mask.has_value() ? mask.value().buffer()->address() : 0; + uint32_t mask_addr = mask.has_value() ? mask->buffer()->address() : 0; union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax uint32_t mask_start_tile_id = 0; @@ -712,9 +712,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( num_cores_per_batch_index = 0; if (mask.has_value()) { if (causal_mask) { - mask_start_tile_id += mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; + mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; } else { - mask_start_tile_id += use_row_major_kernel ? mask.value().get_legacy_shape()[-2] : mask.value().get_legacy_shape()[-1] / TILE_WIDTH; + mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH; } } } @@ -748,9 +748,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( num_cores_per_batch_index = 0; if (mask.has_value()) { if (causal_mask) { - mask_start_tile_id += mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; + mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; } else { - mask_start_tile_id += use_row_major_kernel ? mask.value().get_legacy_shape()[-2] : mask.value().get_legacy_shape()[-1] / TILE_WIDTH; + mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH; } } } @@ -763,6 +763,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( reader_kernels_id, cb_in0_id, cb_out0_id, + cb_in3_id, num_cores, grid_size ] @@ -779,12 +780,15 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( UpdateDynamicCircularBufferAddress(program, cb_in0_id, *in0_buffer); UpdateDynamicCircularBufferAddress(program, cb_out0_id, *out_buffer); + if (mask_tensor.has_value() && mask_tensor->is_sharded()) { + UpdateDynamicCircularBufferAddress(program, cb_in3_id.value(), *mask_tensor->buffer()); + } if (mask_tensor.has_value()) { for (uint32_t i = 0; i < num_cores; ++i) { CoreCoord core = {i % grid_size.x, i / grid_size.x}; auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); - runtime_args[2] = mask_tensor.value().buffer()->address(); + runtime_args[2] = mask_tensor->buffer()->address(); } } };