From 20f82a746cb0306b51cab7697d2a3e7009d56838 Mon Sep 17 00:00:00 2001 From: yugaoT Date: Wed, 26 Jun 2024 21:21:02 +0000 Subject: [PATCH] #0: add pybinds for softmax apis, change python tests to ttnn --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 6 +- .../bert/tt/ttnn_optimized_sharded_bert.py | 2 +- ...n_matmuls_and_bmms_with_mixed_precision.py | 14 +- models/demos/falcon7b/tt/falcon_attention.py | 14 +- models/demos/falcon7b/tt/model_config.py | 2 +- models/demos/metal_BERT_large_11/tt/mha.py | 6 +- .../metal_BERT_large_11/tt/model_config.py | 4 +- .../t3000/falcon40b/tt/falcon_attention.py | 2 +- .../demos/t3000/falcon40b/tt/model_config.py | 8 +- .../t3000/falcon40b/tt/ops/falcon_softmax.py | 8 +- .../llama2_70b/scripts/model_config_n150.py | 4 +- .../tests/unit_tests/test_attn_sdpa.py | 4 +- .../tt/llama_attention_optimized.py | 2 +- .../demos/t3000/llama2_70b/tt/model_config.py | 4 +- .../t3000/mixtral8x7b/tt/mixtral_attention.py | 2 +- .../t3000/mixtral8x7b/tt/model_config.py | 2 +- .../tt2/ttnn_functional_cross_attention.py | 10 +- .../test_bert_large_fused_softmax.py | 4 +- .../tt/ttnn_optimized_sharded_vit.py | 2 +- .../tt/ttnn_optimized_sharded_vit_backup.py | 2 +- .../experimental/llama/tt/llama_attention.py | 2 +- .../llama2_70b/scripts/model_config_n150.py | 4 +- .../tests/unit_tests/test_attn_sdpa.py | 4 +- .../tt/llama_attention_optimized.py | 2 +- .../llama2_70b/tt/model_config.py | 8 +- .../mistral/tt/mistral_attention.py | 2 +- .../nanogpt/tt/nanogpt_attention.py | 4 +- models/experimental/t5/tt/t5_attention.py | 2 +- tests/tt_eager/profiling/ops_for_profiling.py | 8 +- .../pytests/tt_dnn/test_softmax_sharded.py | 8 +- .../sweep_tests/tt_lib_ops.py | 4 +- .../misc/test_single_core_fused_ops.py | 4 +- .../unit_testing/misc/test_softmax.py | 144 ++-- .../unit_testing/misc/test_softmax_sharded.py | 218 ++--- .../test_sharded_attention.py | 30 +- ttnn/CMakeLists.txt | 2 + ttnn/cpp/pybind11/operations/__init__.hpp | 2 +- .../cpp/pybind11/operations/normalization.hpp | 101 --- ttnn/cpp/ttnn/operations/normalization.hpp | 246 ------ .../normalization/groupnorm/groupnorm.hpp | 92 ++ .../groupnorm/groupnorm_pybind.hpp | 42 + .../normalization/layernorm/layernorm.hpp | 130 +++ .../layernorm/layernorm_pybind.hpp | 52 ++ .../normalization/normalization_pybind.hpp | 31 + .../device/kernels/compute/softmax.cpp | 223 +++++ .../kernels/compute/softmax_sharded.cpp | 177 ++++ ...d_unary_sharded_sm_causal_mask_hw_dims.cpp | 62 ++ .../dataflow/reader_unary_interleaved_sm.cpp | 142 ++++ .../dataflow/reader_unary_sharded_sm.cpp | 88 ++ .../reader_unary_sharded_sm_rm_mask.cpp | 65 ++ ..._unary_interleaved_start_id_blocked_sm.cpp | 89 ++ .../multi_core/softmax_op_multi_core.cpp | 796 ++++++++++++++++++ .../softmax/device/softmax_op.cpp | 219 +++++ .../softmax/device/softmax_op.hpp | 131 +++ .../normalization/softmax/softmax.hpp | 167 ++++ .../normalization/softmax/softmax_pybind.hpp | 247 ++++++ ttnn/ttnn/__init__.py | 7 + ttnn/ttnn/experimental/golden_functions.py | 2 +- ttnn/ttnn/operations/normalization.py | 21 + 59 files changed, 3017 insertions(+), 663 deletions(-) delete mode 100644 ttnn/cpp/pybind11/operations/normalization.hpp delete mode 100644 ttnn/cpp/ttnn/operations/normalization.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/normalization_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/readed_unary_sharded_sm_causal_mask_hw_dims.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_interleaved_sm.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_sharded_sm.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_sharded_sm_rm_mask.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/writer_unary_interleaved_start_id_blocked_sm.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp create mode 100644 ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.hpp diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 1891974ed9d2..3b705b3c89a7 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -255,8 +255,6 @@ autofunction:: tt_lib.operations.primary.matmul .. autofunction:: tt_lib.operations.primary.add_layernorm -.. autofunction:: tt_lib.operations.primary.softmax_in_place - .. autofunction:: tt_lib.operations.primary.moreh_softmax .. autofunction:: tt_lib.operations.primary.moreh_softmax_backward @@ -269,8 +267,6 @@ autofunction:: tt_lib.operations.primary.matmul .. autofunction:: tt_lib.operations.primary.moreh_logsoftmax_backward -.. autofunction:: tt_lib.operations.primary.transformers.scale_mask_softmax_in_place - .. autofunction:: tt_lib.operations.primary.moreh_mean .. autofunction:: tt_lib.operations.primary.moreh_mean_backward @@ -414,7 +410,7 @@ Tensor elementwise operations .. autofunction:: tt_lib.tensor.unary_remainder .. autofunction:: tt_lib.tensor.remainder - + .. autofunction:: tt_lib.tensor.unary_fmod .. autofunction:: tt_lib.tensor.fmod diff --git a/models/demos/bert/tt/ttnn_optimized_sharded_bert.py b/models/demos/bert/tt/ttnn_optimized_sharded_bert.py index 560deab00439..3ed897baa567 100644 --- a/models/demos/bert/tt/ttnn_optimized_sharded_bert.py +++ b/models/demos/bert/tt/ttnn_optimized_sharded_bert.py @@ -83,7 +83,7 @@ def update_model_config(config, batch_size): block_w=4, inplace=True, ), - "softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + "softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(core_grid.x, core_grid.y), subblock_w=6, block_h=24, diff --git a/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py b/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py index 0e36ca763d8e..50d6f963f64b 100644 --- a/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py +++ b/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py @@ -402,7 +402,7 @@ def test_falcon7b_attnention_sliced( subblock_w = 1 if seq_len == 2048: subblock_w = 8 - softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=subblock_w, block_h=mm_output_height_shard_spec[0] // 32, @@ -669,14 +669,14 @@ def test_falcon7b_attention_softmax_sequence( subblock_w = 1 if seq_len == 2048: subblock_w = 8 - softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=subblock_w, block_h=mm_output_height_shard_spec[0] // 32, block_w=mm_output_height_shard_spec[1] // 32, ) - mm_slice = ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place( + mm_slice = ttnn.scale_causal_mask_hw_dims_softmax_in_place( mm_slice, scalar_value, attention_masks_per_slice[i], @@ -725,11 +725,11 @@ def test_falcon7b_attention_softmax_sequence( reference_query_layer, reference_key_layer_transposed, memory_config=dram_interleaved_memory_config ) - attn_weights = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + attn_weights = ttnn.scale_mask_softmax_in_place( attn_weights, scalar_value, attention_mask_proper_dim, - program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(), + program_config=ttnn.SoftmaxDefaultProgramConfig(), is_causal_mask=True, compute_kernel_config=compute_kernel_config, ) @@ -876,14 +876,14 @@ def test_softmax(device, num_cores, seq_len): ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, ) - softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=1, block_h=height_shard_spec[0] // 32, block_w=height_shard_spec[1] // 32, ) - input_slice = ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place( + input_slice = ttnn.scale_causal_mask_hw_dims_softmax_in_place( input_slice, scalar_value, tt_attention_masks_per_slice[i], diff --git a/models/demos/falcon7b/tt/falcon_attention.py b/models/demos/falcon7b/tt/falcon_attention.py index eb5566e7e3f7..fff5c6e378e8 100644 --- a/models/demos/falcon7b/tt/falcon_attention.py +++ b/models/demos/falcon7b/tt/falcon_attention.py @@ -322,9 +322,7 @@ def forward( ### SOFTMAX ### ############### for i in range(self.num_devices): - attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( - attn_weights[i] - ) + attn_weights[i] = ttnn.scale_mask_softmax_in_place(attn_weights[i]) ###################### ### V CACHE UPDATE ### @@ -499,7 +497,7 @@ def _optimized_forward( ) ### SOFTMAX ### mm_slices = [ - ttnn.experimental.operations.primary.transformers.scale_causal_mask_hw_dims_softmax_in_place( + ttnn.scale_causal_mask_hw_dims_softmax_in_place( mm_slices[device_id], self.scalar_for_optimized_prefill, attention_mask[device_id][i], @@ -842,19 +840,17 @@ def forward( ### SOFTMAX ### ############### for i in range(self.num_devices): - attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( - attn_weights[i] - ) + attn_weights[i] = ttnn.scale_mask_softmax_in_place(attn_weights[i]) else: ############### ### SOFTMAX ### ############### for i in range(self.num_devices): - attn_weights[i] = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + attn_weights[i] = ttnn.scale_mask_softmax_in_place( attn_weights[i], scale=self.scale, mask=attention_mask[i], - program_config=ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + program_config=ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4), subblock_w=1, block_h=self.padded_local_heads // 32, diff --git a/models/demos/falcon7b/tt/model_config.py b/models/demos/falcon7b/tt/model_config.py index 46ec97f9f6f8..199540a30b4d 100644 --- a/models/demos/falcon7b/tt/model_config.py +++ b/models/demos/falcon7b/tt/model_config.py @@ -403,7 +403,7 @@ def set_prefill_config(model_config, seq_len, dram_memcfg): model_config[ "SOFTMAX_OPTIMIZED_PROGCFG" - ] = lambda grid_size, subblock_w, block_h, block_w: ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + ] = lambda grid_size, subblock_w, block_h, block_w: ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=subblock_w, block_h=block_h, diff --git a/models/demos/metal_BERT_large_11/tt/mha.py b/models/demos/metal_BERT_large_11/tt/mha.py index 57bd55a2ddd0..f8d218f06ab8 100644 --- a/models/demos/metal_BERT_large_11/tt/mha.py +++ b/models/demos/metal_BERT_large_11/tt/mha.py @@ -84,9 +84,7 @@ def op3_bmm(Q_heads, K_T_heads): ) return qkt - softmax_program_config = model_config.get( - "OP4_SOFTMAX_CONFIG", tt_lib.operations.primary.transformers.SoftmaxDefaultProgramConfig() - ) + softmax_program_config = model_config.get("OP4_SOFTMAX_CONFIG", ttnn.SoftmaxDefaultProgramConfig()) def op4_scale_mask_softmax(qkt, attention_mask): # Attention scores computation @@ -95,7 +93,7 @@ def op4_scale_mask_softmax(qkt, attention_mask): # No-op reshapes are handled within pre-softmax (op 7) and post-softmax bmms (op 9) shape = qkt.get_legacy_shape() qkt = qkt.reshape(shape[0], 1, shape[1] * shape[2], shape[3]) - attention_scores = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + attention_scores = ttnn.scale_mask_softmax_in_place( qkt, freciprocal_of_sqrt_hidden_dim, attention_mask, program_config=softmax_program_config ) attention_scores = attention_scores.reshape(shape) diff --git a/models/demos/metal_BERT_large_11/tt/model_config.py b/models/demos/metal_BERT_large_11/tt/model_config.py index 569b71e519c6..d739ac9372b9 100644 --- a/models/demos/metal_BERT_large_11/tt/model_config.py +++ b/models/demos/metal_BERT_large_11/tt/model_config.py @@ -244,7 +244,7 @@ def get_model_config(batch, device_grid_size, model_config_str): transpose_mcast=False, fused_activation=None, ), - "OP4_SOFTMAX_CONFIG": tt_lib.operations.primary.transformers.SoftmaxDefaultProgramConfig(), + "OP4_SOFTMAX_CONFIG": ttnn.SoftmaxDefaultProgramConfig(), "OP8_LAYERNORM_CONFIG": tt_lib.operations.primary.LayerNormDefaultProgramConfig(), "OP11_LAYERNORM_CONFIG": tt_lib.operations.primary.LayerNormDefaultProgramConfig(), } @@ -397,7 +397,7 @@ def get_model_config(batch, device_grid_size, model_config_str): block_w=4, inplace=True, ), - "OP4_SOFTMAX_CONFIG": tt_lib.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + "OP4_SOFTMAX_CONFIG": ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=6, block_h=24, diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index 426674f999cb..05657519698b 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -635,7 +635,7 @@ def fwd_decode( softmax_progcfg = self.model_config["SOFTMAX_PROGCFG"] softmax_progcfg.block_w = padded_layer_past_len // 32 - attn_weights = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + attn_weights = ttnn.scale_mask_softmax_in_place( attn_weights, self.scalar, attention_mask, diff --git a/models/demos/t3000/falcon40b/tt/model_config.py b/models/demos/t3000/falcon40b/tt/model_config.py index 8967a5d2cad2..46dd6e1b6fc8 100644 --- a/models/demos/t3000/falcon40b/tt/model_config.py +++ b/models/demos/t3000/falcon40b/tt/model_config.py @@ -526,9 +526,7 @@ def get_decode_model_config(model_config_str, input_shape, num_devices): ) model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG - model_config[ - "SOFTMAX_PROGCFG" - ] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 2), subblock_w=1, block_h=row_height // 32, @@ -835,9 +833,7 @@ def get_prefill_model_config(model_config_str, input_shape, num_devices): fused_activation=None, mcast_in0=False, ) - model_config[ - "SOFTMAX_PROGCFG" - ] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=attention_mm_grid_size, subblock_w=1, block_h=attetnion_mm_M, diff --git a/models/demos/t3000/falcon40b/tt/ops/falcon_softmax.py b/models/demos/t3000/falcon40b/tt/ops/falcon_softmax.py index 8506013fa48e..1201425293b1 100644 --- a/models/demos/t3000/falcon40b/tt/ops/falcon_softmax.py +++ b/models/demos/t3000/falcon40b/tt/ops/falcon_softmax.py @@ -28,24 +28,24 @@ def __call__( if self.is_sharded: # Subtract max value from activation before softmax - out = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + out = ttnn.scale_mask_softmax_in_place( x, self.scalar, attention_mask, program_config=softmax_progcfg, # output_mem_config=self.model_config["DEFAULT_MEMCFG"], - # program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(), + # program_config=ttnn.SoftmaxDefaultProgramConfig(), is_causal_mask=True, ) else: # Subtract max value from activation before softmax - out = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + out = ttnn.scale_mask_softmax_in_place( x, self.scalar, attention_mask, # program_config=softmax_progcfg, # output_mem_config=self.model_config["DEFAULT_MEMCFG"], - program_config=ttnn.experimental.operations.primary.transformers.SoftmaxDefaultProgramConfig(), + program_config=ttnn.SoftmaxDefaultProgramConfig(), is_causal_mask=True, ) diff --git a/models/demos/t3000/llama2_70b/scripts/model_config_n150.py b/models/demos/t3000/llama2_70b/scripts/model_config_n150.py index aa731a234a92..caaad7875815 100644 --- a/models/demos/t3000/llama2_70b/scripts/model_config_n150.py +++ b/models/demos/t3000/llama2_70b/scripts/model_config_n150.py @@ -680,14 +680,14 @@ def get_model_config(model_config_str, num_devices=1, all_gather=True): model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG if num_devices == 4: - model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 2), subblock_w=1, block_h=1, block_w=1, # Dynamic ) elif num_devices == 8: - model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 1), subblock_w=1, block_h=1, diff --git a/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py b/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py index 67891d71cbdc..e8d8810f5c87 100644 --- a/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py +++ b/models/demos/t3000/llama2_70b/tests/unit_tests/test_attn_sdpa.py @@ -74,7 +74,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask): # attn_mask = torch2tt_tensor(attn_mask, self.device) attn = ttnn.add(attn, attn_mask) - attn = tt_lib.tensor.softmax(attn) + attn = ttnn.softmax(attn) return attn def forward(self, xq, keys, values, attn_mask): @@ -95,7 +95,7 @@ def forward(self, xq, keys, values, attn_mask): # TODO: This op expects attn_mask to be sharded such that each core has 1 head # This is illegal on single chip since we need 8x8 coregrid to shard # 64 heads on. Until we fracture on multi-chip, we can't use this op. - # attn = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + # attn = ttnn.scale_mask_softmax_in_place( # attn, # 1 / math.sqrt(self.head_dim), # attn_mask, diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index cc4c51bcd429..a12f1c1895e3 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -311,7 +311,7 @@ def attn_mqa( softmax_progcfg = self.model_config["BATCHED_SOFTMAX_PROGCFG"] softmax_progcfg.block_w = padded_layer_past_len // 32 - attn_weights = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + attn_weights = ttnn.scale_mask_softmax_in_place( attn_weights, self.scale, attn_masks, diff --git a/models/demos/t3000/llama2_70b/tt/model_config.py b/models/demos/t3000/llama2_70b/tt/model_config.py index dae1027a9c23..b6afba584663 100644 --- a/models/demos/t3000/llama2_70b/tt/model_config.py +++ b/models/demos/t3000/llama2_70b/tt/model_config.py @@ -542,7 +542,7 @@ def get_model_config( if llm_mode == "decode": model_config[ "BATCHED_SOFTMAX_PROGCFG" - ] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + ] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4), # In-place softmax on 32 cores sharded on batch dim subblock_w=1, block_h=shard_height // 32, @@ -551,7 +551,7 @@ def get_model_config( else: model_config[ "BATCHED_SOFTMAX_PROGCFG" - ] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + ] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4 if seq_len == 128 else 8), subblock_w=1, block_h=32 // 32, # 128 * 8 // 32 cores // TILE_SIZE diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 87b768068b06..de93b2abb12b 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -270,7 +270,7 @@ def forward( # Softmax and scaling - attn_1B4P = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + attn_1B4P = ttnn.scale_mask_softmax_in_place( attn_1B4P, self.scale, attn_mask_1B4P, diff --git a/models/demos/t3000/mixtral8x7b/tt/model_config.py b/models/demos/t3000/mixtral8x7b/tt/model_config.py index dd660e49636a..28089b4de294 100644 --- a/models/demos/t3000/mixtral8x7b/tt/model_config.py +++ b/models/demos/t3000/mixtral8x7b/tt/model_config.py @@ -180,7 +180,7 @@ def __init__(self, device=None, instruct=False, dummy_weights=False): ) self.model_config["ATTN_BATCHED_SOFTMAX_PROGCFG"] = cached_lambda( - lambda padded_layer_past_len: ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + lambda padded_layer_past_len: ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4), # In-place softmax on 32 cores sharded on batch dim subblock_w=1, block_h=1, # Shard_height // 32, diff --git a/models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_cross_attention.py b/models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_cross_attention.py index 1ff72cccb302..836ecf9a117d 100644 --- a/models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_cross_attention.py +++ b/models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_cross_attention.py @@ -324,9 +324,7 @@ def __init__(self, device, parameters, seq_len): mcast_in0=False, ) - self.program_configs[ - "tsa_softmax" - ] = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + self.program_configs["tsa_softmax"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=self.tsa_grid_size, subblock_w=1, block_h=mm_output_height_shard_spec[0] // 32, @@ -393,7 +391,7 @@ def time_sharded_attention(self, query, t_key, value, head_size, attn_type): use_mask = False if use_mask: - mm_slice = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + mm_slice = ttnn.scale_mask_softmax_in_place( mm_slice, 1 / math.sqrt(head_size), attention_mask, @@ -523,7 +521,7 @@ def sharded_attention(self, query, key, value, head_size, attn_type): output_mem_config, ) # attention_scores = ttnn.experimental.tensor.move_sharded(attention_scores) - softmax_program_config = ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=compute_with_storage_grid_size, subblock_w=1, block_h=height_per_core // 32, @@ -531,7 +529,7 @@ def sharded_attention(self, query, key, value, head_size, attn_type): ) use_mask = attn_type == "cross" if use_mask: - attention_scores = ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place( + attention_scores = ttnn.scale_mask_softmax_in_place( attention_scores, 1 / math.sqrt(head_size), attention_mask, diff --git a/models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_softmax.py b/models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_softmax.py index 3b9f46a8b9a8..f668416a87d4 100644 --- a/models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_softmax.py +++ b/models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_softmax.py @@ -89,11 +89,11 @@ def run_softmax_tests(dev, test_id, batch, dtype, in0_mem_config): logger.info("Running scale_mask_softmax") torch_scale, tt_scale = generate_recip_tensor(dev, 0.5 + random.random()) torch_attn_mask, tt_attn_mask = generate_attn_mask(N, C, W, dev, -4.2 * 1, dtype, in0_mem_config) - t1_fused = ttl.operations.primary.transformers.scale_mask_softmax_in_place(t0, tt_scale, tt_attn_mask) + t1_fused = ttnn.scale_mask_softmax_in_place(t0, tt_scale, tt_attn_mask) ref_sm = ref_scale_mask_softmax(torch_scale, torch_attn_mask, x) elif test_id == 1: logger.info("Running softmax") - t1_fused = ttl.operations.primary.softmax_in_place(t0) + t1_fused = ttnn.softmax_in_place(t0) ref_sm = ref_stable_softmax(x) else: assert False diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py index 05eaafe4c21c..f6090f15e63a 100644 --- a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py +++ b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py @@ -135,7 +135,7 @@ def update_model_config(config, batch_size): # out_data_format=ttnn.bfloat8_b, inplace=False, ), - "softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + "softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(core_grid.x, core_grid.y), subblock_w=7, block_h=7, diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py index aaeb459ce7b7..2168facf1505 100644 --- a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py +++ b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py @@ -135,7 +135,7 @@ def update_model_config(config, batch_size): # out_data_format=ttnn.bfloat8_b, inplace=False, ), - "softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + "softmax_program_config": ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(core_grid.x, core_grid.y), subblock_w=7, block_h=7, diff --git a/models/experimental/llama/tt/llama_attention.py b/models/experimental/llama/tt/llama_attention.py index 7ff85829dbb7..2b16cfbe8161 100644 --- a/models/experimental/llama/tt/llama_attention.py +++ b/models/experimental/llama/tt/llama_attention.py @@ -274,7 +274,7 @@ def forward( attn_weights = pad_by_zero(attn_weights, self.device)[0] value_states = torch_to_tt_tensor_rm(value_states, self.device) - attn_weights = tt_lib.operations.primary.softmax_in_place(attn_weights) + attn_weights = ttnn.softmax_in_place(attn_weights) attn_output = tt_lib.tensor.bmm(attn_weights, value_states) if attn_output.get_legacy_shape() != [bsz, self.num_heads, q_len, self.head_dim]: diff --git a/models/experimental/llama2_70b/scripts/model_config_n150.py b/models/experimental/llama2_70b/scripts/model_config_n150.py index aa731a234a92..caaad7875815 100644 --- a/models/experimental/llama2_70b/scripts/model_config_n150.py +++ b/models/experimental/llama2_70b/scripts/model_config_n150.py @@ -680,14 +680,14 @@ def get_model_config(model_config_str, num_devices=1, all_gather=True): model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = HEIGHT_SHARDED_MEMCFG if num_devices == 4: - model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 2), subblock_w=1, block_h=1, block_w=1, # Dynamic ) elif num_devices == 8: - model_config["SOFTMAX_PROGCFG"] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 1), subblock_w=1, block_h=1, diff --git a/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py b/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py index 669d8e46258b..2022eb991a5d 100644 --- a/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py +++ b/models/experimental/llama2_70b/tests/unit_tests/test_attn_sdpa.py @@ -74,7 +74,7 @@ def scale_mask_softmax_decomposed(self, attn, scale, attn_mask): # attn_mask = torch2tt_tensor(attn_mask, self.device) attn = ttnn.add(attn, attn_mask) - attn = tt_lib.tensor.softmax(attn) + attn = ttnn.softmax(attn) return attn def forward(self, xq, keys, values, attn_mask): @@ -95,7 +95,7 @@ def forward(self, xq, keys, values, attn_mask): # TODO: This op expects attn_mask to be sharded such that each core has 1 head # This is illegal on single chip since we need 8x8 coregrid to shard # 64 heads on. Until we fracture on multi-chip, we can't use this op. - # attn = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + # attn = ttnn.scale_mask_softmax_in_place( # attn, # 1 / math.sqrt(self.head_dim), # attn_mask, diff --git a/models/experimental/llama2_70b/tt/llama_attention_optimized.py b/models/experimental/llama2_70b/tt/llama_attention_optimized.py index 610ec37e4871..e2154f22acce 100644 --- a/models/experimental/llama2_70b/tt/llama_attention_optimized.py +++ b/models/experimental/llama2_70b/tt/llama_attention_optimized.py @@ -311,7 +311,7 @@ def attn_mqa( softmax_progcfg = self.model_config["BATCHED_SOFTMAX_PROGCFG"] softmax_progcfg.block_w = padded_layer_past_len // 32 - attn_weights = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + attn_weights = ttnn.scale_mask_softmax_in_place( attn_weights, self.scale, attn_masks, diff --git a/models/experimental/llama2_70b/tt/model_config.py b/models/experimental/llama2_70b/tt/model_config.py index d460bf5dc765..d54cddd8cf3f 100644 --- a/models/experimental/llama2_70b/tt/model_config.py +++ b/models/experimental/llama2_70b/tt/model_config.py @@ -542,18 +542,14 @@ def get_model_config( k_chunk_size=k_chunk_size, ) if llm_mode == "decode": - model_config[ - "BATCHED_SOFTMAX_PROGCFG" - ] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["BATCHED_SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4), # In-place softmax on 32 cores sharded on batch dim subblock_w=1, block_h=shard_height // 32, block_w=1, # Dynamic ) else: - model_config[ - "BATCHED_SOFTMAX_PROGCFG" - ] = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + model_config["BATCHED_SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 4 if seq_len == 128 else 8), subblock_w=1, block_h=32 // 32, # 128 * 8 // 32 cores // TILE_SIZE diff --git a/models/experimental/mistral/tt/mistral_attention.py b/models/experimental/mistral/tt/mistral_attention.py index b26b8428f1dc..ad0e025f6b27 100644 --- a/models/experimental/mistral/tt/mistral_attention.py +++ b/models/experimental/mistral/tt/mistral_attention.py @@ -219,7 +219,7 @@ def forward( if self.args.FALLBACK_SOFTMAX: scores = fallback_ops.softmax(scores, dim=-1) else: - scores = tt_lib.tensor.softmax(scores, output_mem_config=self.args.out_mem_config) + scores = ttnn.softmax(scores, output_mem_config=self.args.out_mem_config) output = tt_lib.tensor.bmm( scores, value, output_mem_config=self.args.out_mem_config ) # (bs, n_local_heads, slen, head_dim) diff --git a/models/experimental/nanogpt/tt/nanogpt_attention.py b/models/experimental/nanogpt/tt/nanogpt_attention.py index 4269f1794d87..171f8f8457f5 100644 --- a/models/experimental/nanogpt/tt/nanogpt_attention.py +++ b/models/experimental/nanogpt/tt/nanogpt_attention.py @@ -112,9 +112,7 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: tt_att = torch_to_tt_tensor_rm(att, self.device, put_on_device=False) - tt_att = tt_lib.tensor.softmax( - tt_att - ) # Using tt_lib.tensor.softmax reduces pcc from 0.99 to 0.98 for whole model + tt_att = ttnn.softmax(tt_att) # Using ttnn.softmax reduces pcc from 0.99 to 0.98 for whole model tt_y = tt_lib.tensor.bmm(tt_att, v) diff --git a/models/experimental/t5/tt/t5_attention.py b/models/experimental/t5/tt/t5_attention.py index ef704e015f1a..2635f55a335e 100644 --- a/models/experimental/t5/tt/t5_attention.py +++ b/models/experimental/t5/tt/t5_attention.py @@ -548,7 +548,7 @@ def project(hidden_states, proj_weights, key_value_states, past_key_value): scores = ttnn.add(scores, position_bias, memory_config=self.mem_config) # attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = tt_lib.operations.primary.softmax_in_place(scores) + attn_weights = ttnn.softmax_in_place(scores) # Dropout is not used in inference # attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length) diff --git a/tests/tt_eager/profiling/ops_for_profiling.py b/tests/tt_eager/profiling/ops_for_profiling.py index aa1237967d71..78f34b165f5d 100644 --- a/tests/tt_eager/profiling/ops_for_profiling.py +++ b/tests/tt_eager/profiling/ops_for_profiling.py @@ -305,7 +305,7 @@ def primary_moreh_logsoftmax_backward_3(x, y): def primary_scale_mask_softmax_in_place(x, y): - tt_lib.operations.primary.transformers.scale_mask_softmax_in_place(x, scale=3.3, mask=y) + ttnn.scale_mask_softmax_in_place(x, scale=3.3, mask=y) def scale_mask_softmax_in_place_shape_func(input_shape): @@ -1046,7 +1046,7 @@ def angle_bw(x, y): }, { "op": primary_scale_mask_softmax_in_place, - "name": "tt_lib.operations.primary.transformers.scale_mask_softmax_in_place", + "name": "ttnn.scale_mask_softmax_in_place", "shape_func": scale_mask_softmax_in_place_shape_func, }, { @@ -2359,8 +2359,8 @@ def primary_moreh_norm_3(x): "name": "tt_lib.tensor.round_bw", }, { - "op": tt_lib.operations.primary.softmax_in_place, - "name": "tt_lib.operations.primary.softmax_in_place", + "op": ttnn.softmax_in_place, + "name": "ttnn.softmax_in_place", }, { "op": primary_moreh_softmax_0, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_softmax_sharded.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_softmax_sharded.py index d2d927eb78c4..aef27313bbe7 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_softmax_sharded.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_softmax_sharded.py @@ -7,6 +7,8 @@ import pytest import math +import ttnn + import tt_lib as ttl from tt_lib.utils import ( pad_weight, @@ -114,7 +116,7 @@ def test_softmax(device, in_dtype, causal_mask, grid_size, seq_len, scale_mask): subblock_w = i break - program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=subblock_w, block_h=block_h, @@ -122,11 +124,11 @@ def test_softmax(device, in_dtype, causal_mask, grid_size, seq_len, scale_mask): ) if scale_mask: - tt_output_sharded = ttl.operations.primary.transformers.scale_mask_softmax_in_place( + tt_output_sharded = ttnn.scale_mask_softmax_in_place( in1_t_shard, scale, attention_mask_t, program_config=program_config, is_causal_mask=causal_mask ) else: - tt_output_sharded = ttl.operations.primary.softmax_in_place(in1_t_shard, program_config=program_config) + tt_output_sharded = ttnn.softmax_in_place(in1_t_shard, program_config=program_config) tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config) tt_output_tensor = tt_output.cpu().to_torch().float() diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 1939e5c8a1f5..b7a54f51b50d 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -364,7 +364,7 @@ def eltwise_gelu( @setup_host_and_device def eltwise_softmax_in_place(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.operations.primary.softmax_in_place(t0) + t1 = ttnn.softmax_in_place(t0) return tt2torch_tensor(t1) @@ -385,7 +385,7 @@ def eltwise_scale_mask_softmax_in_place( t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) - t2 = ttl.operations.primary.transformers.scale_mask_softmax_in_place(t0, scale, t1) + t2 = ttnn.scale_mask_softmax_in_place(t0, scale, t1) return tt2torch_tensor(t2) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_single_core_fused_ops.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_single_core_fused_ops.py index 60ef109ba83c..525c87fd245b 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_single_core_fused_ops.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_single_core_fused_ops.py @@ -6,6 +6,8 @@ import pytest from loguru import logger +import ttnn + import tt_lib as ttl from tt_lib.utils import ( is_close, @@ -22,7 +24,7 @@ def test_softmax(shape, device): torch.manual_seed(1234) x = torch.randn(shape).bfloat16().float() xt = ttl.tensor.Tensor(x, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device) - xtt = ttl.operations.primary.softmax_in_place(xt) + xtt = ttnn.softmax_in_place(xt) tt_got_back = xtt.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax.py index bdd59c1baff7..f247f3db5fed 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax.py @@ -7,6 +7,8 @@ import pytest import math +import ttnn + import tt_lib as ttl from tt_lib.utils import ( pad_weight, @@ -21,34 +23,34 @@ @pytest.mark.parametrize( "dtype", - (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.FLOAT32), + (ttnn.bfloat16, ttnn.float32), ids=["bfloat16", "float"], ) @pytest.mark.parametrize("inplace", [True, False]) def test_softmax(device, inplace, dtype): - if is_grayskull() and dtype == ttl.tensor.DataType.FLOAT32: + if is_grayskull() and dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") torch.manual_seed(0) - sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax + sm_op = ttnn.softmax_in_place if inplace else ttnn.softmax input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)] for input_shape in input_shapes: input_tensor = torch.randn(input_shape).bfloat16() - tt_input_tensor = ttl.tensor.Tensor(input_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device) + tt_input_tensor = ttnn.from_torch(input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) if not is_grayskull(): - if dtype == ttl.tensor.DataType.FLOAT32: - compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( - math_fidelity=ttl.tensor.MathFidelity.HiFi4, + if dtype == ttnn.float32: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=True, ) else: - compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( - math_fidelity=ttl.tensor.MathFidelity.HiFi4, + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=False, ) @@ -56,7 +58,9 @@ def test_softmax(device, inplace, dtype): tt_output_tensor_on_device = sm_op( tt_input_tensor, compute_kernel_config=compute_kernel_config if not is_grayskull() else None ) - tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + tt_output_tensor = ttnn.to_layout(tt_output_tensor_on_device, ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) golden_output_tensor = torch.softmax(input_tensor, dim=-1) print_diff_argmax(tt_output_tensor, golden_output_tensor) @@ -69,18 +73,18 @@ def test_softmax(device, inplace, dtype): @pytest.mark.parametrize("inplace", [True, False]) def test_softmax_with_program_cache(device, use_program_cache, inplace): torch.manual_seed(0) - sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax + sm_op = ttnn.softmax_in_place if inplace else ttnn.softmax input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)] for input_shape in input_shapes: input_tensor = torch.randn(input_shape).bfloat16() - tt_input_tensor = ( - ttl.tensor.Tensor(input_tensor, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device) - ) + tt_input_tensor = ttnn.from_torch(input_tensor, layout=ttnn.TILE_LAYOUT, device=device) tt_output_tensor_on_device = sm_op(tt_input_tensor) - tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + tt_output_tensor = ttnn.to_layout(tt_output_tensor_on_device, ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) golden_output_tensor = torch.softmax(input_tensor, dim=-1) print_diff_argmax(tt_output_tensor, golden_output_tensor) @@ -89,29 +93,26 @@ def test_softmax_with_program_cache(device, use_program_cache, inplace): assert allclose, f"FAILED: {output}" -@pytest.mark.parametrize( - "cb_dtype", - (ttl.tensor.DataType.BFLOAT16,), - ids=["BFLOAT16"], -) @pytest.mark.parametrize( "in_dtype", - (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), - ids=["BFLOAT16", "BFLOAT8_B"], + (ttnn.bfloat16, ttnn.bfloat8_b), + ids=["bfloat16", "bfloat8_b"], ) @pytest.mark.parametrize("inplace", [True, False]) -def test_softmax_mix_precision(device, inplace, in_dtype, cb_dtype): +def test_softmax_mix_precision(device, inplace, in_dtype): torch.manual_seed(0) - sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax + sm_op = ttnn.softmax_in_place if inplace else ttnn.softmax input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)] for input_shape in input_shapes: input_tensor = torch.randn(input_shape).bfloat16() - tt_input_tensor = ttl.tensor.Tensor(input_tensor, in_dtype).to(ttl.tensor.Layout.TILE).to(device) + tt_input_tensor = ttnn.from_torch(input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device) tt_output_tensor_on_device = sm_op(tt_input_tensor) - tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + tt_output_tensor = ttnn.to_layout(tt_output_tensor_on_device, ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) golden_output_tensor = torch.softmax(input_tensor, dim=-1) print_diff_argmax(tt_output_tensor, golden_output_tensor) @@ -132,7 +133,7 @@ def test_softmax_mix_precision(device, inplace, in_dtype, cb_dtype): ) @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], @@ -140,14 +141,14 @@ def test_softmax_mix_precision(device, inplace, in_dtype, cb_dtype): @pytest.mark.parametrize( "in_dtype", ( - ttl.tensor.DataType.FLOAT32, - ttl.tensor.DataType.BFLOAT16, - ttl.tensor.DataType.BFLOAT8_B, + ttnn.float32, + ttnn.bfloat16, + ttnn.bfloat8_b, ), - ids=["FLOAT32", "BFLOAT16", "BFLOAT8_B"], + ids=["float32", "bfloat16", "bfloat8_b"], ) def test_scale_mask_softmax_inplace(device, in_dtype, in0_mem_config, causal_mask, seq_len): - if is_grayskull() and in_dtype == ttl.tensor.DataType.FLOAT32: + if is_grayskull() and in_dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") torch.manual_seed(0) @@ -162,57 +163,42 @@ def test_scale_mask_softmax_inplace(device, in_dtype, in0_mem_config, causal_mas hidden_dim = 1024 num_heads = 16 - # scale = 1.0 scale = 1 / math.sqrt(hidden_dim // num_heads) + mask_dtype = ttnn.float32 if in_dtype == ttnn.float32 else ttnn.bfloat16 + if causal_mask == False: attention_mask = torch.rand(batch, 1, 32, seq_len) mask = torch.rand_like(attention_mask) < 0.2 attention_mask[mask] = float("-inf") - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 32, seq_len], - # ttl.tensor.DataType.BFLOAT16, - ttl.tensor.DataType.FLOAT32 if in_dtype == ttl.tensor.DataType.FLOAT32 else ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=mask_dtype, layout=ttnn.TILE_LAYOUT, device=device) else: - # attention_mask = torch.zeros(batch, 1, seq_len, seq_len) attention_mask = torch.rand(batch, 1, seq_len, seq_len) mask = torch.rand_like(attention_mask) < 0.2 attention_mask[mask] = float("-inf") - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, seq_len, seq_len], - # ttl.tensor.DataType.BFLOAT16, - ttl.tensor.DataType.FLOAT32 if in_dtype == ttl.tensor.DataType.FLOAT32 else ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask = pad_weight(attention_mask) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=mask_dtype, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config + ) if not is_grayskull(): - if in_dtype == ttl.tensor.DataType.FLOAT32: - compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( - math_fidelity=ttl.tensor.MathFidelity.HiFi4, + if in_dtype == ttnn.float32: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=True, ) else: - compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( - math_fidelity=ttl.tensor.MathFidelity.HiFi4, + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=False, ) - tt_output = ttl.operations.primary.transformers.scale_mask_softmax_in_place( + tt_output = ttnn.scale_mask_softmax_in_place( in1_t, scale, attention_mask_t, @@ -220,9 +206,9 @@ def test_scale_mask_softmax_inplace(device, in_dtype, in0_mem_config, causal_mas compute_kernel_config=compute_kernel_config if not is_grayskull() else None, ) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output, ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) if causal_mask == False: attention_mask = attention_mask.reshape(batch, 1, 32, seq_len)[:, :, 0, :] @@ -243,15 +229,15 @@ def test_scale_mask_softmax_inplace(device, in_dtype, in0_mem_config, causal_mas @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], ) @pytest.mark.parametrize( "in_dtype", - (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), - ids=["BFLOAT16", "BFLOAT8_B"], + (ttnn.bfloat16, ttnn.bfloat8_b), + ids=["bfloat16", "bfloat8_b"], ) def test_scale_mask_softmax(device, in_dtype, in0_mem_config): torch.manual_seed(0) @@ -261,32 +247,24 @@ def test_scale_mask_softmax(device, in_dtype, in0_mem_config): batch = grid_size[0] num_cores_r = grid_size[1] input_shape = (batch, 1, num_cores_r * fuse_head * 384, 384) - M = input_shape[2] - K = input_shape[3] * batch hidden_dim = 1024 num_heads = 16 scale = 1 / math.sqrt(hidden_dim // num_heads) attention_mask = torch.rand(batch, 1, 32, 384) attention_mask = (attention_mask > 0.5).float() - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 32, 384], - ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config + ) - tt_output = ttl.tensor.scale_mask_softmax(in1_t, scale, attention_mask_t) + tt_output = ttnn.scale_mask_softmax(in1_t, scale, attention_mask_t) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output, ttnn.ROW_MAJOR_LAYOUT) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) attention_mask = attention_mask.reshape(batch, 1, 32, 384) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax_sharded.py index 96cc713cb45c..eaff5e39a8b6 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_softmax_sharded.py @@ -7,6 +7,8 @@ import pytest import math +import ttnn + import tt_lib as ttl from tt_lib.utils import ( pad_weight, @@ -22,19 +24,19 @@ @pytest.mark.parametrize("device_params", [{"l1_small_size": 8192}], indirect=True) @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], ) @pytest.mark.parametrize( "in_dtype", - (ttl.tensor.DataType.BFLOAT8_B,), - ids=["BFLOAT8_B"], + (ttnn.bfloat8_b,), + ids=["bfloat8_b"], ) def test_softmax_causal_mask(device, in_dtype, in0_mem_config): torch.manual_seed(0) - sm_op = ttl.operations.primary.transformers.scale_mask_softmax_in_place + sm_op = ttnn.scale_mask_softmax_in_place fuse_head = 2 @@ -53,27 +55,20 @@ def test_softmax_causal_mask(device, in_dtype, in0_mem_config): attention_mask = torch.rand(batch, 1, 384, 768) attention_mask = (attention_mask > 0.5).float() - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 384, 768], - ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) - in1_t_shard = ttl.tensor.interleaved_to_sharded( - in1_t, - grid_size, - [fuse_head * 384, 768], - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttl.tensor.ShardOrientation.COL_MAJOR, + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config ) - - program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + grid_coord = ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_shape = [fuse_head * 384, 768] + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False) + sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) + in1_t_shard = ttnn.to_memory_config(in1_t, sharded_mem_config) + + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=8, block_h=12 * fuse_head, @@ -82,10 +77,9 @@ def test_softmax_causal_mask(device, in_dtype, in0_mem_config): tt_output_sharded = sm_op(in1_t_shard, scale, attention_mask_t, program_config=program_config, is_causal_mask=True) - tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output_sharded, ttnn.ROW_MAJOR_LAYOUT, memory_config=in0_mem_config) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) attention_mask = attention_mask.repeat(1, 1, fuse_head, 1) @@ -108,7 +102,7 @@ def test_softmax_causal_mask(device, in_dtype, in0_mem_config): ) @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], @@ -116,17 +110,17 @@ def test_softmax_causal_mask(device, in_dtype, in0_mem_config): @pytest.mark.parametrize( "in_dtype", ( - ttl.tensor.DataType.FLOAT32, - ttl.tensor.DataType.BFLOAT8_B, + ttnn.float32, + ttnn.bfloat8_b, ), - ids=["FLOAT32", "BFLOAT8_B"], + ids=["float32", "bfloat8_b"], ) def test_softmax(device, in_dtype, in0_mem_config, causal_mask): - if is_grayskull() and in_dtype == ttl.tensor.DataType.FLOAT32: + if is_grayskull() and in_dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") torch.manual_seed(0) - sm_op = ttl.operations.primary.transformers.scale_mask_softmax_in_place + sm_op = ttnn.scale_mask_softmax_in_place fuse_head = 2 @@ -147,41 +141,27 @@ def test_softmax(device, in_dtype, in0_mem_config, causal_mask): # attention_mask = torch.zeros(1, 1, 1, 384 * batch) attention_mask = torch.rand(batch, 1, 1, 384) attention_mask = (attention_mask > 0.5).float() - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 32, 384], - ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask32 = pad_weight(attention_mask) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) else: attention_mask = torch.rand(batch, 1, 384, 384) attention_mask = (attention_mask > 0.5).float() - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 384, 384], - ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) - in1_t_shard = ttl.tensor.interleaved_to_sharded( - in1_t, - grid_size, - [fuse_head * 384, 384], - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttl.tensor.ShardOrientation.COL_MAJOR, + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config ) - - program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + grid_coord = ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_shape = [fuse_head * 384, 384] + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False) + sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) + in1_t_shard = ttnn.to_memory_config(in1_t, sharded_mem_config) + + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, - subblock_w=4 if in_dtype == ttl.tensor.DataType.FLOAT32 else 6, + subblock_w=4 if in_dtype == ttnn.float32 else 6, block_h=12 * fuse_head, block_w=12, ) @@ -190,10 +170,9 @@ def test_softmax(device, in_dtype, in0_mem_config, causal_mask): in1_t_shard, scale, attention_mask_t, program_config=program_config, is_causal_mask=causal_mask ) - tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output_sharded, ttnn.ROW_MAJOR_LAYOUT, memory_config=in0_mem_config) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) if causal_mask == False: attention_mask = attention_mask.reshape(batch, 1, 1, 384) @@ -219,22 +198,22 @@ def test_softmax(device, in_dtype, in0_mem_config, causal_mask): ) @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], ) @pytest.mark.parametrize( "in_dtype", - (ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT8_B), - ids=["FLOAT32", "BFLOAT8_B"], + (ttnn.float32, ttnn.bfloat8_b), + ids=["float32", "bfloat8_b"], ) def test_scale_mask_softmax_rm(device, in_dtype, in0_mem_config, causal_mask): - if is_grayskull() and in_dtype == ttl.tensor.DataType.FLOAT32: + if is_grayskull() and in_dtype == ttnn.float32: pytest.skip("Skipping float32 tests on Grayskull") torch.manual_seed(0) - sm_op = ttl.operations.primary.transformers.scale_mask_softmax_in_place + sm_op = ttnn.scale_mask_softmax_in_place fuse_head = 1 @@ -257,39 +236,29 @@ def test_scale_mask_softmax_rm(device, in_dtype, in0_mem_config, causal_mask): attention_mask = torch.rand(batch, 1, 1, 384) attention_mask = (attention_mask > 0.5).float() attention_mask = attention_mask.reshape(batch, 1, -1, 32) - attention_mask_t = ttl.tensor.Tensor( - attention_mask, - # ttl.tensor.DataType.BFLOAT16, - ttl.tensor.DataType.FLOAT32 if in_dtype == ttl.tensor.DataType.FLOAT32 else ttl.tensor.DataType.BFLOAT16, - ).to(device, ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1)) + attention_mask_t = ttnn.from_torch( + attention_mask, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ) else: # attention_mask = torch.zeros(batch, 1, 384, 384) attention_mask = torch.rand(batch, 1, 384, 384) attention_mask = (attention_mask > 0.5).float() - attention_mask32 = tilize_to_list(pad_weight(attention_mask)) - attention_mask_t = ttl.tensor.Tensor( - attention_mask32, - [batch, 1, 384, 384], - # ttl.tensor.DataType.BFLOAT16, - ttl.tensor.DataType.FLOAT32 if in_dtype == ttl.tensor.DataType.FLOAT32 else ttl.tensor.DataType.BFLOAT16, - ttl.tensor.Layout.TILE, - device, - ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) - in1_t_shard = ttl.tensor.interleaved_to_sharded( - in1_t, - grid_size, - [fuse_head * 384, 384], - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttl.tensor.ShardOrientation.ROW_MAJOR, + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config ) - - program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + grid_coord = ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_shape = [fuse_head * 384, 384] + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) + in1_t_shard = ttnn.to_memory_config(in1_t, sharded_mem_config) + + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, - subblock_w=4 if in_dtype == ttl.tensor.DataType.FLOAT32 else 6, + subblock_w=4 if in_dtype == ttnn.float32 else 6, block_h=12 * fuse_head, block_w=12, ) @@ -298,10 +267,9 @@ def test_scale_mask_softmax_rm(device, in_dtype, in0_mem_config, causal_mask): in1_t_shard, scale, attention_mask_t, program_config=program_config, is_causal_mask=causal_mask ) - tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output_sharded, ttnn.ROW_MAJOR_LAYOUT, memory_config=in0_mem_config) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) if causal_mask == False: attention_mask = attention_mask.reshape(batch, 1, 1, 384) @@ -322,24 +290,24 @@ def test_scale_mask_softmax_rm(device, in_dtype, in0_mem_config, causal_mask): @pytest.mark.parametrize("device_params", [{"l1_small_size": 8192}], indirect=True) @pytest.mark.parametrize( "shard_orient", - [ttl.tensor.ShardOrientation.COL_MAJOR, ttl.tensor.ShardOrientation.ROW_MAJOR], + [ttnn.ShardOrientation.COL_MAJOR, ttnn.ShardOrientation.ROW_MAJOR], ids=["CM", "RM"], ) @pytest.mark.parametrize( "in0_mem_config", - (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),), + (ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),), ids=[ "in0_DRAM", ], ) @pytest.mark.parametrize( "in_dtype", - (ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT8_B), - ids=["FLOAT32", "BFLOAT8_B"], + (ttnn.float32, ttnn.bfloat8_b), + ids=["float32", "bfloat8_b"], ) def test_softmax_with_sharded_mask(device, in_dtype, in0_mem_config, shard_orient): torch.manual_seed(0) - sm_op = ttl.operations.primary.transformers.scale_mask_softmax_in_place + sm_op = ttnn.scale_mask_softmax_in_place grid_size = (8, 4) input_shape = (1, 32, 32, 1024) @@ -351,35 +319,26 @@ def test_softmax_with_sharded_mask(device, in_dtype, in0_mem_config, shard_orien attention_mask = torch.rand(1, 32, 32, 1024) attention_mask = (attention_mask > 0.5).float() attention_mask = torch.where(attention_mask == 1, torch.tensor(0.0), torch.tensor(-float("inf"))) - attention_mask_t = torch2tt_tensor( - attention_mask, - device, - tt_memory_config=in0_mem_config, - tt_dtype=( - ttl.tensor.DataType.FLOAT32 if in_dtype == ttl.tensor.DataType.FLOAT32 else ttl.tensor.DataType.BFLOAT16 - ), - ) - attention_mask_t_shard = ttl.tensor.interleaved_to_sharded( - attention_mask_t, - grid_size, - [M, K], - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - shard_orient, + mask_dtype = ttnn.float32 if in_dtype == ttnn.float32 else ttnn.bfloat16 + attention_mask_t = ttnn.from_torch( + attention_mask, dtype=mask_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config ) + grid_coord = ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_shape = [M, K] + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orient, False) + sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) + attention_mask_t_shard = ttnn.to_memory_config(attention_mask_t, sharded_mem_config) input_tensor = torch.randn(input_shape).bfloat16().float() - in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype) - in1_t_shard = ttl.tensor.interleaved_to_sharded( - in1_t, - grid_size, - [M, K], - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - shard_orient, + in1_t = ttnn.from_torch( + input_tensor, dtype=in_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=in0_mem_config ) + in1_t_shard = ttnn.to_memory_config(in1_t, sharded_mem_config) - program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, - subblock_w=4 if in_dtype == ttl.tensor.DataType.FLOAT32 else 8, + subblock_w=4 if in_dtype == ttnn.float32 else 8, block_h=1, block_w=32, ) @@ -388,10 +347,9 @@ def test_softmax_with_sharded_mask(device, in_dtype, in0_mem_config, shard_orien in1_t_shard, scale, attention_mask_t_shard, program_config=program_config, is_causal_mask=True ) - tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config) - tt_output_tensor = tt_output.cpu().to_torch().float() - tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape) - tt_output_tensor = untilize(tt_output_tensor) + tt_output_tensor = ttnn.to_layout(tt_output_sharded, ttnn.ROW_MAJOR_LAYOUT, memory_config=in0_mem_config) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) golden_output_tensor = input_tensor * scale + attention_mask golden_output_tensor = torch.softmax(golden_output_tensor, dim=-1) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py b/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py index b484e45d8b9a..b215b2220fe5 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py +++ b/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py @@ -180,7 +180,7 @@ def test_time_sharded_attnention_hwb( ) mm_slice = ttl.tensor.move_sharded(mm_slice) - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=1, block_h=mm_output_height_shard_spec[0] // 32, @@ -188,7 +188,7 @@ def test_time_sharded_attnention_hwb( ) # print(program_config) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) # mmt = tt2torch_tensor(mm_slice) # passed, message = comp_pcc(mmt, attn_weights_torch_sm[:, i * heads_per_slice : (i + 1) * heads_per_slice, :, :]) # print(message) @@ -365,14 +365,14 @@ def test_time_sharded_attnention( k_slice.deallocate() slice.deallocate() - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=1, block_h=mm_output_height_shard_spec[0] // 32, block_w=mm_output_height_shard_spec[1] // 32, ) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( compute_with_storage_grid_size=grid_size, @@ -418,7 +418,7 @@ def test_time_sharded_attnention( attn_weights = ttl.tensor.bmm( reference_query_layer, reference_key_layer_transposed, output_mem_config=dram_interleaved_memory_config ) - attn_weights = ttl.operations.primary.softmax_in_place(attn_weights) + attn_weights = ttnn.softmax_in_place(attn_weights) attn_weights = ttl.tensor.bmm(attn_weights, reference_value_layer, output_mem_config=dram_interleaved_memory_config) attn_weights_torch = tt2torch_tensor(attn_weights) @@ -566,23 +566,23 @@ def test_cross_attnention( mm_slice, output_mem_config, ) - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 8), subblock_w=1, block_h=32, block_w=3, ) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) mm_slice = ttl.tensor.reshard(mm_slice, orig_mem_config) else: - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=1, block_h=seq_len // 32, block_w=kv_len // 32, ) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) v_sharded = ttl.tensor.interleaved_to_sharded( reference_value_layer, @@ -751,13 +751,13 @@ def test_attention( ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttl.tensor.ShardOrientation.ROW_MAJOR, ) - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 8), subblock_w=1, block_h=height_per_core // 32, block_w=seq_len // 32, ) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) mm_slice = ttl.tensor.sharded_to_interleaved(mm_slice, l1_interleaved_memory_config) mm_slice = ttl.tensor.interleaved_to_sharded( mm_slice, @@ -781,23 +781,23 @@ def test_attention( mm_slice, output_mem_config, ) - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=(8, 8), subblock_w=1, block_h=height_per_core // 32, block_w=seq_len // 32, ) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) mm_slice = ttl.tensor.reshard(mm_slice, orig_mem_config) else: - softmax_program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( compute_with_storage_grid_size=grid_size, subblock_w=1, block_h=seq_len // 32, block_w=seq_len // 32, ) print(softmax_program_config) - mm_slice = ttl.operations.primary.softmax_in_place(mm_slice, program_config=softmax_program_config) + mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) v_sharded = ttl.tensor.interleaved_to_sharded( reference_value_layer, diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 6d3b1549c428..7b2f0e5c9771 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -12,6 +12,8 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp ) add_library(ttnn_lib OBJECT ${TTNN_SRCS}) diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index e16c04ceb7ef..4b34239b639b 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -16,7 +16,6 @@ #include "pybind11/operations/kv_cache.hpp" #include "pybind11/operations/matmul.hpp" #include "pybind11/operations/maxpool2d.hpp" -#include "pybind11/operations/normalization.hpp" #include "pybind11/operations/pool.hpp" #include "pybind11/operations/copy.hpp" #include "pybind11/operations/ternary.hpp" @@ -24,6 +23,7 @@ #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" #include "ttnn/operations/eltwise/unary/unary_pybind.hpp" +#include "ttnn/operations/normalization/normalization_pybind.hpp" #include "ttnn/operations/reduction/reduction_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp deleted file mode 100644 index 28b773c46de5..000000000000 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/normalization.hpp" - -namespace py = pybind11; - -namespace { - MemoryConfig dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM}; -} - -namespace ttnn { -namespace operations { -namespace normalization { -void py_module(py::module& module) { - ttnn::bind_registered_operation( - module, - ttnn::softmax, - R"doc(softmax(input_tensor: ttnn.Tensor, dim: int, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - - Compute softmax over :attr:`input_tensor` along :attr:`dim`. - - Args: - * :attr:`input_tensor`: the input tensor - * :attr:`dim`: the dimension along which to compute softmax. - - Keyword Args: - * :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used. - - Example: - - >>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device) - >>> output = ttnn.softmax(tensor, -1) - >>> print(output[0, 0, 0, :3]) - ttnn.Tensor([ 0.0310059, 0.0310059, 0.0310059], dtype=bfloat16 ) - )doc", - ttnn::pybind_arguments_t{ - py::arg("input_tensor"), py::arg("dim"), py::kw_only(), py::arg("memory_config") = std::nullopt}); - - ttnn::bind_registered_operation( - module, - ttnn::layer_norm, - R"doc(rms_norm(input_tensor: ttnn.Tensor, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None, residual_input_tensor: Optional[ttnn.Tensor] = None, memory_config: Optional[ttnn.MemoryConfig] = None, program_config: Optional[ttnn.ProgramConfig] = None) -> ttnn.Tensor - Compute layer_norm over :attr:`input_tensor`. - )doc", - ttnn::pybind_arguments_t{ - py::arg("input_tensor"), - py::kw_only(), - py::arg("epsilon") = 1e-12, - py::arg("weight") = std::nullopt, - py::arg("bias") = std::nullopt, - py::arg("residual_input_tensor") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("program_config") = std::nullopt}); - - ttnn::bind_registered_operation( - module, - ttnn::rms_norm, - R"doc(rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-12, Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor - Compute rms_norm over :attr:`input_tensor`. - )doc", - ttnn::pybind_arguments_t{ - py::arg("input_tensor"), - py::arg("weight"), - py::kw_only(), - py::arg("epsilon") = 1e-12, - py::arg("memory_config") = std::nullopt}); - - ttnn::bind_registered_operation( - module, - ttnn::group_norm, - R"doc(group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor - Compute group_norm over :attr:`input_tensor`. - )doc", - ttnn::pybind_arguments_t{ - py::arg("input_tensor"), - py::kw_only(), - py::arg("num_groups"), - py::arg("epsilon") = 1e-12, - py::arg("input_mask") = std::nullopt, - py::arg("weight") = std::nullopt, - py::arg("bias") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("dtype") = std::nullopt, - py::arg("core_grid") = std::nullopt, - py::arg("inplace") = true, - py::arg("output_layout") = std::nullopt - } - ); -} - -} // namespace normalization -} // namespace operations -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp deleted file mode 100644 index 4d7524ce4e8f..000000000000 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "tt_dnn/op_library/softmax/softmax_op.hpp" -#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp" -#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" - -namespace ttnn { -namespace operations { -namespace normalization { - -template -struct Softmax { - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; - } - - template - static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static ttnn::Tensor execute_on_worker_thread( - const ttnn::Tensor& input_tensor, - const int dim_arg, - const std::optional& memory_config = std::nullopt) { - auto input_shape = input_tensor.get_shape(); - auto rank = input_shape.size(); - auto dim = dim_arg; - if (dim < 0) { - dim = rank + dim; - } - - auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); - auto is_tile_padded = input_tensor.get_shape()[-2] != input_tensor.get_shape().with_tile_padding()[-2] or - input_tensor.get_shape()[-1] != input_tensor.get_shape().with_tile_padding()[-1]; - if (dim == rank - 1) { - auto output_tensor = - tt::tt_metal::softmax(input_tensor_4D, memory_config.value_or(input_tensor.memory_config())); - return ttnn::reshape(output_tensor, input_shape); - } else { - auto dim_4D = dim + 4 - rank; - auto output_tensor = tt::operations::primary::moreh_softmax(input_tensor_4D, dim_4D); - return ttnn::reshape(output_tensor, input_shape); - } - } -}; - -struct LayerNorm { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - true}}; - } - - template - static auto input_tensors_to_validate( - const Tensor& input_tensor, - float epsilon = 1e-12, - const std::optional& weight = std::nullopt, - const std::optional& bias = std::nullopt, - const std::optional& residual_input_tensor = std::nullopt, - Args&&... args) { - return std::forward_as_tuple(input_tensor, weight, bias, residual_input_tensor); - } - - static inline ttnn::Tensor execute_on_worker_thread( - const ttnn::Tensor& input_tensor, - float epsilon = 1e-12, - const std::optional& weight = std::nullopt, - const std::optional& bias = std::nullopt, - const std::optional& residual_input_tensor = std::nullopt, - const std::optional& memory_config_arg = std::nullopt, - const std::optional& program_config_arg = std::nullopt) { - const LayerNormProgramConfig& program_config = program_config_arg.value_or(LayerNormDefaultProgramConfig{}); - - auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); - if (residual_input_tensor.has_value()) { - return tt::operations::primary::add_layernorm( - input_tensor, residual_input_tensor.value(), epsilon, weight, bias, memory_config, program_config); - } else { - return tt::operations::primary::layernorm( - input_tensor, epsilon, weight, bias, memory_config, program_config); - } - } -}; - -struct RMSNorm { - static inline const std::array input_tensor_schemas() { - return { - ttnn::TensorSchema{ - 2, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT}, - true, - false, - false, - false}, - ttnn::TensorSchema{ - 1, - 4, - {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, - {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, - true, - false, - false, - false}}; - } - - template - static auto input_tensors_to_validate(const Tensor& input_tensor, const Tensor& weight, Args&&... args) { - return std::forward_as_tuple(input_tensor, weight); - } - - static inline ttnn::Tensor execute_on_worker_thread( - const ttnn::Tensor& input_tensor, - const ttnn::Tensor& weight, - float epsilon = 1e-12, - const std::optional& memory_config_arg = std::nullopt) { - auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); - return tt::operations::primary::rmsnorm(input_tensor, epsilon, weight, std::nullopt, memory_config); - } -}; - -struct GroupNorm { - template - static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { - return std::forward_as_tuple(input_tensor); - } - - static inline const std::array input_tensor_schemas() { - return {ttnn::TensorSchema{ - 2, 4, {ttnn::bfloat16}, {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}}; - } - - static inline ttnn::Tensor execute_on_worker_thread( - const ttnn::Tensor& input_tensor, - const int num_groups, - const float epsilon, - const std::optional& input_mask = std::nullopt, - const std::optional& weight = std::nullopt, - const std::optional& bias = std::nullopt, - const std::optional& memory_config = std::nullopt, - const std::optional dtype = std::nullopt, - std::optional core_grid = std::nullopt, - std::optional inplace = std::nullopt, - std::optional output_layout = std::nullopt) { - if (input_tensor.get_layout() == Layout::TILE and inplace.has_value()) { - TT_FATAL(inplace == false, "Tile layour does not support inplace tensors"); - } - if (output_layout.has_value() and inplace.has_value()) { - if (output_layout != input_tensor.get_layout()) { - TT_FATAL(inplace == false, "cannot inplace tensors when layout are different"); - } - } - TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported"); - - TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported"); - - TT_FATAL( - input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, - "Input tensor cannot be width sharded"); - - TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4"); - - TT_FATAL( - input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups"); - - const auto& ts = input_tensor.get_shape(); - TT_FATAL( - (ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, - "Input tensor dim NHW must be divisible by tile size"); - - const auto output_dtype = dtype.value_or(input_tensor.get_dtype()); - - const std::optional& gamma = - weight.has_value() ? std::optional(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt; - const std::optional& beta = - bias.has_value() ? std::optional(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt; - - const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, - .buffer_type = tt::tt_metal::BufferType::DRAM}; - const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config); - - const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = { - .compute_with_storage_grid_size = core_grid.value().to_CoreCoord(), - .math_fidelity = MathFidelity::HiFi4, - .im_data_format = DataType::BFLOAT16, - .out_data_format = DataType::BFLOAT16, - .inplace = inplace.value_or(false), - .output_layout = output_layout.value_or(input_tensor.get_layout())}; - - return tt::operations::primary::groupnorm( - input_tensor, num_groups, epsilon, gamma, beta, input_mask, output_mem_config, program_config); - } -}; - -} // namespace normalization -} // namespace operations - -constexpr auto softmax = ttnn::register_operation>("ttnn::softmax"); -constexpr auto layer_norm = ttnn::register_operation("ttnn::layer_norm"); -constexpr auto rms_norm = ttnn::register_operation("ttnn::rms_norm"); -constexpr auto group_norm = ttnn::register_operation("ttnn::group_norm"); -} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm.hpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm.hpp new file mode 100644 index 000000000000..492284343990 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm.hpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp" + +namespace ttnn { +namespace operations { +namespace normalization { + +struct GroupNorm { + template + static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16}, {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}}; + } + + static inline ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const int num_groups, + const float epsilon, + const std::optional& input_mask = std::nullopt, + const std::optional& weight = std::nullopt, + const std::optional& bias = std::nullopt, + const std::optional& memory_config = std::nullopt, + const std::optional dtype = std::nullopt, + std::optional core_grid = std::nullopt, + std::optional inplace = std::nullopt, + std::optional output_layout = std::nullopt) { + if (input_tensor.get_layout() == Layout::TILE and inplace.has_value()) { + TT_FATAL(inplace == false, "Tile layour does not support inplace tensors"); + } + if (output_layout.has_value() and inplace.has_value()) { + if (output_layout != input_tensor.get_layout()) { + TT_FATAL(inplace == false, "cannot inplace tensors when layout are different"); + } + } + TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported"); + + TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported"); + + TT_FATAL( + input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, + "Input tensor cannot be width sharded"); + + TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4"); + + TT_FATAL( + input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups"); + + const auto& ts = input_tensor.get_shape(); + TT_FATAL( + (ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, + "Input tensor dim NHW must be divisible by tile size"); + + const auto output_dtype = dtype.value_or(input_tensor.get_dtype()); + + const std::optional& gamma = + weight.has_value() ? std::optional(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt; + const std::optional& beta = + bias.has_value() ? std::optional(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt; + + const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, + .buffer_type = tt::tt_metal::BufferType::DRAM}; + const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config); + + const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = { + .compute_with_storage_grid_size = core_grid.value().to_CoreCoord(), + .math_fidelity = MathFidelity::HiFi4, + .im_data_format = DataType::BFLOAT16, + .out_data_format = DataType::BFLOAT16, + .inplace = inplace.value_or(false), + .output_layout = output_layout.value_or(input_tensor.get_layout())}; + + return tt::operations::primary::groupnorm( + input_tensor, num_groups, epsilon, gamma, beta, input_mask, output_mem_config, program_config); + } +}; + +} // namespace normalization +} // namespace operations + +constexpr auto group_norm = ttnn::register_operation("ttnn::group_norm"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.hpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.hpp new file mode 100644 index 000000000000..6a28ad37dbf7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.hpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "groupnorm.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::normalization::detail { + +void bind_normalization_group_norm_operation(py::module& module) { + + ttnn::bind_registered_operation( + module, + ttnn::group_norm, + R"doc(group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor + Compute group_norm over :attr:`input_tensor`. + )doc", + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::kw_only(), + py::arg("num_groups"), + py::arg("epsilon") = 1e-12, + py::arg("input_mask") = std::nullopt, + py::arg("weight") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt, + py::arg("core_grid") = std::nullopt, + py::arg("inplace") = true, + py::arg("output_layout") = std::nullopt + } + ); +} + +} // namespace ttnn::operations::normalization::detail diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm.hpp new file mode 100644 index 000000000000..d46225955940 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm.hpp @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" + +namespace ttnn { +namespace operations { +namespace normalization { + +struct LayerNorm { + static inline const std::array input_tensor_schemas() { + return { + ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT}, + true, + false, + false, + false}, + ttnn::TensorSchema{ + 1, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, + true, + false, + false, + true}, + ttnn::TensorSchema{ + 1, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, + true, + false, + false, + true}, + ttnn::TensorSchema{ + 1, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, + true, + false, + false, + true}}; + } + + template + static auto input_tensors_to_validate( + const Tensor& input_tensor, + float epsilon = 1e-12, + const std::optional& weight = std::nullopt, + const std::optional& bias = std::nullopt, + const std::optional& residual_input_tensor = std::nullopt, + Args&&... args) { + return std::forward_as_tuple(input_tensor, weight, bias, residual_input_tensor); + } + + static inline ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + float epsilon = 1e-12, + const std::optional& weight = std::nullopt, + const std::optional& bias = std::nullopt, + const std::optional& residual_input_tensor = std::nullopt, + const std::optional& memory_config_arg = std::nullopt, + const std::optional& program_config_arg = std::nullopt) { + const LayerNormProgramConfig& program_config = program_config_arg.value_or(LayerNormDefaultProgramConfig{}); + + auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); + if (residual_input_tensor.has_value()) { + return tt::operations::primary::add_layernorm( + input_tensor, residual_input_tensor.value(), epsilon, weight, bias, memory_config, program_config); + } else { + return tt::operations::primary::layernorm( + input_tensor, epsilon, weight, bias, memory_config, program_config); + } + } +}; + +struct RMSNorm { + static inline const std::array input_tensor_schemas() { + return { + ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT}, + true, + false, + false, + false}, + ttnn::TensorSchema{ + 1, + 4, + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, + true, + false, + false, + false}}; + } + + template + static auto input_tensors_to_validate(const Tensor& input_tensor, const Tensor& weight, Args&&... args) { + return std::forward_as_tuple(input_tensor, weight); + } + + static inline ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const ttnn::Tensor& weight, + float epsilon = 1e-12, + const std::optional& memory_config_arg = std::nullopt) { + auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); + return tt::operations::primary::rmsnorm(input_tensor, epsilon, weight, std::nullopt, memory_config); + } +}; + +} // namespace normalization +} // namespace operations + +constexpr auto layer_norm = ttnn::register_operation("ttnn::layer_norm"); +constexpr auto rms_norm = ttnn::register_operation("ttnn::rms_norm"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm_pybind.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm_pybind.hpp new file mode 100644 index 000000000000..57447a68acfc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/layernorm_pybind.hpp @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "layernorm.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::normalization::detail { + +void bind_normalization_layer_norm_operation(py::module& module) { + + ttnn::bind_registered_operation( + module, + ttnn::layer_norm, + R"doc(rms_norm(input_tensor: ttnn.Tensor, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None, residual_input_tensor: Optional[ttnn.Tensor] = None, memory_config: Optional[ttnn.MemoryConfig] = None, program_config: Optional[ttnn.ProgramConfig] = None) -> ttnn.Tensor + Compute layer_norm over :attr:`input_tensor`. + )doc", + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::kw_only(), + py::arg("epsilon") = 1e-12, + py::arg("weight") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("residual_input_tensor") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("program_config") = std::nullopt}); +} + +void bind_normalization_rms_norm_operation(py::module& module) { + + ttnn::bind_registered_operation( + module, + ttnn::rms_norm, + R"doc(rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-12, Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + Compute rms_norm over :attr:`input_tensor`. + )doc", + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::arg("weight"), + py::kw_only(), + py::arg("epsilon") = 1e-12, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace ttnn::operations::normalization::detail diff --git a/ttnn/cpp/ttnn/operations/normalization/normalization_pybind.hpp b/ttnn/cpp/ttnn/operations/normalization/normalization_pybind.hpp new file mode 100644 index 000000000000..3b3155400d72 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/normalization_pybind.hpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" + +#include "softmax/softmax_pybind.hpp" +#include "layernorm/layernorm_pybind.hpp" +#include "groupnorm/groupnorm_pybind.hpp" + +namespace ttnn::operations::normalization { + +void py_module(py::module& module) { + + detail::bind_normalization_softmax_program_config_operation(module); + detail::bind_normalization_softmax_operation(module); + detail::bind_normalization_scale_mask_softmax_operation(module); + detail::bind_normalization_softmax_in_place_operation(module); + detail::bind_normalization_scale_mask_softmax_in_place_operation(module); + detail::bind_normalization_scale_causal_mask_hw_dims_softmax_in_place_operation(module); + detail::bind_normalization_layer_norm_operation(module); + detail::bind_normalization_rms_norm_operation(module); + detail::bind_normalization_group_norm_operation(module); +} + +} // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp new file mode 100644 index 000000000000..326bb51bd56b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp @@ -0,0 +1,223 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#define REDUCE_OP PoolType::SUM +#define REDUCE_DIM ReduceDim::REDUCE_ROW + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/softmax.h" +#include "compute_kernel_api/reduce.h" + +// #include "debug/dprint.h" + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + +// for scale+mask+softmax: +// bcast HW (mul by 1 tile) example: ( [2,1,1024,64] * [1,1,32,32] ) +// bcast add H example: ( [2,1,1024,64] + [2,1,32,64] ) (bcast W -> H) +// Note that the attention mask will not fit in L1 for the entire tensor +// The buffer for the att mask is currently sized as (1t,Wt) so we only reuse it for one HtWt-sized batch of x +// then read another Wt tiles of mask for the next batch + +namespace NAMESPACE { +void MAIN { + + const uint32_t NCHt = get_arg_val(0); + const uint32_t Ht = get_arg_val(1); + const uint32_t Wt = get_arg_val(2); + const uint32_t ndst = get_arg_val(3); + const uint32_t start_ht = get_arg_val(4); + const uint32_t mask_padded_data = get_arg_val(5); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in2, tt::CB::c_intermed0); + + constexpr uint32_t onetile = 1; + // reserve one tile for zeros on cb_in2 + // We only do the reserve for the intermediates once and use pack_tile + // So effectively these are used as pre-allocated arrays + // Note that the entire W dimension must fit in the intermed0 CB for this kernel to be correct + constexpr auto cb_bcast_scaler = tt::CB::c_in2; + constexpr auto cb_fused_scale = tt::CB::c_in3; + constexpr auto cb_fused_attn = tt::CB::c_in4; + constexpr auto cb_mask_padded = tt::CB::c_in5; + constexpr auto cb_exps = tt::CB::c_intermed0; + constexpr auto cb_scale_mask = tt::CB::c_intermed3; + constexpr auto cb_recipsumexps = tt::CB::c_intermed1; + constexpr auto cb_in0 = tt::CB::c_in0; + constexpr auto cb_out0 = tt::CB::c_out0; + + + cb_wait_front(cb_bcast_scaler, 1); // comes from the reader + + #if FUSED_SCALE_MASK + cb_wait_front(cb_fused_scale, 1); + #endif + + constexpr int dst0 = 0; + uint32_t ht = start_ht; + bool wait_mask = true; + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + #if FUSED_SCALE_MASK + unpack_reconfig_data_format(cb_in0, cb_fused_scale); + pack_reconfig_data_format(cb_scale_mask); + mul_tiles_bcast_scalar_init_short(); + for (uint32_t wt = 0; wt < Wt; wt+=ndst) { + // apply fused scale [*= 1/sqrt(...)] + ACQ(); + cb_wait_front(cb_in0, ndst); + cb_reserve_back(cb_scale_mask, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + mul_tiles_bcast_scalar(cb_in0, cb_fused_scale, wt8, 0, wt8); // mul bcast-HW -> DST[wt8] + pack_tile(wt8, cb_scale_mask); // reuse exps buffer + } + cb_push_back(cb_scale_mask, ndst); + cb_pop_front(cb_in0, ndst); + REL(); + } + unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn); + + exp_tile_init(); + #ifdef CAUSAL_MASK + add_tiles_init(); + #else + add_bcast_rows_init_short(); + #endif + for (uint32_t wt = 0; wt < Wt; wt+=ndst) { + ACQ(); + cb_wait_front(cb_scale_mask, ndst); + #ifdef CAUSAL_MASK + cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + add_tiles(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) + } + #else + if (wait_mask) { + cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles, only at first ht + } + + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) + } + #endif + cb_pop_front(cb_scale_mask, ndst); + cb_reserve_back(cb_exps, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + exp_tile(wt8); // exp on DST[0] + pack_tile(wt8, cb_exps); // reuse the exps buffer again, this time in a circular manner + } + cb_push_back(cb_exps, ndst); + REL(); + } + #ifdef CAUSAL_MASK + cb_pop_front(cb_fused_attn, Wt); + #else + if (wait_mask) { + wait_mask = false; + } + ht++; + if (ht == Ht) { + cb_pop_front(cb_fused_attn, Wt); + ht = 0; + wait_mask = true; + } + #endif // CAUSAL_MASK + + unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + #else + unpack_reconfig_data_format(cb_in0, cb_in0); + pack_reconfig_data_format(cb_exps); + copy_tile_to_dst_init_short(); // need to copy from CB to DST to be able to run sfpu math + exp_tile_init(); + if (mask_padded_data) { + for (uint32_t wt = 0; wt < Wt; wt+=ndst) { + ACQ(); + cb_wait_front(cb_in0, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + if (wt == (Wt - ndst) && (wt8 == ndst - 1)) { + unpack_reconfig_data_format(cb_in0, cb_mask_padded); + add_bcast_rows_init_short(); + cb_wait_front(cb_mask_padded, 1); + add_tiles_bcast_rows(cb_in0, cb_mask_padded, wt8, 0, wt8); + } else { + copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0] + } + } + cb_pop_front(cb_in0, ndst); + + cb_reserve_back(cb_exps, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + exp_tile(wt8); // exp on DST[0] + pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt] + } + cb_push_back(cb_exps, ndst); + REL(); + } + + } else { + for (uint32_t wt = 0; wt < Wt; wt+=ndst) { + ACQ(); + cb_wait_front(cb_in0, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0] + } + cb_pop_front(cb_in0, ndst); + + cb_reserve_back(cb_exps, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + exp_tile(wt8); // exp on DST[0] + pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt] + } + cb_push_back(cb_exps, ndst); + REL(); + } + } + + unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + #endif + + ACQ(); + cb_reserve_back(cb_recipsumexps, onetile); + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + for (uint32_t wt = 0; wt < Wt; wt++) { + cb_wait_front(cb_exps, wt+1); // must be a cumulative wait for correctness + constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB + reduce_tile(cb_exps, cb_bcast_scaler, wt, bcast_scaler0, dst0); + } + reduce_revert_delta(); + recip_tile_init(); + recip_tile(dst0); // DST[0] = 1/sum(exp(x)) + pack_tile(dst0, cb_recipsumexps); + cb_push_back(cb_recipsumexps, 1); + + REL(); + + cb_wait_front(cb_recipsumexps, 1); // will reuse Wt times for bcast + + unpack_reconfig_data_format(cb_exps, cb_recipsumexps); + pack_reconfig_data_format(cb_out0); + // now cb_sumexps has exp tiles, need to multiply by our DST[2] + // by now we already did a umulative wait for Wt tiles in cb_exps + mul_bcast_cols_init_short(); + for (uint32_t wt = 0; wt < Wt; wt += ndst) { + ACQ(); + cb_reserve_back(cb_out0, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + // wt+wt8 since we pop Wt after the entire loop + mul_tiles_bcast(cb_exps, cb_recipsumexps, wt+wt8, 0, wt8); // tile *= 1/(sum(exp(x))) + pack_tile(wt8, cb_out0); + } + cb_push_back(cb_out0, ndst); + REL(); + } + cb_pop_front(cb_recipsumexps, 1); + cb_pop_front(cb_exps, Wt); + } // NCHt loop + //cb_pop_front(cb_bcast_scaler, 1); // we don't actually have to do this + //cb_pop_front(cb_fused_scale, 1); // we don't actually have to do this +} +} diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp new file mode 100644 index 000000000000..303ae3386b1e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#define REDUCE_OP PoolType::SUM +#define REDUCE_DIM ReduceDim::REDUCE_ROW + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/softmax.h" +#include "compute_kernel_api/reduce.h" + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + +namespace NAMESPACE { +void MAIN { + + constexpr uint32_t block_h = get_compile_time_arg_val(0); + constexpr uint32_t block_w = get_compile_time_arg_val(1); + constexpr uint32_t subblock_w = get_compile_time_arg_val(2); + constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(3); + + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1, tt::CB::c_intermed0); + + constexpr auto cb_in0 = tt::CB::c_in0; + constexpr auto cb_bcast_scaler = tt::CB::c_in1; + constexpr auto cb_fused_scale = tt::CB::c_in2; + constexpr auto cb_fused_attn = tt::CB::c_in3; + constexpr auto cb_exps = tt::CB::c_intermed0; + constexpr auto cb_recipsumexps = tt::CB::c_intermed1; + constexpr auto cb_scale_mask = tt::CB::c_intermed2; + constexpr auto cb_out0 = tt::CB::c_out0; + + constexpr int dst0 = 0; + int index_subblock_w_offset = 0; + int index = 0; + + for (uint32_t i = 0; i < block_h; i++) { + #if FUSED_SCALE_MASK + // fused scale + unpack_reconfig_data_format(cb_in0, cb_fused_scale); + pack_reconfig_data_format(cb_scale_mask); + cb_wait_front(cb_fused_scale, 1); + // UNPACK(( DPRINT << TSLICE(cb_fused_scale, 0, SliceRange::h0_w0_32()) << ENDL() )); + mul_tiles_bcast_scalar_init_short(); + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + cb_reserve_back(cb_scale_mask, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + mul_tiles_bcast_scalar(cb_in0, cb_fused_scale, index, 0, w); + pack_tile(w, cb_scale_mask); + } + cb_push_back(cb_scale_mask, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; + } + cb_pop_front(cb_in0, block_w); + unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn); + + // fused attn + cb_wait_front(cb_scale_mask, block_w); + + #ifndef SHARDED_CAUSAL_MASK + cb_wait_front(cb_fused_attn, block_w); + #endif + + index_subblock_w_offset = 0; + + #ifdef CAUSAL_MASK + add_tiles_init(); + #else + add_bcast_rows_init_short(); + #endif + + exp_tile_init(); + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + #ifdef CAUSAL_MASK + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + add_tiles(cb_scale_mask, cb_fused_attn, index, index, w); + } + #else + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, index, index, w); + } + #endif + cb_reserve_back(cb_exps, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + exp_tile(w); + pack_tile(w, cb_exps); + } + cb_push_back(cb_exps, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; + } + cb_pop_front(cb_scale_mask, block_w); + + #ifdef CAUSAL_MASK + cb_pop_front(cb_fused_attn, block_w); + #endif + unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + + #else + unpack_reconfig_data_format(cb_in0, cb_in0); + pack_reconfig_data_format(cb_exps); + // exp(x) + index_subblock_w_offset = 0; + copy_tile_to_dst_init_short(); + exp_tile_init(); + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + copy_tile(cb_in0, index, w); + } + cb_reserve_back(cb_exps, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + exp_tile(w); + pack_tile(w, cb_exps); + } + cb_push_back(cb_exps, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; + } + cb_pop_front(cb_in0, block_w); + unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + #endif // FUSED_SCALE_MASK + + // sum(exp(x)) + ACQ(); + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_wait_front(cb_exps, block_w); + cb_wait_front(cb_bcast_scaler, 1); + cb_reserve_back(cb_recipsumexps, 1); + for (uint32_t w = 0; w < block_w; w++) { + constexpr uint32_t bcast_scaler0 = 0; + reduce_tile(cb_exps, cb_bcast_scaler, w, bcast_scaler0, dst0); + } + reduce_revert_delta(); + recip_tile_init(); + recip_tile(dst0); + pack_tile(dst0, cb_recipsumexps); + cb_push_back(cb_recipsumexps, 1); + REL(); + + // exp(x) / (sum(exp(x))) + unpack_reconfig_data_format(cb_exps, cb_recipsumexps); + pack_reconfig_data_format(cb_out0); + cb_wait_front(cb_recipsumexps, 1); + mul_bcast_cols_init_short(); + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + cb_reserve_back(cb_out0, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + mul_tiles_bcast(cb_exps, cb_recipsumexps, index, 0, w); + pack_tile(w, cb_out0); + } + cb_push_back(cb_out0, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; + } + cb_pop_front(cb_recipsumexps, 1); + cb_pop_front(cb_exps, block_w); + } + +} +} diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/readed_unary_sharded_sm_causal_mask_hw_dims.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/readed_unary_sharded_sm_causal_mask_hw_dims.cpp new file mode 100644 index 000000000000..f1694ce1193b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/readed_unary_sharded_sm_causal_mask_hw_dims.cpp @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" + +// HW-bcast scale for fused scale-attn-softmax +FORCE_INLINE void generate_inv_sqrt_hw_bcast_tile() { + constexpr auto cb_fused_scale = tt::CB::c_in2; + uint32_t u = get_arg_val(1); + cb_reserve_back(cb_fused_scale, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_fused_scale)); + ptr[0] = u >> 16; + cb_push_back(cb_fused_scale, 1); +} + +void kernel_main() { + constexpr uint32_t cb_reduce_scaler = tt::CB::c_in1; + const uint32_t reduce_scaler = get_arg_val(0); + + constexpr uint32_t block_wt = get_compile_time_arg_val(0); + constexpr bool is_dram_mask = get_compile_time_arg_val(1) == 1; + + const uint32_t mask_addr = get_arg_val(2); + const uint32_t mask_start_tile_id = get_arg_val(3); + uint32_t mask_num_tiles = get_arg_val(4); + + constexpr uint32_t cb_attn = tt::CB::c_in3; + uint32_t mask_tile_bytes = get_tile_size(cb_attn); + const DataFormat mask_data_format = get_dataformat(cb_attn); + uint32_t mask_id = mask_start_tile_id; + + const InterleavedAddrGenFast addr_mask = { + .bank_base_address = mask_addr, .page_size = mask_tile_bytes, .data_format = mask_data_format}; + + constexpr auto cb_fused_scale = tt::CB::c_in2; + const uint32_t pre_scale = get_arg_val(1); + generate_bcast_unary_scalar(cb_fused_scale, pre_scale); + + constexpr uint32_t block_ht = get_compile_time_arg_val(4); + for (uint32_t h = 0; h < block_ht; h++) { + cb_reserve_back(cb_attn, block_wt); + uint32_t l1_write_addr = get_write_ptr(cb_attn); + for (uint32_t w = 0; w < block_wt; w++) { + noc_async_read_tile(mask_id, addr_mask, l1_write_addr); + l1_write_addr += mask_tile_bytes; + ++mask_id; + + if (h == 0 && w == 0) { + generate_reduce_scaler(cb_reduce_scaler, reduce_scaler); + } + } + noc_async_read_barrier(); + + cb_push_back(cb_attn, block_wt); + if (mask_id == mask_num_tiles) { + mask_id = 0; + } + } +} diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_interleaved_sm.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_interleaved_sm.cpp new file mode 100644 index 000000000000..100cc2a6fb19 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_interleaved_sm.cpp @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" + +void kernel_main() { + + const uint32_t src_addr = get_arg_val(0); + const uint32_t blk = get_arg_val(1); + const uint32_t num_blks = get_arg_val(3); // same arg index as in reader_unary and in reader_unary_transpose_wh_8bank + const uint32_t tile_offset = get_arg_val(4); + const uint32_t Wt = get_arg_val(5); + + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_in0 = 0, cb_id_in1 = 1; + + // ublocks size defined in tiles + constexpr uint32_t onetile = 1; + uint32_t src0_tile_bytes = get_tile_size(cb_id_in0); + const DataFormat src0_data_format = get_dataformat(cb_id_in0); + + #if FUSED_SCALE_MASK + uint32_t Ht = get_arg_val(6); + uint32_t mask_addr = get_arg_val(7); + uint32_t start_ht = get_arg_val(8); + uint32_t start_mask_id = get_arg_val(9); + constexpr bool mask_is_dram = get_compile_time_arg_val(1) == 1; + + constexpr uint32_t cb_id_attn = 4; + uint32_t mask_tile_bytes = get_tile_size(cb_id_attn); + const DataFormat mask_data_format = get_dataformat(cb_id_attn); + + const InterleavedAddrGenFast addr_mask = { + .bank_base_address = mask_addr, + .page_size = mask_tile_bytes, + .data_format = mask_data_format + }; + + #if CAUSAL_MASK + constexpr uint32_t num_tiles_causal_mask = get_compile_time_arg_val(2); + uint32_t mask_start_ht = get_arg_val(11); + uint32_t mask_offset = get_arg_val(12); + + uint32_t mask_id_offset = mask_offset; + uint32_t mask_ht = mask_start_ht; + #endif + + uint32_t ht = start_ht; + uint32_t mask_id = start_mask_id; + bool read_mask = true; + constexpr auto cb_fused_scale = tt::CB::c_in3; + const uint32_t pre_scale = get_arg_val(2); + generate_bcast_unary_scalar(cb_fused_scale, pre_scale); + #endif + + const InterleavedAddrGenFast src_a = { + .bank_base_address = src_addr, + .page_size = src0_tile_bytes, + .data_format = src0_data_format + }; + + + // TODO(AP): cleanup, probably with named args/param pack/reflection. + { + constexpr uint32_t cb_in_2 = 2; + const uint32_t reduce_scaler = get_arg_val(10); + generate_reduce_scaler(cb_in_2, reduce_scaler); + } + + // read a ublock of tiles from src to CB, and then push the ublock to unpacker + uint32_t i_tile = 0; + uint32_t curr_tile = tile_offset; + for (uint32_t i = 0; i num_tiles) ? num_tiles - i : blk; + cb_reserve_back(cb_id_in0, rem); + uint32_t l1_write_addr = get_write_ptr(cb_id_in0); + + for (uint32_t r = 0; r(1); + cb_reserve_back(cb_fused_scale, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_fused_scale)); + ptr[0] = u>>16; + cb_push_back(cb_fused_scale, 1); +} + +void kernel_main() { + + constexpr uint32_t cb_reduce_scaler = tt::CB::c_in1; + const uint32_t reduce_scaler = get_arg_val(0); + + #if FUSED_SCALE_MASK + constexpr uint32_t block_wt = get_compile_time_arg_val(0); + constexpr bool is_dram_mask = get_compile_time_arg_val(1) == 1; + const uint32_t mask_addr = get_arg_val(2); + const uint32_t mask_start_tile_id = get_arg_val(3); + + constexpr uint32_t cb_attn = tt::CB::c_in3; + uint32_t mask_tile_bytes = get_tile_size(cb_attn); + const DataFormat mask_data_format = get_dataformat(cb_attn); + uint32_t mask_id = mask_start_tile_id; + + const InterleavedAddrGenFast addr_mask = { + .bank_base_address = mask_addr, + .page_size = mask_tile_bytes, + .data_format = mask_data_format + }; + + constexpr auto cb_fused_scale = tt::CB::c_in2; + const uint32_t pre_scale = get_arg_val(1); + generate_bcast_unary_scalar(cb_fused_scale, pre_scale); + + #if defined(CAUSAL_MASK) && !defined(SHARDED_CAUSAL_MASK) + + constexpr uint32_t fused_head = get_compile_time_arg_val(4); + constexpr uint32_t mask_block_ht = get_compile_time_arg_val(6); + + for (uint32_t f = 0; f(2); + const uint32_t mask_start_tile_id = get_arg_val(3); + + constexpr uint32_t cb_attn = tt::CB::c_in3; + uint32_t mask_tile_bytes = get_tile_size(cb_attn); + + #define stick_size_is_pow2 get_compile_time_arg_val(2) == 1 + #if (stick_size_is_pow2) + constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(3); + #else + constexpr uint32_t page_size = get_compile_time_arg_val(3); + #endif + #if (stick_size_is_pow2) + const InterleavedPow2AddrGen addr_mask = { + .bank_base_address = mask_addr, + .log_base_2_of_page_size = log_base_2_of_page_size + }; + #else + const InterleavedAddrGen addr_mask = { + .bank_base_address = mask_addr, + .page_size = page_size + }; + #endif + + constexpr auto cb_fused_scale = tt::CB::c_in2; + const uint32_t pre_scale = get_arg_val(1); + generate_bcast_unary_scalar(cb_fused_scale, pre_scale); + + constexpr uint32_t FLOAT32_DTYPE = get_compile_time_arg_val(4); + uint32_t mask_read_tile_face_bytes = FLOAT32_DTYPE ? 64 : 32; + uint32_t mask_read_tile_offset_bytes = FLOAT32_DTYPE ? 1024 : 512; + + cb_reserve_back(cb_attn, block_wt); + uint32_t l1_write_addr = get_write_ptr(cb_attn); + for (uint32_t w = 0; w(0); + generate_reduce_scaler(cb_reduce_scaler, reduce_scaler); + } +} diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/writer_unary_interleaved_start_id_blocked_sm.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/writer_unary_interleaved_start_id_blocked_sm.cpp new file mode 100644 index 000000000000..885c558c41f8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/writer_unary_interleaved_start_id_blocked_sm.cpp @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +// #include "debug/dprint.h" + +// H-bcast mask +FORCE_INLINE void generate_bcast_row_mask(const uint32_t cb_id, const uint32_t num_datum_padded, const uint32_t val) { + const uint32_t mask_val = val>>16; + cb_reserve_back(cb_id, 1); + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(get_write_ptr(cb_id)); + + if (num_datum_padded > 16) { + uint32_t num_datum_unpadded_f1 = 32 - num_datum_padded; + uint32_t idx = 0; + for (uint32_t j = 0; j < num_datum_unpadded_f1; ++j) { // first face + ptr[idx + j] = 0; + } + for (uint32_t j = num_datum_unpadded_f1; j < 16; ++j) { // first face + ptr[idx + j] = mask_val; + } + + idx = 1 << 8; + for (uint32_t j = 0; j < 16; ++j) { // second face + ptr[idx + j] = mask_val; + } + } else { + uint32_t num_datum_unpadded_f2 = 16 - num_datum_padded; + uint32_t idx = 0; + for (uint32_t j = 0; j < 16; ++j) { // first face + ptr[idx + j] = 0; + } + + idx = 1 << 8; + for (uint32_t j = 0; j < num_datum_unpadded_f2; ++j) { // second face + ptr[idx + j] = 0; + } + for (uint32_t j = num_datum_unpadded_f2; j < 16; ++j) { // second face + ptr[idx + j] = mask_val; + } + } + + cb_push_back(cb_id, 1); +} + +void kernel_main() { + const uint32_t dst_addr = get_arg_val(0); + const uint32_t num_tiles = get_arg_val(1); + const uint32_t tile_offset = get_arg_val(2); + const uint32_t blk = get_arg_val(3); + + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; + + + constexpr uint32_t cb_id_out0 = 16; + constexpr uint32_t onetile = 1; + const uint32_t tile_bytes = get_tile_size(cb_id_out0); + const DataFormat data_format = get_dataformat(cb_id_out0); + + constexpr uint32_t cb_id_mask = tt::CB::c_in5; + const uint32_t mask_padded_data = get_arg_val(4); + const uint32_t num_datum_padded = get_arg_val(5); + const uint32_t val_to_pad = get_arg_val(6); + if (mask_padded_data) { + generate_bcast_row_mask(cb_id_mask, num_datum_padded, val_to_pad); + } + + const InterleavedAddrGenFast s = { + .bank_base_address = dst_addr, + .page_size = tile_bytes, + .data_format = data_format + }; + + uint32_t tile_id = tile_offset; + for (uint32_t i = 0; i + +using namespace tt::constants; +namespace ttnn::operations::normalization { + +inline bool is_dram(const Tensor& input_tensor) { return input_tensor.memory_config().buffer_type == BufferType::DRAM; } +inline bool is_dram(const std::optional input_tensor) { + return input_tensor.has_value() ? is_dram(input_tensor.value()) : true; +} +inline bool is_dram(const Buffer* b) { return b->buffer_type() == BufferType::DRAM; } + +// implementation of softmax with optional scale/mask (see the header for input_tensor more detailed description) +operation::ProgramWithCallbacks scale_mask_softmax_multi_core( + const Tensor &input_tensor, + const Tensor &output_tensor, + const std::optional mask, + std::optional scale, + bool causal_mask, + DeviceComputeKernelConfig compute_kernel_config +) { + + const auto shape = input_tensor.get_legacy_shape(); + uint32_t W = shape[-1], H = (input_tensor.volume() / (shape[0] * shape[-1])), NC = shape[0]; + uint32_t HW = H*W; + + bool mask_padded_data = false; + uint32_t num_datum_padded = 0; + const auto shape_unpadded = input_tensor.get_shape(); + uint32_t W_unpadded = shape_unpadded[-1]; + if (W > W_unpadded) { + mask_padded_data = true; + num_datum_padded = W - W_unpadded; + } + + uint32_t Wt = W/TILE_WIDTH; + uint32_t Ht = H/TILE_HEIGHT; + + uint32_t mask_H = H; + if (mask.has_value()) { + mask_H = mask.value().get_legacy_shape()[2]; + } + uint32_t mask_Ht = mask_H/TILE_HEIGHT; + + Program program = CreateProgram(); + + // This should allocate input_tensor DRAM buffer on the device + Device *device = input_tensor.device(); + + tt::DataFormat in0_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t in0_tile_size = tt::tt_metal::detail::TileSize(in0_cb_data_format); + + MathFidelity math_fidelity; + bool math_approx_mode; + bool fp32_dest_acc_en; + + std::visit([&](auto&& compute_kernel_config) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = false; + } else if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = in0_cb_data_format == tt::DataFormat::Float32 ? true : compute_kernel_config.fp32_dest_acc_en; + } else { + TT_FATAL("arch not supported"); + } + + }, compute_kernel_config); + + tt::DataFormat scalar_cb_data_format = tt::DataFormat::Float16_b; + uint32_t scalar_tile_size = tt::tt_metal::detail::TileSize(scalar_cb_data_format); + + tt::DataFormat out0_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + uint32_t out0_tile_size = tt::tt_metal::detail::TileSize(out0_cb_data_format); + + tt::DataFormat mask_cb_data_format = mask.has_value() ? tt::tt_metal::datatype_to_dataformat_converter(mask.value().get_dtype()) : tt::DataFormat::Float16_b; + uint32_t mask_tile_size = tt::tt_metal::detail::TileSize(mask_cb_data_format); + + tt::DataFormat im_cb_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b; + uint32_t im_tile_size = tt::tt_metal::detail::TileSize(im_cb_data_format); + + tt::log_debug("in0_cb_data_format: {}", in0_cb_data_format); + tt::log_debug("out0_cb_data_format: {}", out0_cb_data_format); + tt::log_debug("mask_cb_data_format: {}", mask_cb_data_format); + tt::log_debug("im_cb_data_format: {}", im_cb_data_format); + tt::log_debug("math_fidelity: {}", math_fidelity); + tt::log_debug("math_approx_mode: {}", math_approx_mode); + tt::log_debug("fp32_dest_acc_en: {}", fp32_dest_acc_en); + + auto src0_buffer = input_tensor.buffer(); + auto out0_buffer = output_tensor.buffer(); + + uint32_t num_tiles = input_tensor.volume()/TILE_HW; + + uint32_t block_size = fp32_dest_acc_en ? find_max_divisor(Wt, 4) : find_max_divisor(Wt, 8); + + // These tile capacity counts for CBs need to match the number of tiles expected by the kernel (softmax.cpp) + uint32_t in0_t = block_size*2; + uint32_t out0_t = block_size*2; + uint32_t im1_t = 1; // 1/sum(exp(x)) + uint32_t in2_t = 1; // scaler for reduce coming from reader + uint32_t in3_t = 1; // 1/sqrt() scaler tile cb for fused scale/mask/softmax variant + uint32_t in4_t = tt::div_up(Wt, block_size)*block_size; // attention mask (N,C,32,W) - Wt is reused for each Ht, NC is cycled + uint32_t in5_t = 1; + + // cb_exps - keeps exps in tt::CB in L1 to avoid recomputing + uint32_t im0_t = block_size*tt::div_up(Wt, block_size); + TT_ASSERT(im0_t == Wt); + + // used for buffering scale-mask + // can't easily reuse im0_t because cumulative wait for Wt needs to have Wt tiles contiguous free + uint32_t im3_t = block_size*(tt::div_up(Wt, block_size)+1); + TT_ASSERT(im3_t == Wt+block_size); + + TT_ASSERT(Wt % block_size == 0); + TT_ASSERT((block_size != -1) && "Wt must be divisible by one of the numbers in the range from 8 to 1."); + TT_ASSERT(im0_t % block_size == 0 && "Size of cb must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(out0_t % block_size == 0 && "Size of cb must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(in4_t % block_size == 0); + TT_ASSERT(W <= TILE_WIDTH*im0_t && "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + + uint32_t num_tile_rows = NC * Ht; + auto grid_size = device->compute_with_storage_grid_size(); + auto all_device_cores = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tile_rows_per_core_group_1, num_tile_rows_per_core_group_2] = split_work_to_cores(grid_size, num_tile_rows, true); + + bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool out0_is_dram = out0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + // interleaved accessor args + src0_is_dram + }; + if (mask.has_value()) { + bool mask_is_dram = mask.value().buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + reader_compile_time_args.push_back(mask_is_dram); + } + if (causal_mask) { + uint32_t num_tiles_causal_mask = mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; + reader_compile_time_args.push_back(num_tiles_causal_mask); + } + + std::vector writer_compile_time_args = {// interleaved accessor args + out0_is_dram}; + std::map softmax_defines; + if (mask.has_value()) { + softmax_defines["FUSED_SCALE_MASK"] = "1"; + } + if (causal_mask) { + softmax_defines["CAUSAL_MASK"] = "1"; + } + auto reader_kernels_id = CreateKernel( + program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_interleaved_sm.cpp", all_device_cores, + tt::tt_metal::ReaderDataMovementConfig( + reader_compile_time_args, + softmax_defines + )); + + auto writer_kernels_id = CreateKernel( + program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/writer_unary_interleaved_start_id_blocked_sm.cpp", all_device_cores, + tt::tt_metal::WriterDataMovementConfig( + writer_compile_time_args, + softmax_defines + )); + + // for broadcasting in H direction we need to + // NCHt, Nt, Wt + // if wtpc < Ht then since we pass tpc to the kernel as Ht, the broadcasts should be correct + // if wtpc >= Ht then tpc should be a multiple of Ht + + softmax_defines["EXP_APPROX"] = math_approx_mode ? "1" : "0"; + auto softmax_kernels_id = CreateKernel( + program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp", all_device_cores, + tt::tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, + .compile_args = {}, + .defines = softmax_defines + }); + + // Create circular buffers + // see softmax.cpp for which buffers are needed + + auto c_in0_config = CircularBufferConfig(in0_t * in0_tile_size, {{tt::CB::c_in0, in0_cb_data_format}}).set_page_size(tt::CB::c_in0, in0_tile_size); + auto cb_in0_id = CreateCircularBuffer( program, all_device_cores, c_in0_config); + auto c_out0_config = CircularBufferConfig(out0_t * out0_tile_size, {{tt::CB::c_out0, out0_cb_data_format}}).set_page_size(tt::CB::c_out0, out0_tile_size); + auto cb_out0_id = CreateCircularBuffer( program, all_device_cores, c_out0_config ); + auto c_intermed1_config = CircularBufferConfig(im1_t * im_tile_size, {{tt::CB::c_intermed1, im_cb_data_format}}).set_page_size(tt::CB::c_intermed1, im_tile_size); + auto cb_intermed1_id = CreateCircularBuffer( program, all_device_cores, c_intermed1_config ); + auto c_in2_config = CircularBufferConfig(in2_t * scalar_tile_size, {{tt::CB::c_in2, scalar_cb_data_format}}).set_page_size(tt::CB::c_in2, scalar_tile_size); + auto cb_in2_id = CreateCircularBuffer( program, all_device_cores, c_in2_config ); + auto c_intermed0_config = CircularBufferConfig(im0_t * im_tile_size, {{tt::CB::c_intermed0, im_cb_data_format}}).set_page_size(tt::CB::c_intermed0, im_tile_size); + auto cb_intermed0_id = CreateCircularBuffer( program, all_device_cores, c_intermed0_config ); + std::optional cb_intermed3_id; + std::optional cb_in3_id; + std::optional cb_in4_id; + std::optional cb_in5_id; + if (mask.has_value()) { + CircularBufferConfig c_intermed3_config = CircularBufferConfig(im3_t * im_tile_size, {{tt::CB::c_intermed3, im_cb_data_format}}).set_page_size(tt::CB::c_intermed3, im_tile_size); + cb_intermed3_id = CreateCircularBuffer( program, all_device_cores, c_intermed3_config ); + CircularBufferConfig c_in3_config = CircularBufferConfig(in3_t * scalar_tile_size, {{tt::CB::c_in3, scalar_cb_data_format}}).set_page_size(tt::CB::c_in3, scalar_tile_size); + cb_in3_id = CreateCircularBuffer( program, all_device_cores, c_in3_config ); + CircularBufferConfig c_in4_config = CircularBufferConfig(in4_t * mask_tile_size, {{tt::CB::c_in4, mask_cb_data_format}}).set_page_size(tt::CB::c_in4, mask_tile_size); + cb_in4_id = CreateCircularBuffer( program, all_device_cores, c_in4_config); + } + CircularBufferConfig c_in5_config = CircularBufferConfig(in5_t * mask_tile_size, {{tt::CB::c_in5, mask_cb_data_format}}).set_page_size(tt::CB::c_in5, mask_tile_size); + cb_in5_id = CreateCircularBuffer( program, all_device_cores, c_in5_config); + + uint32_t src_addr = src0_buffer->address(); + uint32_t mask_addr = mask.has_value() ? mask.value().buffer()->address() : 0; + uint32_t out_addr = out0_buffer->address(); + + uint32_t curr_row = 0; + union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax + for (uint32_t i = 0; i < grid_size.x * grid_size.y; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + if (i >= num_cores) { + SetRuntimeArgs(program, reader_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); // [8]=1.0f is scaler + SetRuntimeArgs(program, softmax_kernels_id, core, { 0, 0, 0, 0, 0, 0 }); + SetRuntimeArgs(program, writer_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0 }); + continue; + } + uint32_t num_tile_rows_per_core = 0; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_2; + } else { + TT_ASSERT(false, "Core not in specified core ranges"); + } + + uint32_t tile_offset = curr_row * Wt; + uint32_t curr_ht = curr_row % Ht; + uint32_t mask_curr_ht = curr_ht % mask_Ht; // the start offset for causal mask + uint32_t mask_offset = curr_row / Ht * mask_Ht * Wt; // causal mask batch offset + uint32_t mask_id = causal_mask ? (mask_curr_ht * Wt + mask_offset) : (curr_row / Ht * Wt); // causal mask start offset + causal mask batch offset + + if (causal_mask) { + SetRuntimeArgs(program, reader_kernels_id, core, { src_addr, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_addr, curr_ht, mask_id, 0x3f803f80, mask_curr_ht, mask_offset }); // [10]=1.0f is scaler + } else { + SetRuntimeArgs(program, reader_kernels_id, core, { src_addr, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_addr, curr_ht, mask_id, 0x3f803f80 }); // [10]=1.0f is scaler + } + + SetRuntimeArgs(program, softmax_kernels_id, core, { num_tile_rows_per_core, Ht, Wt, block_size, curr_ht, mask_padded_data }); + + SetRuntimeArgs(program, writer_kernels_id, core, { out_addr, num_tile_rows_per_core * Wt, tile_offset, block_size, mask_padded_data, num_datum_padded, 0xFF00FF00}); + + curr_row += num_tile_rows_per_core; + } + + auto override_runtime_arguments_callback = [ + reader_kernels_id, + writer_kernels_id, + softmax_kernels_id, + grid_size, + scalar_tile_size, + in0_tile_size, + im_tile_size, + out0_tile_size, + mask_tile_size, + cb_in0_id, + cb_out0_id, + cb_intermed1_id, + cb_in2_id, + cb_intermed0_id, + cb_intermed3_id, + cb_in3_id, + cb_in4_id, + causal_mask + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + + const auto scale = static_cast(operation)->scale; + + auto src_buffer_address = input_tensors.at(0).buffer()->address(); + auto mask_buffer_address = optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; + auto dst_buffer_address = output_tensors.size() == 1 ? output_tensors.at(0).buffer()->address() : src_buffer_address; + + const auto shape = input_tensors.at(0).get_legacy_shape(); + uint32_t W = shape[-1], H = (input_tensors.at(0).volume() / (shape[0] * shape[-1])), NC = shape[0]; + uint32_t HW = H*W; + + uint32_t Wt = W/TILE_WIDTH; + uint32_t Ht = H/TILE_HEIGHT; + + bool mask_padded_data = false; + uint32_t num_datum_padded = 0; + const auto shape_unpadded = input_tensors.at(0).get_shape(); + uint32_t W_unpadded = shape_unpadded[-1]; + if (W > W_unpadded) { + mask_padded_data = true; + num_datum_padded = W - W_unpadded; + } + + int32_t num_tiles = input_tensors.at(0).volume()/TILE_HW; + uint32_t block_size = find_max_divisor(Wt, 8); + + // These tile capacity counts for CBs need to match the number of tiles expected by the kernel (softmax.cpp) + uint32_t in0_t = block_size*2; + uint32_t out0_t = block_size*2; + uint32_t im1_t = 1; // 1/sum(exp(x)) + uint32_t in2_t = 1; // scaler for reduce coming from reader + uint32_t in3_t = 1; // 1/sqrt() scaler tile cb for fused scale/mask/softmax variant + uint32_t in4_t = tt::div_up(Wt, block_size)*block_size; // attention mask (N,C,32,W) - Wt is reused for each Ht, NC is cycled + + // cb_exps - keeps exps in tt::CB in L1 to avoid recomputing + uint32_t im0_t = block_size*tt::div_up(Wt, block_size); + TT_ASSERT(im0_t == Wt); + + // used for buffering scale-mask + // can't easily reuse im0_t because cumulative wait for Wt needs to have Wt tiles contiguous free + uint32_t im3_t = block_size*(tt::div_up(Wt, block_size)+1); + TT_ASSERT(im3_t == Wt+block_size); + + TT_ASSERT(Wt % block_size == 0); + TT_ASSERT((block_size != -1) && "Wt must be divisible by one of the numbers in the range from 8 to 1."); + TT_ASSERT(im0_t % block_size == 0 && "Size of cb must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(out0_t % block_size == 0 && "Size of cb must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(in4_t % block_size == 0); + TT_ASSERT(W <= TILE_WIDTH*im0_t && "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + + uint32_t NCHt = NC*Ht; + uint32_t num_tile_rows = NC * Ht; + auto all_device_cores = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tile_rows_per_core_group_1, num_tile_rows_per_core_group_2] = split_work_to_cores(grid_size, num_tile_rows, true); + + UpdateCircularBufferTotalSize(program, cb_in0_id, in0_t * in0_tile_size); + UpdateCircularBufferTotalSize(program, cb_out0_id, out0_t * out0_tile_size); + UpdateCircularBufferTotalSize(program, cb_intermed1_id, im1_t * im_tile_size); + UpdateCircularBufferTotalSize(program, cb_in2_id, in2_t * scalar_tile_size); + UpdateCircularBufferTotalSize(program, cb_intermed0_id, im0_t * im_tile_size); + + if (optional_input_tensors.at(0).has_value()) { + UpdateCircularBufferTotalSize(program, cb_intermed3_id.value(), im3_t * im_tile_size); + UpdateCircularBufferTotalSize(program, cb_in3_id.value(), in3_t * scalar_tile_size); + UpdateCircularBufferTotalSize(program, cb_in4_id.value(), in4_t * mask_tile_size); + } + + uint32_t curr_row = 0; + union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax + for (uint32_t i = 0; i < grid_size.x * grid_size.y; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + if (i >= num_cores) { + SetRuntimeArgs(program, reader_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); // [8]=1.0f is scaler + SetRuntimeArgs(program, softmax_kernels_id, core, { 0, 0, 0, 0, 0, 0 }); + SetRuntimeArgs(program, writer_kernels_id, core, { 0, 0, 0, 0, 0, 0, 0}); + continue; + } + + uint32_t num_tile_rows_per_core = 0; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_2; + } else { + TT_ASSERT(false, "Core not in specified core ranges"); + } + + uint32_t tile_offset = curr_row * Wt; + uint32_t curr_ht = curr_row % Ht; + uint32_t mask_curr_ht = curr_ht % Wt; // the start offset for causal mask + uint32_t mask_offset = curr_row / Ht * Wt * Wt; // causal mask batch offset + uint32_t mask_id = causal_mask ? (mask_curr_ht * Wt + mask_offset) : (curr_row / Ht * Wt); // causal mask start offset + causal mask batch offset + + if (causal_mask) { + SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80, mask_curr_ht, mask_offset }); // [10]=1.0f is scaler + } else { + SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80 }); // [10]=1.0f is scaler + } + + SetRuntimeArgs(program, softmax_kernels_id, core, { num_tile_rows_per_core, Ht, Wt, block_size, curr_ht, mask_padded_data }); + + SetRuntimeArgs(program, writer_kernels_id, core, { dst_buffer_address, num_tile_rows_per_core * Wt, tile_offset, block_size, mask_padded_data, num_datum_padded, 0xFF00FF00}); + + curr_row += num_tile_rows_per_core; + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} // scale_mask_softmax_multi_core + +// implementation of softmax with optional scale/mask (see the header for input_tensor more detailed description) +operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( + const Tensor &input_tensor, + const Tensor &output_tensor, + const std::optional mask, + std::optional scale, + bool causal_mask, + bool hw_dims_only_causal_mask, + CoreCoord grid_size, + uint32_t subblock_wt, + uint32_t block_ht, + uint32_t block_wt, + DeviceComputeKernelConfig compute_kernel_config +) { + //////////////////////////////////////////////////////////////////////////// + // Device Setup + //////////////////////////////////////////////////////////////////////////// + Device *device = input_tensor.device(); + + // convert data format + tt::DataFormat in0_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + + MathFidelity math_fidelity; + bool math_approx_mode; + bool fp32_dest_acc_en; + + std::visit([&](auto&& compute_kernel_config) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = false; + } else if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = in0_cb_data_format == tt::DataFormat::Float32 ? true : compute_kernel_config.fp32_dest_acc_en; + if (fp32_dest_acc_en) + TT_FATAL(subblock_wt <= 4, "in fp32 mode, subblock width must be smaller/equal than 4"); + } else { + TT_FATAL("arch not supported"); + } + + }, compute_kernel_config); + + tt::DataFormat out0_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + tt::DataFormat im_cb_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b; + tt::DataFormat mask_cb_data_format = mask.has_value() ? tt::tt_metal::datatype_to_dataformat_converter(mask->get_dtype()) : tt::DataFormat::Float16_b; + tt::DataFormat scale_cb_data_format = tt::DataFormat::Float16_b; + tt::DataFormat scalar_cb_data_format = tt::DataFormat::Float16_b; + + tt::log_debug("in0_cb_data_format: {}", in0_cb_data_format); + tt::log_debug("out0_cb_data_format: {}", out0_cb_data_format); + tt::log_debug("mask_cb_data_format: {}", mask_cb_data_format); + tt::log_debug("im_cb_data_format: {}", im_cb_data_format); + tt::log_debug("scale_cb_data_format: {}", im_cb_data_format); + tt::log_debug("scalar_cb_data_format: {}", im_cb_data_format); + tt::log_debug("math_fidelity: {}", math_fidelity); + tt::log_debug("math_approx_mode: {}", math_approx_mode); + tt::log_debug("fp32_dest_acc_en: {}", fp32_dest_acc_en); + + // tensor shape + const auto shard_orient = input_tensor.shard_spec().value().orientation; + const auto shape = input_tensor.get_legacy_shape(); + uint32_t M = shape[2] * shape[0]; + uint32_t K = shape[3] * shape[1]; + uint32_t Mt = M / TILE_WIDTH; + uint32_t Kt = K / TILE_WIDTH; + uint32_t num_cores_per_batch = (shape[1] * shape[2] * shape[3]) / (input_tensor.shard_spec().value().shape[0] * input_tensor.shard_spec().value().shape[1]); + + uint32_t mask_H = shape[2]; + if (mask.has_value()) { + mask_H = mask->get_legacy_shape()[2]; + } + uint32_t mask_Ht = mask_H/TILE_HEIGHT; + // block + uint32_t block_w = block_wt * TILE_WIDTH; + uint32_t block_h = block_ht * TILE_WIDTH; + uint32_t num_subblocks_w = block_wt / subblock_wt; + + // single tile sizes + uint32_t im_tile_size = tt::tt_metal::detail::TileSize(im_cb_data_format); + uint32_t in0_tile_size = tt::tt_metal::detail::TileSize(in0_cb_data_format); + uint32_t out0_tile_size = tt::tt_metal::detail::TileSize(out0_cb_data_format); + uint32_t mask_tile_size = tt::tt_metal::detail::TileSize(mask_cb_data_format); + uint32_t scale_tile_size = tt::tt_metal::detail::TileSize(scale_cb_data_format); + uint32_t scalar_tile_size = tt::tt_metal::detail::TileSize(scalar_cb_data_format); + // in out buffer + auto src0_buffer = input_tensor.buffer(); + auto out0_buffer = output_tensor.buffer(); + // num tiles + uint32_t num_tiles = input_tensor.volume()/TILE_HW; + + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + // block size for in0 (tensor a) + uint32_t in0_CB_size = block_wt * block_ht * in0_tile_size; + // scaler for reduce coming from reader + uint32_t in1_CB_size = 1 * scalar_tile_size; + // 1/sqrt() scaler tile cb for fused scale/mask/softmax variant + uint32_t in2_CB_size = 1 * scale_tile_size; + // attention mask + uint32_t in3_CB_size; + if (causal_mask) { + if (mask.value().is_sharded()) { + in3_CB_size = block_wt * block_ht * mask_tile_size; + } else { + in3_CB_size = block_wt * mask_tile_size; + if (!hw_dims_only_causal_mask) { + // For some reason, if we have hw_dims_causal_mask version, single buffering is up to ~20% faster + // Then double buffering CB3. + in3_CB_size *= 2; + } + } + } else { + in3_CB_size = block_wt * mask_tile_size; + } + // cb_exps - keeps exps in tt::CB in L1 to avoid recomputing + uint32_t im0_CB_size = block_wt * im_tile_size; + // 1/sum(exp(x)) + uint32_t im1_CB_size = 1 * im_tile_size; + // attn mask im + uint32_t im2_CB_size = block_wt * im_tile_size; + // output buffer size + uint32_t out_CB_size = block_wt * block_ht * out0_tile_size; + + //////////////////////////////////////////////////////////////////////////// + // Application Setup + //////////////////////////////////////////////////////////////////////////// + Program program = CreateProgram(); + // define core ranges + uint32_t start_core_x = 0; + uint32_t start_core_y = 0; + uint32_t num_cores_c = grid_size.x; + uint32_t num_cores_r = grid_size.y; + uint32_t num_cores = num_cores_c * num_cores_r; + CoreRange all_device_cores( + {(std::size_t) start_core_x, (std::size_t) start_core_y}, + {(std::size_t) start_core_x + num_cores_c - 1, (std::size_t) start_core_y + num_cores_r - 1}); + // reader compile arg + bool is_dram_mask = 0; + if (mask.has_value()) { + is_dram_mask = mask->buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + } + std::vector reader_compile_time_args = { + (std::uint32_t) block_wt, + (std::uint32_t) is_dram_mask + }; + std::map softmax_defines; + // hw_dims_only_causal_mask does not support RM Layout atm + bool use_row_major_kernel = (mask.has_value() and mask->get_layout() == Layout::ROW_MAJOR); + if (use_row_major_kernel) { + auto mask_stick_size = mask->get_legacy_shape()[3] * mask->element_size(); + bool mask_stick_size_is_power_of_two = is_power_of_two_at_least_32(mask_stick_size); + reader_compile_time_args.push_back((std::uint32_t) mask_stick_size_is_power_of_two); + if (mask_stick_size_is_power_of_two) { + uint32_t mask_log2_stick_size = (std::uint32_t)log2(mask_stick_size); + reader_compile_time_args.push_back((std::uint32_t) mask_log2_stick_size); + } else { + reader_compile_time_args.push_back(mask_stick_size); + } + } else { + reader_compile_time_args.push_back(0); + reader_compile_time_args.push_back(0); + } + if (causal_mask) { + if (!hw_dims_only_causal_mask) { + reader_compile_time_args.push_back((std::uint32_t) block_ht / mask_Ht); // fused head + } else { + reader_compile_time_args.push_back((std::uint32_t) block_ht); + } + } + reader_compile_time_args.push_back((std::uint32_t) (mask_cb_data_format == tt::DataFormat::Float32)); // mask float32 + reader_compile_time_args.push_back((std::uint32_t) mask_Ht); + + if (mask.has_value()) { + softmax_defines["FUSED_SCALE_MASK"] = "1"; + } + if (causal_mask) { + softmax_defines["CAUSAL_MASK"] = "1"; + if (mask.value().is_sharded()) + softmax_defines["SHARDED_CAUSAL_MASK"] = "1"; + } + std::string reader_kernel_path; + if (use_row_major_kernel) { + reader_kernel_path = "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_sharded_sm_rm_mask.cpp"; + } else if (!hw_dims_only_causal_mask) { + reader_kernel_path = "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/reader_unary_sharded_sm.cpp"; + } else { + reader_kernel_path = "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/dataflow/readed_unary_sharded_sm_causal_mask_hw_dims.cpp"; + } + auto reader_kernels_id = CreateKernel( + program, + reader_kernel_path, + all_device_cores, + tt::tt_metal::ReaderDataMovementConfig( + reader_compile_time_args, + softmax_defines + )); + // compute kernel compile time args + std::vector compute_compile_time_args = { + block_ht, + block_wt, + subblock_wt, + num_subblocks_w, + }; + softmax_defines["EXP_APPROX"] = math_approx_mode ? "1" : "0"; + auto softmax_kernels_id = CreateKernel( + program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp", all_device_cores, + tt::tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, + .compile_args = compute_compile_time_args, + .defines = softmax_defines + }); + + // Create circular buffers + // in0 sharded + auto c_in0_config = CircularBufferConfig(in0_CB_size, {{tt::CB::c_in0, in0_cb_data_format}}) + .set_page_size(tt::CB::c_in0, in0_tile_size).set_globally_allocated_address(*src0_buffer); + auto cb_in0_id = CreateCircularBuffer(program, all_device_cores, c_in0_config); + // in1 scalar + auto c_in1_config = CircularBufferConfig(in1_CB_size, {{tt::CB::c_in1, scalar_cb_data_format}}) + .set_page_size(tt::CB::c_in1, scalar_tile_size); + auto cb_in1_id = CreateCircularBuffer(program, all_device_cores, c_in1_config); + // in2 in3 attn scale mask + std::optional cb_intermed2_id; + std::optional cb_in2_id; + std::optional cb_in3_id; + if (mask.has_value()) { + // im2 + auto c_intermed2_config = CircularBufferConfig(im2_CB_size, {{tt::CB::c_intermed2, im_cb_data_format}}) + .set_page_size(tt::CB::c_intermed2, im_tile_size); + cb_intermed2_id = CreateCircularBuffer( program, all_device_cores, c_intermed2_config ); + // in2 scale + auto c_in2_config = CircularBufferConfig(in2_CB_size, {{tt::CB::c_in2, scale_cb_data_format}}) + .set_page_size(tt::CB::c_in2, scale_tile_size); + cb_in2_id = CreateCircularBuffer(program, all_device_cores, c_in2_config); + // in3 attn mask + if (mask->is_sharded()) { + auto mask_buffer = mask->buffer(); + auto c_in3_config = CircularBufferConfig(in3_CB_size, {{tt::CB::c_in3, mask_cb_data_format}}) + .set_page_size(tt::CB::c_in3, mask_tile_size).set_globally_allocated_address(*mask_buffer); + cb_in3_id = CreateCircularBuffer( program, all_device_cores, c_in3_config); + } else { + auto c_in3_config = CircularBufferConfig(in3_CB_size, {{tt::CB::c_in3, mask_cb_data_format}}) + .set_page_size(tt::CB::c_in3, mask_tile_size); + cb_in3_id = CreateCircularBuffer( program, all_device_cores, c_in3_config); + } + } + // out + auto c_out0_config = CircularBufferConfig(out_CB_size, {{tt::CB::c_out0, out0_cb_data_format}}) + .set_page_size(tt::CB::c_out0, out0_tile_size).set_globally_allocated_address(*out0_buffer);; + auto cb_out0_id = CreateCircularBuffer( program, all_device_cores, c_out0_config ); + // im0 for exp(x) + auto c_intermed0_config = CircularBufferConfig(im0_CB_size, {{tt::CB::c_intermed0, im_cb_data_format}}) + .set_page_size(tt::CB::c_intermed0, im_tile_size); + auto cb_intermed0_id = CreateCircularBuffer( program, all_device_cores, c_intermed0_config ); + // im1 for 1/sum(exp(x)) + auto c_intermed1_config = CircularBufferConfig(im1_CB_size, {{tt::CB::c_intermed1, im_cb_data_format}}) + .set_page_size(tt::CB::c_intermed1, im_tile_size); + auto cb_intermed1_id = CreateCircularBuffer( program, all_device_cores, c_intermed1_config ); + + // Runtime Args + uint32_t mask_addr = mask.has_value() ? mask->buffer()->address() : 0; + union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax + uint32_t mask_start_tile_id = 0; + + uint32_t num_tiles_in_attn_mask = 0; + uint32_t num_tiles_of_attn_mask_needed_per_core = 0; + if (hw_dims_only_causal_mask) { + num_tiles_in_attn_mask = mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_HW; + num_tiles_of_attn_mask_needed_per_core = block_ht * block_wt; + } + uint32_t num_cores_per_batch_index = 0; + + if (shard_orient == ShardOrientation::COL_MAJOR) { + for(int core_idx_x = 0; core_idx_x < num_cores_c; core_idx_x++) { + for(int core_idx_y = 0; core_idx_y < num_cores_r; core_idx_y++) { + CoreCoord core = {(std::size_t) start_core_x + core_idx_x, (std::size_t) start_core_y + core_idx_y}; + + // reader args + std::vector reader_args; + reader_args.push_back(0x3f803f80); + reader_args.push_back(s.u); + reader_args.push_back(mask_addr); + reader_args.push_back(mask_start_tile_id); + if (hw_dims_only_causal_mask) { + reader_args.push_back(num_tiles_in_attn_mask); + } + + tt::tt_metal::SetRuntimeArgs(program, reader_kernels_id, core, reader_args); + + num_cores_per_batch_index ++; + + if (hw_dims_only_causal_mask) { + uint32_t mask_tile_id_end = (mask_start_tile_id + num_tiles_of_attn_mask_needed_per_core) % num_tiles_in_attn_mask; + mask_start_tile_id = mask_tile_id_end; + } else { + if (num_cores_per_batch_index == num_cores_per_batch) { + num_cores_per_batch_index = 0; + if (mask.has_value()) { + if (causal_mask) { + mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; + } else { + mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH; + } + } + } + } + } + } + } else { + for(int core_idx_y = 0; core_idx_y < num_cores_r; core_idx_y++) { + for(int core_idx_x = 0; core_idx_x < num_cores_c; core_idx_x++) { + CoreCoord core = {(std::size_t) start_core_x + core_idx_x, (std::size_t) start_core_y + core_idx_y}; + + // reader args + std::vector reader_args; + reader_args.push_back(0x3f803f80); + reader_args.push_back(s.u); + reader_args.push_back(mask_addr); + reader_args.push_back(mask_start_tile_id); + if (hw_dims_only_causal_mask) { + reader_args.push_back(num_tiles_in_attn_mask); + } + + tt::tt_metal::SetRuntimeArgs(program, reader_kernels_id, core, reader_args); + + num_cores_per_batch_index ++; + + if (hw_dims_only_causal_mask) { + uint32_t mask_tile_id_end = (mask_start_tile_id + num_tiles_of_attn_mask_needed_per_core) % num_tiles_in_attn_mask; + mask_start_tile_id = mask_tile_id_end; + } else { + if (num_cores_per_batch_index == num_cores_per_batch) { + num_cores_per_batch_index = 0; + if (mask.has_value()) { + if (causal_mask) { + mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT; + } else { + mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH; + } + } + } + } + } + } + } + + auto override_runtime_arguments_callback = [ + reader_kernels_id, + cb_in0_id, + cb_out0_id, + cb_in3_id, + num_cores, + grid_size + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + auto in0_buffer = input_tensors.at(0).buffer(); + auto &mask_tensor = optional_input_tensors.at(0); + auto out_buffer = output_tensors.size() == 1 ? output_tensors.at(0).buffer() : in0_buffer; + + UpdateDynamicCircularBufferAddress(program, cb_in0_id, *in0_buffer); + UpdateDynamicCircularBufferAddress(program, cb_out0_id, *out_buffer); + if (mask_tensor.has_value() && mask_tensor->is_sharded()) { + UpdateDynamicCircularBufferAddress(program, cb_in3_id.value(), *mask_tensor->buffer()); + } + + if (mask_tensor.has_value()) { + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core); + runtime_args[2] = mask_tensor->buffer()->address(); + } + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} // scale_mask_softmax_sharded_multi_core + +} // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp new file mode 100644 index 000000000000..ffb5e673e3c7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "softmax_op.hpp" +#include "tt_metal/common/assert.hpp" +#include "common/base_types.hpp" +#include "tensor/types.hpp" +#include "tt_eager/tt_dnn/op_library/math.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_dnn/op_library/run_operation.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/util.hpp" + +#include +#include + +using uint32_t = std::uint32_t; +using namespace tt::constants; + +namespace ttnn::operations::normalization { + +void Softmax::validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const { + TT_FATAL(input_tensors.size() == 1 and optional_input_tensors.size() <= 1, "Must have 1 or 2 input tensors"); + auto& input_tensor = input_tensors.at(0); + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr , "Operands to softmax need to be allocated in buffers on device!"); + TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to softmax must be tilized"); + TT_FATAL(input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::BFLOAT8_B); + if (optional_input_tensors.size() == 1) { + if (optional_input_tensors.at(0).has_value()) { + auto& mask = optional_input_tensors.at(0).value(); + TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); + TT_FATAL(input_tensor.device() == mask.device()); + if (mask.is_sharded()) { // sharded mask + TT_FATAL(mask.get_layout() == Layout::TILE); + TT_FATAL(mask.get_legacy_shape() == input_tensor.get_legacy_shape()); + } else { + if (mask.get_layout() == Layout::ROW_MAJOR) { + tt::tt_metal::Shape expected_shape = {mask.get_legacy_shape()[0], 1, input_tensor.get_legacy_shape()[-1] / TILE_WIDTH, TILE_WIDTH}; + TT_FATAL(mask.get_legacy_shape() == expected_shape); + } + for (uint32_t i = 1; i < input_tensor.get_legacy_shape().rank() - 2; i++) { + TT_FATAL(mask.get_legacy_shape()[i] == 1); + } + } + + std::visit( + [&](const auto& program_config) { + using ProgramConfigType = std::decay_t; + if constexpr ( + std::is_same_v + ) { + TT_FATAL(input_tensor.get_legacy_shape()[0] == mask.get_legacy_shape()[0]); + TT_FATAL(!this->is_scale_causal_mask_hw_dims_softmax); + } else if constexpr ( + std::is_same_v + ) { + const auto shape = input_tensor.get_legacy_shape(); + uint32_t M = input_tensor.volume() / shape[-1]; + uint32_t K = shape[-1]; + + TT_FATAL(M % TILE_HEIGHT == 0, "M must be divisible by tile height."); + TT_FATAL(K % TILE_WIDTH == 0, "K must be divisible by tile width."); + TT_FATAL(program_config.block_w % program_config.subblock_w == 0, "block_w must be divisible by subblock_w."); + TT_FATAL(program_config.block_w * TILE_WIDTH == shape[3], "shard width must equal to input tensor shape[3]!"); + TT_FATAL(this->inplace); + if (!this->is_scale_causal_mask_hw_dims_softmax) { + // grid + auto num_cores_c = program_config.compute_with_storage_grid_size.x; + auto num_cores_r = program_config.compute_with_storage_grid_size.y; + // check dims + TT_FATAL(M * K / ((program_config.block_w * program_config.block_h) * TILE_HW) == num_cores_r * num_cores_c, "number of shards must equal to number of cores. M = {}, K = {}, block_w = {}, block_h = {}, num_cores = {}", M, K, program_config.block_w, program_config.block_h, num_cores_r * num_cores_c); + } else { + TT_FATAL(this->is_causal_mask); + TT_FATAL(mask.get_layout() == Layout::TILE); + TT_FATAL(mask.is_sharded() == false); + TT_FATAL(input_tensor.get_layout() == Layout::TILE); + TT_FATAL(input_tensor.is_sharded()); + TT_FATAL(input_tensor.shard_spec()->orientation == ShardOrientation::ROW_MAJOR); + TT_FATAL(this->scale.has_value()); + } + } + }, + this->program_config + ); + } else { + TT_FATAL(not this->scale.has_value()); + } + } else { + TT_FATAL(not this->scale.has_value()); + TT_FATAL(not this->is_scale_causal_mask_hw_dims_softmax); + } +} + +std::vector Softmax::compute_output_shapes(const std::vector& input_tensors) const { + return {input_tensors.at(0).get_legacy_shape()}; +} + +std::vector Softmax::create_output_tensors(const std::vector& input_tensors) const { + if (this->inplace) { + return {input_tensors.at(0)}; + } else { + return operation::generic_create_output_tensors(*this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + } +} + +operation::ProgramWithCallbacks Softmax::create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors +) const { + auto& input_tensor = input_tensors.at(0); + auto& output_tensor = output_tensors.at(0); + const auto& mask = optional_input_tensors.at(0); + // bool causal_mask = mask.has_value() ? mask.value().get_legacy_shape()[-2] == mask.value().get_legacy_shape()[-1] : false; + bool causal_mask = this->is_causal_mask; + + return std::visit( + [&](const auto& program_config) -> operation::ProgramWithCallbacks { + using ProgramConfigType = std::decay_t; + if constexpr ( + std::is_same_v + ) { + return scale_mask_softmax_sharded_multi_core( + input_tensor, + output_tensor, + mask, + this->scale, + causal_mask, + this->is_scale_causal_mask_hw_dims_softmax, + program_config.compute_with_storage_grid_size, + program_config.subblock_w, + program_config.block_h, + program_config.block_w, + this->compute_kernel_config); + } + else { + return scale_mask_softmax_multi_core(input_tensor, output_tensor, mask, this->scale, causal_mask, this->compute_kernel_config); + } + }, + this->program_config + ); +} + +const operation::Hash Softmax::compute_program_hash( + const std::vector &input_tensors, + const std::vector>& optional_input_tensors) const { + return operation::hash_operation( + std::get(input_tensors.at(0).storage()).memory_config(), + input_tensors.at(0).dtype(), + optional_input_tensors.at(0).has_value() ? std::optional{std::get(optional_input_tensors.at(0).value().storage()).memory_config()} + : std::nullopt, + optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().dtype()} + : std::nullopt, + this->output_mem_config); +} + +Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config) { + return scale_mask_softmax_in_place(input_tensor, std::nullopt, std::nullopt, program_config, false, compute_kernel_config); +} + +Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, const bool is_causal_mask, std::optional compute_kernel_config) { + std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [scale, mask, program_config, is_causal_mask, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto& mask = optional_input_tensors.at(0); + auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); + return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val}, {input_tensor}, {mask}); + }, {input_tensor}, dummy_output_tensors, {mask}); + return input_tensor; +} + +Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config) { + std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [scale, mask, program_config, compute_kernel_config](const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto& mask = optional_input_tensors.at(0); + auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); + return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=true, .compute_kernel_config=kernel_config_val, .is_scale_causal_mask_hw_dims_softmax=true}, {input_tensor}, {mask}); + }, {input_tensor}, dummy_output_tensors, {mask}); + return input_tensor; +} + +Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config, std::optional compute_kernel_config) { + return scale_mask_softmax(input_tensor, std::nullopt, std::nullopt, output_mem_config, false, compute_kernel_config); +} + +Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config, const bool is_causal_mask, std::optional compute_kernel_config) { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_with_autoformat( + [scale, mask, output_mem_config, is_causal_mask, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto& mask = optional_input_tensors.at(0); + tt::tt_metal::Shape input_pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); + FormatParams input_format_params = {.pad_shape=input_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; + std::optional mask_format_params = std::nullopt; + if (mask.has_value()) { + TT_FATAL(input_tensor.get_legacy_shape()[-1] == mask.value().get_legacy_shape()[-1]); + TT_FATAL(input_tensor.get_legacy_shape()[0] == mask.value().get_legacy_shape()[0]); + TT_FATAL(mask.value().get_legacy_shape()[-2] == 1 or mask.value().get_legacy_shape()[-2] == TILE_HEIGHT); + for (uint32_t i = 1; i < input_tensor.get_legacy_shape().rank() - 2; i++) { + TT_FATAL(mask.value().get_legacy_shape()[i] == 1); + } + tt::tt_metal::Shape mask_pad_shape = AutoFormat::pad_to_tile_shape(mask.value().get_legacy_shape()); + mask_format_params = {.pad_shape=mask_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; + } + auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); + return operation::run_with_autoformat(Softmax{.scale=scale, .inplace=false, .output_mem_config=output_mem_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val}, {input_tensor}, {input_format_params}, {Layout::TILE}, {mask}, {mask_format_params}); + }, {input_tensor}, output_tensors, {mask}); + return output_tensors.at(0); +} + +} // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp new file mode 100644 index 000000000000..6f600e1ff215 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "common/base_types.hpp" +#include "common/core_coord.h" +#include "tensor/types.hpp" +#include "tt_eager/tensor/tensor.hpp" +#include "tt_dnn/op_library/operation.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_dnn/op_library/compute_kernel_config.hpp" + +namespace ttnn::operations::normalization { + +struct SoftmaxDefaultProgramConfig{ + tt::stl::reflection::Attributes attributes() const { return {}; }; +}; +struct SoftmaxShardedMultiCoreProgramConfig { + CoreCoord compute_with_storage_grid_size; + std::size_t subblock_w; + std::size_t block_h; + std::size_t block_w; + + tt::stl::reflection::Attributes attributes() const { + return { + {"compute_with_storage_grid_size", compute_with_storage_grid_size}, + {"subblock_w", subblock_w}, + {"block_h", block_h}, + {"block_w", block_w}, + }; + }; +}; + +using SoftmaxProgramConfig = std::variant< + SoftmaxDefaultProgramConfig, + SoftmaxShardedMultiCoreProgramConfig +>; + +struct Softmax { + const std::optional scale; + const bool inplace; + const MemoryConfig output_mem_config; + const SoftmaxProgramConfig program_config; + const bool is_causal_mask; + const DeviceComputeKernelConfig compute_kernel_config; + const bool is_scale_causal_mask_hw_dims_softmax; + + void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors + ) const; + + static constexpr auto attribute_names = std::forward_as_tuple( + "scale", + "inplace", + "output_mem_config", + "program_config", + "is_causal_mask", + "compute_kernel_config", + "is_scale_causal_mask_hw_dims_softmax"); + + const auto attribute_values() const { + return std::forward_as_tuple( + this->scale, + this->inplace, + this->output_mem_config, + this->program_config, + this->is_causal_mask, + this->compute_kernel_config, + this->is_scale_causal_mask_hw_dims_softmax); + }; + + const operation::Hash compute_program_hash( + const std::vector &input_tensors, + const std::vector>& optional_input_tensors) const; +}; + +operation::ProgramWithCallbacks scale_mask_softmax_multi_core( + const Tensor &input_tensor, + const Tensor &output_tensor, + const std::optional mask, + std::optional scale, + bool causal_mask, + DeviceComputeKernelConfig compute_kernel_config +); + +// hw_dims_only_causal_mask - represents if the causal mask is of shape [1, 1, h, w] +// valid only if causal_mask == true, and is interleaved +operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( + const Tensor &input_tensor, + const Tensor &output_tensor, + const std::optional mask, + std::optional scale, + bool causal_mask, + bool hw_dims_only_causal_mask, + CoreCoord grid_size, + uint32_t subblock_wt, + uint32_t block_ht, + uint32_t block_wt, + DeviceComputeKernelConfig compute_kernel_config +); + +// softmax +Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt); +// const ref prevents in-place +Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, std::optional compute_kernel_config = std::nullopt); + +// computes +// tmp1 = bcast_hw_mul(scale, x) ; shape of scale is [1,1,32,32] +// tmp2 = bcast_add_w->h(tmp1, mask) ; shape of attn mask is [1,N,32,W] +// y = softmax(tmp2) ; r=result +// If scale == 0.0f then just y = softmax(x) is computed +Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale = std::nullopt, std::optional mask = std::nullopt, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt); + +// Experimental feature. Does the same same as above, with the following assumptions: +// 1. Input must be sharded +// 2. Scale must exist +// 3. Attention mask must be interleaved and be of this shape [1, 1, H, W] +// 4. Causal mask argument is set to true. +Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config = SoftmaxShardedMultiCoreProgramConfig{}, std::optional compute_kernel_config = std::nullopt); + +Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt); + +} // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp new file mode 100644 index 000000000000..bc96f6960219 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" +// #include "tt_dnn/op_library/softmax/softmax_op.hpp" +#include "device/softmax_op.hpp" + +namespace ttnn { +namespace operations::normalization { + +struct ExecuteSoftmax { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + // softmax + static ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const int dim_arg, + const std::optional& memory_config = std::nullopt, + const std::optional compute_kernel_config = std::nullopt) { + auto input_shape = input_tensor.get_shape(); + auto rank = input_shape.size(); + auto dim = dim_arg; + if (dim < 0) { + dim = rank + dim; + } + + auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); + if (dim == rank - 1) { + auto output_tensor = + ttnn::operations::normalization::softmax(input_tensor_4D, memory_config.value_or(input_tensor.memory_config()), compute_kernel_config); + return ttnn::reshape(output_tensor, input_shape); + } else { + auto dim_4D = dim + 4 - rank; + auto output_tensor = tt::operations::primary::moreh_softmax(input_tensor_4D, dim_4D); + return ttnn::reshape(output_tensor, input_shape); + } + } +}; + +struct ExecuteScaleMaskSoftmax { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + // scale_mask_softmax + static ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const std::optional scale = std::nullopt, + const std::optional mask = std::nullopt, + const std::optional& memory_config = std::nullopt, + const bool is_causal_mask = false, + const std::optional compute_kernel_config = std::nullopt) { + auto input_shape = input_tensor.get_shape(); + + auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); + auto output_tensor = + ttnn::operations::normalization::scale_mask_softmax(input_tensor_4D, scale, mask, memory_config.value_or(input_tensor.memory_config()), is_causal_mask, compute_kernel_config); + return ttnn::reshape(output_tensor, input_shape); + } +}; + +struct ExecuteSoftmaxInPlace { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + // softmax_in_place + static ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, + const std::optional compute_kernel_config = std::nullopt) { + auto input_shape = input_tensor.get_shape(); + + auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); + auto output_tensor = + ttnn::operations::normalization::softmax_in_place(input_tensor_4D, program_config, compute_kernel_config); + return ttnn::reshape(output_tensor, input_shape); + } +}; + +struct ExecuteScaleMaskSoftmaxInPlace { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + // scale_mask_softmax_in_place + static ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const std::optional scale = std::nullopt, + const std::optional mask = std::nullopt, + const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, + const bool is_causal_mask = false, + const std::optional compute_kernel_config = std::nullopt) { + auto input_shape = input_tensor.get_shape(); + + auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); + auto output_tensor = + ttnn::operations::normalization::scale_mask_softmax_in_place(input_tensor_4D, scale, mask, program_config, is_causal_mask, compute_kernel_config); + return ttnn::reshape(output_tensor, input_shape); + } +}; + +struct ExecuteScaleCausalMaskHWSoftmaxInPlace { + static inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}}; + } + + template + static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) { + return std::forward_as_tuple(input_tensor); + } + + // scale_causal_mask_hw_dims_softmax_in_place + static ttnn::Tensor execute_on_worker_thread( + const ttnn::Tensor& input_tensor, + const std::optional scale = std::nullopt, + const std::optional mask = std::nullopt, + const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, + const std::optional compute_kernel_config = std::nullopt) { + auto input_shape = input_tensor.get_shape(); + + auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); + auto output_tensor = + ttnn::operations::normalization::scale_causal_mask_hw_dims_softmax_in_place(input_tensor_4D, scale, mask, program_config, compute_kernel_config); + return ttnn::reshape(output_tensor, input_shape); + } +}; + +} // namespace operations::normalization + +constexpr auto softmax = ttnn::register_operation("ttnn::softmax"); +constexpr auto scale_mask_softmax = ttnn::register_operation("ttnn::scale_mask_softmax"); +constexpr auto softmax_in_place = ttnn::register_operation("ttnn::softmax_in_place"); +constexpr auto scale_mask_softmax_in_place = ttnn::register_operation("ttnn::scale_mask_softmax_in_place"); +constexpr auto scale_causal_mask_hw_dims_softmax_in_place = ttnn::register_operation("ttnn::scale_causal_mask_hw_dims_softmax_in_place"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.hpp new file mode 100644 index 000000000000..171774d41313 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.hpp @@ -0,0 +1,247 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "softmax.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::normalization::detail { + +void bind_normalization_softmax_program_config_operation(py::module& module) { + py::class_(module, "SoftmaxProgramConfig").def(py::init<>()); + + py::class_(module, "SoftmaxDefaultProgramConfig") + .def(py::init<>()); + + py::class_(module, "SoftmaxShardedMultiCoreProgramConfig") + .def( + py::init(), + py::kw_only(), + py::arg("compute_with_storage_grid_size"), + py::arg("subblock_w").noconvert(), + py::arg("block_h").noconvert(), + py::arg("block_w").noconvert() + ) + .def_readwrite("block_w", &SoftmaxShardedMultiCoreProgramConfig::block_w); +} + +void bind_normalization_softmax_operation(py::module& module) { + + auto doc = + R"doc(softmax(input_tensor: ttnn.Tensor, dim: int, memory_config: Optional[ttnn.MemoryConfig] = None, compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + + Compute softmax over :attr:`input_tensor` along :attr:`dim`. + + Args: + * :attr:`input_tensor`: the input tensor + * :attr:`dim`: the dimension along which to compute softmax. + + Keyword Args: + * :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used. + * :attr:`compute_kernel_config`: the compute kernel configuration for the op. If not provided, the default configuration of the op is used. + + Example: + + >>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device) + >>> output = ttnn.softmax(tensor, -1) + >>> print(output[0, 0, 0, :3]) + ttnn.Tensor([ 0.0310059, 0.0310059, 0.0310059], dtype=bfloat16 ) + )doc"; + + using OperationType = decltype(ttnn::softmax); + + ttnn::bind_registered_operation( + module, + ttnn::softmax, + doc, + ttnn::pybind_overload_t{ + [] (const OperationType& self, + const ttnn::Tensor& input_tensor, + const int8_t dim, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + return self(input_tensor, dim, memory_config, compute_kernel_config); + }, + py::arg("input_tensor").noconvert(), + py::arg("dim") = -1, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("compute_kernel_config").noconvert() = std::nullopt}); +} + +void bind_normalization_scale_mask_softmax_operation(py::module& module) { + + auto doc = + R"doc(scale_mask_softmax(input_tensor: ttnn.Tensor, scale: Optional[float] = None, mask: Optional[ttnn.Tensor] = None, memory_config: Optional[ttnn.MemoryConfig] = None, is_causal_mask: Optional[bool] = False, compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + + Compute fused scale->attention_mask->softmax operation over :attr:`input_tensor` on the last dim. + + Args: + * :attr:`input_tensor`: the input tensor + * :attr:`scale`: the scale to be multiplied with input tensor + * :attr:`mask`: the input mask tensor to be applied to input tensor + + Keyword Args: + * :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used. + * :attr:`is_causal_mask`: determines whether the mask tensor is causal or not. If not provided, non-causal mask will be used. + * :attr:`compute_kernel_config`: the compute kernel configuration for the op. If not provided, the default configuration of the op is used. + + Example: + + )doc"; + + using OperationType = decltype(ttnn::scale_mask_softmax); + + ttnn::bind_registered_operation( + module, + ttnn::scale_mask_softmax, + doc, + ttnn::pybind_overload_t{ + [] (const OperationType& self, + const ttnn::Tensor& input_tensor, + const std::optional scale, + const std::optional mask, + const std::optional& memory_config, + const bool is_causal_mask, + const std::optional& compute_kernel_config) { + return self(input_tensor, scale, mask, memory_config, is_causal_mask, compute_kernel_config); + }, + py::arg("input_tensor").noconvert(), + py::arg("scale").noconvert() = std::nullopt, + py::arg("mask").noconvert() = std::nullopt, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("is_causal_mask") = false, + py::arg("compute_kernel_config") = std::nullopt}); +} + +void bind_normalization_softmax_in_place_operation(py::module& module) { + + auto doc = + R"doc(softmax_in_place(input_tensor: ttnn.Tensor, program_config: Optional[SoftmaxProgramConfig], compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + + Compute softmax over :attr:`input_tensor` along the last dim, input and output tensor are in-placed on the same L1 address. + + Args: + * :attr:`input_tensor`: the input tensor + + Keyword Args: + * :attr:`program_config`: the program configuration for op. If not provided, SoftmaxDefaultProgramConfig is used. + * :attr:`compute_kernel_config`: the compute kernel configuration for the op. If not provided, the default configuration of the op is used. + + Example: + + )doc"; + + using OperationType = decltype(ttnn::softmax_in_place); + + ttnn::bind_registered_operation( + module, + ttnn::softmax_in_place, + doc, + ttnn::pybind_overload_t{ + [] (const OperationType& self, + const ttnn::Tensor& input_tensor, + const SoftmaxProgramConfig& program_config, + const std::optional& compute_kernel_config) { + return self(input_tensor, program_config, compute_kernel_config); + }, + py::arg("input_tensor").noconvert(), + py::kw_only(), + py::arg("program_config") = SoftmaxDefaultProgramConfig{}, + py::arg("compute_kernel_config") = std::nullopt}); +} + +void bind_normalization_scale_mask_softmax_in_place_operation(py::module& module) { + + auto doc = + R"doc(softmax_in_place(input_tensor: ttnn.Tensor, scale: Optional[float] = None, mask: Optional[ttnn.Tensor] = None, program_config: Optional[SoftmaxProgramConfig], compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + + Compute fused scale->attention_mask->softmax over :attr:`input_tensor` along the last dim, input and output tensor are in-placed on the same L1 address. + + Args: + * :attr:`input_tensor`: the input tensor + + Keyword Args: + * :attr:`program_config`: the program configuration for op. If not provided, SoftmaxDefaultProgramConfig is used. + * :attr:`compute_kernel_config`: the compute kernel configuration for the op. If not provided, the default configuration of the op is used. + + Example: + + )doc"; + + using OperationType = decltype(ttnn::scale_mask_softmax_in_place); + + ttnn::bind_registered_operation( + module, + ttnn::scale_mask_softmax_in_place, + doc, + ttnn::pybind_overload_t{ + [] (const OperationType& self, + const ttnn::Tensor& input_tensor, + const std::optional scale, + const std::optional mask, + const SoftmaxProgramConfig& program_config, + const bool is_causal_mask, + const std::optional& compute_kernel_config) { + return self(input_tensor, scale, mask, program_config, is_causal_mask, compute_kernel_config); + }, + py::arg("input_tensor").noconvert(), + py::arg("scale").noconvert() = std::nullopt, + py::arg("mask").noconvert() = std::nullopt, + py::kw_only(), + py::arg("program_config") = SoftmaxDefaultProgramConfig{}, + py::arg("is_causal_mask") = false, + py::arg("compute_kernel_config") = std::nullopt}); +} + + +void bind_normalization_scale_causal_mask_hw_dims_softmax_in_place_operation(py::module& module) { + + auto doc = + R"doc(scale_causal_mask_hw_dims_softmax_in_place(input_tensor: ttnn.Tensor, scale: Optional[float] = None, mask: Optional[ttnn.Tensor] = None, program_config: Optional[SoftmaxProgramConfig], compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + + Compute fused scale->attention_mask->softmax over :attr:`input_tensor` along the last dim, input and output tensor are in-placed on the same L1 address. + + Args: + * :attr:`input_tensor`: the input tensor + + Keyword Args: + * :attr:`program_config`: the program configuration for op. If not provided, SoftmaxDefaultProgramConfig is used. + * :attr:`compute_kernel_config`: the compute kernel configuration for the op. If not provided, the default configuration of the op is used. + + Example: + + )doc"; + + using OperationType = decltype(ttnn::scale_causal_mask_hw_dims_softmax_in_place); + + ttnn::bind_registered_operation( + module, + ttnn::scale_causal_mask_hw_dims_softmax_in_place, + doc, + ttnn::pybind_overload_t{ + [] (const OperationType& self, + const ttnn::Tensor& input_tensor, + const std::optional scale, + const std::optional mask, + const SoftmaxProgramConfig& program_config, + const std::optional& compute_kernel_config) { + return self(input_tensor, scale, mask, program_config, compute_kernel_config); + }, + py::arg("input_tensor").noconvert(), + py::arg("scale").noconvert() = std::nullopt, + py::arg("mask").noconvert() = std::nullopt, + py::kw_only(), + py::arg("program_config") = SoftmaxDefaultProgramConfig{}, + py::arg("compute_kernel_config") = std::nullopt}); +} + +} // namespace ttnn::operations::normalization::detail diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index c78718071e51..1b7311c4957a 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -478,6 +478,13 @@ def manage_config(name, value): from ttnn.operations.normalization import ( softmax, + scale_mask_softmax, + softmax_in_place, + scale_mask_softmax_in_place, + scale_causal_mask_hw_dims_softmax_in_place, + SoftmaxProgramConfig, + SoftmaxDefaultProgramConfig, + SoftmaxShardedMultiCoreProgramConfig, layer_norm, rms_norm, group_norm, diff --git a/ttnn/ttnn/experimental/golden_functions.py b/ttnn/ttnn/experimental/golden_functions.py index 7dc74309faa9..9048d1559370 100644 --- a/ttnn/ttnn/experimental/golden_functions.py +++ b/ttnn/ttnn/experimental/golden_functions.py @@ -93,7 +93,7 @@ def _golden_function(input_tensor, scalar, attention_mask, *args, **kwargs): ret = torch.softmax(input_tensor, dim=-1) return ret - attach_golden(ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place, _golden_function) + attach_golden(ttnn.scale_mask_softmax_in_place, _golden_function) def _golden_function(input_tensor, *args, **kwargs): import torch diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index d9077450956f..a73a01f1e8d1 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -22,6 +22,27 @@ def _golden_function(input_tensor: ttnn.Tensor, dim: int, **_): golden_function=_golden_function, )(ttnn._ttnn.operations.normalization.softmax) +softmax_in_place = ttnn.register_operation( + golden_function=_golden_function, +)(ttnn._ttnn.operations.normalization.softmax_in_place) + +scale_mask_softmax_in_place = ttnn.register_operation( + golden_function=_golden_function, +)(ttnn._ttnn.operations.normalization.scale_mask_softmax_in_place) + +scale_mask_softmax = ttnn.register_operation( + golden_function=_golden_function, +)(ttnn._ttnn.operations.normalization.scale_mask_softmax) + +scale_causal_mask_hw_dims_softmax_in_place = ttnn.register_operation( + golden_function=_golden_function, +)(ttnn._ttnn.operations.normalization.scale_causal_mask_hw_dims_softmax_in_place) + + +SoftmaxProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxProgramConfig +SoftmaxDefaultProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxDefaultProgramConfig +SoftmaxShardedMultiCoreProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxShardedMultiCoreProgramConfig + def _golden_function( input_tensor: ttnn.Tensor, *, epsilon=1e-12, residual_input_tensor=None, weight=None, bias=None, **_