From a7c183126af279b37b63fdde6627ec553cbc4e4e Mon Sep 17 00:00:00 2001 From: Nitika Shanker Date: Tue, 6 Feb 2024 02:27:01 +0000 Subject: [PATCH] #5153: TNN Support for 1x1 Conv as matmul. Enabled TNN convs in SD Unet. Fixes. --- models/demos/resnet/tt/metalResnetBlock50.py | 2 + ...ttnn_functional_basic_transformer_block.py | 4 +- ...unctional_cross_attention_down_block_2d.py | 6 +- .../tt/ttnn_functional_cross_attn_upblock.py | 14 +- .../tt/ttnn_functional_downblock_2d.py | 10 +- .../tt/ttnn_functional_downsample_2d.py | 112 ++++-- .../tt/ttnn_functional_feedforward.py | 4 +- .../tt/ttnn_functional_geglu.py | 3 +- .../tt/ttnn_functional_resnetblock2d.py | 329 ++++++++++++++---- .../tt/ttnn_functional_transformer_2d.py | 182 +++++++--- ...ttnn_functional_unet_2d_condition_model.py | 158 +++++++-- ...functional_unet_mid_block_2d_cross_attn.py | 4 + .../tt/ttnn_functional_upblock_2d.py | 11 +- .../tt/ttnn_functional_upsample_2d.py | 107 ++++-- .../tt/ttnn_functional_upsample_nearest_2d.py | 3 - .../tt/ttnn_functional_utility_functions.py | 75 ++++ .../unit_testing/test_optimized_conv_v2.py | 1 + ...resnet50_untilize_with_halo_and_conv_v2.py | 1 + .../stable_diffusion/test_resnet_block_2d.py | 12 +- .../test_unet_2d_condition_model.py | 5 +- .../ttnn/unit_tests/operations/test_conv2d.py | 4 +- .../tt_py_composite_conv.py | 133 +++++-- .../tt_lib/fallback_ops/conversion_wrapper.py | 22 +- tt_eager/tt_lib/fallback_ops/fallback_ops.py | 71 +++- .../impl/allocator/algorithms/free_list.cpp | 2 +- ttnn/cpp/pybind11/operations/binary.hpp | 5 +- ttnn/cpp/ttnn/operations/binary.hpp | 10 +- ttnn/ttnn/__init__.py | 4 +- ttnn/ttnn/device.py | 4 + ttnn/ttnn/operations/binary.py | 3 +- ttnn/ttnn/operations/core.py | 48 ++- ttnn/ttnn/operations/unary.py | 2 - 32 files changed, 1071 insertions(+), 280 deletions(-) create mode 100644 models/experimental/functional_stable_diffusion/tt/ttnn_functional_utility_functions.py diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index ef18556e730d..e3cbb32ab9e3 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -1191,6 +1191,7 @@ def downsample_conv_op_with_formatting(x): output_dtype=model_config["ACTIVATIONS_DTYPE"], math_fidelity=model_config["MATH_FIDELITY"], move_utwh_output=move_utwh_output, + deallocate_activation=True, ) else: self.conv2 = resnet50_optimized_conv( @@ -1500,6 +1501,7 @@ def __init__( output_dtype=model_config["ACTIVATIONS_DTYPE"], math_fidelity=model_config["MATH_FIDELITY"], use_shallow_conv_variant=True, + deallocate_activation=True, ) self.first_conv_op_params = sliding_window_op_params else: diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py index c83589a5fe1d..60c8872be8a4 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_basic_transformer_block.py @@ -88,8 +88,8 @@ def basic_transformer_block( ) if use_ada_layer_norm_zero: assert False, "AdaLayerNormZero not supported and not used in stable diffusion" - - ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff) + # ttnn.dump_device_memory_state(device) + ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff, device=device) hidden_states = ttnn.add(ff_output, hidden_states) diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attention_down_block_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attention_down_block_2d.py index f719d26ced4d..f7fd2996097e 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attention_down_block_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attention_down_block_2d.py @@ -39,6 +39,7 @@ def cross_attention_down_block_2d( *, parameters, device, + reader_patterns_cache=None, ): output_states = () @@ -54,6 +55,7 @@ def cross_attention_down_block_2d( parameters=resnet, device=device, use_in_shortcut=use_in_shortcut, + reader_patterns_cache=reader_patterns_cache, eps=resnet_eps, groups=resnet_groups, time_embedding_norm=resnet_time_scale_shift, @@ -76,7 +78,8 @@ def cross_attention_down_block_2d( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, device=device, - ) + reader_patterns_cache=reader_patterns_cache, + ) output_states += (hidden_states,) @@ -89,6 +92,7 @@ def cross_attention_down_block_2d( device=device, parameters=parameters.downsamplers[0], use_conv=True, + reader_patterns_cache=reader_patterns_cache, ) output_states += (hidden_states,) return hidden_states, output_states diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attn_upblock.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attn_upblock.py index eaad766ff2dc..7334014f7720 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attn_upblock.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_cross_attn_upblock.py @@ -4,7 +4,7 @@ import torch import ttnn - +from typing import Optional, Dict from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D from models.experimental.functional_stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model @@ -57,6 +57,7 @@ def cross_attention_upblock2d( cross_attention_dim=1280, attn_num_head_channels=1, only_cross_attention: bool = False, + reader_patterns_cache: Optional[Dict] = None, ): for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -88,6 +89,7 @@ def cross_attention_upblock2d( pre_norm=resnet_pre_norm, non_linearity=resnet_act_fn, device=device, + reader_patterns_cache=reader_patterns_cache, ) if not dual_cross_attention: hidden_states = transformer_2d_model( @@ -110,11 +112,19 @@ def cross_attention_upblock2d( device=device, upcast_attention=upcast_attention, cross_attention_dim=cross_attention_dim, + reader_patterns_cache=reader_patterns_cache, ) else: assert False, "We do not support Dual Transformer2DModel" if add_upsample: - hidden_states = upsample2d(device, hidden_states, parameters.upsamplers[0], out_channels, out_channels) + hidden_states = upsample2d( + device, + hidden_states, + parameters.upsamplers[0], + out_channels, + out_channels, + reader_patterns_cache=reader_patterns_cache, + ) return hidden_states diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downblock_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downblock_2d.py index a99b9724cdc0..ada5ae4091ec 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downblock_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downblock_2d.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 - +import ttnn +import torch +from typing import Optional from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D from models.experimental.functional_stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d @@ -25,6 +27,8 @@ def downblock2d( add_downsample=False, downsample_padding=1, parameters=None, + reader_patterns_cache: Optional[dict] = None, + dtype: Optional[ttnn.DataType] = None, ): output_states = () for i in range(num_layers): @@ -45,6 +49,8 @@ def downblock2d( eps=resnet_eps, up=False, down=False, + dtype=dtype, + reader_patterns_cache=reader_patterns_cache, ) hidden_states = resnet @@ -61,6 +67,8 @@ def downblock2d( name="op", parameters=parameters.downsamplers[0], device=device, + dtype=dtype, + reader_patterns_cache=reader_patterns_cache, ) ] diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downsample_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downsample_2d.py index 2dbbf5f9da80..288b51578059 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downsample_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_downsample_2d.py @@ -5,9 +5,14 @@ import ttnn import torch +from typing import Optional import torch.nn as nn from tt_lib.fallback_ops import fallback_ops from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor +from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import ( + run_ttnn_conv_with_pre_and_post_tensor_formatting, +) +import math def permute_conv_parameters(weight, bias): @@ -19,6 +24,21 @@ def permute_conv_parameters(weight, bias): return weight, bias +config_override = { + (320, 320, 64, 64): {"act_block_h": 64}, + (640, 640, 32, 32): {"act_block_h": 64}, + (640, 1920, 32, 32): {"act_block_h": 32}, + (640, 1280, 32, 32): {"act_block_h": 32}, + (1280, 1920, 16, 16): {"act_block_h": 32}, + (1280, 1280, 32, 32): {"act_block_h": 32}, + (320, 960, 64, 64): {"act_block_h": 32}, + (640, 960, 32, 32): {"act_block_h": 32}, + (320, 640, 64, 64): {"act_block_h": 32}, + (640, 640, 64, 64): {"act_block_h": 64}, + (640, 320, 64, 64): {"act_block_h": 64}, +} + + def downsample_2d( in_channels, hidden_states, @@ -27,23 +47,61 @@ def downsample_2d( use_conv=False, out_channels=None, padding=1, + reader_patterns_cache: Optional[dict] = None, + dtype: Optional[ttnn.DataType] = None, ): stride = 2 parameters.conv.weight, parameters.conv.bias = permute_conv_parameters(parameters.conv.weight, parameters.conv.bias) - parameters.conv.weight = torch_to_tt_tensor_rm(parameters.conv.weight, device, put_on_device=False) - parameters.conv.bias = torch_to_tt_tensor_rm(parameters.conv.bias, device, put_on_device=False) + conv_on_device = reader_patterns_cache is not None + batch_size = hidden_states.shape[0] + input_height = hidden_states.shape[2] + input_width = hidden_states.shape[3] if use_conv: - conv = fallback_ops.Conv2d( - parameters.conv.weight, - parameters.conv.bias, - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=padding, - ) + if conv_on_device: + parameters.conv.bias = torch.reshape(parameters.conv.bias, (1, 1, 1, parameters.conv.bias.shape[-1])) + tt_weight_tensor = ttnn.from_torch(parameters.conv.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.conv.bias, ttnn.float32) + # breakpoint() + out_channels = parameters.conv.weight.shape[0] + in_channels = parameters.conv.weight.shape[1] + conv_config_override = {} + if (out_channels, in_channels, input_height, input_width) in config_override: + conv_config_override = config_override[(out_channels, in_channels, input_height, input_width)] + conv = ttnn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(stride, stride), + padding=(padding, padding), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=True if in_channels < 320 else False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override=conv_config_override, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + else: + parameters.conv.weight = torch_to_tt_tensor_rm(parameters.conv.weight, device, put_on_device=False) + parameters.conv.bias = torch_to_tt_tensor_rm(parameters.conv.bias, device, put_on_device=False) + conv = fallback_ops.Conv2d( + parameters.conv.weight, + parameters.conv.bias, + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + ) else: assert in_channels == out_channels @@ -58,16 +116,26 @@ def downsample_2d( assert hidden_states.shape[1] == in_channels - hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) - - hidden_states = torch_to_tt_tensor_rm(hidden_states, device) - hidden_states = conv(hidden_states) - hidden_states = tt_to_torch_tensor(hidden_states) - - hidden_states = ttnn.to_device( - ttnn.to_layout(ttnn.from_torch(hidden_states, ttnn.bfloat16), ttnn.TILE_LAYOUT), - device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) + if conv_on_device: + hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, + conv, + hidden_states, + batch_size, + math.ceil(input_height / 2), + math.ceil(input_width / 2), + out_channels, + ) + else: + hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) + hidden_states = torch_to_tt_tensor_rm(hidden_states, device) + hidden_states = conv(hidden_states) + hidden_states = tt_to_torch_tensor(hidden_states) + + hidden_states = ttnn.to_device( + ttnn.to_layout(ttnn.from_torch(hidden_states, ttnn.bfloat16), ttnn.TILE_LAYOUT), + device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) return hidden_states diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_feedforward.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_feedforward.py index b414f1195b98..a11acca6d9c6 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_feedforward.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_feedforward.py @@ -6,8 +6,8 @@ from models.experimental.functional_stable_diffusion.tt.ttnn_functional_geglu import geglu -def feedforward(config, hidden_states, parameters): - act = geglu(config, hidden_states, parameters.net[0]) +def feedforward(config, hidden_states, parameters, device): + act = geglu(config, hidden_states, parameters.net[0], device) output = act @ parameters.net[2].weight output = ttnn.add(output, parameters.net[2].bias, memory_config=ttnn.L1_MEMORY_CONFIG) return output diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_geglu.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_geglu.py index 5cf00f5cff88..00a50edfaed7 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_geglu.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_geglu.py @@ -5,8 +5,9 @@ import ttnn -def geglu(config, hidden_states, parameters): +def geglu(config, hidden_states, parameters, device): output = ttnn.matmul(hidden_states, parameters.proj.weight) + output = ttnn.add(output, parameters.proj.bias, memory_config=ttnn.L1_MEMORY_CONFIG) hidden_states, gate = ttnn.split(output, split_size=output.shape[-1] // 2, dim=-1) diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_resnetblock2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_resnetblock2d.py index a2bd08b2adf2..ba11707eaefc 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_resnetblock2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_resnetblock2d.py @@ -9,7 +9,12 @@ torch_to_tt_tensor_rm, ) import torch -from typing import Optional +from typing import Optional, Dict +from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import ( + run_ttnn_conv_with_pre_and_post_tensor_formatting, + pre_process_input, + post_process_output, +) def torch_to_ttnn(input, device, layout=ttnn.TILE_LAYOUT): @@ -35,6 +40,31 @@ def permute_conv_weights(weight, bias): return weight, bias +config_override = { + (320, 320, 64, 64): {"act_block_h": 64}, + (640, 640, 32, 32): {"act_block_h": 64}, + (640, 1920, 32, 32): {"act_block_h": 32}, + (640, 1280, 32, 32): {"act_block_h": 32}, + (1280, 1920, 16, 16): {"act_block_h": 32}, + (1280, 1280, 32, 32): {"act_block_h": 32}, + (320, 960, 64, 64): {"act_block_h": 32}, + (640, 960, 32, 32): {"act_block_h": 32}, + (320, 640, 64, 64): {"act_block_h": 32}, + (640, 320, 64, 64): {"act_block_h": 64}, + (640, 640, 64, 64): {"act_block_h": 32}, +} + +split_chunks = { + (320, 960, 64, 64): 2, + (640, 1920, 32, 32): 3, + (640, 1280, 32, 32): 2, + (640, 960, 32, 32): 2, + (1280, 1920, 16, 16): 3, + (1280, 2560, 8, 8): 2, + (1280, 2560, 16, 16): 2, +} + + def resnetBlock2D( input_tensor, *, @@ -53,14 +83,19 @@ def resnetBlock2D( up=False, down=False, use_in_shortcut: Optional[bool] = None, + reader_patterns_cache=None, + dtype: Optional[ttnn.DataType] = None, ): + convs_on_device = reader_patterns_cache is not None if non_linearity == "mish": assert False, "Mish is not implemented!" else: nonlinearity = ttnn.silu out_channels = in_channels if out_channels is None else out_channels + hidden_states = input_tensor + hidden_states = ttnn.group_norm( hidden_states, num_groups=groups, weight=parameters.norm1.weight, bias=parameters.norm1.bias, epsilon=eps ) @@ -74,24 +109,105 @@ def resnetBlock2D( parameters.conv1.weight, parameters.conv1.bias = permute_conv_weights( parameters.conv1.weight, parameters.conv1.bias ) - parameters.conv1.weight = torch_to_tt_tensor_rm(parameters.conv1.weight, device, put_on_device=False) - parameters.conv1.bias = torch_to_tt_tensor_rm(parameters.conv1.bias, device, put_on_device=False) - - # Using fallback Conv2D as we face issue with ttnn.Conv2D - conv1 = fallback_ops.Conv2d( - parameters.conv1.weight, - parameters.conv1.bias, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - ) - hidden_states = ttnn_to_torch(hidden_states) - hidden_states = torch_to_tt_tensor_rm(hidden_states, device) - hidden_states = conv1(hidden_states) - hidden_states = tt_to_torch_tensor(hidden_states) - hidden_states = torch_to_ttnn(hidden_states, device=device) + if out_channels != parameters.conv1.bias.shape[-1]: + # breakpoint() + out_channels = parameters.conv1.bias.shape[-1] + if in_channels != parameters.conv1.weight.shape[1]: + # breakpoint() + in_channels = parameters.conv1.weight.shape[1] + if convs_on_device: + batch_size = hidden_states.shape[0] + input_height = hidden_states.shape[2] + input_width = hidden_states.shape[3] + parameters.conv1.bias = torch.reshape(parameters.conv1.bias, (1, 1, 1, out_channels)) + conv1_split_chunks = 1 + if (out_channels, in_channels, input_height, input_width) in split_chunks: + conv1_split_chunks = split_chunks[(out_channels, in_channels, input_height, input_width)] + split_input_channels = in_channels // conv1_split_chunks + if conv1_split_chunks == 1: + split_weight_tensors = [parameters.conv1.weight] + else: + split_weight_tensors = torch.split(parameters.conv1.weight, split_input_channels, 1) + conv1s = [] + for i in range(conv1_split_chunks): + tt_weight_tensor = ttnn.from_torch(split_weight_tensors[i], ttnn.float32) + if i == 0: + tt_bias_tensor = ttnn.from_torch(parameters.conv1.bias, ttnn.float32) + else: + # TODO: fix no bias in conv error + torch_bias_zeros_tensor = torch.zeros(parameters.conv1.bias.shape, dtype=torch.bfloat16).float() + tt_bias_tensor = ttnn.from_torch(torch_bias_zeros_tensor, ttnn.float32) + conv1_config_override = {} + if (out_channels, in_channels, input_height, input_width) in config_override: + conv1_config_override = config_override[(out_channels, in_channels, input_height, input_width)] + conv1s.append( + ttnn.Conv2d( + split_input_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override=conv1_config_override, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + ) + # breakpoint() + hidden_states = pre_process_input(device, hidden_states) + if conv1_split_chunks == 1: + hidden_states = [hidden_states] + else: + split_hidden_states = [] + output_tensor_start_width_dim = 0 + output_tensor_end_width_dim = split_input_channels - 1 + for i in range(conv1_split_chunks): + split_hidden_states.append( + hidden_states[:, :, :, output_tensor_start_width_dim:output_tensor_end_width_dim] + ) + output_tensor_start_width_dim += split_input_channels + output_tensor_end_width_dim += split_input_channels + hidden_states = split_hidden_states + if conv1_split_chunks == 1: + hidden_states = conv1s[0](hidden_states[0]) + else: + for i in range(conv1_split_chunks): + hidden_states[i] = conv1s[i](hidden_states[i]) + if i != 0: + hidden_states[i] = ttnn.add(hidden_states[i], hidden_states[i - 1]) + ttnn.deallocate(hidden_states[i - 1]) + hidden_states = hidden_states[-1] + + split_hidden_states = [] + else: + parameters.conv1.weight = torch_to_tt_tensor_rm(parameters.conv1.weight, device, put_on_device=False) + parameters.conv1.bias = torch_to_tt_tensor_rm(parameters.conv1.bias, device, put_on_device=False) + # Using fallback Conv2d as we face issue with ttnn.Conv2d + # assert out_channels == parameters.conv1.bias.shape()[-1] + conv1 = fallback_ops.Conv2d( + parameters.conv1.weight, + parameters.conv1.bias, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + hidden_states = ttnn_to_torch(hidden_states) + hidden_states = torch_to_tt_tensor_rm(hidden_states, device) + hidden_states = conv1(hidden_states) + hidden_states = tt_to_torch_tensor(hidden_states) + hidden_states = torch_to_ttnn(hidden_states, device=device) if temb is not None: temb = nonlinearity(temb) @@ -103,66 +219,157 @@ def resnetBlock2D( else: raise ValueError(f"unknown time_embedding_norm : {time_embedding_norm} ") # temb=ttnn.linear(temb,parameters.time_emb_proj.weight,bias=parameters.time_emb_proj.bias) + # breakpoint() temb = ttnn.matmul(temb, parameters.time_emb_proj.weight) temb = ttnn.add(temb, parameters.time_emb_proj.bias) - - temb = ttnn.permute(temb, (2, 3, 0, 1)) + if not convs_on_device: + temb = ttnn.permute(temb, (2, 3, 0, 1)) + else: + # breakpoint() + temb = ttnn.permute(temb, (2, 0, 1, 3)) if temb is not None and time_embedding_norm == "default": + if convs_on_device: + # breakpoint() + hidden_states = ttnn.clone( + hidden_states, memory_config=ttnn.get_memory_config(hidden_states), dtype=ttnn.bfloat16 + ) + hidden_states = ttnn.reshape(hidden_states, (batch_size, 1, input_height * input_width, out_channels)) + temb = ttnn.reshape(temb, (batch_size, 1, 1, out_channels)) + # breakpoint() hidden_states = hidden_states + temb - + if convs_on_device: + hidden_states = post_process_output(device, hidden_states, batch_size, input_height, input_width, out_channels) hidden_states = ttnn.group_norm( hidden_states, num_groups=groups, weight=parameters.norm2.weight, bias=parameters.norm2.bias, epsilon=eps ) hidden_states = nonlinearity(hidden_states) - parameters.conv2.weight, parameters.conv2.bias = permute_conv_weights( parameters.conv2.weight, parameters.conv2.bias ) - parameters.conv2.weight = torch_to_tt_tensor_rm(parameters.conv2.weight, device, put_on_device=False) - parameters.conv2.bias = torch_to_tt_tensor_rm(parameters.conv2.bias, device, put_on_device=False) - # Using fallback Conv2D as we face issue with ttnn.Conv2D - conv2 = fallback_ops.Conv2d( - parameters.conv2.weight, - parameters.conv2.bias, - out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - ) - hidden_states = ttnn_to_torch(hidden_states) - hidden_states = torch_to_tt_tensor_rm(hidden_states, device) - hidden_states = conv2(hidden_states) - hidden_states = tt_to_torch_tensor(hidden_states) - hidden_states = torch_to_ttnn(hidden_states, device=device) + if convs_on_device: + batch_size = hidden_states.shape[0] + input_height = hidden_states.shape[2] + input_width = hidden_states.shape[3] + parameters.conv2.bias = torch.reshape(parameters.conv2.bias, (1, 1, 1, out_channels)) + # print("conv2 weight shape=", parameters.conv2.weight.shape) + # print("conv2 bias shape=", parameters.conv2.bias.shape) + tt_weight_tensor = ttnn.from_torch(parameters.conv2.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.conv2.bias, ttnn.float32) + conv2_config_override = {} + if (out_channels, out_channels, input_height, input_width) in config_override: + conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)] + assert out_channels == parameters.conv2.weight.shape[0] + assert out_channels == parameters.conv2.weight.shape[1] + conv2 = ttnn.Conv2d( + out_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override=conv2_config_override, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + deallocate_activation=True, + # reallocate_halo_output=(out_channels, out_channels, input_height, input_width) == (640, 640, 64, 64) + ) + + hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, conv2, hidden_states, batch_size, input_height, input_width, out_channels + ) + else: + parameters.conv2.weight = torch_to_tt_tensor_rm(parameters.conv2.weight, device, put_on_device=False) + parameters.conv2.bias = torch_to_tt_tensor_rm(parameters.conv2.bias, device, put_on_device=False) + # Using fallback Conv2d as we face issue with ttnn.Conv2d + conv2 = fallback_ops.Conv2d( + parameters.conv2.weight, + parameters.conv2.bias, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + hidden_states = ttnn_to_torch(hidden_states) + hidden_states = torch_to_tt_tensor_rm(hidden_states, device) + hidden_states = conv2(hidden_states) + hidden_states = tt_to_torch_tensor(hidden_states) + hidden_states = torch_to_ttnn(hidden_states, device=device) use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut if use_in_shortcut: parameters.conv_shortcut.weight, parameters.conv_shortcut.bias = permute_conv_weights( parameters.conv_shortcut.weight, parameters.conv_shortcut.bias ) - parameters.conv_shortcut.weight = torch_to_tt_tensor_rm( - parameters.conv_shortcut.weight, device, put_on_device=False - ) - parameters.conv_shortcut.bias = torch_to_tt_tensor_rm( - parameters.conv_shortcut.bias, device, put_on_device=False - ) - # Using fallback Conv2D as we face issue with ttnn.Conv2D - conv_shortcut = fallback_ops.Conv2d( - parameters.conv_shortcut.weight, - parameters.conv_shortcut.bias, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - ) - input_tensor = ttnn_to_torch(input_tensor) - input_tensor = torch_to_tt_tensor_rm(input_tensor, device) - input_tensor = conv_shortcut(input_tensor) - input_tensor = tt_to_torch_tensor(input_tensor) - input_tensor = torch_to_ttnn(input_tensor, device=device) + if convs_on_device: + batch_size = input_tensor.shape[0] + input_height = input_tensor.shape[2] + input_width = input_tensor.shape[3] + parameters.conv_shortcut.bias = torch.reshape(parameters.conv_shortcut.bias, (1, 1, 1, out_channels)) + tt_weight_tensor = ttnn.from_torch(parameters.conv_shortcut.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.conv_shortcut.bias, ttnn.float32) + conv_shortcut_config_override = {} + # if (out_channels, in_channels, input_height, input_width) in config_override: + # conv2_config_override = config_override[(out_channels, in_channels, input_height, input_width)] + assert in_channels == parameters.conv_shortcut.weight.shape[1] + assert out_channels == parameters.conv_shortcut.weight.shape[0] + conv_shortcut = ttnn.Conv2d( + in_channels, + out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override=conv_shortcut_config_override, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + input_tensor = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, conv_shortcut, input_tensor, batch_size, input_height, input_width, out_channels + ) + else: + parameters.conv_shortcut.weight = torch_to_tt_tensor_rm( + parameters.conv_shortcut.weight, device, put_on_device=False + ) + parameters.conv_shortcut.bias = torch_to_tt_tensor_rm( + parameters.conv_shortcut.bias, device, put_on_device=False + ) + # Using fallback Conv2d as we face issue with ttnn.Conv2d + conv_shortcut = fallback_ops.Conv2d( + parameters.conv_shortcut.weight, + parameters.conv_shortcut.bias, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + input_tensor = ttnn_to_torch(input_tensor) + input_tensor = torch_to_tt_tensor_rm(input_tensor, device) + input_tensor = conv_shortcut(input_tensor) + input_tensor = tt_to_torch_tensor(input_tensor) + input_tensor = torch_to_ttnn(input_tensor, device=device) output_sc_recip = 1 / output_scale_factor output_tensor = ttnn.add(input_tensor, hidden_states) diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_transformer_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_transformer_2d.py index 244de65612f2..b65b066dac6e 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_transformer_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_transformer_2d.py @@ -4,12 +4,17 @@ import ttnn import torch +from typing import Optional, Dict from tt_lib.fallback_ops import fallback_ops from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor from models.experimental.functional_stable_diffusion.tt.ttnn_functional_basic_transformer_block import ( basic_transformer_block, ) +from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import ( + run_ttnn_conv_with_pre_and_post_tensor_formatting, + post_process_output, +) def permute_conv_parameters(weight, bias): @@ -48,7 +53,9 @@ def transformer_2d_model( eps=1e-5, device=None, norm_elementwise_affine: bool = True, + reader_patterns_cache: Optional[Dict] = None, ): + conv_on_device = reader_patterns_cache is not None inner_dim = num_attention_heads * attention_head_dim is_input_continuous = (in_channels is not None) and (patch_size is None) @@ -87,31 +94,69 @@ def transformer_2d_model( ) if not use_linear_projection: - hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) - hidden_states = torch_to_tt_tensor_rm(hidden_states, device) - parameters.proj_in.weight, parameters.proj_in.bias = permute_conv_parameters( parameters.proj_in.weight, parameters.proj_in.bias ) - parameters.proj_in.weight = torch_to_tt_tensor_rm(parameters.proj_in.weight, device, put_on_device=False) - parameters.proj_in.bias = torch_to_tt_tensor_rm(parameters.proj_in.bias, device, put_on_device=False) - - proj_in = fallback_ops.Conv2d( - weights=parameters.proj_in.weight, - biases=parameters.proj_in.bias, - in_channels=in_channels, - out_channels=inner_dim, - kernel_size=1, - stride=1, - padding=0, - ) - hidden_states = proj_in(hidden_states) - - hidden_states = tt_to_torch_tensor(hidden_states) - hidden_states = ttnn.to_layout( - ttnn.to_device(ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16), device), - layout=ttnn.TILE_LAYOUT, - ) + if conv_on_device: + batch_size = hidden_states.shape[0] + input_height = hidden_states.shape[2] + input_width = hidden_states.shape[3] + parameters.proj_in.bias = torch.reshape( + parameters.proj_in.bias, (1, 1, 1, parameters.proj_in.bias.shape[-1]) + ) + tt_weight_tensor = ttnn.from_torch(parameters.proj_in.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.proj_in.bias, ttnn.float32) + out_channels = parameters.proj_in.weight.shape[0] + in_channels = parameters.proj_in.weight.shape[1] + proj_in = ttnn.Conv2d( + in_channels, + out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override={}, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, proj_in, hidden_states, batch_size, input_height, input_width, out_channels + ) + else: + hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) + hidden_states = torch_to_tt_tensor_rm(hidden_states, device) + parameters.proj_in.weight = torch_to_tt_tensor_rm( + parameters.proj_in.weight, device, put_on_device=False + ) + parameters.proj_in.bias = torch_to_tt_tensor_rm(parameters.proj_in.bias, device, put_on_device=False) + # assert False + + proj_in = fallback_ops.Conv2d( + weights=parameters.proj_in.weight, + biases=parameters.proj_in.bias, + in_channels=in_channels, + out_channels=inner_dim, + kernel_size=1, + stride=1, + padding=0, + ) + hidden_states = proj_in(hidden_states) + + hidden_states = tt_to_torch_tensor(hidden_states) + hidden_states = ttnn.to_layout( + ttnn.to_device(ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16), device), + layout=ttnn.TILE_LAYOUT, + ) inner_dim = hidden_states.shape[1] @@ -154,36 +199,75 @@ def transformer_2d_model( out_channels = in_channels if out_channels is None else out_channels if is_input_continuous: if not use_linear_projection: - hidden_states = ttnn.to_layout(hidden_states, layout=ttnn.ROW_MAJOR_LAYOUT) - hidden_states = ttnn.reshape(hidden_states, (batch, height, width, inner_dim)) - - hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2)) - - hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) - hidden_states = torch_to_tt_tensor_rm(hidden_states, device) - parameters.proj_out.weight, parameters.proj_out.bias = permute_conv_parameters( parameters.proj_out.weight, parameters.proj_out.bias ) - parameters.proj_out.weight = torch_to_tt_tensor_rm(parameters.proj_out.weight, device, put_on_device=False) - parameters.proj_out.bias = torch_to_tt_tensor_rm(parameters.proj_out.bias, device, put_on_device=False) - - proj_out = fallback_ops.Conv2d( - weights=parameters.proj_out.weight, - biases=parameters.proj_out.bias, - in_channels=inner_dim, - out_channels=in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - hidden_states = proj_out(hidden_states) - - hidden_states = tt_to_torch_tensor(hidden_states) - hidden_states = ttnn.to_layout( - ttnn.to_device(ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16), device), - layout=ttnn.TILE_LAYOUT, - ) + if conv_on_device: + batch_size = batch + input_height = height + input_width = width + parameters.proj_out.bias = torch.reshape( + parameters.proj_out.bias, (1, 1, 1, parameters.proj_out.bias.shape[-1]) + ) + tt_weight_tensor = ttnn.from_torch(parameters.proj_out.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.proj_out.bias, ttnn.float32) + out_channels = parameters.proj_out.weight.shape[0] + in_channels = parameters.proj_out.weight.shape[1] + proj_out = ttnn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override={}, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + hidden_states = proj_out(hidden_states) + hidden_states = post_process_output( + device, hidden_states, batch_size, input_height, input_width, out_channels + ) + else: + hidden_states = ttnn.to_layout(hidden_states, layout=ttnn.ROW_MAJOR_LAYOUT) + hidden_states = ttnn.reshape(hidden_states, (batch, height, width, inner_dim)) + + hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2)) + + hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT))) + hidden_states = torch_to_tt_tensor_rm(hidden_states, device) + + parameters.proj_out.weight = torch_to_tt_tensor_rm( + parameters.proj_out.weight, device, put_on_device=False + ) + parameters.proj_out.bias = torch_to_tt_tensor_rm(parameters.proj_out.bias, device, put_on_device=False) + + proj_out = fallback_ops.Conv2d( + weights=parameters.proj_out.weight, + biases=parameters.proj_out.bias, + in_channels=inner_dim, + out_channels=in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + hidden_states = proj_out(hidden_states) + + hidden_states = tt_to_torch_tensor(hidden_states) + hidden_states = ttnn.to_layout( + ttnn.to_device(ttnn.from_torch(hidden_states, dtype=ttnn.bfloat16), device), + layout=ttnn.TILE_LAYOUT, + ) else: hidden_states = ttnn.to_device(hidden_states, device) diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_2d_condition_model.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_2d_condition_model.py index 12fdd6022b7e..1807974bbcae 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_2d_condition_model.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_2d_condition_model.py @@ -4,6 +4,7 @@ import tt_lib import torch.nn as nn +import math import ttnn from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -26,6 +27,9 @@ ) from models.experimental.functional_stable_diffusion.tt.ttnn_functional_downblock_2d import downblock2d from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upblock_2d import upblock_2d +from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import ( + run_ttnn_conv_with_pre_and_post_tensor_formatting, +) def permute_conv_weights(weight, bias): @@ -93,6 +97,8 @@ def UNet2DConditionModel( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", return_dict: bool = True, + reader_patterns_cache: Optional[Dict] = None, + dtype: Optional[ttnn.DataType] = None, ): num_upsamplers = len(block_out_channels) - 1 default_overall_up_factor = 2**num_upsamplers @@ -155,29 +161,64 @@ def UNet2DConditionModel( class_emb = class_embedding(class_labels) emb = emb + class_emb - # params change parameters.conv_in.weight, parameters.conv_in.bias = permute_conv_weights( parameters.conv_in.weight, parameters.conv_in.bias ) - parameters.conv_in.weight = torch_to_tt_tensor_rm(parameters.conv_in.weight, device, put_on_device=False) - parameters.conv_in.bias = torch_to_tt_tensor_rm(parameters.conv_in.bias, device, put_on_device=False) - # params change - # Using fallback Conv2D as we face issue with ttnn.Conv2D - conv_in = fallback_ops.Conv2d( - parameters.conv_in.weight, - parameters.conv_in.bias, - in_channels, - block_out_channels[0], - kernel_size=3, - padding=(1, 1), - ) - - sample = ttnn_to_torch(sample) - sample = torch_to_tt_tensor_rm(sample, device) - sample = conv_in(sample) - sample = tt_to_torch_tensor(sample) - sample = torch_to_ttnn(sample, device=device) + convs_on_device = reader_patterns_cache is not None + if not convs_on_device: + parameters.conv_in.weight = torch_to_tt_tensor_rm(parameters.conv_in.weight, device, put_on_device=False) + parameters.conv_in.bias = torch_to_tt_tensor_rm(parameters.conv_in.bias, device, put_on_device=False) + # params change + # Using fallback Conv2d as we face issue with ttnn.Conv2d + conv_in = fallback_ops.Conv2d( + parameters.conv_in.weight, + parameters.conv_in.bias, + in_channels, + block_out_channels[0], + kernel_size=3, + padding=(1, 1), + ) + sample = ttnn_to_torch(sample) + sample = torch_to_tt_tensor_rm(sample, device) + sample = conv_in(sample) + sample = tt_to_torch_tensor(sample) + sample = torch_to_ttnn(sample, device=device) + else: + # breakpoint() + parameters.conv_in.bias = torch.reshape(parameters.conv_in.bias, (1, 1, 1, parameters.conv_in.bias.shape[-1])) + tt_weight_tensor = ttnn.from_torch(parameters.conv_in.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.conv_in.bias, ttnn.float32) + # breakpoint() + out_channels = parameters.conv_in.weight.shape[0] + in_channels = parameters.conv_in.weight.shape[1] + batch_size = sample.shape[0] + input_height = sample.shape[2] + input_width = sample.shape[3] + conv_in = ttnn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=True if in_channels < 320 else False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override={}, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + ) + sample = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, conv_in, sample, batch_size, input_height, input_width, out_channels + ) # con_in completes @@ -220,6 +261,7 @@ def UNet2DConditionModel( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, device=device, + reader_patterns_cache=reader_patterns_cache, # enable convs on device for this module causes failure. Investigating. ) elif down_block_type == "DownBlock2D": sample, res_samples = downblock2d( @@ -237,6 +279,8 @@ def UNet2DConditionModel( downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, device=device, + dtype=dtype, + reader_patterns_cache=reader_patterns_cache, ) else: assert ( @@ -268,6 +312,7 @@ def UNet2DConditionModel( upcast_attention=upcast_attention, parameters=parameters.mid_block, device=device, + reader_patterns_cache=reader_patterns_cache, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": assert False, "This is not happening" @@ -331,6 +376,7 @@ def UNet2DConditionModel( resnet_time_scale_shift=resnet_time_scale_shift, parameters=parameters.up_blocks[i], device=device, + reader_patterns_cache=reader_patterns_cache, ) elif up_block_type == "UpBlock2D": sample = upblock_2d( @@ -350,6 +396,7 @@ def UNet2DConditionModel( resnet_time_scale_shift=resnet_time_scale_shift, parameters=parameters.up_blocks[i], device=device, + reader_patterns_cache=reader_patterns_cache, ) else: assert ( @@ -370,25 +417,62 @@ def UNet2DConditionModel( parameters.conv_out.weight, parameters.conv_out.bias = permute_conv_weights( parameters.conv_out.weight, parameters.conv_out.bias ) - parameters.conv_out.weight = torch_to_tt_tensor_rm(parameters.conv_out.weight, device, put_on_device=False) - parameters.conv_out.bias = torch_to_tt_tensor_rm(parameters.conv_out.bias, device, put_on_device=False) - # params change - - # Using fallback Conv2D as we face issue with ttnn.Conv2D - conv_out = fallback_ops.Conv2d( - parameters.conv_out.weight, - parameters.conv_out.bias, - block_out_channels[0], - out_channels, - kernel_size=3, - padding=1, - ) - sample = ttnn_to_torch(sample) - sample = torch_to_tt_tensor_rm(sample, device) - sample = conv_out(sample) - sample = tt_to_torch_tensor(sample) - sample = torch_to_ttnn(sample, device=device) + if convs_on_device: + parameters.conv_out.bias = torch.reshape( + parameters.conv_out.bias, (1, 1, 1, parameters.conv_out.bias.shape[-1]) + ) + tt_weight_tensor = ttnn.from_torch(parameters.conv_out.weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(parameters.conv_out.bias, ttnn.float32) + out_channels = parameters.conv_out.weight.shape[0] + in_channels = parameters.conv_out.weight.shape[1] + batch_size = sample.shape[0] + input_height = sample.shape[2] + input_width = sample.shape[3] + conv_out = ttnn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=True, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override={"act_block_h": 64}, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + deallocate_activation=True, + ) + sample = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, conv_out, sample, batch_size, input_height, input_width, out_channels + ) + else: + parameters.conv_out.weight = torch_to_tt_tensor_rm(parameters.conv_out.weight, device, put_on_device=False) + parameters.conv_out.bias = torch_to_tt_tensor_rm(parameters.conv_out.bias, device, put_on_device=False) + # params change + + # Using fallback Conv2d as we face issue with ttnn.Conv2d + conv_out = fallback_ops.Conv2d( + parameters.conv_out.weight, + parameters.conv_out.bias, + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + sample = ttnn_to_torch(sample) + sample = torch_to_tt_tensor_rm(sample, device) + sample = conv_out(sample) + sample = tt_to_torch_tensor(sample) + sample = torch_to_ttnn(sample, device=device) # con_in completes diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_mid_block_2d_cross_attn.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_mid_block_2d_cross_attn.py index a5f784b534e8..671e7f368155 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_mid_block_2d_cross_attn.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_unet_mid_block_2d_cross_attn.py @@ -30,6 +30,7 @@ def unet_mid_block_2d_cross_attn( dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, + reader_patterns_cache=None, ): has_cross_attention = True @@ -50,6 +51,7 @@ def unet_mid_block_2d_cross_attn( pre_norm=resnet_pre_norm, eps=resnet_eps, use_in_shortcut=None, + reader_patterns_cache=reader_patterns_cache, ) for attn, resnet in zip(parameters.attentions, parameters.resnets[1:]): @@ -73,6 +75,7 @@ def unet_mid_block_2d_cross_attn( device=device, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + reader_patterns_cache=reader_patterns_cache, ) else: assert False, "We do not support Dual Transformer" @@ -92,6 +95,7 @@ def unet_mid_block_2d_cross_attn( pre_norm=resnet_pre_norm, eps=resnet_eps, use_in_shortcut=None, + reader_patterns_cache=reader_patterns_cache, ) return hidden_states diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upblock_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upblock_2d.py index 3bcf4ed77d5c..d0bca3810893 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upblock_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upblock_2d.py @@ -30,6 +30,7 @@ def upblock_2d( base_address=None, temb=None, upsample_size=None, + reader_patterns_cache=None, ): for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -58,9 +59,17 @@ def upblock_2d( output_scale_factor=output_scale_factor, parameters=parameters.resnets[i], device=device, + reader_patterns_cache=reader_patterns_cache, ) if add_upsample: - hidden_states = upsample2d(device, hidden_states, parameters.upsamplers[0], in_channels, out_channels) + hidden_states = upsample2d( + device, + hidden_states, + parameters.upsamplers[0], + in_channels, + out_channels, + reader_patterns_cache=reader_patterns_cache, + ) return hidden_states diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_2d.py index 31287084fa73..4b3cab445697 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_2d.py @@ -12,43 +12,90 @@ from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_nearest_2d import upsample_nearest2d from tt_lib.fallback_ops import fallback_ops +from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import ( + run_ttnn_conv_with_pre_and_post_tensor_formatting, +) +config_override = { + (320, 320, 64, 64): {"act_block_h": 64}, + (640, 640, 32, 32): {"act_block_h": 64}, + (640, 1920, 32, 32): {"act_block_h": 32}, + (640, 1280, 32, 32): {"act_block_h": 32}, + (1280, 1920, 16, 16): {"act_block_h": 32}, + (1280, 1280, 32, 32): {"act_block_h": 32}, + (320, 960, 64, 64): {"act_block_h": 32}, + (640, 960, 32, 32): {"act_block_h": 32}, + (320, 640, 64, 64): {"act_block_h": 32}, + (640, 640, 64, 64): {"act_block_h": 64}, +} -def upsample2d( - device, - input, - parameters, - in_channels, - out_channels, - scale_factor=2.0, -): - tt_out = upsample_nearest2d(input, scale_factor) - tt_out = ttnn.from_device(tt_out) - tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) - tt_out = ttnn.to_torch(tt_out) - tt_out = torch_to_tt_tensor_rm(tt_out, device) +def upsample2d(device, input, parameters, in_channels, out_channels, scale_factor=2.0, reader_patterns_cache=None): + conv_on_device = reader_patterns_cache is not None + tt_out = upsample_nearest2d(input, scale_factor) weight = ttnn.to_layout(parameters.conv.weight, layout=ttnn.ROW_MAJOR_LAYOUT) weight = ttnn.to_torch(weight) weight = torch.permute(weight, (2, 3, 0, 1)) bias = ttnn.to_layout(parameters.conv.bias, layout=ttnn.ROW_MAJOR_LAYOUT) bias = ttnn.to_torch(bias) + if conv_on_device: + batch_size = tt_out.shape[0] + input_height = tt_out.shape[2] + input_width = tt_out.shape[3] + out_channels = weight.shape[0] + in_channels = weight.shape[1] + # breakpoint() + bias = torch.reshape(bias, (1, 1, 1, out_channels)) + tt_weight_tensor = ttnn.from_torch(weight, ttnn.float32) + tt_bias_tensor = ttnn.from_torch(bias, ttnn.float32) + conv_config_override = {} + if (out_channels, in_channels, input_height, input_width) in config_override: + conv_config_override = config_override[(out_channels, in_channels, input_height, input_width)] + conv = ttnn.Conv2d( + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dtype=ttnn.bfloat8_b, + device=device, + use_1d_systolic_array=False, + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + reader_patterns_cache=reader_patterns_cache, + weight=tt_weight_tensor, + bias=tt_bias_tensor, + math_fidelity=ttnn.MathFidelity.LoFi, + weights_dtype=ttnn.bfloat8_b, + conv_blocking_and_parallelization_config_override=conv_config_override, + use_shallow_conv_variant=False, + enable_auto_formatting=True, + deallocate_activation=True, + ) + tt_out = run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, conv, tt_out, batch_size, input_height, input_width, out_channels + ) + else: + tt_out = ttnn.from_device(tt_out) + tt_out = ttnn.to_layout(tt_out, ttnn.ROW_MAJOR_LAYOUT) + tt_out = ttnn.to_torch(tt_out) + tt_out = torch_to_tt_tensor_rm(tt_out, device) + weight = torch_to_tt_tensor_rm(weight, device, put_on_device=False) + bias = torch_to_tt_tensor_rm(bias, device, put_on_device=False) + + conv = fallback_ops.Conv2d( + weight, + bias, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) - weight = torch_to_tt_tensor_rm(weight, device, put_on_device=False) - bias = torch_to_tt_tensor_rm(bias, device, put_on_device=False) - - conv = fallback_ops.Conv2d( - weight, - bias, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - tt_out = conv(tt_out) - torch_out = tt_to_torch_tensor(tt_out) - ttnn_out = ttnn.from_torch(torch_out, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) - return ttnn_out + tt_out = conv(tt_out) + torch_out = tt_to_torch_tensor(tt_out) + tt_out = ttnn.from_torch(torch_out, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + return tt_out diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_nearest_2d.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_nearest_2d.py index 0e8608acfb85..9be80b7ffb4b 100644 --- a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_nearest_2d.py +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_upsample_nearest_2d.py @@ -12,13 +12,10 @@ def upsample_nearest2d(input, scale_factor=2.0): # up_output = ttnn.repeat_interleave(input, scale_factor, dim=3) # up_output = ttnn.repeat_interleave(up_output, scale_factor, dim=2) - print(f"=============================== input shape: {input.shape}") - ## permute to NHWC input = ttnn.to_layout(input, ttnn.ROW_MAJOR_LAYOUT) input = ttnn.permute(input, (0, 2, 3, 1)) - print(f"=============================== input shape: {input.shape}") up_output = ttnn.upsample(input, scale_factor) ## permute back to NCHW diff --git a/models/experimental/functional_stable_diffusion/tt/ttnn_functional_utility_functions.py b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_utility_functions.py new file mode 100644 index 000000000000..cfe6c9b314f2 --- /dev/null +++ b/models/experimental/functional_stable_diffusion/tt/ttnn_functional_utility_functions.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from tt_lib.fallback_ops import fallback_ops + +import torch +from typing import Optional, Dict + + +def pre_process_input(device, tensor): + tensor = ttnn.to_layout(tensor, ttnn.ROW_MAJOR_LAYOUT) + batch_size = tensor.shape[0] + input_channels = tensor.shape[1] + input_height = tensor.shape[2] + input_width = tensor.shape[3] + tensor = fallback_ops.permute( + tensor.value, (0, 2, 3, 1), output_layout=ttnn.ROW_MAJOR_LAYOUT, output_on_device=False + ) + import math + + assert input_channels == tensor.shape()[3] + padded_input_channels = math.ceil(input_channels / 16) * 16 + if padded_input_channels != input_channels: + tensor = fallback_ops.pad( + tensor, + (0, padded_input_channels - input_channels, 0, 0, 0, 0), + output_layout=ttnn.ROW_MAJOR_LAYOUT, + output_on_device=False, + ) + # Reshape 4d to 2d + tensor = fallback_ops.reshape( + tensor, + 1, + 1, + batch_size * input_height * input_width, + padded_input_channels, + output_layout=ttnn.ROW_MAJOR_LAYOUT, + output_on_device=False, + ) + tensor = ttnn.Tensor(tensor) + tensor = ttnn.to_device(tensor, device) + tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT) + return tensor + + +def post_process_output(device, tensor, batch_size, output_height, output_width, output_channels): + tensor = ttnn.to_layout(tensor, ttnn.ROW_MAJOR_LAYOUT) + tensor = ttnn.from_device(tensor) + assert output_channels == tensor.shape[3] + tensor = fallback_ops.reshape( + tensor.value, + batch_size, + output_height, + output_width, + output_channels, + output_layout=ttnn.ROW_MAJOR_LAYOUT, + output_on_device=False, + ) + tensor = fallback_ops.permute(tensor, (0, 3, 1, 2), output_layout=ttnn.ROW_MAJOR_LAYOUT, output_on_device=False) + tensor = ttnn.Tensor(tensor) + tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT) + tensor = ttnn.to_device(tensor, device) + return tensor + + +def run_ttnn_conv_with_pre_and_post_tensor_formatting( + device, ttnn_conv_op, tensor: ttnn.Tensor, batch_size, output_height, output_width, output_channels +) -> ttnn.Tensor: + tensor = pre_process_input(device, tensor) + # print("Running conv op") + tensor = ttnn_conv_op(tensor) + tensor = post_process_output(device, tensor, batch_size, output_height, output_width, output_channels) + return tensor diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_optimized_conv_v2.py b/tests/tt_eager/python_api_testing/unit_testing/test_optimized_conv_v2.py index 7d4159ee2e76..6b11ffc3740f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_optimized_conv_v2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_optimized_conv_v2.py @@ -167,6 +167,7 @@ def test_optimized_conv_v2( weights_dtype=weights_dtype, output_dtype=activations_dtype, math_fidelity=math_fidelity, + deallocate_activation=True, ) conv_input = tt_lib.tensor.Tensor( diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_resnet50_untilize_with_halo_and_conv_v2.py b/tests/tt_eager/python_api_testing/unit_testing/test_resnet50_untilize_with_halo_and_conv_v2.py index 4bfdd33dc330..8b4bb6265609 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_resnet50_untilize_with_halo_and_conv_v2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_resnet50_untilize_with_halo_and_conv_v2.py @@ -654,6 +654,7 @@ def test_resnet50_conv( output_dtype=activations_dtype, math_fidelity=math_fidelity, use_shallow_conv_variant=(C == 16), + deallocate_activation=True, ) conv_input = tt_lib.tensor.Tensor( diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d.py b/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d.py index f000d442371e..8e28d6a1ee80 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d.py +++ b/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d.py @@ -87,7 +87,6 @@ def test_resnet_block_2d_256x256( temb = ttnn.from_torch(temb, ttnn.bfloat16) temb = ttnn.to_layout(temb, ttnn.TILE_LAYOUT) temb = ttnn.to_device(temb, device, memory_config=ttnn.L1_MEMORY_CONFIG) - ttnn_output = resnetBlock2D( input, temb=temb, @@ -117,10 +116,10 @@ def test_resnet_block_2d_256x256( (2, 1280, 8, 8, 2, 1, "down", None), (2, 2560, 8, 8, 0, 0, "up", 1280), (2, 2560, 16, 16, 0, 0, "up", 1280), - (2, 1920, 16, 16, 2, 0, "up", 640), + (2, 1920, 16, 16, 2, 0, "up", 1280), (2, 1920, 32, 32, 2, 0, "up", 640), (2, 1280, 32, 32, 3, 0, "down", None), - (2, 960, 32, 32, 3, 0, "up", 320), + (2, 960, 32, 32, 3, 0, "up", 640), (2, 960, 64, 64, 3, 0, "up", 320), (2, 640, 64, 64, 3, 1, "up", 320), ], @@ -165,12 +164,12 @@ def test_resnet_block_2d_512x512( input = ttnn.from_torch(input, ttnn.bfloat16) input = ttnn.to_layout(input, ttnn.TILE_LAYOUT) - input = ttnn.to_device(input, device, memory_config=ttnn.L1_MEMORY_CONFIG) + input = ttnn.to_device(input, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) temb = ttnn.from_torch(temb, ttnn.bfloat16) temb = ttnn.to_layout(temb, ttnn.TILE_LAYOUT) - temb = ttnn.to_device(temb, device, memory_config=ttnn.L1_MEMORY_CONFIG) - + temb = ttnn.to_device(temb, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + reader_patterns_cache = {} ttnn_output = resnetBlock2D( input, temb=temb, @@ -183,6 +182,7 @@ def test_resnet_block_2d_512x512( output_scale_factor=output_scale_factor, parameters=parameters, device=device, + reader_patterns_cache=reader_patterns_cache, ) ttnn_output = ttnn_to_torch(ttnn_output) assert_with_pcc(torch_output, ttnn_output, pcc=0.99) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py b/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py index ccc47611a392..2d615e9b91bf 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py +++ b/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py @@ -141,8 +141,8 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_ torch_output = model(input, timestep=timestep, encoder_hidden_states=encoder_hidden_states.squeeze(0)).sample input = ttnn.from_torch(input, ttnn.bfloat16) - input = ttnn.to_layout(input, ttnn.TILE_LAYOUT) input = ttnn.to_device(input, device, memory_config=ttnn.L1_MEMORY_CONFIG) + input = ttnn.to_layout(input, ttnn.TILE_LAYOUT, ttnn.bfloat8_b) ttnn_timestep = ttnn.from_torch(ttnn_timestep, ttnn.bfloat16) ttnn_timestep = ttnn.to_layout(ttnn_timestep, ttnn.TILE_LAYOUT) @@ -151,7 +151,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_ encoder_hidden_states = ttnn.from_torch(encoder_hidden_states, ttnn.bfloat16) encoder_hidden_states = ttnn.to_layout(encoder_hidden_states, ttnn.TILE_LAYOUT) encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device, memory_config=ttnn.L1_MEMORY_CONFIG) - + reader_patterns_cache = {} ttnn_output = UNet2DConditionModel( input, timestep=ttnn_timestep, @@ -163,6 +163,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_ parameters=parameters, device=device, config=config, + reader_patterns_cache=reader_patterns_cache, ) ttnn_output = ttnn_to_torch(ttnn_output) assert_with_pcc(torch_output, ttnn_output, pcc=0.80) diff --git a/tests/ttnn/unit_tests/operations/test_conv2d.py b/tests/ttnn/unit_tests/operations/test_conv2d.py index 04066f1dda85..323d59bb06f3 100644 --- a/tests/ttnn/unit_tests/operations/test_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_conv2d.py @@ -502,6 +502,8 @@ def test_resnet50_conv_wh( (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), + # 1x1 conv + (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, False, None), ), ) @pytest.mark.parametrize( @@ -535,7 +537,7 @@ def test_sd_conv( config_override, enable_auto_formatting, ): - if input_channels > 1280 or (input_channels > 640 and input_height > 16): + if filter_height > 1 and (input_channels > 1280 or (input_channels > 640 and input_height > 16)): if enable_auto_formatting: pytest.skip("Not running split SD conv with auto formatting") run_conv_with_split( diff --git a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py index 2f22cdaca072..6698c8b0a8dd 100644 --- a/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py +++ b/tt_eager/tt_dnn/op_library/sliding_window_op_infra/tt_py_composite_conv.py @@ -297,7 +297,6 @@ def determine_per_core_block_config( assert ( "out_subblock_h" in config_override ), "out_subblock_h must also be provided as override config if out_subblock_w is provided" - conv_blocking_config = ttl.tensor.OptimizedConvBlockConfig( act_block_h_ntiles=act_block_h_ntiles, act_block_w_ntiles=act_block_w_ntiles, @@ -310,6 +309,42 @@ def determine_per_core_block_config( return conv_blocking_config +def determine_1x1conv_as_matmul_config( + conv_parallelization_config, conv_blocking_config, use_1d_systolic_array, fuse_relu +): + if use_1d_systolic_array: + matmul_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=conv_parallelization_config.grid_size, + in0_block_w=conv_blocking_config.act_block_w_ntiles, + out_subblock_h=conv_blocking_config.out_subblock_h_ntiles, + out_subblock_w=conv_blocking_config.out_subblock_w_ntiles, + per_core_M=conv_parallelization_config.per_core_out_matrix_height_ntiles, + per_core_N=conv_parallelization_config.per_core_weight_matrix_width_ntiles, + fuse_batch=True, + fused_activation=ttl.tensor.FusibleActivationWithParam(ttl.tensor.FusibleActivation.RELU) + if fuse_relu + else None, + mcast_in0=False, + ) + else: + assert ( + conv_blocking_config.act_block_w_ntiles % conv_blocking_config.act_c_num_blocks == 0 + ), "Expected act block width to be divisible by act channel num blocks." + matmul_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=conv_parallelization_config.grid_size, + in0_block_w=conv_blocking_config.act_block_w_ntiles // conv_blocking_config.act_c_num_blocks, + out_subblock_h=conv_blocking_config.out_subblock_h_ntiles, + out_subblock_w=conv_blocking_config.out_subblock_w_ntiles, + per_core_M=conv_parallelization_config.per_core_out_matrix_height_ntiles, + per_core_N=conv_parallelization_config.per_core_weight_matrix_width_ntiles, + transpose_mcast=True, + fused_activation=ttl.tensor.FusibleActivationWithParam(ttl.tensor.FusibleActivation.RELU) + if fuse_relu + else None, + ) + return matmul_config + + class TTPyCompositeConv(TTPyOp): config_keys = [ "num_cores_nhw", @@ -418,7 +453,20 @@ def __init__( pad_w = sliding_window_op_params.pad_w filter_height = sliding_window_op_params.window_h filter_width = sliding_window_op_params.window_w - + self.matmul_config = None + self.use_matmul_for_1x1_conv = False + if ( + filter_height == filter_width + and filter_height == 1 + and stride_h == stride_w + and stride_h == 1 + and pad_h == pad_w + and pad_h == 0 + ): + self.use_matmul_for_1x1_conv = True + self.matmul_config = determine_1x1conv_as_matmul_config( + self.opt_conv_parall_conf_auto, self.opt_conv_block_conf_auto, is_1d_systolic, fuse_relu + ) if isinstance(sliding_window_op_params, SlidingWindowOpParams): # populate parallelization params in sliding_window_op_params sliding_window_op_params = SlidingWindowOpParamsWithParallelConfig( @@ -438,6 +486,7 @@ def __init__( self.sliding_window_op_params = sliding_window_op_params self.move_utwh_output = move_utwh_output + self.deallocate_input = deallocate_activation sliding_window_op_params_hash = get_hash_from_sliding_window_op_params(sliding_window_op_params) @@ -455,17 +504,19 @@ def __init__( 1, 1, ] - # set_op_configs populates reader_patterns_cache["conv"][sliding_window_op_params_hash] with conv_reader_indices sharded tensor - self.set_op_configs( - self.device, - sliding_window_op_params_hash, - sliding_window_op_params, - conv_params, - not is_1d_systolic, - reader_patterns_cache["conv"], - ) - assert sliding_window_op_params_hash in reader_patterns_cache["conv"] - conv_reader_indices = reader_patterns_cache["conv"][sliding_window_op_params_hash] + conv_reader_indices = None + if not self.use_matmul_for_1x1_conv: + # set_op_configs populates reader_patterns_cache["conv"][sliding_window_op_params_hash] with conv_reader_indices sharded tensor + self.set_op_configs( + self.device, + sliding_window_op_params_hash, + sliding_window_op_params, + conv_params, + not is_1d_systolic, + reader_patterns_cache["conv"], + ) + assert sliding_window_op_params_hash in reader_patterns_cache["conv"] + conv_reader_indices = reader_patterns_cache["conv"][sliding_window_op_params_hash] self.set_op_weights_biases( weight, @@ -487,10 +538,11 @@ def __init__( move_weights_to_device=move_weights_to_device, ) - # create untilize with halo op - self.tt_py_untilize_with_halo_op = TTPyUntilizeWithHalo( - device, self.sliding_window_op_params, reader_patterns_cache["halo"] - ) + if not self.use_matmul_for_1x1_conv: + # create untilize with halo op + self.tt_py_untilize_with_halo_op = TTPyUntilizeWithHalo( + device, self.sliding_window_op_params, reader_patterns_cache["halo"] + ) # override abstract methods from base class TTPyOp def set_op_configs( @@ -696,28 +748,61 @@ def conv_(activation): ) # assert(output.storage_type() == ttl.tensor.StorageType.DEVICE) + def composite_conv_with_deallocate_input(activation): + # assert(activation.layout() == ttl.tensor.Layout.ROW_MAJOR) + utwh_output = self.tt_py_untilize_with_halo_op(activation) + activation.deallocate() + return conv_(utwh_output) + def composite_conv(activation): # assert(activation.layout() == ttl.tensor.Layout.ROW_MAJOR) utwh_output = self.tt_py_untilize_with_halo_op(activation) - if self.deallocate_activation: - activation.deallocate() return conv_(utwh_output) + def composite_conv_with_move_utwh_output_with_deallocate_input(activation): + # assert(activation.layout() == ttl.tensor.Layout.ROW_MAJOR) + utwh_output = self.tt_py_untilize_with_halo_op(activation) + activation.deallocate() + move_output = ttl.tensor.move_sharded(utwh_output) + utwh_output.deallocate() + return conv_(move_output) + def composite_conv_with_move_utwh_output(activation): # assert(activation.layout() == ttl.tensor.Layout.ROW_MAJOR) utwh_output = self.tt_py_untilize_with_halo_op(activation) - if self.deallocate_activation: - activation.deallocate() move_output = ttl.tensor.move_sharded(utwh_output) utwh_output.deallocate() return conv_(move_output) - if self.move_utwh_output: - self.conv = composite_conv_with_move_utwh_output + def conv1x1_as_matmul(activation): + # conv1x1 stride 1 padding 0, use matmul op + output = ttl.operations.primary.matmul( + activation, + weight_on_device, + bias=bias_on_device, + program_config=self.matmul_config, + output_mem_config=activation.memory_config() if output_mem_config is None else output_mem_config, + output_dtype=output_dtype, + math_fidelity=math_fidelity, + ) + return output + + if self.use_matmul_for_1x1_conv: + self.conv = conv1x1_as_matmul + elif self.move_utwh_output: + if self.deallocate_input: + self.conv = composite_conv_with_move_utwh_output_with_deallocate_input + else: + self.conv = composite_conv_with_move_utwh_output else: - self.conv = composite_conv + if self.deallocate_input: + self.conv = composite_conv_with_deallocate_input + else: + self.conv = composite_conv def __call__(self, activation): + # print("Going to run conv with input shape-", self.input_tensor_shape) + # print("with output shape = ", self.conv_output_shape) if self.enable_auto_formatting: activation = self.conv_input_interleaved_to_sharded(activation) activation = self.conv(activation) diff --git a/tt_eager/tt_lib/fallback_ops/conversion_wrapper.py b/tt_eager/tt_lib/fallback_ops/conversion_wrapper.py index 8a18cbba7269..b436cf16b296 100644 --- a/tt_eager/tt_lib/fallback_ops/conversion_wrapper.py +++ b/tt_eager/tt_lib/fallback_ops/conversion_wrapper.py @@ -49,7 +49,11 @@ def custom_pt_tensor_to_str_fn(tensor): def convert_tt_tensor_to_pt_tensor(tt_tensor, output_format): # Update output_format with format of first encountered arg - if output_format.get("device", None) is None and tt_tensor.storage_type() == ttl_tensor.StorageType.DEVICE: + if ( + output_format.get("device", None) is None + and tt_tensor.storage_type() == ttl_tensor.StorageType.DEVICE + and output_format["on_device"] + ): output_format["device"] = tt_tensor.device() if ttl_profiler.get_profiler_flag(): @@ -82,7 +86,9 @@ def convert_pt_tensor_to_tt_tensor(pt_tensor, output_format): else: assert output_format["layout"] == ttl_tensor.Layout.ROW_MAJOR - if isinstance(output_format["device"], ttl_device.Device): + if output_format["on_device"]: + assert "device" in output_format + assert isinstance(output_format["device"], ttl_device.Device) if ( tt_tensor.layout() == ttl_tensor.Layout.TILE or tt_tensor.layout() == ttl_tensor.Layout.ROW_MAJOR @@ -135,7 +141,15 @@ def convert_tt_tensors_wrapper(func): def wrap(*args, **kwargs): ttl_tensor.log_external_operation(func, *args, **kwargs) - output_format = {"layout": ttl_tensor.Layout.TILE} + output_format = {} + if "output_on_device" in kwargs: + output_format["on_device"] = kwargs["output_on_device"] + else: + output_format["on_device"] = True + if "output_layout" in kwargs: + output_format["layout"] = kwargs["output_layout"] + else: + output_format["layout"] = ttl_tensor.Layout.TILE if ttl_profiler.get_profiler_flag(): ttl_profiler.start_profiling("fallback_op", ttl_profiler.OpType.python_fallback) @@ -164,7 +178,7 @@ def wrap(*args, **kwargs): new_kwargs = convert_tt_tensors_to_pt_tensors(kwargs, output_format) # Set default output format - if output_format.get("device", None) is None: + if output_format.get("device", None) is None and output_format["on_device"]: output_format["device"] = ttl_device.GetDefaultDevice() outputs = func(*new_args, **new_kwargs) diff --git a/tt_eager/tt_lib/fallback_ops/fallback_ops.py b/tt_eager/tt_lib/fallback_ops/fallback_ops.py index 625915a76c02..a9d82065b358 100644 --- a/tt_eager/tt_lib/fallback_ops/fallback_ops.py +++ b/tt_eager/tt_lib/fallback_ops/fallback_ops.py @@ -46,27 +46,64 @@ def tensor_slice(input: ttl_tensor.Tensor, slices: List[Union[slice, EllipsisTyp @convert_tt_tensors_wrapper -def reshape(input: ttl_tensor.Tensor, N: int, C: int, H: int, W: int) -> ttl_tensor.Tensor: +def reshape( + input: ttl_tensor.Tensor, + N: int, + C: int, + H: int, + W: int, + output_layout: Optional[ttl_tensor.Layout] = ttl_tensor.Layout.TILE, + output_on_device: Optional[bool] = True, +) -> ttl_tensor.Tensor: """ Returns a new ``tt_lib.tensor.Tensor`` with the same data and number of elements as ``input``, but with the specified shape ``[N, C, H, W]``. - +------------+-----------------------------------------------+-------------+-----------------+----------+ - | Argument | Description | Data type | Valid range | Required | - +============+===============================================+=============+=================+==========+ - | input | Input tensor | Tensor | | Yes | - +------------+-----------------------------------------------+-------------+-----------------+----------+ - | N | Size of the first dimension of output tensor | int | | Yes | - +------------+-----------------------------------------------+-------------+-----------------+----------+ - | C | Size of the second dimension of output tensor | int | | Yes | - +------------+-----------------------------------------------+-------------+-----------------+----------+ - | H | Size of the third dimension of output tensor | int | | Yes | - +------------+-----------------------------------------------+-------------+-----------------+----------+ - | W | Size of the fourth dimension of output tensor | int | | Yes | - +------------+-----------------------------------------------+-------------+-----------------+----------+ + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +==================+===============================================+=============+=================+==========+ + | input | Input tensor | Tensor | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | N | Size of the first dimension of output tensor | int | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | C | Size of the second dimension of output tensor | int | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | H | Size of the third dimension of output tensor | int | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | W | Size of the fourth dimension of output tensor | int | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | output_layout | Output layout | Layout | default is TILE | No | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | output_on_device | Output on device | bool | default is True | No | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ """ return torch.reshape(input, (N, C, H, W)) +@convert_tt_tensors_wrapper +def permute( + input: ttl_tensor.Tensor, + dims: Tuple[int], + output_layout: Optional[ttl_tensor.Layout] = ttl_tensor.Layout.TILE, + output_on_device: Optional[bool] = True, +) -> ttl_tensor.Tensor: + """ + Returns a new ``tt_lib.tensor.Tensor`` with the same data and number of elements as ``input``, but with the specified shape ``[N, C, H, W]``. + + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +==================+===============================================+=============+=================+==========+ + | input | Input tensor | Tensor | | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | dims | Desired ordering of dimensions | Tuple of int| | Yes | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | output_layout | Output layout | Layout | default is TILE | No | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + | output_on_device | Output on device | bool | default is True | No | + +------------------+-----------------------------------------------+-------------+-----------------+----------+ + """ + return torch.permute(input, dims) + + @convert_tt_tensors_wrapper def chunk(input: ttl_tensor.Tensor, chunks: int, dim: int = 0) -> List[ttl_tensor.Tensor]: """ @@ -224,6 +261,8 @@ def pad( pad: Tuple[int], mode: str = "constant", value: Optional[int] = None, + output_layout: Optional[ttl_tensor.Layout] = ttl_tensor.Layout.TILE, + output_on_device: Optional[bool] = True, ) -> ttl_tensor.Tensor: r""" Pads tensor. @@ -245,6 +284,10 @@ def pad( +------------------+-----------------------------------------------------------+------------------+---------------------------------------------------------------------------+----------+ | value | Fill value for `constant` padding | int | default is 0 | No | +------------------+-----------------------------------------------------------+------------------+---------------------------------------------------------------------------+----------+ + | output_layout | Output layout | Layout | default is TILE | No | + +------------------+-----------------------------------------------------------+------------------+---------------------------------------------------------------------------+----------+ + | output_on_device | Output on device | bool | default is True | No | + +------------------+-----------------------------------------------------------+------------------+---------------------------------------------------------------------------+----------+ """ return torch.nn.functional.pad(input, pad, mode, value) diff --git a/tt_metal/impl/allocator/algorithms/free_list.cpp b/tt_metal/impl/allocator/algorithms/free_list.cpp index ed305e69ef6a..8f10db412304 100644 --- a/tt_metal/impl/allocator/algorithms/free_list.cpp +++ b/tt_metal/impl/allocator/algorithms/free_list.cpp @@ -240,7 +240,7 @@ std::optional FreeList::allocate(uint64_t size_bytes, bool bottom_up, this->update_lowest_occupied_address(allocated_block->address); if (allocated_block->address + this->offset_bytes_ < address_limit) { - TT_THROW("Out of Memory: Cannot allocate at an address below {}", address_limit); + TT_THROW("Out of Memory: Cannot allocate at an address below {}. Tried to allocate at {}", address_limit, allocated_block->address + this->offset_bytes_); } return allocated_block->address + this->offset_bytes_; } diff --git a/ttnn/cpp/pybind11/operations/binary.hpp b/ttnn/cpp/pybind11/operations/binary.hpp index 08f1b7855b76..bf8eac75a2ce 100644 --- a/ttnn/cpp/pybind11/operations/binary.hpp +++ b/ttnn/cpp/pybind11/operations/binary.hpp @@ -19,11 +19,12 @@ namespace binary { void py_module(py::module& m_binary) { m_binary.def( "add", - static_cast(&ttnn::operations::binary::add), + static_cast)>(&ttnn::operations::binary::add), py::arg("input_tensor_a"), py::arg("input_tensor_b"), py::kw_only(), - py::arg("memory_config") = DRAM_MEMORY_CONFIG + py::arg("memory_config") = DRAM_MEMORY_CONFIG, + py::arg("dtype") = std::nullopt ); m_binary.def( diff --git a/ttnn/cpp/ttnn/operations/binary.hpp b/ttnn/cpp/ttnn/operations/binary.hpp index 4a27e1ba2558..85da18142dae 100644 --- a/ttnn/cpp/ttnn/operations/binary.hpp +++ b/ttnn/cpp/ttnn/operations/binary.hpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once - +#include #include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" @@ -18,7 +18,8 @@ namespace binary { inline ttnn::Tensor add( const ttnn::Tensor& input_tensor_a_arg, const ttnn::Tensor& input_tensor_b_arg, - const tt::tt_metal::MemoryConfig& memory_config) { + const tt::tt_metal::MemoryConfig& memory_config, + std::optional dtype = std::nullopt) { auto&& [input_tensor_a, input_tensor_b] = [](const auto& input_tensor_a_arg, const auto& input_tensor_b_arg) { // Swap tensors if input_tensor_a needs to be broadcasted to input_tensor_b if (tt::tt_metal::compute_volume(input_tensor_a_arg.ttnn_shape()) < @@ -45,6 +46,9 @@ inline ttnn::Tensor add( auto input_tensor_b_4D = ttnn::unsqueeze_to_4D(input_tensor_b); if (height_b == 1 or width_b == 1) { + if (dtype.has_value()) { + TT_THROW("ttnn.add: cannot change dtype when broadcasting"); + } tt::tt_metal::BcastOpDim bcast_op_dim; if (height_b == 1 and width_b == 1) { bcast_op_dim = tt::tt_metal::BcastOpDim::HW; @@ -59,7 +63,7 @@ inline ttnn::Tensor add( input_tensor_a_4D, input_tensor_b_4D, tt::tt_metal::BcastOpMath::ADD, bcast_op_dim, memory_config); return ttnn::reshape(output, original_shape); } else { - auto output = tt::tt_metal::add(input_tensor_a_4D, input_tensor_b_4D, std::nullopt, memory_config); + auto output = tt::tt_metal::add(input_tensor_a_4D, input_tensor_b_4D, std::nullopt, memory_config, dtype); return ttnn::reshape(output, original_shape); } } diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 90a8067e3e91..39ee6e256131 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -52,7 +52,7 @@ def get_bool_env_var(name, default): Tensor, ) -from ttnn.device import Device, open_device, close_device, manage_device +from ttnn.device import Device, open_device, close_device, manage_device, dump_device_memory_state from ttnn.core import ( has_storage_type_of, @@ -93,6 +93,7 @@ def get_bool_env_var(name, default): dump_tensor, unsqueeze_to_4D, squeeze, + clone, ) from ttnn.operations.matmul import ( @@ -158,7 +159,6 @@ def get_bool_env_var(name, default): atanh, logical_not, logit, - clone, signbit, ) diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index b89436584c13..6122e034f78a 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -57,4 +57,8 @@ def manage_device(*, device_id: int): close_device(device) +def dump_device_memory_state(device): + ttl.device.DumpDeviceMemoryState(device) + + __all__ = [] diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/binary.py index c311a7f13101..a8a2c2896e21 100644 --- a/ttnn/ttnn/operations/binary.py +++ b/ttnn/ttnn/operations/binary.py @@ -135,6 +135,7 @@ def add( input_tensor_b: Union[ttnn.Tensor, int, float], *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, + dtype: Optional[ttnn.DataType] = None, ) -> ttnn.Tensor: r""" add(input_tensor_a: ttnn.Tensor, input_tensor_b: Union[ttnn.Tensor, int, float], *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor: @@ -164,7 +165,7 @@ def add( """ input_tensor_a = input_tensor_a.value input_tensor_b = input_tensor_b.value if isinstance(input_tensor_b, ttnn.Tensor) else input_tensor_b - output = ttnn._ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config) + output = ttnn._ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config, dtype=dtype) return ttnn.Tensor(output) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 0ff74c4db480..25f2039da9d1 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -630,7 +630,7 @@ def _to_layout_validate_input_tensors(operation_name, input_tensor, *args, **kwa @ttnn.register_operation(name="ttnn.to_layout", validate_input_tensors=_to_layout_validate_input_tensors) -def to_layout(tensor, layout: ttnn.Layout): +def to_layout(tensor, layout: ttnn.Layout, dtype: ttnn.DataType = None): """ to_layout(tensor: ttnn.Tensor, layout: Layout) -> ttnn.Tensor @@ -684,13 +684,18 @@ def requires_padding_change(layout, shape): else: return False + if dtype is not None and (not is_on_device or layout is not ttnn.TILE_LAYOUT): + raise RuntimeError(f"Unsupported datatype conversion to {dtype}") + if not requires_padding_change(layout, tensor.shape): ttl_tensor = tensor.value if is_on_device: if layout == ttnn.ROW_MAJOR_LAYOUT: return ttnn.Tensor(ttl.tensor.untilize(ttl_tensor)) elif layout == ttnn.TILE_LAYOUT: - return ttnn.Tensor(ttl.tensor.tilize(ttl_tensor, output_mem_config=ttl_tensor.memory_config())) + return ttnn.Tensor( + ttl.tensor.tilize(ttl_tensor, output_mem_config=ttl_tensor.memory_config(), output_dtype=dtype) + ) else: raise RuntimeError(f"Unsupported layout: {layout}") else: @@ -764,10 +769,7 @@ def impl(input_tensor): if is_on_device: tensor = ttnn.Tensor( ttl.tensor.tilize_with_val_padding( - tensor.value, - batch_sizes + [padded_height, padded_width], - [0, 0, 0, 0], - 0, + tensor.value, batch_sizes + [padded_height, padded_width], [0, 0, 0, 0], 0, output_dtype=dtype ) ) else: @@ -788,6 +790,40 @@ def impl(tensor): raise RuntimeError(f"Unsupported output layout: {layout}") +def _clone_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + ttnn.validate_input_tensor( + operation_name, + input_tensor, + ranks=(1, 2, 3, 4, 5), + dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32, ttnn.float32), + layouts=(ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT), + can_be_on_device=True, + can_be_on_cpu=False, + ) + + +@ttnn.register_operation(name="ttnn.clone", validate_input_tensors=_clone_validate_input_tensors) +def clone(tensor, memory_config: ttnn.MemoryConfig, dtype: ttnn.DataType): + """ + clone(tensor: ttnn.Tensor, memory_config: MemoryConfig, dtype: DataType) -> ttnn.Tensor + Clones the tensor by copying it with the given `memory config`. Also, converts the dataype to `dtype`. + Note: clone does not change the layout of the tensor. + Organizes the `ttnn.Tensor` :attr:`tensor` into either ROW_MAJOR_LAYOUT or TILE_LAYOUT. When requesting ROW_MAJOR_LAYOUT + the tensor will be returned unpadded in the last two dimensions. When requesting TILE_LAYOUT the tensor will be automatically + padded where the width and height become multiples of 32. + In the case where the layout is the same, the operation simply pad or unpad the last two dimensions depending on layout requested. + Args: + * :attr:`tensor`: the ttnn.Tensor + * :attr:`memory_config`: the `ttnn` memory config, DRAM_MEMORY_CONFIG or L1_MEMORY_CONFIG. + * :attr:`dtype`: the `ttnn` data type. + Example:: + >>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16, layout=ttnn.TILE_LAYOUT)), device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + >>> output = ttnn.clone(tensor, tnn.DRAM_MEMORY_CONFIG, tnn.bfloat8_b) + """ + ttl_tensor = tensor.value + return ttnn.Tensor(ttl.tensor.clone(ttl_tensor, output_mem_config=memory_config, output_dtype=dtype)) + + def _torch_identity(input_tensor): input_tensor = to_torch(input_tensor) return input_tensor.clone() diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index d1f62c82c2ee..213c59c5eac3 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -38,7 +38,6 @@ def _torch_unary(input_tensor: ttnn.Tensor, **_): "acosh": torch.acosh, "atanh": torch.atanh, "logical_not": torch.logical_not, - "clone": torch.clone, "signbit": torch.signbit, } torch_function = name_to_torch_function[name] @@ -123,7 +122,6 @@ def unary_function( ("acosh", ttl.tensor.acosh), ("atanh", ttl.tensor.atanh), ("logical_not", ttl.tensor.logical_not_unary), - ("clone", ttl.tensor.clone), ("signbit", ttl.tensor.signbit), ]