From b7f468da0ad8c0a5be59888130e39733050f8d55 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Mon, 15 Apr 2024 21:18:50 +0000 Subject: [PATCH 1/2] #7478: added more zone scopes in tensor methods --- tt_eager/tensor/tensor.cpp | 12 ++++++++++++ tt_metal/tools/profiler/op_profiler.hpp | 8 ++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index b426eba658f..7900d694d93 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -30,6 +30,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L tensor_id{std::nullopt}, tensor_attributes(std::make_shared(storage, shape, dtype, layout)), deallocate_through_destructor(false) { + ZoneScoped; std::visit( [&] (auto&& storage) { using StorageType = std::decay_t; @@ -70,6 +71,7 @@ Tensor::Tensor(const Storage storage, const Shape shape, DataType dtype, Layout Tensor(storage, ttnn::Shape{shape}, dtype, layout) {} Tensor::~Tensor() { + ZoneScoped; this->deallocate_through_destructor = true; this->deallocate(); // Decrement main thread ref count for all tensors on device @@ -80,6 +82,7 @@ Tensor::~Tensor() { } void Tensor::deallocate(bool force) { + ZoneScoped; if (this->tensor_attributes.use_count()) { // Check if the attributes didn't get moved to another tensor. // If not, we can deallocate this tensor. @@ -163,6 +166,7 @@ void Tensor::deallocate(bool force) { // Main Thread - Wait for all workers in this tensor to populate the entire tensor void Tensor::wait_for_tensor_data_populated() const { + ZoneScoped; // Stall until all the workers for this tensor // have populated the full tensor for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { @@ -175,6 +179,7 @@ void Tensor::wait_for_tensor_data_populated() const { // Main Thread - Wait for the first worker in this tensor to populate the global metadata fields void Tensor::wait_for_tensor_metadata_populated() const { + ZoneScoped; // First worker is responsible for updating all metadata fields // Stall until this worker is done while (true) { @@ -198,6 +203,7 @@ void Tensor::set_populated(Device* worker) { } void Tensor::deepcopy(const Tensor& other) { + ZoneScoped; // Wait until the tensor being copied is populated other.wait_for_tensor_data_populated(); // Populate tensor metadata @@ -210,6 +216,7 @@ void Tensor::deepcopy(const Tensor& other) { } void Tensor::populate_buffers_and_metadata(const Tensor& other) { + ZoneScoped; // Similar to deepcopy, but to be applied on a tensor that has an empty storage // container initialized. Require tensor storage to be correctly initialized. this->set_shape(other.get_shape()); @@ -230,6 +237,7 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { } std::vector Tensor::get_workers(bool blocking) const { + ZoneScoped; // Initialize an empty worker vector (remains empty for host side storage) std::vector workers = {}; @@ -406,6 +414,7 @@ Tensor Tensor::cpu_sharded() const { Tensor Tensor::extract_shard(const CoreCoord & core) const{ + ZoneScoped; auto buffer_page_mapping = generate_buffer_page_mapping(*this->buffer()); auto core_id = buffer_page_mapping.core_to_core_id_.at(core); return this->extract_shard(core_id); @@ -485,11 +494,13 @@ uint32_t Tensor::element_size() const { } Tensor Tensor::reshape(int N, int C, int H, int W) const { + ZoneScoped; auto new_shape = infer_dims_for_reshape(N, C, H, W, this->volume()); return this->reshape(new_shape); } Tensor Tensor::reshape(const Shape& new_shape) const { + ZoneScoped; TT_ASSERT( this->volume() == tt::tt_metal::compute_volume(new_shape), "{} != {}", @@ -504,6 +515,7 @@ Tensor Tensor::reshape(const Shape& new_shape) const { } bool Tensor::is_allocated() const { + ZoneScoped; return std::visit( [](auto&& storage) -> bool { diff --git a/tt_metal/tools/profiler/op_profiler.hpp b/tt_metal/tools/profiler/op_profiler.hpp index 809aafd2b3c..f884d169a84 100644 --- a/tt_metal/tools/profiler/op_profiler.hpp +++ b/tt_metal/tools/profiler/op_profiler.hpp @@ -132,10 +132,10 @@ namespace op_profiler { } auto tensor_shape = tensor.get_legacy_shape(); - ret["shape"]["W"] = tensor_shape[0]; - ret["shape"]["Z"] = tensor_shape[1]; - ret["shape"]["Y"] = tensor_shape[2]; - ret["shape"]["X"] = tensor_shape[3]; + ret["shape"]["W"] = tensor_shape.rank() >= 4 ? tensor_shape[-4] : 1; + ret["shape"]["Z"] = tensor_shape.rank() >= 3 ? tensor_shape[-3] : 1; + ret["shape"]["Y"] = tensor_shape.rank() >= 2 ? tensor_shape[-2] : 1; + ret["shape"]["X"] = tensor_shape[-1]; ret["layout"] = fmt::format("{}", magic_enum::enum_name(tensor.get_layout())); ret["dtype"] = fmt::format("{}", magic_enum::enum_name(tensor.get_dtype())); From 8cf2c6799bb88ae56eee34d06a0366b7109bd06b Mon Sep 17 00:00:00 2001 From: Aleks Knezevic Date: Fri, 12 Apr 2024 19:56:03 +0000 Subject: [PATCH 2/2] #0: Removed tilize ops in res block and kept transformer output in L1 --- ...ttnn_functional_basic_transformer_block.py | 1 - .../tt2/ttnn_functional_cross_attention.py | 1 - ...unctional_cross_attention_down_block_2d.py | 2 +- .../tt2/ttnn_functional_cross_attn_upblock.py | 5 - .../tt2/ttnn_functional_resnetblock2d.py | 165 ++++++++---------- .../tt2/ttnn_functional_transformer_2d.py | 14 +- .../ttnn_functional_upsample_nearest_2d.py | 2 + tt_eager/tt_dnn/op_library/move/move_op.hpp | 3 +- ttnn/ttnn/experimental/golden_functions.py | 1 + 9 files changed, 88 insertions(+), 106 deletions(-) diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py index 60c8872be8a..e7d612dfdb4 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py @@ -88,7 +88,6 @@ def basic_transformer_block( ) if use_ada_layer_norm_zero: assert False, "AdaLayerNormZero not supported and not used in stable diffusion" - # ttnn.dump_device_memory_state(device) ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff, device=device) hidden_states = ttnn.add(ff_output, hidden_states) diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py index 03b0ff25457..e51b245d15b 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py @@ -261,7 +261,6 @@ def time_sharded_attention(self, query, t_key, value, head_size): ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, ) - print(slice.memory_config()) program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( compute_with_storage_grid_size=grid_size, in0_block_w=2, diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py index 3e80753573b..1bd7dc78cbb 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py @@ -105,7 +105,7 @@ def __call__( upcast_attention=upcast_attention, ) - output_states += (hidden_states,) + output_states += (ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG),) if add_downsample is not None: hidden_states = self.downsample_2d( diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attn_upblock.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attn_upblock.py index d99e0cfef5c..918bc7982fa 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attn_upblock.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attn_upblock.py @@ -96,7 +96,6 @@ def __call__( index=-1, ): for i, (resnet, attention) in enumerate(zip(self.resnets, self.attentions)): - ttnn.dump_device_memory_state(self.device, prefix="in_uplock_") res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -134,11 +133,7 @@ def __call__( ttnn.clone, hidden_states, memory_config=ttnn.get_memory_config(hidden_states), dtype=ttnn.bfloat8_b ) hidden_states = dealloc_input(ttnn.concat, [hidden_states, on_dev_res_hidden_states], dim=3) - # breakpoint() ttnn.deallocate(on_dev_res_hidden_states) - # breakpoint() - # hidden_states = ttnn.reallocate(hidden_states) - # ttnn.dump_device_memory_state(self.device, prefix="after_reallocate_before_resnet") hidden_states = resnet( hidden_states, temb=temb, diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py index 22b154b05d8..c5b2cc2ffda 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py @@ -16,6 +16,7 @@ pre_process_input, post_process_output, permute_conv_parameters, + weight_to_bfp8, ) import time @@ -209,6 +210,7 @@ def __init__( self.groups = 32 if use_in_shortcut: assert self.conv2.conv.output_sharded_memory_config == self.conv_shortcut.conv.output_sharded_memory_config + ( self.first_gn_expected_input_sharded_memory_config, self.first_group_norm_core_grid, @@ -328,6 +330,8 @@ def __init__( device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) + self.parameters.time_emb_proj.weight = weight_to_bfp8(self.parameters.time_emb_proj.weight) + self.parameters.time_emb_proj.bias = weight_to_bfp8(self.parameters.time_emb_proj.bias) def __call__( self, @@ -356,48 +360,29 @@ def __call__( nonlinearity = ttnn.silu out_channels = in_channels if out_channels is None else out_channels - hidden_states = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True) + hidden_states = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True) if ttnn.get_memory_config(hidden_states) != self.first_gn_expected_input_sharded_memory_config: hidden_states = ttnn.reshape( hidden_states, (self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, in_channels) ) hidden_states = ttnn.to_memory_config(hidden_states, self.first_gn_expected_input_sharded_memory_config) - if self.fallback_on_groupnorm: - hidden_states = ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG) - hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True) - hidden_states = ttnn.reshape( - hidden_states, (self.conv2.batch_size, self.conv2.input_height, self.conv2.input_width, in_channels) - ) - hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2)) - hidden_states = ttnn.operations.normalization._fallback_group_norm( - hidden_states, - num_groups=groups, - weight=self.parameters.norm1.weight, - bias=self.parameters.norm1.bias, - epsilon=eps, - ) - hidden_states = pre_process_input(self.device, hidden_states) - hidden_states = ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT, use_multicore=True) - else: - hidden_states = ttnn.group_norm( - hidden_states, - num_groups=groups, - input_mask=self.norm1_input_mask, - weight=self.parameters.norm1.weight, - bias=self.parameters.norm1.bias, - epsilon=eps, - memory_config=ttnn.get_memory_config(hidden_states), - core_grid=self.first_group_norm_core_grid, - ) - hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG) - hidden_states = ttnn.reshape( - hidden_states, - (1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, in_channels), - ) - hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states)) + hidden_states = ttnn.group_norm( + hidden_states, + num_groups=groups, + input_mask=self.norm1_input_mask, + weight=self.parameters.norm1.weight, + bias=self.parameters.norm1.bias, + epsilon=eps, + memory_config=ttnn.get_memory_config(hidden_states), + core_grid=self.first_group_norm_core_grid, + dtype=ttnn.bfloat8_b, + ) + hidden_states = ttnn.reshape( + hidden_states, + (1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, in_channels), + ) if up: assert False, "Up block within residual block is not implemented!" @@ -405,14 +390,28 @@ def __call__( assert False, "Down block within residual block is not implemented" conv1_split_chunks = len(self.conv1s) - if conv1_split_chunks > 1: + if conv1_split_chunks == 1: + hidden_states = ttnn.experimental.tensor.sharded_to_interleaved( + hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype + ) + hidden_states = ttnn.experimental.tensor.interleaved_to_sharded( + hidden_states, self.conv1s[0].conv.input_sharded_memory_config, hidden_states.dtype + ) + hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states)) + hidden_states = self.conv1s[0](hidden_states) + else: split_hidden_states = [] output_tensor_start_width_dim = 0 in_channels = self.parameters.conv1.weight.shape[1] split_input_channels = in_channels // conv1_split_chunks + # unpad sharded causes output mismatch + hidden_states = ttnn.experimental.tensor.sharded_to_interleaved( + hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype + ) output_tensor_end_width_dim = split_input_channels for i in range(conv1_split_chunks): + # TODO: Can we replace this with interleaved_to_sharded_partial split_hidden_states.append( ttnn.experimental.tensor.unpad( hidden_states, @@ -423,18 +422,19 @@ def __call__( hidden_states.shape[2] - 1, output_tensor_end_width_dim - 1, ], + # output_mem_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1) ) ) output_tensor_start_width_dim += split_input_channels output_tensor_end_width_dim += split_input_channels - if conv1_split_chunks == 1: - hidden_states = ttnn.to_memory_config(hidden_states, self.conv1s[0].conv.input_sharded_memory_config) - hidden_states = self.conv1s[0](hidden_states) - else: - for i in range(conv1_split_chunks): - split_hidden_states[i] = ttnn.to_memory_config( - split_hidden_states[i], self.conv1s[i].conv.input_sharded_memory_config + split_hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded( + split_hidden_states[i], + self.conv1s[i].conv.input_sharded_memory_config, + split_hidden_states[i].dtype, + ) + split_hidden_states[i] = nonlinearity( + split_hidden_states[i], memory_config=ttnn.get_memory_config(split_hidden_states[i]) ) split_hidden_states[i] = self.conv1s[i](split_hidden_states[i]) if i != 0: @@ -461,82 +461,63 @@ def __call__( self.parameters.time_emb_proj.weight, bias=self.parameters.time_emb_proj.bias, core_grid=temb.device().core_grid, + dtype=ttnn.bfloat8_b, + memory_config=ttnn.L1_MEMORY_CONFIG, ) - # temb = ttnn.permute(temb, (2, 0, 1, 3)) if temb is not None and time_embedding_norm == "default": hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG) - hidden_states = ttnn.clone( - hidden_states, memory_config=ttnn.get_memory_config(hidden_states), dtype=ttnn.bfloat16 - ) hidden_states = ttnn.reshape( hidden_states, (self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, out_channels), ) - hidden_states = hidden_states + temb + hidden_states = ttnn.add(hidden_states, temb, memory_config=ttnn.L1_MEMORY_CONFIG) - # TODO: Reshape happening twice hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True) - hidden_states = ttnn.reshape( + hidden_states = ttnn.to_memory_config(hidden_states, self.second_gn_expected_input_sharded_memory_config) + hidden_states = ttnn.group_norm( hidden_states, - (self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, out_channels), + num_groups=groups, + input_mask=self.norm2_input_mask, + weight=self.parameters.norm2.weight, + bias=self.parameters.norm2.bias, + epsilon=eps, + memory_config=self.second_gn_expected_input_sharded_memory_config, + core_grid=self.second_group_norm_core_grid, + dtype=ttnn.bfloat8_b, ) - if self.fallback_on_groupnorm: - hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG) - hidden_states = ttnn.reshape( - hidden_states, - (self.conv1s[0].batch_size, self.conv1s[0].input_height, self.conv1s[0].input_width, out_channels), - ) - hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2)) - hidden_states = ttnn.operations.normalization._fallback_group_norm( - hidden_states, - num_groups=groups, - weight=self.parameters.norm2.weight, - bias=self.parameters.norm2.bias, - epsilon=eps, - ) - - hidden_states = pre_process_input(self.device, hidden_states) - else: - hidden_states = ttnn.to_memory_config(hidden_states, self.second_gn_expected_input_sharded_memory_config) - hidden_states = ttnn.group_norm( - hidden_states, - num_groups=groups, - input_mask=self.norm2_input_mask, - weight=self.parameters.norm2.weight, - bias=self.parameters.norm2.bias, - epsilon=eps, - memory_config=self.second_gn_expected_input_sharded_memory_config, - core_grid=self.second_group_norm_core_grid, - ) - hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG) hidden_states = ttnn.reshape( hidden_states, (1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, out_channels), ) + + hidden_states = ttnn.experimental.tensor.sharded_to_interleaved( + hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype + ) + hidden_states = ttnn.experimental.tensor.interleaved_to_sharded( + hidden_states, self.conv2.conv.input_sharded_memory_config, hidden_states.dtype + ) + hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states)) - hidden_states = ttnn.to_memory_config(hidden_states, self.conv2.conv.input_sharded_memory_config) hidden_states = self.conv2(hidden_states) use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut if use_in_shortcut: if ttnn.get_memory_config(input_tensor) != self.conv_shortcut.conv.input_sharded_memory_config: - input_tensor = ttnn.to_memory_config(input_tensor, self.conv_shortcut.conv.input_sharded_memory_config) + # TODO: Once reshard fix is in, store input tensor in sharded + if input_tensor.memory_config().is_sharded(): + input_tensor = ttnn.experimental.tensor.sharded_to_interleaved( + input_tensor, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype + ) + input_tensor = ttnn.experimental.tensor.interleaved_to_sharded( + input_tensor, self.conv_shortcut.conv.input_sharded_memory_config, hidden_states.dtype + ) input_tensor = self.conv_shortcut(input_tensor) if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states): input_tensor = ttnn.to_memory_config(input_tensor, ttnn.get_memory_config(hidden_states)) - output_tensor = ttnn.add(input_tensor, hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG) - - if output_scale_factor != 1.0: - assert False # Do we need this? - output_sc_recip = 1 / output_scale_factor - output_sc_recip = ttnn.from_torch( - torch.full([1, 1, 1, 1], output_sc_recip), layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b - ) - output_sc_recip = ttnn.to_device(output_sc_recip, self.device, memory_config=ttnn.L1_MEMORY_CONFIG) - output_tensor = ttnn.mul(output_tensor, output_sc_recip, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.add(input_tensor, hidden_states, memory_config=hidden_states.memory_config()) ttnn.deallocate(hidden_states) output_tensor = ttnn.reallocate(output_tensor) diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_transformer_2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_transformer_2d.py index 0173a1cece8..ad0ceaff306 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_transformer_2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_transformer_2d.py @@ -16,6 +16,7 @@ pre_process_input, pad_group_norm_weight, permute_conv_parameters, + dealloc_input, ) @@ -225,8 +226,6 @@ def __call__( height = self.input_height width = self.input_width - if ttnn.get_memory_config(hidden_states) != self.proj_in.conv.input_sharded_memory_config: - hidden_states = ttnn.to_memory_config(hidden_states, self.proj_in.conv.input_sharded_memory_config) residual = hidden_states spilled_residual = False if spilled_residual: @@ -317,18 +316,22 @@ def __call__( assert False # hidden_states = ttnn.to_memory_config(hidden_states, self.proj_out.conv.input_sharded_memory_config) hidden_states = self.proj_out(hidden_states) - if spilled_residual: + if ttnn.get_memory_config(residual) != self.proj_out.conv.input_sharded_memory_config: residual = ttnn.to_memory_config(residual, self.proj_out.conv.input_sharded_memory_config) if output_bfloat16: - hidden_states = ttnn.add( + hidden_states = dealloc_input( + ttnn.add, hidden_states, residual, dtype=ttnn.bfloat16, + memory_config=hidden_states.memory_config(), ) else: - hidden_states = ttnn.add( + hidden_states = dealloc_input( + ttnn.add, hidden_states, residual, + memory_config=hidden_states.memory_config(), ) else: hidden_states = ttnn.to_device(hidden_states, self.device) @@ -347,4 +350,5 @@ def __call__( if not return_dict: return (hidden_states,) + hidden_states = ttnn.reallocate(hidden_states) return hidden_states diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_upsample_nearest_2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_upsample_nearest_2d.py index 6146401f135..07b80d82528 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_upsample_nearest_2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_upsample_nearest_2d.py @@ -12,5 +12,7 @@ def upsample_nearest2d(input, scale_factor=2.0): # input is in N, 1, HW, C, upsample expects, [N, H, W, C] # set h_scale to 1, w_scale to scale_factor, c_scale to 1 # scale_factor = (1, scale_factor*2, 1) + if input.is_sharded(): + input = ttnn.to_memory_config(input, ttnn.L1_MEMORY_CONFIG) up_output = ttnn.upsample(input, scale_factor) return up_output diff --git a/tt_eager/tt_dnn/op_library/move/move_op.hpp b/tt_eager/tt_dnn/op_library/move/move_op.hpp index ca60eae31d8..7e3fd99f626 100644 --- a/tt_eager/tt_dnn/op_library/move/move_op.hpp +++ b/tt_eager/tt_dnn/op_library/move/move_op.hpp @@ -138,7 +138,8 @@ inline Tensor move_sharded(Tensor& input_tensor, std::optional& me shard_mem_config.shard_spec = shard_spec; auto output_tensor = create_sharded_device_tensor(input_shape, input_dtype, input_layout, input_tensor.device(), shard_mem_config); if (input_tensor.buffer()->address() == output_tensor.buffer()->address()) { - TT_FATAL(false, "No space to move the tensor. Move op's input address == output address. No-op move unsupported."); + tt::log_debug(tt::LogOp, "WARNING: No space to move the tensor. Move op's input address and output address are equal: {}", input_address); + return output_tensor; } MoveOpParallelizationStrategy move_op_parallelization_strategy = MoveOpParallelizationStrategy::MULTI_CORE_SHARDED; auto output = operation::run(Move{output_mem_config, move_op_parallelization_strategy}, {input_tensor, output_tensor}).at(0); diff --git a/ttnn/ttnn/experimental/golden_functions.py b/ttnn/ttnn/experimental/golden_functions.py index 2d0faf50c68..c09b4661bf8 100644 --- a/ttnn/ttnn/experimental/golden_functions.py +++ b/ttnn/ttnn/experimental/golden_functions.py @@ -67,6 +67,7 @@ def _golden_function(tensor, grid_size, shard_spec, num_slices, slice, *args, ** def _nop_golden_function(input_tensor, *args, **kwargs): return input_tensor + ttnn.experimental.tensor.sharded_to_interleaved.golden_function = _nop_golden_function ttnn.experimental.tensor.interleaved_to_sharded.golden_function = _nop_golden_function ttnn.experimental.tensor.reshard.golden_function = _nop_golden_function ttnn.experimental.tensor.tilize.golden_function = _nop_golden_function