Skip to content

Commit

Permalink
#6654: Moving init for self.compute_kernel_config
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland authored and tt-rkim committed Mar 27, 2024
1 parent 981503c commit 4882e23
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
19 changes: 9 additions & 10 deletions models/experimental/resnet/tt/ttnn_functional_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ def __init__(self, device, torch_model, input_shape, batch_size, act_dtype, weig
self.act_dtype = act_dtype
self.weight_dtype = weight_dtype
self.math_fidelity = math_fidelity
if is_wormhole_b0():
self.compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=math_fidelity,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
self.compute_kernel_config = None

torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
self.impl = preprocess_model(
Expand Down Expand Up @@ -281,16 +290,6 @@ def __init__(self, device, torch_model, input_shape, batch_size, act_dtype, weig
True,
)

if is_wormhole_b0():
self.compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=math_fidelity,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
self.compute_kernel_config = None

self.input_tensor_height_snapped_to_tile = self.sharded_mem_config.shard_spec.shape[0] * ncores_nhw

def update_ttnn_module_args_resnet50(self, ttnn_module_args):
Expand Down
9 changes: 2 additions & 7 deletions tests/ttnn/integration_tests/resnet/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import time

import pytest

from loguru import logger

from tests.ttnn.integration_tests.resnet.test_ttnn_functional_resnet50 import create_test_infra
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report

from models.utility_functions import skip_for_wormhole_b0

import ttnn

from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
Expand Down

0 comments on commit 4882e23

Please sign in to comment.