diff --git a/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py deleted file mode 100644 index 1d98feb58ac6..000000000000 --- a/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import tt_lib - -from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_2cq_model, run_trace_2cq_model -from models.utility_functions import skip_for_wormhole_b0 - - -@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) -@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) -@pytest.mark.parametrize( - "weights_dtype", - [tt_lib.tensor.DataType.BFLOAT8_B], - ids=["weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [tt_lib.tensor.DataType.BFLOAT8_B], - ids=["activations_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "math_fidelity", - [tt_lib.tensor.MathFidelity.LoFi], - ids=["LoFi"], -) -def test_run_resnet50_2cqs_inference( - device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input -): - run_resnet50_inference( - device, - batch_size, - weights_dtype, - activations_dtype, - math_fidelity, - imagenet_sample_input, - run_2cq_model, - ) - - -@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") -@pytest.mark.parametrize( - "device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2, "trace_region_size": 1500000}], indirect=True -) -@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) -@pytest.mark.parametrize( - "weights_dtype", - [tt_lib.tensor.DataType.BFLOAT8_B], - ids=["weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [tt_lib.tensor.DataType.BFLOAT8_B], - ids=["activations_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "math_fidelity", - [tt_lib.tensor.MathFidelity.LoFi], - ids=["LoFi"], -) -@pytest.mark.parametrize("enable_async", [True, False]) -def test_run_resnet50_trace_2cqs_inference( - device, - use_program_cache, - batch_size, - weights_dtype, - activations_dtype, - math_fidelity, - imagenet_sample_input, - enable_async, -): - device.enable_async(enable_async) - - run_resnet50_inference( - device, - batch_size, - weights_dtype, - activations_dtype, - math_fidelity, - imagenet_sample_input, - run_trace_2cq_model, - ) - - device.enable_async(False) diff --git a/models/demos/resnet/tests/test_metal_resnet50_performant.py b/models/demos/resnet/tests/test_metal_resnet50_performant.py index 535bac8dc77b..bb3fb53d875e 100644 --- a/models/demos/resnet/tests/test_metal_resnet50_performant.py +++ b/models/demos/resnet/tests/test_metal_resnet50_performant.py @@ -5,7 +5,13 @@ import pytest import tt_lib -from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_model, run_trace_model +from models.demos.resnet.tests.test_metal_resnet50 import ( + run_resnet50_inference, + run_model, + run_trace_model, + run_2cq_model, + run_trace_2cq_model, +) from models.utility_functions import skip_for_wormhole_b0 @@ -83,3 +89,81 @@ def test_run_resnet50_trace_inference( ) device.enable_async(False) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize( + "device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2, "trace_region_size": 1500000}], indirect=True +) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_run_resnet50_trace_2cqs_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + enable_async, +): + device.enable_async(enable_async) + + run_resnet50_inference( + device, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_2cq_model, + ) + + device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index 7a5e86a73f18..9ef7c042e0df 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -384,3 +384,50 @@ def test_perf_trace_bare_metal( f"resnet50_trace_{mode}", ) device.enable_async(False) + + +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0042, 16),), +) +def test_perf_2cqs_bare_metal( + device, + use_program_cache, + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, +): + run_perf_resnet( + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" + ) + + +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize( + "device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2, "trace_region_size": 1332224}], indirect=True +) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0042, 16),), +) +def test_perf_trace_2cqs_bare_metal( + device, + use_program_cache, + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, +): + run_perf_resnet( + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, + device, + "resnet50_trace_2cqs", + ) diff --git a/models/demos/resnet/tests/test_perf_resnet_2cqs.py b/models/demos/resnet/tests/test_perf_resnet_2cqs.py deleted file mode 100644 index 07e0234333d0..000000000000 --- a/models/demos/resnet/tests/test_perf_resnet_2cqs.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from models.demos.resnet.tests.test_perf_resnet import run_perf_resnet -from models.utility_functions import skip_for_wormhole_b0 - - -@skip_for_wormhole_b0(reason_str="Not tested on single WH") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "batch_size, expected_inference_time, expected_compile_time", - ((20, 0.0042, 16),), -) -def test_perf_2cqs_bare_metal( - device, - use_program_cache, - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, -): - run_perf_resnet( - batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" - ) - - -@skip_for_wormhole_b0(reason_str="Not tested on single WH") -@pytest.mark.parametrize( - "device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2, "trace_region_size": 1332224}], indirect=True -) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "batch_size, expected_inference_time, expected_compile_time", - ((20, 0.0042, 16),), -) -def test_perf_trace_2cqs_bare_metal( - device, - use_program_cache, - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, -): - run_perf_resnet( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, - "resnet50_trace_2cqs", - ) diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 754bcbc9ab1e..b52fa9d3017b 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,9 +17,7 @@ run_perf_models_other() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker - # Separate calls since we can't mix switching between number of cqs env pytest models/demos/resnet/tests/test_perf_resnet.py -m $test_marker - env pytest models/demos/resnet/tests/test_perf_resnet_2cqs.py -m $test_marker env pytest tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker diff --git a/tests/scripts/single_card/nightly/run_gs_only.sh b/tests/scripts/single_card/nightly/run_gs_only.sh index c5bcc9f97452..aba5fcc2301c 100755 --- a/tests/scripts/single_card/nightly/run_gs_only.sh +++ b/tests/scripts/single_card/nightly/run_gs_only.sh @@ -10,5 +10,3 @@ fi echo "Running model nightly tests for GS only" env pytest models/demos/resnet/tests/test_metal_resnet50_performant.py - -env pytest models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py index ba355a3e7054..12cacc6b1ad2 100644 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py +++ b/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py @@ -37,169 +37,232 @@ ], ) @pytest.mark.parametrize("enable_async", [True, False]) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 34816}], indirect=True) -def test_bert_linear( - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - use_program_cache, - function_level_defaults, - enable_async, -): - device.enable_async(enable_async) - has_bias = False - in0_shape = [1, 1, M, K] - in1_shape = [1, 1, K, N] - bias_shape = [1, 1, N] - out_shape = [1, 1, M, N] - grid_size = (12, 8) - # grid_size = (2, 2) - shard_shape = [M // grid_size[0], K // grid_size[1]] # shard height, width - - in0_block_w = K // grid_size[1] // 32 # 16 - in0_block_h = M // grid_size[0] // 32 - out_block_h = M // grid_size[0] // 32 - out_block_w = N // grid_size[1] // 32 - - if out_block_w <= 8: - out_subblock_w = out_block_w - out_subblock_h = 8 // out_subblock_w - else: - out_subblock_h = 1 - out_subblock_w = 8 // out_subblock_h - while out_block_w % out_subblock_w != 0: - out_subblock_w = out_block_w // 2 - - # in0_block_w = K // grid_size[1] // 32 - # out_subblock_w = 4 - # out_subblock_h = 4 - - logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) - logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) - logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) - logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) - - interleaved_mem_config_L1 = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttl.tensor.BufferType.L1, - ) - interleaved_mem_config_DRAM = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttl.tensor.BufferType.DRAM, - ) - sharded_mem_config = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - buffer_type=ttl.tensor.BufferType.L1, - ) - - in0 = torch.randn(in0_shape).bfloat16().float() - in1 = torch.randn(in1_shape).bfloat16().float() - bias = torch.randn(bias_shape).bfloat16().float() - in0_t_res = torch2tt_tensor( - in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - ) - - if in1_in_dram: - in1_t = torch2tt_tensor( - in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +class TestBertOpsTrace: + # TODO: Not all ops here take in cq id, only works with 0 for now + def run_bert_linear( + self, + device, + fidelity, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + activation, + enable_async, + cq_id, + ): + device.enable_async(enable_async) + has_bias = False + in0_shape = [1, 1, M, K] + in1_shape = [1, 1, K, N] + bias_shape = [1, 1, N] + out_shape = [1, 1, M, N] + grid_size = (12, 8) + # grid_size = (2, 2) + shard_shape = [M // grid_size[0], K // grid_size[1]] # shard height, width + + in0_block_w = K // grid_size[1] // 32 # 16 + in0_block_h = M // grid_size[0] // 32 + out_block_h = M // grid_size[0] // 32 + out_block_w = N // grid_size[1] // 32 + + if out_block_w <= 8: + out_subblock_w = out_block_w + out_subblock_h = 8 // out_subblock_w + else: + out_subblock_h = 1 + out_subblock_w = 8 // out_subblock_h + while out_block_w % out_subblock_w != 0: + out_subblock_w = out_block_w // 2 + + # in0_block_w = K // grid_size[1] // 32 + # out_subblock_w = 4 + # out_subblock_h = 4 + + logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) + logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) + logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) + logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) + + interleaved_mem_config_L1 = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, + buffer_type=ttl.tensor.BufferType.L1, + ) + interleaved_mem_config_DRAM = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, + buffer_type=ttl.tensor.BufferType.DRAM, ) - else: - in1_t = torch2tt_tensor( - in1, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B + sharded_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, ) - output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 - - bias_t = pad_by_zero( - bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - )[0] - - program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=True, - # transpose_mcast=False, - fused_activation=activation, - ) - - compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig(math_fidelity=fidelity, math_approx_mode=True) - - trace_loops = 4 - - def run_ops(in0_t_res): - if in0_sharded: - in0_t = ttl.tensor.interleaved_to_sharded( - in0_t_res, - grid_size, - [M // grid_size[0], K // grid_size[1]], - ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - ttl.tensor.ShardOrientation.COL_MAJOR, - ) - else: - in0_t = ttl.tensor.clone(in0_t_res, interleaved_mem_config_L1) - - if has_bias: - output_t = ttnn.linear( - in0_t, - in1_t, - bias=bias_t, - program_config=program_config, - memory_config=output_mem_config, - compute_kernel_config=compute_kernel_config, + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + bias = torch.randn(bias_shape).bfloat16().float() + in0_t_res = torch2tt_tensor( + in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B + ) + + if in1_in_dram: + in1_t = torch2tt_tensor( + in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B ) else: - output_t = ttnn.matmul( - in0_t, - in1_t, - program_config=program_config, - memory_config=output_mem_config, - compute_kernel_config=compute_kernel_config, + in1_t = torch2tt_tensor( + in1, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B ) - if out_sharded: - output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) - return output_t - - # Compile - run_ops(in0_t_res) - # Capture - logger.info("Start Trace capture") - tid = ttl.device.BeginTraceCapture(device, 0) - output_t_res = run_ops(in0_t_res) - ttl.device.EndTraceCapture(device, 0, tid) - logger.info("Trace captured") - - for iter in range(trace_loops): - in0 = torch.randn(in0_shape).bfloat16().float() - in0_t_updated = torch2tt_tensor( - in0, None, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B + + output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 + + bias_t = pad_by_zero( + bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B + )[0] + + program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=in0_block_w, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + transpose_mcast=True, + # transpose_mcast=False, + fused_activation=activation, ) - ttl.tensor.write_tensor(in0_t_updated, in0_t_res) - logger.info(f"Running iteration {iter}") - ttl.device.ReplayTrace(device, 0, tid, True) - pt_out = in0 @ in1 + compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig(math_fidelity=fidelity, math_approx_mode=True) + + trace_loops = 4 + + def run_ops(in0_t_res): + if in0_sharded: + in0_t = ttl.tensor.interleaved_to_sharded( + in0_t_res, + grid_size, + [M // grid_size[0], K // grid_size[1]], + ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, + ttl.tensor.ShardOrientation.COL_MAJOR, + ) + else: + in0_t = ttl.tensor.clone(in0_t_res, interleaved_mem_config_L1) - if has_bias: - pt_out = pt_out + bias + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=output_mem_config, + compute_kernel_config=compute_kernel_config, + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=output_mem_config, + compute_kernel_config=compute_kernel_config, + ) + if out_sharded: + output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) + return output_t + + # Compile + run_ops(in0_t_res) + # Capture + logger.info("Start Trace capture") + tid = ttl.device.BeginTraceCapture(device, cq_id) + output_t_res = run_ops(in0_t_res) + ttl.device.EndTraceCapture(device, cq_id, tid) + logger.info("Trace captured") + + for iter in range(trace_loops): + in0 = torch.randn(in0_shape).bfloat16().float() + in0_t_updated = torch2tt_tensor( + in0, None, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B + ) + ttl.tensor.write_tensor(in0_t_updated, in0_t_res) + logger.info(f"Running iteration {iter}") + ttl.device.ReplayTrace(device, cq_id, tid, True) - if activation != None: - pt_out = torch.nn.functional.gelu(pt_out) - tt_out = tt2torch_tensor(output_t_res) + pt_out = in0 @ in1 - passing, output = comp_pcc(pt_out, tt_out) - logger.info(output) - assert passing + if has_bias: + pt_out = pt_out + bias - # Done with the trace, can deallocate the buffers now. - ttl.device.ReleaseTrace(device, tid) - device.enable_async(False) + if activation != None: + pt_out = torch.nn.functional.gelu(pt_out) + tt_out = tt2torch_tensor(output_t_res) + + passing, output = comp_pcc(pt_out, tt_out) + logger.info(output) + assert passing + + # Done with the trace, can deallocate the buffers now. + ttl.device.ReleaseTrace(device, tid) + device.enable_async(False) + + @pytest.mark.parametrize("device_params", [{"trace_region_size": 34816}], indirect=True) + def test_bert_linear_1cq_initialized( + self, + device, + fidelity, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + activation, + use_program_cache, + function_level_defaults, + enable_async, + ): + self.run_bert_linear( + device, + fidelity, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + activation, + enable_async, + 0, + ) + + @pytest.mark.parametrize("cq_id", [0]) + @pytest.mark.parametrize("device_params", [{"trace_region_size": 34816, "num_hw_cqs": 2}], indirect=True) + def test_bert_linear_2cqs_initialized( + self, + device, + fidelity, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + activation, + use_program_cache, + function_level_defaults, + enable_async, + cq_id, + ): + self.run_bert_linear( + device, + fidelity, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + activation, + enable_async, + cq_id, + )