Skip to content

Commit

Permalink
#14179: unet and sd
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Dec 23, 2024
1 parent de63abd commit c36cd1b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,41 @@ def __call__(
if self.conv_config_override and "act_block_h" in self.conv_config_override:
conv_config.act_block_h_override = self.conv_config_override["act_block_h"]

conv_kwargs = {
"input_layout": hidden_states.get_layout(),
"in_channels": in_channels,
"out_channels": self.out_channels,
"batch_size": hidden_states.shape[0],
"input_height": hidden_states.shape[1],
"input_width": hidden_states.shape[2],
"kernel_size": (3, 3),
"stride": (self.stride, self.stride),
"padding": (1, 1),
"dilation": (1, 1),
"groups": 1,
"device": self.device,
"conv_config": conv_config,
}

if not ttnn.is_tensor_storage_on_device(self.conv_weights):
self.conv_weights = ttnn.prepare_conv_weights(
weight_tensor=self.conv_weights,
weights_format="OIHW",
input_memory_config=hidden_states.memory_config(),
**conv_kwargs,
)
self.conv_bias = (
ttnn.prepare_conv_bias(
bias_tensor=self.conv_bias,
input_memory_config=hidden_states.memory_config(),
**conv_kwargs,
)
if self.conv_bias is not None
else None
)
self.conv_weights = ttnn.to_device(self.conv_weights, self.device)
self.conv_bias = ttnn.to_device(self.conv_bias, self.device)

[hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d(
input_tensor=hidden_states,
in_channels=self.in_channels,
Expand Down
33 changes: 33 additions & 0 deletions models/experimental/functional_unet/tt/unet_shallow_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,39 @@ def __init__(
self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper)

def __call__(self, x):
print(ttnn.get_memory_config(x))
conv_kwargs = {
"input_layout": x.get_layout(),
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"batch_size": self.batch_size,
"input_height": self.input_height,
"input_width": self.input_width,
"kernel_size": self.kernel_size,
"stride": self.stride,
"padding": self.padding,
"dilation": [1, 1],
"groups": self.groups,
"device": self.device,
"conv_config": self.conv_config,
}

if not ttnn.is_tensor_storage_on_device(self.weight):
breakpoint
self.weight = ttnn.prepare_conv_weights(
weight_tensor=self.weight,
weights_format="OIHW",
input_memory_config=x.memory_config(),
**conv_kwargs,
)
self.bias = ttnn.prepare_conv_bias(
bias_tensor=self.bias,
input_memory_config=x.memory_config(),
**conv_kwargs,
)
self.weight = ttnn.to_device(self.weight, self.device)
self.bias = ttnn.to_device(self.bias, self.device) if self.bias else None

x, [self.weight, self.bias] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.weight,
Expand Down

0 comments on commit c36cd1b

Please sign in to comment.