Skip to content

Commit

Permalink
Merge branch 'main' into asaigal/async_falcon_rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-asaigal authored Apr 16, 2024
2 parents 0e98d33 + 8cf2c67 commit 432e93b
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def basic_transformer_block(
)
if use_ada_layer_norm_zero:
assert False, "AdaLayerNormZero not supported and not used in stable diffusion"
# ttnn.dump_device_memory_state(device)
ff_output = feedforward(config=config, hidden_states=norm_hidden_states, parameters=parameters.ff, device=device)

hidden_states = ttnn.add(ff_output, hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def time_sharded_attention(self, query, t_key, value, head_size):
ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)
print(slice.memory_config())
program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __call__(
upcast_attention=upcast_attention,
)

output_states += (hidden_states,)
output_states += (ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG),)

if add_downsample is not None:
hidden_states = self.downsample_2d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __call__(
index=-1,
):
for i, (resnet, attention) in enumerate(zip(self.resnets, self.attentions)):
ttnn.dump_device_memory_state(self.device, prefix="in_uplock_")
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels

Expand Down Expand Up @@ -134,11 +133,7 @@ def __call__(
ttnn.clone, hidden_states, memory_config=ttnn.get_memory_config(hidden_states), dtype=ttnn.bfloat8_b
)
hidden_states = dealloc_input(ttnn.concat, [hidden_states, on_dev_res_hidden_states], dim=3)
# breakpoint()
ttnn.deallocate(on_dev_res_hidden_states)
# breakpoint()
# hidden_states = ttnn.reallocate(hidden_states)
# ttnn.dump_device_memory_state(self.device, prefix="after_reallocate_before_resnet")
hidden_states = resnet(
hidden_states,
temb=temb,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
pre_process_input,
post_process_output,
permute_conv_parameters,
weight_to_bfp8,
)
import time

Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
self.groups = 32
if use_in_shortcut:
assert self.conv2.conv.output_sharded_memory_config == self.conv_shortcut.conv.output_sharded_memory_config

(
self.first_gn_expected_input_sharded_memory_config,
self.first_group_norm_core_grid,
Expand Down Expand Up @@ -328,6 +330,8 @@ def __init__(
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
self.parameters.time_emb_proj.weight = weight_to_bfp8(self.parameters.time_emb_proj.weight)
self.parameters.time_emb_proj.bias = weight_to_bfp8(self.parameters.time_emb_proj.bias)

def __call__(
self,
Expand Down Expand Up @@ -356,63 +360,58 @@ def __call__(
nonlinearity = ttnn.silu

out_channels = in_channels if out_channels is None else out_channels
hidden_states = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True)

hidden_states = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True)
if ttnn.get_memory_config(hidden_states) != self.first_gn_expected_input_sharded_memory_config:
hidden_states = ttnn.reshape(
hidden_states, (self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, in_channels)
)
hidden_states = ttnn.to_memory_config(hidden_states, self.first_gn_expected_input_sharded_memory_config)

if self.fallback_on_groupnorm:
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG)
hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True)
hidden_states = ttnn.reshape(
hidden_states, (self.conv2.batch_size, self.conv2.input_height, self.conv2.input_width, in_channels)
)
hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2))
hidden_states = ttnn.operations.normalization._fallback_group_norm(
hidden_states,
num_groups=groups,
weight=self.parameters.norm1.weight,
bias=self.parameters.norm1.bias,
epsilon=eps,
)
hidden_states = pre_process_input(self.device, hidden_states)
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG)
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT, use_multicore=True)
else:
hidden_states = ttnn.group_norm(
hidden_states,
num_groups=groups,
input_mask=self.norm1_input_mask,
weight=self.parameters.norm1.weight,
bias=self.parameters.norm1.bias,
epsilon=eps,
memory_config=ttnn.get_memory_config(hidden_states),
core_grid=self.first_group_norm_core_grid,
)
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.reshape(
hidden_states,
(1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, in_channels),
)
hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states))
hidden_states = ttnn.group_norm(
hidden_states,
num_groups=groups,
input_mask=self.norm1_input_mask,
weight=self.parameters.norm1.weight,
bias=self.parameters.norm1.bias,
epsilon=eps,
memory_config=ttnn.get_memory_config(hidden_states),
core_grid=self.first_group_norm_core_grid,
dtype=ttnn.bfloat8_b,
)
hidden_states = ttnn.reshape(
hidden_states,
(1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, in_channels),
)

if up:
assert False, "Up block within residual block is not implemented!"
elif down:
assert False, "Down block within residual block is not implemented"

conv1_split_chunks = len(self.conv1s)
if conv1_split_chunks > 1:
if conv1_split_chunks == 1:
hidden_states = ttnn.experimental.tensor.sharded_to_interleaved(
hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype
)
hidden_states = ttnn.experimental.tensor.interleaved_to_sharded(
hidden_states, self.conv1s[0].conv.input_sharded_memory_config, hidden_states.dtype
)
hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states))
hidden_states = self.conv1s[0](hidden_states)
else:
split_hidden_states = []
output_tensor_start_width_dim = 0
in_channels = self.parameters.conv1.weight.shape[1]
split_input_channels = in_channels // conv1_split_chunks

# unpad sharded causes output mismatch
hidden_states = ttnn.experimental.tensor.sharded_to_interleaved(
hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype
)
output_tensor_end_width_dim = split_input_channels
for i in range(conv1_split_chunks):
# TODO: Can we replace this with interleaved_to_sharded_partial
split_hidden_states.append(
ttnn.experimental.tensor.unpad(
hidden_states,
Expand All @@ -423,18 +422,19 @@ def __call__(
hidden_states.shape[2] - 1,
output_tensor_end_width_dim - 1,
],
# output_mem_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1)
)
)
output_tensor_start_width_dim += split_input_channels
output_tensor_end_width_dim += split_input_channels

if conv1_split_chunks == 1:
hidden_states = ttnn.to_memory_config(hidden_states, self.conv1s[0].conv.input_sharded_memory_config)
hidden_states = self.conv1s[0](hidden_states)
else:
for i in range(conv1_split_chunks):
split_hidden_states[i] = ttnn.to_memory_config(
split_hidden_states[i], self.conv1s[i].conv.input_sharded_memory_config
split_hidden_states[i] = ttnn.experimental.tensor.interleaved_to_sharded(
split_hidden_states[i],
self.conv1s[i].conv.input_sharded_memory_config,
split_hidden_states[i].dtype,
)
split_hidden_states[i] = nonlinearity(
split_hidden_states[i], memory_config=ttnn.get_memory_config(split_hidden_states[i])
)
split_hidden_states[i] = self.conv1s[i](split_hidden_states[i])
if i != 0:
Expand All @@ -461,82 +461,63 @@ def __call__(
self.parameters.time_emb_proj.weight,
bias=self.parameters.time_emb_proj.bias,
core_grid=temb.device().core_grid,
dtype=ttnn.bfloat8_b,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
# temb = ttnn.permute(temb, (2, 0, 1, 3))

if temb is not None and time_embedding_norm == "default":
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.clone(
hidden_states, memory_config=ttnn.get_memory_config(hidden_states), dtype=ttnn.bfloat16
)
hidden_states = ttnn.reshape(
hidden_states,
(self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, out_channels),
)
hidden_states = hidden_states + temb
hidden_states = ttnn.add(hidden_states, temb, memory_config=ttnn.L1_MEMORY_CONFIG)

# TODO: Reshape happening twice
hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, use_multicore=True)
hidden_states = ttnn.reshape(
hidden_states = ttnn.to_memory_config(hidden_states, self.second_gn_expected_input_sharded_memory_config)
hidden_states = ttnn.group_norm(
hidden_states,
(self.conv2.batch_size, 1, self.conv2.input_height * self.conv2.input_width, out_channels),
num_groups=groups,
input_mask=self.norm2_input_mask,
weight=self.parameters.norm2.weight,
bias=self.parameters.norm2.bias,
epsilon=eps,
memory_config=self.second_gn_expected_input_sharded_memory_config,
core_grid=self.second_group_norm_core_grid,
dtype=ttnn.bfloat8_b,
)
if self.fallback_on_groupnorm:
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.reshape(
hidden_states,
(self.conv1s[0].batch_size, self.conv1s[0].input_height, self.conv1s[0].input_width, out_channels),
)
hidden_states = ttnn.permute(hidden_states, (0, 3, 1, 2))
hidden_states = ttnn.operations.normalization._fallback_group_norm(
hidden_states,
num_groups=groups,
weight=self.parameters.norm2.weight,
bias=self.parameters.norm2.bias,
epsilon=eps,
)

hidden_states = pre_process_input(self.device, hidden_states)
else:
hidden_states = ttnn.to_memory_config(hidden_states, self.second_gn_expected_input_sharded_memory_config)
hidden_states = ttnn.group_norm(
hidden_states,
num_groups=groups,
input_mask=self.norm2_input_mask,
weight=self.parameters.norm2.weight,
bias=self.parameters.norm2.bias,
epsilon=eps,
memory_config=self.second_gn_expected_input_sharded_memory_config,
core_grid=self.second_group_norm_core_grid,
)
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.reshape(
hidden_states,
(1, 1, self.conv2.batch_size * self.conv2.input_height * self.conv2.input_width, out_channels),
)

hidden_states = ttnn.experimental.tensor.sharded_to_interleaved(
hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype
)
hidden_states = ttnn.experimental.tensor.interleaved_to_sharded(
hidden_states, self.conv2.conv.input_sharded_memory_config, hidden_states.dtype
)

hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states))

hidden_states = ttnn.to_memory_config(hidden_states, self.conv2.conv.input_sharded_memory_config)
hidden_states = self.conv2(hidden_states)
use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut

if use_in_shortcut:
if ttnn.get_memory_config(input_tensor) != self.conv_shortcut.conv.input_sharded_memory_config:
input_tensor = ttnn.to_memory_config(input_tensor, self.conv_shortcut.conv.input_sharded_memory_config)
# TODO: Once reshard fix is in, store input tensor in sharded
if input_tensor.memory_config().is_sharded():
input_tensor = ttnn.experimental.tensor.sharded_to_interleaved(
input_tensor, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype
)
input_tensor = ttnn.experimental.tensor.interleaved_to_sharded(
input_tensor, self.conv_shortcut.conv.input_sharded_memory_config, hidden_states.dtype
)
input_tensor = self.conv_shortcut(input_tensor)

if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states):
input_tensor = ttnn.to_memory_config(input_tensor, ttnn.get_memory_config(hidden_states))
output_tensor = ttnn.add(input_tensor, hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG)

if output_scale_factor != 1.0:
assert False # Do we need this?
output_sc_recip = 1 / output_scale_factor
output_sc_recip = ttnn.from_torch(
torch.full([1, 1, 1, 1], output_sc_recip), layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b
)
output_sc_recip = ttnn.to_device(output_sc_recip, self.device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.mul(output_tensor, output_sc_recip, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.add(input_tensor, hidden_states, memory_config=hidden_states.memory_config())

ttnn.deallocate(hidden_states)
output_tensor = ttnn.reallocate(output_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
pre_process_input,
pad_group_norm_weight,
permute_conv_parameters,
dealloc_input,
)


Expand Down Expand Up @@ -225,8 +226,6 @@ def __call__(
height = self.input_height
width = self.input_width

if ttnn.get_memory_config(hidden_states) != self.proj_in.conv.input_sharded_memory_config:
hidden_states = ttnn.to_memory_config(hidden_states, self.proj_in.conv.input_sharded_memory_config)
residual = hidden_states
spilled_residual = False
if spilled_residual:
Expand Down Expand Up @@ -317,18 +316,22 @@ def __call__(
assert False
# hidden_states = ttnn.to_memory_config(hidden_states, self.proj_out.conv.input_sharded_memory_config)
hidden_states = self.proj_out(hidden_states)
if spilled_residual:
if ttnn.get_memory_config(residual) != self.proj_out.conv.input_sharded_memory_config:
residual = ttnn.to_memory_config(residual, self.proj_out.conv.input_sharded_memory_config)
if output_bfloat16:
hidden_states = ttnn.add(
hidden_states = dealloc_input(
ttnn.add,
hidden_states,
residual,
dtype=ttnn.bfloat16,
memory_config=hidden_states.memory_config(),
)
else:
hidden_states = ttnn.add(
hidden_states = dealloc_input(
ttnn.add,
hidden_states,
residual,
memory_config=hidden_states.memory_config(),
)
else:
hidden_states = ttnn.to_device(hidden_states, self.device)
Expand All @@ -347,4 +350,5 @@ def __call__(

if not return_dict:
return (hidden_states,)
hidden_states = ttnn.reallocate(hidden_states)
return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ def upsample_nearest2d(input, scale_factor=2.0):
# input is in N, 1, HW, C, upsample expects, [N, H, W, C]
# set h_scale to 1, w_scale to scale_factor, c_scale to 1
# scale_factor = (1, scale_factor*2, 1)
if input.is_sharded():
input = ttnn.to_memory_config(input, ttnn.L1_MEMORY_CONFIG)
up_output = ttnn.upsample(input, scale_factor)
return up_output
Loading

0 comments on commit 432e93b

Please sign in to comment.