From c36cd1bd383d1774fce6c29a55278ed24cd86cd6 Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Mon, 23 Dec 2024 08:50:05 +0000 Subject: [PATCH] #14179: unet and sd --- .../ttnn_functional_downsample_2d_new_conv.py | 35 +++++++++++++++++++ .../functional_unet/tt/unet_shallow_ttnn.py | 33 +++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 570d2457f1ae..e1864f9acef3 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -146,6 +146,41 @@ def __call__( if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] + conv_kwargs = { + "input_layout": hidden_states.get_layout(), + "in_channels": in_channels, + "out_channels": self.out_channels, + "batch_size": hidden_states.shape[0], + "input_height": hidden_states.shape[1], + "input_width": hidden_states.shape[2], + "kernel_size": (3, 3), + "stride": (self.stride, self.stride), + "padding": (1, 1), + "dilation": (1, 1), + "groups": 1, + "device": self.device, + "conv_config": conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.conv_weights): + self.conv_weights = ttnn.prepare_conv_weights( + weight_tensor=self.conv_weights, + weights_format="OIHW", + input_memory_config=hidden_states.memory_config(), + **conv_kwargs, + ) + self.conv_bias = ( + ttnn.prepare_conv_bias( + bias_tensor=self.conv_bias, + input_memory_config=hidden_states.memory_config(), + **conv_kwargs, + ) + if self.conv_bias is not None + else None + ) + self.conv_weights = ttnn.to_device(self.conv_weights, self.device) + self.conv_bias = ttnn.to_device(self.conv_bias, self.device) + [hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.in_channels, diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 8a5157d51dc3..bf6903dd75b7 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -147,6 +147,39 @@ def __init__( self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): + print(ttnn.get_memory_config(x)) + conv_kwargs = { + "input_layout": x.get_layout(), + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "batch_size": self.batch_size, + "input_height": self.input_height, + "input_width": self.input_width, + "kernel_size": self.kernel_size, + "stride": self.stride, + "padding": self.padding, + "dilation": [1, 1], + "groups": self.groups, + "device": self.device, + "conv_config": self.conv_config, + } + + if not ttnn.is_tensor_storage_on_device(self.weight): + breakpoint + self.weight = ttnn.prepare_conv_weights( + weight_tensor=self.weight, + weights_format="OIHW", + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.bias = ttnn.prepare_conv_bias( + bias_tensor=self.bias, + input_memory_config=x.memory_config(), + **conv_kwargs, + ) + self.weight = ttnn.to_device(self.weight, self.device) + self.bias = ttnn.to_device(self.bias, self.device) if self.bias else None + x, [self.weight, self.bias] = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight,