diff --git a/models/demos/segformer/tt/common.py b/models/demos/segformer/tt/common.py index 5f52fe0e5072..dae5ecc6bbe9 100644 --- a/models/demos/segformer/tt/common.py +++ b/models/demos/segformer/tt/common.py @@ -40,12 +40,8 @@ def __call__(self, device, input_tensor): conv_config = ttnn.Conv2dConfig( dtype=self.dtype, weights_dtype=ttnn.bfloat16, - math_fidelity=ttnn.MathFidelity.LoFi, activation=self.activation, shard_layout=self.shard_layout, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, - packer_l1_accum_enabled=False, input_channels_alignment=16 if input_tensor.shape[3] < 16 else 32, transpose_shards=False, reshard_if_not_optimal=self.reshard, @@ -54,6 +50,12 @@ def __call__(self, device, input_tensor): enable_act_double_buffer=True, enable_split_reader=False, ) + compute_config = ttnn.GetComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h @@ -71,6 +73,7 @@ def __call__(self, device, input_tensor): input_height=input_tensor.shape[1], input_width=input_tensor.shape[2], conv_config=conv_config, + compute_config=compute_config, groups=self.groups, ) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index a1c0496f2118..05b4c45a91df 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -181,12 +181,12 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=shard_layout, deallocate_activation=True, reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) ttnn.deallocate(x) @@ -228,13 +228,13 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -291,7 +291,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -301,6 +300,7 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -322,12 +322,12 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -538,12 +538,12 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 @@ -844,12 +844,12 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py index 5c0750003c16..f99e08a0f78b 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py @@ -178,7 +178,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -186,6 +185,7 @@ def run_downsample_if_req( reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) ttnn.deallocate(x) @@ -225,13 +225,13 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -286,7 +286,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -296,6 +295,7 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -317,12 +317,12 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -532,12 +532,12 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 @@ -835,12 +835,12 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index b6643d55d4a2..b64e2fdf4ba3 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -179,7 +179,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -187,6 +186,7 @@ def run_downsample_if_req( reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) ttnn.deallocate(x) @@ -226,13 +226,13 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -288,7 +288,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -298,6 +297,7 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -319,12 +319,12 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -534,12 +534,12 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 @@ -865,12 +865,12 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py index 45d93ebf6859..f357270768b0 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py @@ -179,7 +179,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -188,6 +187,7 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) ttnn.deallocate(x) @@ -232,7 +232,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding @@ -240,6 +239,7 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -337,7 +337,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -348,6 +347,7 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -369,13 +369,13 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) @@ -597,13 +597,13 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 @@ -931,12 +931,12 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.GetComputeKernelConfig(math_fidelity=self.model_config["MATH_FIDELITY"]), conv_op_cache=conv_op_cache, ) # Relu is fused with conv1 diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py index f687e217cba6..a15a21c6b930 100644 --- a/models/demos/vgg/tests/test_perf_vgg.py +++ b/models/demos/vgg/tests/test_perf_vgg.py @@ -79,17 +79,6 @@ def test_vgg( "ACTIVATIONS_DTYPE": act_dtype, } - conv_config = ttnn.Conv2dConfig( - dtype=model_config["ACTIVATIONS_DTYPE"], - weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - activation="relu", - deallocate_activation=True, - input_channels_alignment=16, - act_block_h_override=0, - transpose_shards=True, - ) - torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1)) tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 4cb986c27304..9a309ec2b644 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -90,10 +90,6 @@ def ttnn_vgg16( conv_config = ttnn.Conv2dConfig( dtype=model_config["ACTIVATIONS_DTYPE"], weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, - packer_l1_accum_enabled=False, activation="relu", deallocate_activation=False, input_channels_alignment=32, @@ -106,6 +102,12 @@ def ttnn_vgg16( reshard_if_not_optimal=True, enable_weights_double_buffer=True, ) + compute_config = ttnn.GetComputeKernelConfig( + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) @@ -126,6 +128,7 @@ def ttnn_vgg16( input_height=conv_ttnn_params[iter_conv_id][2], input_width=conv_ttnn_params[iter_conv_id][3], conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_op_cache, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) @@ -213,9 +216,6 @@ def ttnn_vgg11( conv_config = ttnn.Conv2dConfig( dtype=model_config["ACTIVATIONS_DTYPE"], weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, activation="relu", deallocate_activation=False, input_channels_alignment=32, @@ -227,7 +227,11 @@ def ttnn_vgg11( ), enable_weights_double_buffer=True, ) - + compute_config = ttnn.GetComputeKernelConfig( + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode=True, + fp32_dest_acc_en=True, + ) tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias @@ -248,6 +252,7 @@ def ttnn_vgg11( input_height=conv_ttnn_params_2[iter_conv_id][2], input_width=conv_ttnn_params_2[iter_conv_id][3], conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_op_cache, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 3635026d8094..6d54c7b48d8b 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -99,11 +99,7 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if self.in_channels < 320 else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -111,6 +107,12 @@ def __call__( transpose_shards=False, reshard_if_not_optimal=True, ) + compute_config = ttnn.GetComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] @@ -128,6 +130,7 @@ def __call__( weight_tensor=self.conv_weights, bias_tensor=self.conv_bias, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, ) # hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting(