diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py index 1d9ac8acba0..85a5019e8b9 100644 --- a/models/demos/convnet_mnist/tt/convnet_mnist.py +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -35,13 +35,41 @@ def convnet_mnist( packer_l1_acc=False, ) x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) + + tt_weight = parameters.conv1.weight + tt_bias = parameters.conv1.bias + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": 1, + "out_channels": 32, + "batch_size": batch_size, + "input_height": input_tensor.shape[1], + "input_width": input_tensor.shape[2], + "kernel_size": (3, 3), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(tt_weight): + tt_weight = ttnn.prepare_conv_weights( + weight_tensor=tt_weight, + weights_format="OIHW", + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + tt_weight = ttnn.to_device(tt_weight, device) + x = ttnn.conv2d( input_tensor=x, - weight_tensor=parameters.conv1.weight, + weight_tensor=tt_weight, in_channels=1, out_channels=32, device=device, - bias_tensor=parameters.conv1.bias, + bias_tensor=tt_bias, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), @@ -81,13 +109,40 @@ def convnet_mnist( dilation=[1, 1], ) + tt_weight = parameters.conv2.weight + tt_bias = parameters.conv2.bias + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": 32, + "out_channels": 64, + "batch_size": batch_size, + "input_height": 15, + "input_width": 15, + "kernel_size": (3, 3), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(tt_weight): + tt_weight = ttnn.prepare_conv_weights( + weight_tensor=tt_weight, + weights_format="OIHW", + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + tt_weight = ttnn.to_device(tt_weight, device) + x, [out_height, out_width] = ttnn.conv2d( input_tensor=x, - weight_tensor=parameters.conv2.weight, + weight_tensor=tt_weight, in_channels=32, out_channels=64, device=device, - bias_tensor=parameters.conv2.bias, + bias_tensor=tt_bias, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), diff --git a/models/demos/segformer/tt/common.py b/models/demos/segformer/tt/common.py index f2278ad58bf..e1a45263be9 100644 --- a/models/demos/segformer/tt/common.py +++ b/models/demos/segformer/tt/common.py @@ -63,6 +63,41 @@ def __call__(self, device, input_tensor): if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h + conv_kwargs = { + "input_layout": input_tensor.get_layout(), + "in_channels": input_tensor.shape[3], + "out_channels": self.out_channels, + "batch_size": input_tensor.shape[0], + "input_height": input_tensor.shape[1], + "input_width": input_tensor.shape[2], + "kernel_size": self.kernel_size, + "stride": (self.conv_params[0], self.conv_params[1]), + "padding": (self.conv_params[2], self.conv_params[3]), + "dilation": (1, 1), + "groups": self.groups, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.weights): + self.weights = ttnn.prepare_conv_weights( + weight_tensor=self.weights, + weights_format="OIHW", + input_memory_config=input_tensor.memory_config(), + **conv_kwargs, + ) + self.bias = ( + ttnn.prepare_conv_bias( + bias_tensor=self.bias, + input_memory_config=input_tensor.memory_config(), + **conv_kwargs, + ) + if self.bias is not None + else None + ) + self.weights = ttnn.to_device(self.weights, device) + self.bias = ttnn.to_device(self.bias, device) if self.bias else None + [output_tensor, [_out_height, _out_width]] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index a8944b654c3..2751a5cfabb 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -231,6 +231,47 @@ def __call__( # conv1 is 1x1 conv logger.debug(f"Running conv1") module_input_height = input_height + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + activation="relu", + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv1_input_channels, + "out_channels": self.conv1_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (1, 1), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): + self.conv1_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv1_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv1_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv1_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv1_bias_tensor is not None else None, + ) + self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device) + self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, @@ -244,16 +285,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - activation="relu", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -317,6 +349,54 @@ def __call__( reallocate_halo_output = batch_size == 20 logger.debug(f"Running conv2") + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + activation="relu", + deallocate_activation=True, + reallocate_halo_output=reallocate_halo_output, + act_block_h_override=act_block_h_override, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + enable_act_double_buffer=enable_act_double_buffer, + enable_weights_double_buffer=True, + enable_split_reader=enable_split_reader, + enable_subblock_padding=enable_subblock_padding, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv2_input_channels, + "out_channels": self.conv2_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (3, 3), + "stride": (self.stride, self.stride), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor): + self.conv2_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv2_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv2_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv2_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv2_bias_tensor is not None else None, + ) + self.conv2_weight_tensor = ttnn.to_device(self.conv2_weight_tensor, device) + self.conv2_bias_tensor = ttnn.to_device(self.conv2_bias_tensor, device) if self.conv2_bias_tensor else None + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, @@ -330,23 +410,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - activation="relu", - deallocate_activation=True, - reallocate_halo_output=reallocate_halo_output, - act_block_h_override=act_block_h_override, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - enable_act_double_buffer=enable_act_double_buffer, - enable_weights_double_buffer=True, - enable_split_reader=enable_split_reader, - enable_subblock_padding=enable_subblock_padding, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -373,6 +437,46 @@ def __call__( # conv3 is 1x1 conv logger.debug(f"Running conv3") + conv_config = ttnn.Conv2dConfig( + dtype=self.model_config["ACTIVATIONS_DTYPE"], + weights_dtype=self.model_config["WEIGHTS_DTYPE"], + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED + if height_sharding + else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, + ) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.conv3_input_channels, + "out_channels": self.conv3_output_channels, + "batch_size": batch_size, + "input_height": input_height, + "input_width": input_width, + "kernel_size": (1, 1), + "stride": (1, 1), + "padding": (0, 0), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor): + self.conv3_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv3_weight_tensor, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.conv3_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv3_bias_tensor, + input_memory_config=x.memory_config(), + **conv_kwargs if self.conv2_bias_tensor is not None else None, + ) + self.conv3_weight_tensor = ttnn.to_device(self.conv3_weight_tensor, device) + self.conv3_bias_tensor = ttnn.to_device(self.conv3_bias_tensor, device) if self.conv3_bias_tensor else None + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, @@ -386,15 +490,7 @@ def __call__( batch_size=batch_size, input_height=input_height, input_width=input_width, - conv_config=ttnn.Conv2dConfig( - dtype=self.model_config["ACTIVATIONS_DTYPE"], - weights_dtype=self.model_config["WEIGHTS_DTYPE"], - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if height_sharding - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - reshard_if_not_optimal=reshard_if_not_optimal, - transpose_shards=transpose_shards, - ), + conv_config=conv_config, compute_config=ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], @@ -742,6 +838,38 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt logger.debug(f"==== first conv") # first conv + conv_kwargs = { + "input_layout": fold_output_tensor.get_layout(), + "in_channels": self.conv1_input_channels, + "out_channels": self.conv1_output_channels, + "batch_size": self.batch_size, + "input_height": self.conv1_input_height, + "input_width": self.conv1_input_width, + "kernel_size": self.conv1_kernel_size, + "stride": self.conv1_stride, + "padding": self.conv1_padding, + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": self.conv1_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): + self.conv1_weight_tensor = ttnn.prepare_conv_weights( + weight_tensor=self.conv1_weight_tensor, + weights_format="OIHW", + input_memory_config=fold_output_tensor.memory_config(), + **conv_kwargs, + ) + + self.conv1_bias_tensor = ttnn.prepare_conv_bias( + bias_tensor=self.conv1_bias_tensor, + input_memory_config=fold_output_tensor.memory_config(), + **conv_kwargs if self.conv1_bias_tensor is not None else None, + ) + self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device) + self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=fold_output_tensor, weight_tensor=self.conv1_weight_tensor, diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index fe044e07665..cabe60c1e1a 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -115,6 +115,43 @@ def ttnn_vgg16( tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT) + + conv_kwargs = { + "input_layout": tt_x.get_layout(), + "in_channels": conv_ttnn_params[iter_conv_id][0], + "out_channels": conv_ttnn_params[iter_conv_id][1], + "batch_size": batch_size, + "input_height": conv_ttnn_params[iter_conv_id][2], + "input_width": conv_ttnn_params[iter_conv_id][3], + "kernel_size": (3, 3), + "stride": (1, 1), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if True or not ttnn.is_tensor_storage_on_device(tt_weight): + tt_weight = ttnn.prepare_conv_weights( + weight_tensor=tt_weight, + weights_format="OIHW", + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + + tt_bias = ( + ttnn.prepare_conv_bias( + bias_tensor=tt_bias, + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + if tt_bias is not None + else None + ) + + tt_weight = ttnn.to_device(tt_weight, device) + tt_bias = ttnn.to_device(tt_bias, device) if tt_bias else None # Call ttnn.conv conv_op_cache = {} [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( @@ -245,6 +282,43 @@ def ttnn_vgg11( tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias tt_bias = ttnn.to_layout(ttnn.from_device(tt_bias), layout=ttnn.ROW_MAJOR_LAYOUT) + conv_kwargs = { + "input_layout": tt_x.get_layout(), + "in_channels": conv_ttnn_params_2[iter_conv_id][0], + "out_channels": conv_ttnn_params_2[iter_conv_id][1], + "batch_size": batch_size, + "input_height": conv_ttnn_params_2[iter_conv_id][2], + "input_width": conv_ttnn_params_2[iter_conv_id][3], + "kernel_size": (3, 3), + "stride": (1, 1), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(tt_weight): + tt_weight = ttnn.prepare_conv_weights( + weight_tensor=tt_weight, + weights_format="OIHW", + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + + tt_bias = ( + ttnn.prepare_conv_bias( + bias_tensor=tt_bias, + input_memory_config=ttnn.L1_MEMORY_CONFIG, + **conv_kwargs, + ) + if tt_bias is not None + else None + ) + + tt_weight = ttnn.to_device(tt_weight, device) + tt_bias = ttnn.to_device(tt_bias, device) if tt_bias else None + # Call ttnn.conv conv_op_cache = {} [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 570d2457f1a..e1864f9acef 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -146,6 +146,41 @@ def __call__( if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] + conv_kwargs = { + "input_layout": hidden_states.get_layout(), + "in_channels": in_channels, + "out_channels": self.out_channels, + "batch_size": hidden_states.shape[0], + "input_height": hidden_states.shape[1], + "input_width": hidden_states.shape[2], + "kernel_size": (3, 3), + "stride": (self.stride, self.stride), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": self.device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv_weights): + self.conv_weights = ttnn.prepare_conv_weights( + weight_tensor=self.conv_weights, + weights_format="OIHW", + input_memory_config=hidden_states.memory_config(), + **conv_kwargs, + ) + self.conv_bias = ( + ttnn.prepare_conv_bias( + bias_tensor=self.conv_bias, + input_memory_config=hidden_states.memory_config(), + **conv_kwargs, + ) + if self.conv_bias is not None + else None + ) + self.conv_weights = ttnn.to_device(self.conv_weights, self.device) + self.conv_bias = ttnn.to_device(self.conv_bias, self.device) + [hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.in_channels, diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py index 1579f9112f9..d40ba69b3d3 100644 --- a/models/demos/yolov4/ttnn/common.py +++ b/models/demos/yolov4/ttnn/common.py @@ -102,6 +102,43 @@ def __call__(self, device, input_tensor): if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h + conv_kwargs = { + "input_layout": input_tensor.get_layout(), + "in_channels": self.input_params[3], + "out_channels": self.out_channels, + "batch_size": self.input_params[0], + "input_height": self.input_params[1], + "input_width": self.input_params[2], + "kernel_size": self.kernel_size, + "stride": (self.conv_params[0], self.conv_params[1]), + "padding": (self.conv_params[2], self.conv_params[3]), + "dilation": (1, 1), + "groups": 1, + "device": device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.weights): + self.weights = ttnn.prepare_conv_weights( + weight_tensor=self.weights, + weights_format="OIHW", + input_memory_config=input_tensor.memory_config(), + **conv_kwargs, + ) + + self.bias = ( + ttnn.prepare_conv_bias( + bias_tensor=self.bias, + input_memory_config=input_tensor.memory_config(), + **conv_kwargs, + ) + if self.bias is not None + else None + ) + + self.weights = ttnn.to_device(self.weights, device) + self.bias = ttnn.to_device(self.bias, device) if self.bias else None + output_tensor, [self.weights, self.bias] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 8a5157d51dc..ed34523c15e 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -147,6 +147,38 @@ def __init__( self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "batch_size": self.batch_size, + "input_height": self.input_height, + "input_width": self.input_width, + "kernel_size": self.kernel_size, + "stride": self.stride, + "padding": self.padding, + "dilation": [1, 1], + "groups": self.groups, + "device": self.device, + "conv_config": self.conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.weight): + breakpoint + self.weight = ttnn.prepare_conv_weights( + weight_tensor=self.weight, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.bias = ttnn.prepare_conv_bias( + bias_tensor=self.bias, + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.weight = ttnn.to_device(self.weight, self.device) + self.bias = ttnn.to_device(self.bias, self.device) if self.bias else None + x, [self.weight, self.bias] = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index acd3453ecf5..fc2b28a7265 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -184,7 +184,7 @@ ParallelConfig determine_parallel_config( return pconfig; } -static ParallelConfig determine_output_parallel_config( +ParallelConfig determine_output_parallel_config( const ParallelConfig& input_parallel_config, const CoreCoord& compute_grid_size, uint32_t out_channels, @@ -796,6 +796,9 @@ void adjust_conv_op_config_for_auto_shard_if_necessary( // If the input tensor is already sharded, or the conv_config has a specified shard layout, we don't need to do anything. if ((input_memory_config.has_value() && input_memory_config.value().is_sharded()) || conv_config.shard_layout.has_value()) { + if(input_memory_config.has_value() && input_memory_config.value().is_sharded() && !conv_config.reshard_if_not_optimal) { + conv_config.shard_layout = input_memory_config.value().memory_layout; + } return; } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 69ce604a671..3c7193246c3 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -108,6 +108,12 @@ sliding_window::ParallelConfig determine_parallel_config( bool enable_channels_padding, bool is_out_tiled=true); +sliding_window::ParallelConfig determine_output_parallel_config( + const sliding_window::ParallelConfig& input_parallel_config, + const CoreCoord& compute_grid_size, + uint32_t out_channels, + bool is_mm_conv); + uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 0ba0363a9e6..b4152faad3a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -138,16 +138,8 @@ static OptimizedConvBlockConfig get_opt_block_config( shard_orientation, !use_non_tile_height); - auto output_parallel_config = parallel_config; - if(conv_config.shard_layout.value() == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !mm_conv) { - uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - output_parallel_config = { - .grid = num_cores_to_corerangeset( find_closest_largest_divisor(tt::div_up(out_channels, tt::constants::TILE_WIDTH),max_num_cores), compute_grid_size, true), - .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, - .shard_orientation = parallel_config.shard_orientation - }; - log_debug(tt::LogOp, "Changing width sharded output grid to {}",output_parallel_config.grid); - } + ParallelConfig output_parallel_config = + determine_output_parallel_config(parallel_config, compute_grid_size, out_channels, mm_conv); uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( @@ -195,7 +187,6 @@ std::pair> prepare_conv_weights_biases uint32_t input_width, const bool parameters_on_device, bool is_non_tile_mul_width) { - validate_weight_tensor(weight_tensor); ttnn::Tensor weight_tensor_; // tensor to return ttnn::Tensor bias_tensor_; @@ -407,7 +398,6 @@ ttnn::Tensor prepare_conv_bias( const std::optional& compute_config_) { TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); - const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width =