Skip to content

Commit

Permalink
#5153: TNN Support for 1x1 Conv as matmul. Enabled TNN convs in SD
Browse files Browse the repository at this point in the history
Unet. Fixes.
  • Loading branch information
tt-nshanker committed Feb 21, 2024
1 parent 69893cb commit a7c1831
Show file tree
Hide file tree
Showing 32 changed files with 1,071 additions and 280 deletions.
2 changes: 2 additions & 0 deletions models/demos/resnet/tt/metalResnetBlock50.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def downsample_conv_op_with_formatting(x):
output_dtype=model_config["ACTIVATIONS_DTYPE"],
math_fidelity=model_config["MATH_FIDELITY"],
move_utwh_output=move_utwh_output,
deallocate_activation=True,
)
else:
self.conv2 = resnet50_optimized_conv(
Expand Down Expand Up @@ -1500,6 +1501,7 @@ def __init__(
output_dtype=model_config["ACTIVATIONS_DTYPE"],
math_fidelity=model_config["MATH_FIDELITY"],
use_shallow_conv_variant=True,
deallocate_activation=True,
)
self.first_conv_op_params = sliding_window_op_params
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def basic_transformer_block(
)
if use_ada_layer_norm_zero:
assert False, "AdaLayerNormZero not supported and not used in stable diffusion"

ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff)
# ttnn.dump_device_memory_state(device)
ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff, device=device)

hidden_states = ttnn.add(ff_output, hidden_states)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def cross_attention_down_block_2d(
*,
parameters,
device,
reader_patterns_cache=None,
):
output_states = ()

Expand All @@ -54,6 +55,7 @@ def cross_attention_down_block_2d(
parameters=resnet,
device=device,
use_in_shortcut=use_in_shortcut,
reader_patterns_cache=reader_patterns_cache,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
Expand All @@ -76,7 +78,8 @@ def cross_attention_down_block_2d(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
device=device,
)
reader_patterns_cache=reader_patterns_cache,
)

output_states += (hidden_states,)

Expand All @@ -89,6 +92,7 @@ def cross_attention_down_block_2d(
device=device,
parameters=parameters.downsamplers[0],
use_conv=True,
reader_patterns_cache=reader_patterns_cache,
)
output_states += (hidden_states,)
return hidden_states, output_states
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import ttnn

from typing import Optional, Dict
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_upsample_2d import upsample2d
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_transformer_2d import transformer_2d_model
Expand Down Expand Up @@ -57,6 +57,7 @@ def cross_attention_upblock2d(
cross_attention_dim=1280,
attn_num_head_channels=1,
only_cross_attention: bool = False,
reader_patterns_cache: Optional[Dict] = None,
):
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
Expand Down Expand Up @@ -88,6 +89,7 @@ def cross_attention_upblock2d(
pre_norm=resnet_pre_norm,
non_linearity=resnet_act_fn,
device=device,
reader_patterns_cache=reader_patterns_cache,
)
if not dual_cross_attention:
hidden_states = transformer_2d_model(
Expand All @@ -110,11 +112,19 @@ def cross_attention_upblock2d(
device=device,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
reader_patterns_cache=reader_patterns_cache,
)
else:
assert False, "We do not support Dual Transformer2DModel"

if add_upsample:
hidden_states = upsample2d(device, hidden_states, parameters.upsamplers[0], out_channels, out_channels)
hidden_states = upsample2d(
device,
hidden_states,
parameters.upsamplers[0],
out_channels,
out_channels,
reader_patterns_cache=reader_patterns_cache,
)

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# SPDX-License-Identifier: Apache-2.0


import ttnn
import torch
from typing import Optional
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_resnetblock2d import resnetBlock2D
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_downsample_2d import downsample_2d

Expand All @@ -25,6 +27,8 @@ def downblock2d(
add_downsample=False,
downsample_padding=1,
parameters=None,
reader_patterns_cache: Optional[dict] = None,
dtype: Optional[ttnn.DataType] = None,
):
output_states = ()
for i in range(num_layers):
Expand All @@ -45,6 +49,8 @@ def downblock2d(
eps=resnet_eps,
up=False,
down=False,
dtype=dtype,
reader_patterns_cache=reader_patterns_cache,
)

hidden_states = resnet
Expand All @@ -61,6 +67,8 @@ def downblock2d(
name="op",
parameters=parameters.downsamplers[0],
device=device,
dtype=dtype,
reader_patterns_cache=reader_patterns_cache,
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import ttnn
import torch
from typing import Optional
import torch.nn as nn
from tt_lib.fallback_ops import fallback_ops
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_utility_functions import (
run_ttnn_conv_with_pre_and_post_tensor_formatting,
)
import math


def permute_conv_parameters(weight, bias):
Expand All @@ -19,6 +24,21 @@ def permute_conv_parameters(weight, bias):
return weight, bias


config_override = {
(320, 320, 64, 64): {"act_block_h": 64},
(640, 640, 32, 32): {"act_block_h": 64},
(640, 1920, 32, 32): {"act_block_h": 32},
(640, 1280, 32, 32): {"act_block_h": 32},
(1280, 1920, 16, 16): {"act_block_h": 32},
(1280, 1280, 32, 32): {"act_block_h": 32},
(320, 960, 64, 64): {"act_block_h": 32},
(640, 960, 32, 32): {"act_block_h": 32},
(320, 640, 64, 64): {"act_block_h": 32},
(640, 640, 64, 64): {"act_block_h": 64},
(640, 320, 64, 64): {"act_block_h": 64},
}


def downsample_2d(
in_channels,
hidden_states,
Expand All @@ -27,23 +47,61 @@ def downsample_2d(
use_conv=False,
out_channels=None,
padding=1,
reader_patterns_cache: Optional[dict] = None,
dtype: Optional[ttnn.DataType] = None,
):
stride = 2

parameters.conv.weight, parameters.conv.bias = permute_conv_parameters(parameters.conv.weight, parameters.conv.bias)
parameters.conv.weight = torch_to_tt_tensor_rm(parameters.conv.weight, device, put_on_device=False)
parameters.conv.bias = torch_to_tt_tensor_rm(parameters.conv.bias, device, put_on_device=False)
conv_on_device = reader_patterns_cache is not None
batch_size = hidden_states.shape[0]
input_height = hidden_states.shape[2]
input_width = hidden_states.shape[3]

if use_conv:
conv = fallback_ops.Conv2d(
parameters.conv.weight,
parameters.conv.bias,
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
)
if conv_on_device:
parameters.conv.bias = torch.reshape(parameters.conv.bias, (1, 1, 1, parameters.conv.bias.shape[-1]))
tt_weight_tensor = ttnn.from_torch(parameters.conv.weight, ttnn.float32)
tt_bias_tensor = ttnn.from_torch(parameters.conv.bias, ttnn.float32)
# breakpoint()
out_channels = parameters.conv.weight.shape[0]
in_channels = parameters.conv.weight.shape[1]
conv_config_override = {}
if (out_channels, in_channels, input_height, input_width) in config_override:
conv_config_override = config_override[(out_channels, in_channels, input_height, input_width)]
conv = ttnn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(stride, stride),
padding=(padding, padding),
dtype=ttnn.bfloat8_b,
device=device,
use_1d_systolic_array=True if in_channels < 320 else False,
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
reader_patterns_cache=reader_patterns_cache,
weight=tt_weight_tensor,
bias=tt_bias_tensor,
math_fidelity=ttnn.MathFidelity.LoFi,
weights_dtype=ttnn.bfloat8_b,
conv_blocking_and_parallelization_config_override=conv_config_override,
use_shallow_conv_variant=False,
enable_auto_formatting=True,
)
else:
parameters.conv.weight = torch_to_tt_tensor_rm(parameters.conv.weight, device, put_on_device=False)
parameters.conv.bias = torch_to_tt_tensor_rm(parameters.conv.bias, device, put_on_device=False)
conv = fallback_ops.Conv2d(
parameters.conv.weight,
parameters.conv.bias,
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
)

else:
assert in_channels == out_channels
Expand All @@ -58,16 +116,26 @@ def downsample_2d(

assert hidden_states.shape[1] == in_channels

hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT)))

hidden_states = torch_to_tt_tensor_rm(hidden_states, device)
hidden_states = conv(hidden_states)
hidden_states = tt_to_torch_tensor(hidden_states)

hidden_states = ttnn.to_device(
ttnn.to_layout(ttnn.from_torch(hidden_states, ttnn.bfloat16), ttnn.TILE_LAYOUT),
device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
if conv_on_device:
hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting(
device,
conv,
hidden_states,
batch_size,
math.ceil(input_height / 2),
math.ceil(input_width / 2),
out_channels,
)
else:
hidden_states = ttnn.to_torch(ttnn.from_device(ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT)))
hidden_states = torch_to_tt_tensor_rm(hidden_states, device)
hidden_states = conv(hidden_states)
hidden_states = tt_to_torch_tensor(hidden_states)

hidden_states = ttnn.to_device(
ttnn.to_layout(ttnn.from_torch(hidden_states, ttnn.bfloat16), ttnn.TILE_LAYOUT),
device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from models.experimental.functional_stable_diffusion.tt.ttnn_functional_geglu import geglu


def feedforward(config, hidden_states, parameters):
act = geglu(config, hidden_states, parameters.net[0])
def feedforward(config, hidden_states, parameters, device):
act = geglu(config, hidden_states, parameters.net[0], device)
output = act @ parameters.net[2].weight
output = ttnn.add(output, parameters.net[2].bias, memory_config=ttnn.L1_MEMORY_CONFIG)
return output
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import ttnn


def geglu(config, hidden_states, parameters):
def geglu(config, hidden_states, parameters, device):
output = ttnn.matmul(hidden_states, parameters.proj.weight)

output = ttnn.add(output, parameters.proj.bias, memory_config=ttnn.L1_MEMORY_CONFIG)

hidden_states, gate = ttnn.split(output, split_size=output.shape[-1] // 2, dim=-1)
Expand Down
Loading

0 comments on commit a7c1831

Please sign in to comment.