Skip to content

Commit

Permalink
#0: Remove trace buffer size from BeginTraceCapture in Resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jun 14, 2024
1 parent e4248b3 commit 428c17e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions models/demos/resnet/tests/test_metal_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def run_trace_model(device, tt_image, tt_resnet50):
# Compile
tt_resnet50(tt_image_res)
# Trace
tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000)
tid = tt_lib.device.BeginTraceCapture(device, 0)
tt_output_res = tt_resnet50(tt_image_res)
tt_lib.device.EndTraceCapture(device, 0, tid)

Expand Down Expand Up @@ -257,7 +257,7 @@ def run_trace_2cq_model(device, tt_image, tt_resnet50):
reshard_out = tt_lib.tensor.reshard(tt_image_res, reshard_mem_config)
tt_lib.device.RecordEvent(device, 0, op_event)

tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000)
tid = tt_lib.device.BeginTraceCapture(device, 0)
tt_output_res = tt_resnet50(reshard_out, final_out_mem_config=interleaved_dram_mem_config)
reshard_out = tt_lib.tensor.allocate_tensor_on_device(
reshard_out.shape, reshard_out.dtype, reshard_out.layout, device, reshard_mem_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_run_resnet50_2cqs_inference(


@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True)
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2, "trace_region_size": 1500000}], indirect=True
)
@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"])
@pytest.mark.parametrize(
"weights_dtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_run_resnet50_inference(


@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1500000}], indirect=True)
@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"])
@pytest.mark.parametrize(
"weights_dtype",
Expand Down

0 comments on commit 428c17e

Please sign in to comment.