From 41d7e3429ef8307b7eb07bc5306696d1085679d6 Mon Sep 17 00:00:00 2001 From: Abhinav Sarje Date: Wed, 26 Jun 2024 21:51:22 +0000 Subject: [PATCH] #0: added transpose shards arg, but whb0 with RM has some flipped grid issue --- .../ttnn_functional_resnet50_new_conv_api.py | 605 ++---------------- 1 file changed, 69 insertions(+), 536 deletions(-) diff --git a/models/experimental/resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/experimental/resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 537f87bae26..15b979ae638 100644 --- a/models/experimental/resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/experimental/resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -152,6 +152,7 @@ def run_downsample_if_req( conv_op_cache, reshard_if_not_optimal=False, height_sharding=None, + transpose_shards=True, ): if self.downsample: logger.debug(f"Running downsample") @@ -176,6 +177,7 @@ def run_downsample_if_req( deallocate_activation=True, reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16), reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, ), conv_op_cache=conv_op_cache, ) @@ -196,6 +198,7 @@ def __call__( reshard_if_not_optimal=False, height_sharding=None, eltwise_binary_out_in_place=True, + transpose_shards=True, ): logger.debug( f"==== Running {batch_size}, {input_height}, {input_width}, {self.conv1_input_channels}, {self.conv1_output_channels}" @@ -244,6 +247,7 @@ def __call__( activation="relu", height_sharding=height_sharding, reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, ), conv_op_cache=conv_op_cache, ) @@ -284,7 +288,15 @@ def __call__( out = ttnn.reallocate(out) x = ttnn.reallocate(x_rm) ds_out = self.run_downsample_if_req( - x, device, batch_size, input_height, input_width, conv_op_cache, reshard_if_not_optimal, height_sharding + x, + device, + batch_size, + input_height, + input_width, + conv_op_cache, + reshard_if_not_optimal, + height_sharding, + transpose_shards=transpose_shards, ) reallocate_halo_output = batch_size == 20 @@ -312,6 +324,7 @@ def __call__( act_block_h_override=act_block_h_override, height_sharding=height_sharding, reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, ), conv_op_cache=conv_op_cache, ) @@ -349,6 +362,7 @@ def __call__( math_fidelity=self.model_config["MATH_FIDELITY"], height_sharding=height_sharding, reshard_if_not_optimal=reshard_if_not_optimal, + transpose_shards=transpose_shards, ), conv_op_cache=conv_op_cache, ) @@ -367,7 +381,15 @@ def __call__( else reshard_if_not_optimal ) ds_out = self.run_downsample_if_req( - x, device, batch_size, input_height, input_width, conv_op_cache, ds_reshard, height_sharding + x, + device, + batch_size, + input_height, + input_width, + conv_op_cache, + ds_reshard, + height_sharding, + transpose_shards=transpose_shards, ) assert ttnn.get_memory_config(out) == ttnn.get_memory_config( @@ -376,7 +398,6 @@ def __call__( if eltwise_binary_out_in_place: # underscore version is in_place = True - # out = ttnn.add_(out, ds_out, activations=["relu"], memory_config=ttnn.get_memory_config(out)) out = ttnn.add_( out, ds_out, @@ -385,7 +406,6 @@ def __call__( ) else: out = ttnn.add( - # out, ds_out, activations=["relu"], memory_config=ttnn.L1_MEMORY_CONFIG out, ds_out, activations=[ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU)], @@ -569,526 +589,6 @@ def __call__(self, input_tensor, device, batch_size, ops_parallel_config) -> ttn input_tensor, device, batch_size, ops_parallel_config, {} if not ops_parallel_config else self.conv_op_cache ) - # def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> ttnn.Tensor: - # ## copy input to device sharded directly - # # x = ttnn.to_device(input_tensor, device=self.device, memory_config=self.conv1.conv.input_sharded_memory_config) - # conv_op_cache = {} - # if is_wormhole_b0(): - # if batch_size == 16: - # act_block_h_override = 1568 - # elif batch_size == 20: - # act_block_h_override = 640 - # else: - # act_block_h_override = 0 - # x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( - # input_tensor=input_tensor, - # weight_tensor=self.conv1_weight_tensor, - # in_channels=self.conv1_input_channels, - # out_channels=self.conv1_output_channels, - # device=device, - # bias_tensor=self.conv1_bias_tensor, - # kernel_size=(4, 4), - # stride=(1, 1), - # padding=(0, 0), - # batch_size=self.batch_size, - # input_height=115, - # input_width=115, - # conv_config=ttnn.Conv2dConfig( - # dtype=self.model_config["ACTIVATIONS_DTYPE"], - # weights_dtype=self.model_config["WEIGHTS_DTYPE"], - # math_fidelity=self.model_config["MATH_FIDELITY"], - # activation="relu", - # deallocate_activation=True, - # input_channels_alignment=16 if not is_wormhole_b0() else 32, - # act_block_h_override=act_block_h_override, - # ), - # conv_op_cache=conv_op_cache, - # ) - # # Relu is fused with conv1 - - # if self.batch_size == 20: - # x = ttnn.reallocate(x) - - # if is_wormhole_b0() and self.batch_size == 20: - # # TODO: fix the need to do the reshard here - # x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG) - # x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) - # x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config) - # x = self.max_pool(x) - # x_height = 56 - # x_width = 56 - # x = ttnn.reshape(x, (1, 1, x_height * x_width * self.batch_size, 64)) - - # if is_wormhole_b0(): - # # TODO: fix the need to do the reshard here - # # (memory_layout=TensorMemoryLayout::HEIGHT_SHARDED;buffer_type=BufferType::L1;shard_spec=tt::tt_metal::ShardSpec(grid={[(x=0;y=0) - (x=7;y=6)]}; shape={1120; 64}; orientation=ShardOrientation::ROW_MAJOR; halo=false))'} - # mem_config = ttnn.create_sharded_memory_config_( - # ttnn.Shape([self.batch_size * x_height * x_width, 64]), - # ttnn.CoreGrid(x=8, y=7), - # ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - # ttnn.ShardOrientation.ROW_MAJOR, - # tile_layout=True, - # ) - # x = ttnn.to_memory_config(x, mem_config) - # x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"]) - - # if self.batch_size == 20 and not is_wormhole_b0(): - # x = ttnn.reallocate(x) - - # logger.debug(f"==== Running layer 1 module 1") - # layer1_module1_input_shape = [ - # x.get_legacy_shape()[0], - # x.get_legacy_shape()[1], - # x.get_legacy_shape()[2], - # x.get_legacy_shape()[3], - # ] - # if is_wormhole_b0() and self.batch_size == 20: - # x, x_height, x_width = self.layer1_module1( - # x, - # device, - # batch_size, - # x_height, - # x_width, - # conv_op_cache, - # reshard_if_not_optimal=True, - # height_sharding=True, - # ) - # else: - # x, x_height, x_width = self.layer1_module1(x, device, batch_size, x_height, x_width, conv_op_cache) - # x_memory_config = ttnn.get_memory_config(x) - # ops_parallel_config["layer1_module1_input"] = ttnn.create_sharded_memory_config_( - # layer1_module1_input_shape, - # x_memory_config.shard_spec.grid, - # x_memory_config.memory_layout, - # x_memory_config.shard_spec.orientation, - # tile_layout=True, - # ) - # logger.debug(f"==== Running layer 1 module 2") - # x, x_height, x_width = self.layer1_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 1 module 3") - # x, x_height, x_width = self.layer1_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # if self.batch_size == 20 and is_wormhole_b0(): - # x = ttnn.reallocate(x) - - # logger.debug(f"==== Running layer 2 module 1") - # layer2_module1_input_shape = [ - # x.get_legacy_shape()[0], - # x.get_legacy_shape()[1], - # x.get_legacy_shape()[2], - # x.get_legacy_shape()[3], - # ] - # if is_wormhole_b0() and self.batch_size == 20: - # x, x_height, x_width = self.layer2_module1( - # x, - # device, - # batch_size, - # x_height, - # x_width, - # conv_op_cache, - # reshard_if_not_optimal=True if not is_wormhole_b0() else False, - # height_sharding=True, - # ) - # else: - # x, x_height, x_width = self.layer2_module1(x, device, batch_size, x_height, x_width, conv_op_cache) - - # x_memory_config = ttnn.get_memory_config(x) - # ops_parallel_config["layer2_module1_input"] = ttnn.create_sharded_memory_config_( - # layer2_module1_input_shape, - # x_memory_config.shard_spec.grid, - # x_memory_config.memory_layout, - # x_memory_config.shard_spec.orientation, - # tile_layout=True, - # ) - - # logger.debug(f"==== Running layer 2 module 2") - # x, x_height, x_width = self.layer2_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 2 module 3") - # x, x_height, x_width = self.layer2_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 2 module 4") - # x, x_height, x_width = self.layer2_module4(x, device, batch_size, x_height, x_width, conv_op_cache) - - # logger.debug(f"==== Running layer 3 module 1") - # layer3_module1_input_shape = [ - # x.get_legacy_shape()[0], - # x.get_legacy_shape()[1], - # x.get_legacy_shape()[2], - # x.get_legacy_shape()[3], - # ] - # x, x_height, x_width = self.layer3_module1( - # x, device, batch_size, x_height, x_width, conv_op_cache, reshard_if_not_optimal=True, height_sharding=False - # ) - # x_memory_config = ttnn.get_memory_config(x) - # ops_parallel_config["layer3_module1_input"] = ttnn.create_sharded_memory_config_( - # layer3_module1_input_shape, - # x_memory_config.shard_spec.grid, - # x_memory_config.memory_layout, - # x_memory_config.shard_spec.orientation, - # tile_layout=True, - # ) - # logger.debug(f"==== Running layer 3 module 2") - # x, x_height, x_width = self.layer3_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 3 module 3") - # x, x_height, x_width = self.layer3_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 3 module 4") - # x, x_height, x_width = self.layer3_module4(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 3 module 5") - # x, x_height, x_width = self.layer3_module5(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 3 module 6") - # x, x_height, x_width = self.layer3_module6( - # x, - # device, - # batch_size, - # x_height, - # x_width, - # conv_op_cache, - # eltwise_binary_out_in_place=False, - # ) - - # layer4_module1_input_shape = [ - # x.get_legacy_shape()[0], - # x.get_legacy_shape()[1], - # x.get_legacy_shape()[2], - # x.get_legacy_shape()[3], - # ] - # logger.debug(f"==== Running layer 4 module 1") - # x, x_height, x_width = self.layer4_module1( - # x, device, batch_size, x_height, x_width, conv_op_cache, reshard_if_not_optimal=True, height_sharding=False - # ) - # x_memory_config = ttnn.get_memory_config(x) - # ops_parallel_config["layer4_module1_input"] = ttnn.create_sharded_memory_config_( - # layer4_module1_input_shape, - # x_memory_config.shard_spec.grid, - # x_memory_config.memory_layout, - # x_memory_config.shard_spec.orientation, - # tile_layout=True, - # ) - # logger.debug(f"==== Running layer 4 module 2") - # x, x_height, x_width = self.layer4_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # logger.debug(f"==== Running layer 4 module 3") - # x, x_height, x_width = self.layer4_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - - # unpadded_shape = x.shape_without_padding() - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, - # (unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1), - # ttnn.L1_MEMORY_CONFIG, - # ) - # x = ttnn.reshape( - # x, - # ( - # self.batch_size, - # x.get_legacy_shape()[1], - # (int)(x.get_legacy_shape()[2] / self.batch_size), - # x.get_legacy_shape()[3], - # ), - # ) - - # grid_size = (8, 4) - # shard_grid = ttnn.experimental.tensor.CoreRangeSet( - # { - # ttnn.experimental.tensor.CoreRange( - # ttnn.experimental.tensor.CoreCoord(0, 0), - # ttnn.experimental.tensor.CoreCoord(grid_size[0] - 1, grid_size[1] - 1), - # ) - # } - # ) - # shard_shape = [ - # x.volume() // x.get_legacy_shape()[-1], - # x.get_legacy_shape()[-1] // (grid_size[0] * grid_size[1]), - # ] - # shard_spec = ttnn.experimental.tensor.ShardSpec( - # shard_grid, shard_shape, ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False - # ) - # width_sharded_mem_config = ttnn.types.MemoryConfig( - # ttnn.types.TensorMemoryLayout.WIDTH_SHARDED, ttnn.types.BufferType.L1, shard_spec - # ) - # x = ttnn.to_memory_config(x, width_sharded_mem_config) - # unpadded_shape = x.get_legacy_shape() - # padded_shape = [ - # unpadded_shape[0], - # unpadded_shape[1], - # _nearest_32(unpadded_shape[2]), - # _nearest_32(unpadded_shape[3]), - # ] - # x = ttnn.experimental.tensor.tilize_with_val_padding( - # x, - # padded_shape, - # 0, - # output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - # output_dtype=self.model_config["ACTIVATIONS_DTYPE"], - # ) - - # x = self.avgpool(x, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG) - - # unpadded_shape_end = [ - # x.get_legacy_shape()[0] - 1, - # x.get_legacy_shape()[1] - 1, - # 1 - 1, - # x.get_legacy_shape()[3] - 1, - # ] - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, unpadded_shape_end, output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG - # ) - - # x = ttnn.reshape( - # x, (1, x.get_legacy_shape()[1], self.batch_size * x.get_legacy_shape()[2], x.get_legacy_shape()[3]) - # ) - - # unpadded_shape = x.get_legacy_shape() - # padded_shape = [ - # unpadded_shape[0], - # unpadded_shape[1], - # _nearest_32(unpadded_shape[2]), - # _nearest_32(unpadded_shape[3]), - # ] - - # x = ttnn.experimental.tensor.tilize_with_val_padding( - # x, - # padded_shape, - # 0, - # output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - # output_dtype=self.model_config["ACTIVATIONS_DTYPE"], - # ) - - # x = self.fc(x) - # desired_shape = list(x.shape_without_padding()) - # desired_shape[-1] = 1000 - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, - # (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1), - # ttnn.L1_MEMORY_CONFIG, - # ) - # x = ttnn.reshape( - # x, - # ( - # self.batch_size, - # x.get_legacy_shape()[1], - # (int)(x.get_legacy_shape()[2] / self.batch_size), - # x.get_legacy_shape()[3], - # ), - # ) - # # for _, tensor in conv_op_cache["reader_patterns_cache"]["conv"].items(): - # # ttnn.deallocate(tensor) - # # for _, halo_tensors in conv_op_cache["reader_patterns_cache"]["halo"].items(): - # # for tensor in halo_tensors.values(): - # # if isinstance(tensor, ttnn.Tensor): - # # ttnn.deallocate(tensor) - # return x - - # def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cache) -> ttnn.Tensor: - # ## copy input to device sharded directly - # # x = ttnn.to_device(input_tensor, device=self.device, memory_config=self.conv1.conv.input_sharded_memory_config) - # if is_wormhole_b0(): - # if batch_size == 16: - # act_block_h_override = 1568 - # elif batch_size == 20: - # act_block_h_override = 640 - # else: - # act_block_h_override = 0 - # x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( - # input_tensor=input_tensor, - # weight_tensor=self.conv1_weight_tensor, - # in_channels=self.conv1_input_channels, - # out_channels=self.conv1_output_channels, - # device=device, - # bias_tensor=self.conv1_bias_tensor, - # kernel_size=(4, 4), - # stride=(1, 1), - # padding=(0, 0), - # batch_size=self.batch_size, - # input_height=115, - # input_width=115, - # conv_config=ttnn.Conv2dConfig( - # dtype=self.model_config["ACTIVATIONS_DTYPE"], - # weights_dtype=self.model_config["WEIGHTS_DTYPE"], - # math_fidelity=self.model_config["MATH_FIDELITY"], - # activation="relu", - # deallocate_activation=True, - # input_channels_alignment=16 if not is_wormhole_b0() else 32, - # act_block_h_override=act_block_h_override, - # ), - # conv_op_cache=conv_op_cache, - # ) - # # Relu is fused with conv1 - - # if self.batch_size == 20: - # x = ttnn.reallocate(x) - - # if is_wormhole_b0() and self.batch_size == 20: - # # TODO: fix the need to do the reshard here - # x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG) - # x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) - # x = ttnn.to_memory_config(x, self.max_pool.max_pool.input_sharded_memory_config) - # x = self.max_pool(x) - # x_height = 56 - # x_width = 56 - # x = ttnn.reshape(x, (1, 1, x_height * x_width * self.batch_size, 64)) - - # if is_wormhole_b0(): - # # TODO: fix the need to do the reshard here - # mem_config = ttnn.create_sharded_memory_config_( - # ttnn.Shape([self.batch_size * x_height * x_width, 64]), - # ttnn.CoreGrid(x=8, y=7), - # ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - # ttnn.ShardOrientation.ROW_MAJOR, - # tile_layout=True, - # ) - # x = ttnn.to_memory_config(x, mem_config) - # # x = ttnn.to_memory_config(x, self.layer1_module1.conv1.conv.input_sharded_memory_config) - # x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"]) - - # if self.batch_size == 20 and not is_wormhole_b0(): - # x = ttnn.reallocate(x) - - # if is_wormhole_b0() and batch_size == 20: - # x = ttnn.to_memory_config(x, ops_parallel_config["layer1_module1_input"]) - # x, x_height, x_width = self.layer1_module1(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer1_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer1_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # if self.batch_size == 20 and is_wormhole_b0(): - # x = ttnn.reallocate(x) - # x = ttnn.to_memory_config(x, ops_parallel_config["layer2_module1_input"]) - - # x, x_height, x_width = self.layer2_module1(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer2_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer2_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer2_module4(x, device, batch_size, x_height, x_width, conv_op_cache) - - # # do reshard before layer3 - # x = ttnn.to_memory_config(x, ops_parallel_config["layer3_module1_input"]) - # x, x_height, x_width = self.layer3_module1(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer3_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer3_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer3_module4(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer3_module5(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer3_module6( - # x, - # device, - # batch_size, - # x_height, - # x_width, - # conv_op_cache, - # eltwise_binary_out_in_place=False, - # ) - - # # do reshard before layer4 - # x = ttnn.to_memory_config(x, ops_parallel_config["layer4_module1_input"]) - # x, x_height, x_width = self.layer4_module1( - # x, - # device, - # batch_size, - # x_height, - # x_width, - # conv_op_cache, - # reshard_if_not_optimal=(is_wormhole_b0() and batch_size == 16), - # ) - # x, x_height, x_width = self.layer4_module2(x, device, batch_size, x_height, x_width, conv_op_cache) - # x, x_height, x_width = self.layer4_module3(x, device, batch_size, x_height, x_width, conv_op_cache) - - # unpadded_shape = x.shape_without_padding() - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, - # (unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1), - # ttnn.L1_MEMORY_CONFIG, - # ) - - # x = ttnn.reshape( - # x, - # ( - # self.batch_size, - # x.get_legacy_shape()[1], - # (int)(x.get_legacy_shape()[2] / self.batch_size), - # x.get_legacy_shape()[3], - # ), - # ) - - # grid_size = (8, 4) - # shard_grid = ttnn.experimental.tensor.CoreRangeSet( - # { - # ttnn.experimental.tensor.CoreRange( - # ttnn.experimental.tensor.CoreCoord(0, 0), - # ttnn.experimental.tensor.CoreCoord(grid_size[0] - 1, grid_size[1] - 1), - # ) - # } - # ) - # shard_shape = [ - # x.volume() // x.get_legacy_shape()[-1], - # x.get_legacy_shape()[-1] // (grid_size[0] * grid_size[1]), - # ] - # shard_spec = ttnn.experimental.tensor.ShardSpec( - # shard_grid, shard_shape, ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False - # ) - # width_sharded_mem_config = ttnn.types.MemoryConfig( - # ttnn.types.TensorMemoryLayout.WIDTH_SHARDED, ttnn.types.BufferType.L1, shard_spec - # ) - # x = ttnn.to_memory_config(x, width_sharded_mem_config) - # unpadded_shape = x.get_legacy_shape() - # padded_shape = [ - # unpadded_shape[0], - # unpadded_shape[1], - # _nearest_32(unpadded_shape[2]), - # _nearest_32(unpadded_shape[3]), - # ] - # x = ttnn.experimental.tensor.tilize_with_val_padding( - # x, - # padded_shape, - # 0, - # output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - # output_dtype=self.model_config["ACTIVATIONS_DTYPE"], - # ) - - # x = self.avgpool(x, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG) - - # unpadded_shape_end = [ - # x.get_legacy_shape()[0] - 1, - # x.get_legacy_shape()[1] - 1, - # 1 - 1, - # x.get_legacy_shape()[3] - 1, - # ] - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, unpadded_shape_end, output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG - # ) - - # x = ttnn.reshape( - # x, (1, x.get_legacy_shape()[1], self.batch_size * x.get_legacy_shape()[2], x.get_legacy_shape()[3]) - # ) - - # unpadded_shape = x.get_legacy_shape() - # padded_shape = [ - # unpadded_shape[0], - # unpadded_shape[1], - # _nearest_32(unpadded_shape[2]), - # _nearest_32(unpadded_shape[3]), - # ] - - # x = ttnn.experimental.tensor.tilize_with_val_padding( - # x, - # padded_shape, - # 0, - # output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - # output_dtype=self.model_config["ACTIVATIONS_DTYPE"], - # ) - - # x = self.fc(x) - # desired_shape = list(x.shape_without_padding()) - # desired_shape[-1] = 1000 - # x = ttnn.experimental.tensor.untilize_with_unpadding( - # x, - # (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1), - # ttnn.L1_MEMORY_CONFIG, - # ) - # x = ttnn.reshape( - # x, - # ( - # self.batch_size, - # x.get_legacy_shape()[1], - # (int)(x.get_legacy_shape()[2] / self.batch_size), - # x.get_legacy_shape()[3], - # ), - # ) - - # return x - ## merged runs (first and optimized) def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cache={}) -> ttnn.Tensor: is_first_run = False @@ -1098,6 +598,10 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac else: logger.debug(f"==== Optimized run") + transpose_shards = True + # if is_wormhole_b0(): + # transpose_shards = False + if is_wormhole_b0(): if batch_size == 16: act_block_h_override = 1568 @@ -1105,6 +609,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac act_block_h_override = 640 else: act_block_h_override = 0 + x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, @@ -1126,6 +631,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, + transpose_shards=transpose_shards, ), conv_op_cache=conv_op_cache, ) @@ -1148,7 +654,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac # TODO: fix the need to do the reshard here mem_config = ttnn.create_sharded_memory_config_( ttnn.Shape([self.batch_size * x_height * x_width, 64]), - ttnn.CoreGrid(x=8, y=7), + ttnn.CoreGrid(x=7, y=8), ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.ShardOrientation.ROW_MAJOR, tile_layout=True, @@ -1180,6 +686,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac conv_op_cache, reshard_if_not_optimal=reshard, height_sharding=height_shard, + transpose_shards=transpose_shards, ) if is_first_run: @@ -1193,10 +700,14 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac ) logger.debug(f"==== Running layer 1 module 2") - x, x_height, x_width = self.layer1_module2(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer1_module2( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 1 module 3") - x, x_height, x_width = self.layer1_module3(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer1_module3( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) if self.batch_size == 20 and is_wormhole_b0(): x = ttnn.reallocate(x) @@ -1222,6 +733,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac conv_op_cache, reshard_if_not_optimal=reshard, height_sharding=height_shard, + transpose_shards=transpose_shards, ) if is_first_run: @@ -1235,13 +747,19 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac ) logger.debug(f"==== Running layer 2 module 2") - x, x_height, x_width = self.layer2_module2(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer2_module2( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 2 module 3") - x, x_height, x_width = self.layer2_module3(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer2_module3( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 2 module 4") - x, x_height, x_width = self.layer2_module4(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer2_module4( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) layer3_module1_input_shape = ttnn.Shape(x.get_legacy_shape()) @@ -1263,6 +781,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac conv_op_cache, reshard_if_not_optimal=reshard, height_sharding=height_shard, + transpose_shards=transpose_shards, ) if is_first_run: @@ -1276,16 +795,24 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac ) logger.debug(f"==== Running layer 3 module 2") - x, x_height, x_width = self.layer3_module2(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer3_module2( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 3 module 3") - x, x_height, x_width = self.layer3_module3(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer3_module3( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 3 module 4") - x, x_height, x_width = self.layer3_module4(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer3_module4( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 3 module 5") - x, x_height, x_width = self.layer3_module5(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer3_module5( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 3 module 6") x, x_height, x_width = self.layer3_module6( @@ -1296,6 +823,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac x_width, conv_op_cache, eltwise_binary_out_in_place=False, + transpose_shards=transpose_shards, ) layer4_module1_input_shape = ttnn.Shape(x.get_legacy_shape()) @@ -1321,6 +849,7 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac conv_op_cache, reshard_if_not_optimal=reshard, height_sharding=height_shard, + transpose_shards=transpose_shards, ) if is_first_run: @@ -1334,10 +863,14 @@ def run(self, input_tensor, device, batch_size, ops_parallel_config, conv_op_cac ) logger.debug(f"==== Running layer 4 module 2") - x, x_height, x_width = self.layer4_module2(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer4_module2( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) logger.debug(f"==== Running layer 4 module 3") - x, x_height, x_width = self.layer4_module3(x, device, batch_size, x_height, x_width, conv_op_cache) + x, x_height, x_width = self.layer4_module3( + x, device, batch_size, x_height, x_width, conv_op_cache, transpose_shards=transpose_shards + ) unpadded_shape = x.shape_without_padding() x = ttnn.experimental.tensor.untilize_with_unpadding(