diff --git a/docs/source/tt_metal/apis/kernel_apis/compute/matmul_block.rst b/docs/source/tt_metal/apis/kernel_apis/compute/matmul_block.rst index a5d078649cc3..be39d92bd015 100644 --- a/docs/source/tt_metal/apis/kernel_apis/compute/matmul_block.rst +++ b/docs/source/tt_metal/apis/kernel_apis/compute/matmul_block.rst @@ -1,7 +1,7 @@ matmul_block ============ -.. doxygenfunction:: mm_block_init(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t out_cb_id = 16, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) -.. doxygenfunction:: mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) -.. doxygenfunction:: mm_block_init_short_with_dt(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t old_in1_cb_id=2, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) +.. doxygenfunction:: mm_block_init(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t out_cb_id = 16, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) +.. doxygenfunction:: mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) +.. doxygenfunction:: mm_block_init_short_with_dt(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t old_in1_cb_id=2, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) .. doxygenfunction:: matmul_block(uint32_t in0_cb_id, uint32_t in1_cb_id, uint32_t in0_tile_index, uint32_t in1_tile_index, uint32_t idst, const uint32_t transpose, uint32_t ct_dim, uint32_t rt_dim, uint32_t kt_dim) diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py index 7f153964fec4..b921805046b9 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py @@ -28,306 +28,300 @@ def generate_input_shapes(): yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len] -# @pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# def test_attn_matmul(in0_dtype, in1_dtype, out_dtype, device): -# torch.manual_seed(0) - -# for input_shape_a, input_shape_b in generate_input_shapes(): -# input_tensor_a = torch.randn(input_shape_a).bfloat16() -# input_tensor_b = torch.randn(input_shape_b).bfloat16() - -# tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device) -# tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device) - -# compute_grid_size = device.compute_with_storage_grid_size() - -# tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( -# tt_input_tensor_a, -# tt_input_tensor_b, -# compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), -# output_mem_config=ttl.tensor.MemoryConfig( -# ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 -# ), -# output_dtype=out_dtype, -# ) -# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - -# golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) - -# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) -# assert allclose, f"FAILED: {output}" - -# @pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") -# @pytest.mark.parametrize("in_dtype", [ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT8_B, ttl.tensor.DataType.BFLOAT16]) -# # @pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.FLOAT32,]) -# # @pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.FLOAT32,]) -# # @pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.FLOAT32,]) -# def test_attn_matmul_fp32(in_dtype, device): -# torch.manual_seed(0) - -# for input_shape_a, input_shape_b in generate_input_shapes(): -# input_tensor_a = torch.randn(input_shape_a).bfloat16() -# input_tensor_b = torch.randn(input_shape_b).bfloat16() - -# tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in_dtype).to(ttl.tensor.Layout.TILE).to(device) -# tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in_dtype).to(ttl.tensor.Layout.TILE).to(device) - -# compute_grid_size = device.compute_with_storage_grid_size() - -# compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( -# math_fidelity=ttl.tensor.MathFidelity.LoFi, -# math_approx_mode=True, -# fp32_dest_acc_en=True, -# packer_l1_acc=False, -# ) - -# tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( -# tt_input_tensor_a, -# tt_input_tensor_b, -# compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), -# output_mem_config=ttl.tensor.MemoryConfig( -# ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 -# ), -# output_dtype=in_dtype, -# compute_kernel_config=compute_kernel_config -# ) -# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - -# golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) - -# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) -# assert allclose, f"FAILED: {output}" - - -# @pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# def test_attn_matmul_with_program_cache(in0_dtype, in1_dtype, out_dtype, device, use_program_cache): -# torch.manual_seed(0) - -# for input_shape_a, input_shape_b in generate_input_shapes(): -# input_tensor_a = torch.randn(input_shape_a).bfloat16() -# input_tensor_b = torch.randn(input_shape_b).bfloat16() - -# tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device) -# tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device) - -# compute_grid_size = device.compute_with_storage_grid_size() - -# tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( -# tt_input_tensor_a, -# tt_input_tensor_b, -# compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), -# output_mem_config=ttl.tensor.MemoryConfig( -# ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 -# ), -# output_dtype=out_dtype, -# ) -# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - -# golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) - -# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) -# assert allclose, f"FAILED: {output}" - - -# @pytest.mark.parametrize( -# "shard_orientation", -# (ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.ShardOrientation.COL_MAJOR), -# ) -# @pytest.mark.parametrize( -# "output_sharded", -# (False, True), -# ) -# @pytest.mark.parametrize( -# "in1_sharded", -# (False, True), -# ) -# @pytest.mark.parametrize( -# "in0_sharded", -# (False, True), -# ) -# @pytest.mark.parametrize( -# "batch, K, seq_len, q_heads, kv_heads", -# ( -# (32, 64, 512 + 96, 32, 2), -# (32, 1024 + 32, 64, 32, 2), -# (32, 64, 128, 16, 1), -# ), -# ) -# def test_group_attn_matmul( -# batch, K, seq_len, q_heads, kv_heads, in0_sharded, in1_sharded, output_sharded, shard_orientation, device -# ): -# torch.manual_seed(0) - -# compute_grid_size = device.compute_with_storage_grid_size() - -# interleaved_mem_config = ttl.tensor.MemoryConfig( -# ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM -# ) - -# # NOTE: Mixed precision is supported as well; but might not have enough space for larger seq_len with BFLOAT16 -# in0_dtype = ttl.tensor.DataType.BFLOAT8_B -# in1_dtype = ttl.tensor.DataType.BFLOAT8_B -# output_dtype = ttl.tensor.DataType.BFLOAT8_B - -# q_len = 1 -# input_shape_a = [q_len, q_heads, batch, K] -# input_shape_b = [batch, kv_heads, K, seq_len] - -# input_tensor_a = torch.randn(input_shape_a).bfloat16() -# input_tensor_b = torch.randn(input_shape_b).bfloat16() - -# tt_input_tensor_a = ( -# ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) -# ) -# tt_input_tensor_b = ( -# ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) -# ) - -# if in0_sharded: -# tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( -# tt_input_tensor_a, -# compute_grid_size, -# [q_len * batch, K], -# ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# shard_orientation, -# ) - -# if in1_sharded: -# tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( -# tt_input_tensor_b, -# compute_grid_size, -# [kv_heads * K, seq_len], -# ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# shard_orientation, -# ) - -# if output_sharded: -# output_mem_config = ttl.tensor.MemoryConfig( -# memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# buffer_type=ttl.tensor.BufferType.L1, -# ) -# else: -# output_mem_config = interleaved_mem_config - -# tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( -# tt_input_tensor_a, -# tt_input_tensor_b, -# compute_with_storage_grid_size=compute_grid_size, -# output_mem_config=output_mem_config, -# output_dtype=output_dtype, -# ) -# if output_sharded: -# tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( -# tt_output_tensor_on_device, interleaved_mem_config -# ) - -# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - -# input_tensor_a = input_tensor_a.to(torch.float) -# input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) -# golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) - -# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) -# assert allclose, f"FAILED: {output}" - - -# @pytest.mark.parametrize("sharded", [False, True]) -# @pytest.mark.parametrize("output_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# @pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) -# def test_group_attn_matmul_with_program_cache(in0_dtype, in1_dtype, output_dtype, sharded, device, use_program_cache): -# torch.manual_seed(0) - -# compute_grid_size = device.compute_with_storage_grid_size() - -# interleaved_mem_config = ttl.tensor.MemoryConfig( -# ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM -# ) - -# shard_orientation = ttl.tensor.ShardOrientation.COL_MAJOR # Only used if sharded - -# q_len = 1 -# batch = 32 -# num_cache_entries = 0 # Only track cache entries of group_attn_matmul -# # NOTE: Program is cached on out_subblock_w as well, so only seq_len >= 256 (out_subblock_w = 8) will share cache -# # For seq_len < = 256, recompile at worst 8 times. -# for K, seq_len, q_heads, kv_heads in ((96, 512 + 64, 10, 2), (64, 256, 50, 5)): -# input_shape_a = [q_len, q_heads, batch, K] -# input_shape_b = [batch, kv_heads, K, seq_len] - -# input_tensor_a = torch.randn(input_shape_a).bfloat16() -# input_tensor_b = torch.randn(input_shape_b).bfloat16() - -# tt_input_tensor_a = ( -# ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) -# ) -# tt_input_tensor_b = ( -# ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) -# ) - -# if sharded: -# tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( -# tt_input_tensor_a, -# compute_grid_size, -# [q_len * batch, K], -# ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# shard_orientation, -# ) - -# tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( -# tt_input_tensor_b, -# compute_grid_size, -# [kv_heads * K, seq_len], -# ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# shard_orientation, -# ) - -# output_mem_config = ttl.tensor.MemoryConfig( -# memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, -# buffer_type=ttl.tensor.BufferType.L1, -# ) -# else: -# output_mem_config = interleaved_mem_config - -# num_cache_entries_start = ttl.program_cache.num_entries() -# tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( -# tt_input_tensor_a, -# tt_input_tensor_b, -# compute_with_storage_grid_size=compute_grid_size, -# output_mem_config=output_mem_config, -# output_dtype=output_dtype, -# ) -# num_cache_entries += ttl.program_cache.num_entries() - num_cache_entries_start - -# if sharded: -# tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( -# tt_output_tensor_on_device, interleaved_mem_config -# ) - -# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() - -# input_tensor_a = input_tensor_a.to(torch.float) -# input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) -# golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) - -# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) -# assert allclose, f"FAILED: {output}" - -# assert num_cache_entries == 1 +@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +def test_attn_matmul(in0_dtype, in1_dtype, out_dtype, device): + torch.manual_seed(0) + + for input_shape_a, input_shape_b in generate_input_shapes(): + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device) + tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device) + + compute_grid_size = device.compute_with_storage_grid_size() + + tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), + output_mem_config=ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ), + output_dtype=out_dtype, + ) + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" @pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") @pytest.mark.parametrize( - "in_dtype", - [ - ttl.tensor.DataType.BFLOAT16, - ], + "in_dtype", [ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B] +) +def test_attn_matmul_fp32(in_dtype, device): + torch.manual_seed(0) + + for input_shape_a, input_shape_b in generate_input_shapes(): + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in_dtype).to(ttl.tensor.Layout.TILE).to(device) + tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in_dtype).to(ttl.tensor.Layout.TILE).to(device) + + compute_grid_size = device.compute_with_storage_grid_size() + + compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) + + tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), + output_mem_config=ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ), + output_dtype=in_dtype, + compute_kernel_config=compute_kernel_config, + ) + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("out_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +def test_attn_matmul_with_program_cache(in0_dtype, in1_dtype, out_dtype, device, use_program_cache): + torch.manual_seed(0) + + for input_shape_a, input_shape_b in generate_input_shapes(): + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device) + tt_input_tensor_b = ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device) + + compute_grid_size = device.compute_with_storage_grid_size() + + tt_output_tensor_on_device = ttl.operations.primary.transformers.attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=ttl.tensor.CoreCoord(compute_grid_size.x, compute_grid_size.y), + output_mem_config=ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 + ), + output_dtype=out_dtype, + ) + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize( + "shard_orientation", + (ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.ShardOrientation.COL_MAJOR), +) +@pytest.mark.parametrize( + "output_sharded", + (False, True), +) +@pytest.mark.parametrize( + "in1_sharded", + (False, True), ) -# @pytest.mark.parametrize("in_dtype", [ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT16]) +@pytest.mark.parametrize( + "in0_sharded", + (False, True), +) +@pytest.mark.parametrize( + "batch, K, seq_len, q_heads, kv_heads", + ( + (32, 64, 512 + 96, 32, 2), + (32, 1024 + 32, 64, 32, 2), + (32, 64, 128, 16, 1), + ), +) +def test_group_attn_matmul( + batch, K, seq_len, q_heads, kv_heads, in0_sharded, in1_sharded, output_sharded, shard_orientation, device +): + torch.manual_seed(0) + + compute_grid_size = device.compute_with_storage_grid_size() + + interleaved_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM + ) + + # NOTE: Mixed precision is supported as well; but might not have enough space for larger seq_len with BFLOAT16 + in0_dtype = ttl.tensor.DataType.BFLOAT8_B + in1_dtype = ttl.tensor.DataType.BFLOAT8_B + output_dtype = ttl.tensor.DataType.BFLOAT8_B + + q_len = 1 + input_shape_a = [q_len, q_heads, batch, K] + input_shape_b = [batch, kv_heads, K, seq_len] + + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ( + ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + tt_input_tensor_b = ( + ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + + if in0_sharded: + tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_a, + compute_grid_size, + [q_len * batch, K], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + if in1_sharded: + tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_b, + compute_grid_size, + [kv_heads * K, seq_len], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + if output_sharded: + output_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + else: + output_mem_config = interleaved_mem_config + + tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=compute_grid_size, + output_mem_config=output_mem_config, + output_dtype=output_dtype, + ) + if output_sharded: + tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( + tt_output_tensor_on_device, interleaved_mem_config + ) + + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + input_tensor_a = input_tensor_a.to(torch.float) + input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize("sharded", [False, True]) +@pytest.mark.parametrize("output_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) +def test_group_attn_matmul_with_program_cache(in0_dtype, in1_dtype, output_dtype, sharded, device, use_program_cache): + torch.manual_seed(0) + + compute_grid_size = device.compute_with_storage_grid_size() + + interleaved_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM + ) + + shard_orientation = ttl.tensor.ShardOrientation.COL_MAJOR # Only used if sharded + + q_len = 1 + batch = 32 + num_cache_entries = 0 # Only track cache entries of group_attn_matmul + # NOTE: Program is cached on out_subblock_w as well, so only seq_len >= 256 (out_subblock_w = 8) will share cache + # For seq_len < = 256, recompile at worst 8 times. + for K, seq_len, q_heads, kv_heads in ((96, 512 + 64, 10, 2), (64, 256, 50, 5)): + input_shape_a = [q_len, q_heads, batch, K] + input_shape_b = [batch, kv_heads, K, seq_len] + + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() + + tt_input_tensor_a = ( + ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + tt_input_tensor_b = ( + ttl.tensor.Tensor(input_tensor_b, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) + ) + + if sharded: + tt_input_tensor_a = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_a, + compute_grid_size, + [q_len * batch, K], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + tt_input_tensor_b = ttl.tensor.interleaved_to_sharded( + tt_input_tensor_b, + compute_grid_size, + [kv_heads * K, seq_len], + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + shard_orientation, + ) + + output_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + else: + output_mem_config = interleaved_mem_config + + num_cache_entries_start = ttl.program_cache.num_entries() + tt_output_tensor_on_device = ttl.operations.primary.transformers.group_attn_matmul( + tt_input_tensor_a, + tt_input_tensor_b, + compute_with_storage_grid_size=compute_grid_size, + output_mem_config=output_mem_config, + output_dtype=output_dtype, + ) + num_cache_entries += ttl.program_cache.num_entries() - num_cache_entries_start + + if sharded: + tt_output_tensor_on_device = ttl.tensor.sharded_to_interleaved( + tt_output_tensor_on_device, interleaved_mem_config + ) + + tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + + input_tensor_a = input_tensor_a.to(torch.float) + input_tensor_b = torch.repeat_interleave(input_tensor_b.to(torch.float), q_heads // kv_heads, dim=1) + golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) + + allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) + assert allclose, f"FAILED: {output}" + + assert num_cache_entries == 1 + + +@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") +@pytest.mark.parametrize("in_dtype", [ttl.tensor.DataType.FLOAT32, ttl.tensor.DataType.BFLOAT16]) @pytest.mark.parametrize( "shard_orientation", (ttl.tensor.ShardOrientation.ROW_MAJOR,), @@ -347,8 +341,8 @@ def generate_input_shapes(): @pytest.mark.parametrize( "batch, K, seq_len, q_heads, kv_heads", ( - # (32, 64, 512 + 96, 32, 2), - # (32, 64 + 32, 64, 32, 2), + (32, 64, 512 + 96, 32, 2), + (32, 64 + 32, 64, 32, 2), (32, 32, 32, 2, 1), ), ) @@ -372,53 +366,8 @@ def test_group_attn_matmul_fp32( input_shape_a = [q_len, q_heads, batch, K] input_shape_b = [batch, kv_heads, K, seq_len] - input_tensor_a = torch.ones(input_shape_a).bfloat16() - input_tensor_b = torch.ones(input_shape_b).bfloat16() - - # total_rows = q_len * q_heads * batch - # incremental_values = torch.arange(1, total_rows + 1).unsqueeze(-1).repeat(1, K) - # incremental_values = incremental_values.bfloat16() - # input_tensor_a = incremental_values.view(q_len, q_heads, batch, K) - - # input_tensor_a = torch.ones(input_shape_a).bfloat16() - # N = input_tensor_a.numel() - # incremental_values = torch.arange(1, N + 1).bfloat16().view(input_shape_a) - # input_tensor_a = incremental_values - - total_rows = batch * kv_heads * K - incremental_values = torch.arange(1, total_rows + 1).unsqueeze(-1).repeat(1, seq_len) - incremental_values = incremental_values.bfloat16() - input_tensor_b = incremental_values.view(batch, kv_heads, K, seq_len) - - # incremental_row = torch.arange(1, seq_len + 1).bfloat16() - # incremental_pattern = incremental_row.unsqueeze(0).repeat(batch * kv_heads * K, 1) - # input_tensor_b = incremental_pattern.view(batch, kv_heads, K, seq_len) - - # increment_k = torch.arange(1, K + 1).unsqueeze(-1) # Generating 1 to K, with unsqueeze to make it a column vector - # increment_k = increment_k.repeat(1, seq_len) # Now, repeat each K value seq_len times horizontally - # increment_k = increment_k.view(-1) # Flatten it - # increment_reshaped = increment_k.unsqueeze(0).unsqueeze(0).repeat(batch, kv_heads, 1, 1) - # input_tensor_b = increment_reshaped.view(batch, kv_heads, K, seq_len).bfloat16() - - base_sequence = torch.arange(1, K + 1).unsqueeze(-1) # Extend to make it a 2D tensor - equal_rows = base_sequence.repeat(1, seq_len) # Repeat each value in the base_sequence seq_len t - snapshot = equal_rows.view(1, 1, K, seq_len) # Note the view to fabricate the design - grand_design = snapshot.repeat(batch, kv_heads, 1, 1) # Branching the vignett - grand_design = grand_design.bfloat16() - input_tensor_b = grand_design - - base_sequence = torch.arange(1, K + 1).unsqueeze(-1) - equal_rows = base_sequence.repeat(1, seq_len) - plate = equal_rows.view(K, seq_len).unsqueeze(0).unsqueeze(0) # Traditional macro - plate = plate.repeat(batch, kv_heads, 1, 1) # Veil in multi-view - for indx in range(batch): - plate[indx] += indx # Galvanizing a flick per the epos - plate = plate.bfloat16() # Excellent in retaining the loom in a planar zephyr - input_tensor_b = plate.view(batch, kv_heads, K, seq_len) - - print(input_tensor_b) - - # print(input_tensor_a) + input_tensor_a = torch.randn(input_shape_a).bfloat16() + input_tensor_b = torch.randn(input_shape_b).bfloat16() tt_input_tensor_a = ( ttl.tensor.Tensor(input_tensor_a, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, interleaved_mem_config) @@ -480,6 +429,4 @@ def test_group_attn_matmul_fp32( golden_output_tensor = (input_tensor_a.transpose(0, 2) @ input_tensor_b).transpose(0, 2) allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor) - # print("tt_output_tensor = {}", tt_output_tensor[0][0][0]) - # print("golden_output_tensor = {}", golden_output_tensor[0][0][0]) assert allclose, f"FAILED: {output}" diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp index da747de8e151..edf48f4e781d 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp @@ -374,10 +374,7 @@ inline Tensor matmul (const Tensor &input_tensor_a, const Tensor &input_tensor_b .program_config=matmul_program_config, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), - .math_fidelity=MathFidelity::HiFi4, - .fp32_dest_acc_en=false, - .math_approx_mode=false, - .packer_l1_acc=false + .compute_kernel_config=kernel_config_val }, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0); } else { return operation::run_with_autoformat(Matmul{.bcast_batch=true, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), .compute_kernel_config=kernel_config_val}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0); @@ -398,10 +395,7 @@ inline Tensor bmm (const Tensor &input_tensor_a, const Tensor &input_tensor_b .program_config=matmul_program_config, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), - .math_fidelity=MathFidelity::HiFi4, - .fp32_dest_acc_en=false, - .math_approx_mode=false, - .packer_l1_acc=false + .compute_kernel_config=kernel_config_val }, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0); } else { return operation::run_with_autoformat(Matmul{.bcast_batch=false, .output_mem_config=mem_config, .output_dtype=input_tensor_a.dtype(), .compute_kernel_config=kernel_config_val}, {input_tensor_a, input_tensor_b}, {std::nullopt}).at(0); diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp index 81eebefba141..0aa42e26b8ed 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp @@ -8,8 +8,6 @@ #include "compute_kernel_api/tilize.h" #include "compute_kernel_api/pack_untilize.h" -#include "compute_kernel_api/bcast.h" -#include "debug/dprint.h" using std::uint32_t; @@ -58,14 +56,8 @@ void MAIN { constexpr uint32_t in0_num_blocks_w = 1; // TODO: Generalize - mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw); - // mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w ); - - // UNPACK(( DPRINT << num_rows_in_one_tile << ENDL() )); - // UNPACK(( DPRINT << in1_num_blocks << ENDL() )); - // UNPACK(( DPRINT << in1_num_blocks << ENDL() )); - // UNPACK(( DPRINT << in0_block_num_tiles << ENDL() )); - + // mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw); + mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w ); for (uint32_t b = 0; b < batch; b++) { @@ -73,51 +65,29 @@ void MAIN { for (uint32_t in0_block = 0; in0_block < in0_num_blocks_w; in0_block++) { // TODO: Must be 1; generalize to support inner dim blocking cb_wait_front(cb_in0, in0_block_num_tiles); - // UNPACK(( DPRINT << TSLICE(cb_in0, 0, SliceRange::h0_32_w31()) << ENDL() )); - for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) { uint32_t in0_index_subblock_offset = 0; for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) { cb_wait_front(cb_in1, in1_block_num_tiles); cb_pop_front(cb_in1, num_kv_heads_skip); - - // UNPACK(( DPRINT << TSLICE(cb_in1, 0, SliceRange::h0_32_w31()) << ENDL() )); - for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { // TODO: Must be 1; need to review inner dim blocking and the untilizing uint32_t in1_index_subblock_offset = 0; tile_regs_acquire(); - // Compute output sub-block - // uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index - // uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block - // uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block - // // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w - // for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) { - // // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst - // // accumulation is done by iterating matmul_block across inner dim - // // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0 - // matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); - // in0_index ++; // stride right by 1 - // in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w) - // } - - uint32_t dst_index = 0; - uint32_t in0_index_h_offset = 0; - for (uint32_t h = 0; h < out_subblock_h; h++) { - for (uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t in1_index_inner_dim_offset = 0; - for (uint32_t inner_dim = 0; inner_dim < in0_block_w; inner_dim++) { - uint32_t in0_index = in0_index_subblock_offset + in0_index_h_offset + inner_dim; - uint32_t in1_index = in1_index_subblock_offset + in1_index_inner_dim_offset + w; - matmul_tiles(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw); - in1_index_inner_dim_offset += in1_per_core_w; - } - dst_index++; - } - in0_index_h_offset += in0_block_w; + uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index + uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block + uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block + // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w + for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) { + // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst + // accumulation is done by iterating matmul_block across inner dim + // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0 + matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); + in0_index ++; // stride right by 1 + in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w) } tile_regs_commit(); @@ -128,20 +98,15 @@ void MAIN { // TODO: Review inner dim blocking, untilizing, and in1_num_subblocks > 1 (with pack_untilize, can only untilize up to dst num tiles) // This should normally be inside subblock loop and we pack out out_subblock_num_tiles - // pack_untilize_dst_init_short(); + pack_untilize_dst_init_short(); cb_reserve_back(cb_intermed0, intermediate_num_tiles); tile_regs_wait(); - // pack_untilize_dst(cb_intermed0); - pack_tile(0, cb_intermed0); + pack_untilize_dst(cb_intermed0); + pack_untilize_uninit(); - // pack_untilize_uninit(); tile_regs_release(); cb_push_back(cb_intermed0, intermediate_num_tiles); - - cb_wait_front(cb_intermed0, intermediate_num_tiles); - UNPACK(( DPRINT << intermediate_num_tiles << ENDL() )); - UNPACK(( DPRINT << TSLICE(cb_intermed0, 0, SliceRange::h0_32_w31()) << ENDL() )); } // 32 tiles loop in0_index_subblock_offset += in0_subblock_num_tiles; @@ -149,13 +114,14 @@ void MAIN { } // in0_num_blocks_w // cb_intermed1 comes from reader; untilized row-major tile + unpack_reconfig_data_format_srca(cb_in1, cb_intermed1); + pack_reconfig_data_format(cb_intermed0, out_cb_id); cb_wait_front(cb_intermed1, out_num_tiles); cb_reserve_back(out_cb_id, out_num_tiles); // tilize CB::intermed1 and write to CB::c_out0 - tilize_init_short_with_dt(cb_in1, cb_intermed1, out_num_tiles); - pack_reconfig_data_format(cb_intermed0, out_cb_id); + tilize_init_short(cb_intermed1, out_num_tiles); tilize_block(cb_intermed1, out_num_tiles, out_cb_id); cb_push_back(out_cb_id, out_num_tiles); diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_transformer_group_attn_matmul.cpp index a9bc2af513fb..8c54a56b8c2d 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_transformer_group_attn_matmul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_transformer_group_attn_matmul.cpp @@ -4,8 +4,6 @@ #include "dataflow_api.h" -#include "debug/dprint.h" - void kernel_main() { uint32_t i = 0; @@ -33,10 +31,6 @@ void kernel_main() { uint32_t bfloat16_Nt_bytes = get_arg_val(i++); uint32_t bfloat16_last_row_bytes_read = get_arg_val(i++); - DPRINT << "bfloat16_row_bytes " <buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp index 3e3d792fb06e..10fb71828372 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp @@ -54,7 +54,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype()); tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype()); tt::DataFormat interm_data_format = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b; - // interm_data_format=tt::DataFormat::Float16_b; tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype()); uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); @@ -65,8 +64,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co TT_ASSERT(fp32_dest_acc_en == true, "when inputs/output are in fp32 format, fp32_dest_acc_en must be set"); } - - tt_metal::Buffer *src0_buffer = a.buffer(); tt_metal::Buffer *src1_buffer = b.buffer(); tt_metal::Buffer *dst_buffer = output.buffer(); @@ -98,7 +95,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co const uint32_t in1_per_core_w = in1_num_subblocks * out_block_w; const uint32_t in1_block_w_tile_bytes = out_subblock_w * in1_single_tile_size; uint32_t ONE_ROW_BFLOAT16_BYTES = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? 128 : 64; - // ONE_ROW_BFLOAT16_BYTES = 64; const uint32_t bfloat16_row_bytes = ONE_ROW_BFLOAT16_BYTES * out_block_w; // TODO: Generalize log_debug("in0_block_w: {}", in0_block_w); @@ -143,8 +139,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}}) .set_page_size(src0_cb_index, in0_single_tile_size).set_globally_allocated_address(*src0_buffer); cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); - - std::cout << cb0_num_input_tiles << std::endl; } else { uint32_t cb0_num_input_tiles = in0_block_w; // TODO: Generalize; double buffer and add blocking along inner dim if we have Mt > 1 tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(cb0_num_input_tiles * in0_single_tile_size, {{src0_cb_index, in0_data_format}}) @@ -152,8 +146,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); } - - // CB for interleaved/sharded KV heads for mcasting; mcasts to same CB // Then, push all KV_HEADS to compute and compute chooses which head to use for matmul uint32_t src1_cb_index = CB::c_in1; diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index a69fd6d392c0..fa2c2de69004 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -213,10 +213,6 @@ void py_module(py::module& m_primary) { const MemoryConfig& out_mem_config, std::optional output_dtype, std::optional compute_kernel_config - // const MathFidelity math_fidelity, - // const bool fp32_dest_acc_en, - // const bool math_approx_mode, - // const bool packer_l1_acc ) { return matmul( input_tensor_a, input_tensor_b, bias, program_config, out_mem_config, output_dtype, compute_kernel_config); @@ -251,10 +247,7 @@ void py_module(py::module& m_primary) { const MatmulMultiCoreReuseProgramConfig& program_config, const MemoryConfig& out_mem_config, std::optional output_dtype, - const MathFidelity math_fidelity, - const bool fp32_dest_acc_en, - const bool math_approx_mode, - const bool packer_l1_acc) { + std::optional compute_kernel_config) { return matmul( input_tensor_a, input_tensor_b, @@ -262,10 +255,7 @@ void py_module(py::module& m_primary) { program_config, out_mem_config, output_dtype, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode, - packer_l1_acc); + compute_kernel_config); }, py::arg("input_tensor_a").noconvert(), py::arg("input_tensor_b").noconvert(), @@ -274,10 +264,7 @@ void py_module(py::module& m_primary) { py::arg("program_config").noconvert() = MatmulDefaultProgramConfig(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_dtype").noconvert() = std::nullopt, - py::arg("math_fidelity").noconvert() = MathFidelity::LoFi, - py::arg("fp32_dest_acc_en").noconvert() = false, - py::arg("math_approx_mode").noconvert() = true, - py::arg("packer_l1_acc").noconvert() = false, + py::arg("compute_kernel_config").noconvert() = std::nullopt, R"doc( Perform a matrix multiplication ``input_tensor_a x input_tensor_b``. diff --git a/tt_metal/hw/inc/debug/dprint_tile.h b/tt_metal/hw/inc/debug/dprint_tile.h index 77f5755a56aa..cfbd5ea08b91 100644 --- a/tt_metal/hw/inc/debug/dprint_tile.h +++ b/tt_metal/hw/inc/debug/dprint_tile.h @@ -19,10 +19,8 @@ struct SliceRange { static inline SliceRange hw0_32_4() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 4, .w0 = 0, .w1 = 32, .ws = 4 }; } // [0, 0:32] static inline SliceRange h0_w0_32() { return SliceRange{ .h0 = 0, .h1 = 1, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; } - static inline SliceRange h1_w0_32() { return SliceRange{ .h0 = 1, .h1 = 2, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; } // [0:32, 0] static inline SliceRange h0_32_w0() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1 }; } - static inline SliceRange h0_32_w31() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 31, .w1 = 32, .ws = 1 }; } // [0:32:1, 1] static inline SliceRange h0_32_w1() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 1, .w1 = 2, .ws = 1 }; } // [0:4:1, 0:4:1] diff --git a/tt_metal/include/compute_kernel_api/matmul.h b/tt_metal/include/compute_kernel_api/matmul.h index 53be585b8773..3ad177c6e566 100644 --- a/tt_metal/include/compute_kernel_api/matmul.h +++ b/tt_metal/include/compute_kernel_api/matmul.h @@ -202,7 +202,7 @@ ALWI void matmul_block(uint32_t in0_cb_id, uint32_t in1_cb_id, uint32_t in0_tile * | rt_dim | The row dimension for the output block. | uint32_t | Must be equal to block A row dimension | False | * | kt_dim | The inner dimension. | uint32_t | Must be equal to block A column dimension | False | */ -ALWI void mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) { +ALWI void mm_block_init_short(uint32_t in0_cb_id = 0, uint32_t in1_cb_id = 1, const uint32_t transpose=0, uint32_t ct_dim = 1, uint32_t rt_dim = 1, uint32_t kt_dim = 1) { UNPACK(( llk_unpack_AB_matmul_init(in0_cb_id, in1_cb_id, transpose, ct_dim, rt_dim, kt_dim) )); #ifdef ARCH_GRAYSKULL diff --git a/tt_metal/include/compute_kernel_api/untilize.h b/tt_metal/include/compute_kernel_api/untilize.h index ee19f53b28fb..d483acb4a0de 100644 --- a/tt_metal/include/compute_kernel_api/untilize.h +++ b/tt_metal/include/compute_kernel_api/untilize.h @@ -20,13 +20,13 @@ namespace ckernel { */ ALWI void untilize_init(uint32_t icb, uint32_t ocb = 16) { - MATH(( llk_math_eltwise_unary_datacopy_init(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) )); - MATH(( llk_math_pack_sync_init() )); + MATH(( llk_math_eltwise_unary_datacopy_init(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) )); + MATH(( llk_math_pack_sync_init() )); - PACK(( llk_pack_hw_configure_disaggregated(ocb) )); + PACK(( llk_pack_hw_configure_disaggregated(ocb) )); PACK(( llk_pack_init(ocb) )); PACK(( llk_setup_outputs() )); - PACK(( llk_pack_dest_init() )); + PACK(( llk_pack_dest_init() )); UNPACK(( llk_setup_operands() )); UNPACK(( llk_unpack_untilize_hw_configure_disaggregated(icb) )); @@ -38,7 +38,7 @@ ALWI void untilize_init(uint32_t icb, uint32_t ocb = 16) */ ALWI void untilize_init_short(uint32_t icb) { - MATH(( llk_math_eltwise_unary_datacopy_init(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) )); + MATH(( llk_math_eltwise_unary_datacopy_init(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) )); UNPACK(( llk_unpack_untilize_init(icb) )); } @@ -55,20 +55,20 @@ ALWI void untilize_block(uint32_t icb, uint32_t block, uint32_t ocb) // Datacopy for (int reg_id = 0; reg_id < N; reg_id++) { - MATH(( llk_math_eltwise_unary_datacopy(reg_id) )); + MATH(( llk_math_eltwise_unary_datacopy(reg_id) )); } - MATH(( llk_math_dest_section_done() )); + MATH(( llk_math_dest_section_done() )); PACK(( llk_packer_wait_for_math_done() )); // Datacopy for (int reg_id = 0; reg_id < N; reg_id++) { - PACK(( llk_pack(reg_id, ocb) )); + PACK(( llk_pack(reg_id, ocb) )); } // Release dest - PACK(( llk_pack_dest_section_done() )); + PACK(( llk_pack_dest_section_done() )); } }