diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index a8944b654c37..7a2f400d4377 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -231,6 +231,48 @@ def __call__( # conv1 is 1x1 conv logger.debug(f"Running conv1") module_input_height = input_height + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + activation="relu", + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv1_input_channels, + "out_channels": self.conv1_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (1, 1), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): + print("preparing conv1 weights") + self.conv1_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv1_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv1_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv1_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv1_bias_tensor is not None else None, + ) + self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device) + self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, @@ -244,16 +286,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -317,6 +350,55 @@ def __call__( reallocate_halo_output = batch_size == 20 logger.debug(f"Running conv2") + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + activation="relu", + deallocate_activation=True, + reallocate_halo_output=reallocate_halo_output, + act_block_h_override=act_block_h_override, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + enable_act_double_buffer=enable_act_double_buffer, + enable_weights_double_buffer=True, + enable_split_reader=enable_split_reader, + enable_subblock_padding=enable_subblock_padding, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv2_input_channels, + "out_channels": self.conv2_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (3, 3), + "stride": (self.stride, self.stride), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor): + print("Preparing conv2 weights") + self.conv2_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv2_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv2_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv2_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv2_bias_tensor is not None else None, + ) + self.conv2_weight_tensor = ttnn.to_device(self.conv2_weight_tensor, device) + self.conv2_bias_tensor = ttnn.to_device(self.conv2_bias_tensor, device) if self.conv2_bias_tensor else None + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, @@ -330,23 +412,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - activation="relu", - deallocate_activation=True, - reallocate_halo_output=reallocate_halo_output, - act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - enable_act_double_buffer=enable_act_double_buffer, - enable_weights_double_buffer=True, - enable_split_reader=enable_split_reader, - enable_subblock_padding=enable_subblock_padding, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -373,6 +439,47 @@ def __call__( # conv3 is 1x1 conv logger.debug(f"Running conv3") + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv3_input_channels, + "out_channels": self.conv3_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (1, 1), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor): + print("Preparing conv3 weights") + self.conv3_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv3_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv3_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv3_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv2_bias_tensor is not None else None, + ) + self.conv3_weight_tensor = ttnn.to_device(self.conv3_weight_tensor, device) + self.conv3_bias_tensor = ttnn.to_device(self.conv3_bias_tensor, device) if self.conv3_bias_tensor else None + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, @@ -386,15 +493,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -742,6 +841,38 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt logger.debug(f"==== first conv") # first conv + conv_kwargs = { + "input_layout": fold_output_tensor.get_layout(), + "in_channels": self.conv1_input_channels, + "out_channels": self.conv1_output_channels, + "batch_size": self.batch_size, + "input_height": self.conv1_input_height, + "input_width": self.conv1_input_width, + "kernel_size": self.conv1_kernel_size, + "stride": self.conv1_stride, + "padding": self.conv1_padding, + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": self.conv1_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): + self.conv1_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv1_weight_tensor, + weights_format="OIHW", + input_memory_config=fold_output_tensor.memory_config(), + **conv_kwargs, + ) + + self.conv1_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv1_bias_tensor, + input_memory_config=fold_output_tensor.memory_config(), + **conv_kwargs if self.conv1_bias_tensor is not None else None, + ) + self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device) + self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=fold_output_tensor, weight_tensor=self.conv1_weight_tensor,