Skip to content

Commit

Permalink
#14179: resnet fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Dec 23, 2024
1 parent ee1e6cb commit de63abd
Showing 1 changed file with 167 additions and 36 deletions.
203 changes: 167 additions & 36 deletions models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,48 @@ def __call__(
# conv1 is 1x1 conv
logger.debug(f"Running conv1")
module_input_height = input_height
conv_config = ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
)
conv_kwargs = {
"input_layout": x.get_layout(),
"in_channels": self.conv1_input_channels,
"out_channels": self.conv1_output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (1, 1),
"stride": (1, 1),
"padding": (0, 0),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor):
print("preparing conv1 weights")
self.conv1_weight_tensor = ttnn.prepare_conv_weights(
weight_tensor=self.conv1_weight_tensor,
weights_format="OIHW",
input_memory_config=x.memory_config(),
**conv_kwargs,
)
self.conv1_bias_tensor = ttnn.prepare_conv_bias(
bias_tensor=self.conv1_bias_tensor,
input_memory_config=x.memory_config(),
**conv_kwargs if self.conv1_bias_tensor is not None else None,
)
self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device)
self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None

out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
Expand All @@ -244,16 +286,7 @@ def __call__(
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
),
conv_config=conv_config,
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
Expand Down Expand Up @@ -317,6 +350,55 @@ def __call__(

reallocate_halo_output = batch_size == 20
logger.debug(f"Running conv2")
conv_config = ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h_override=act_block_h_override,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
enable_act_double_buffer=enable_act_double_buffer,
enable_weights_double_buffer=True,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
)
conv_kwargs = {
"input_layout": x.get_layout(),
"in_channels": self.conv2_input_channels,
"out_channels": self.conv2_output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (3, 3),
"stride": (self.stride, self.stride),
"padding": (1, 1),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor):
print("Preparing conv2 weights")
self.conv2_weight_tensor = ttnn.prepare_conv_weights(
weight_tensor=self.conv2_weight_tensor,
weights_format="OIHW",
input_memory_config=x.memory_config(),
**conv_kwargs,
)
self.conv2_bias_tensor = ttnn.prepare_conv_bias(
bias_tensor=self.conv2_bias_tensor,
input_memory_config=x.memory_config(),
**conv_kwargs if self.conv2_bias_tensor is not None else None,
)
self.conv2_weight_tensor = ttnn.to_device(self.conv2_weight_tensor, device)
self.conv2_bias_tensor = ttnn.to_device(self.conv2_bias_tensor, device) if self.conv2_bias_tensor else None

out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
Expand All @@ -330,23 +412,7 @@ def __call__(
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h_override=act_block_h_override,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
enable_act_double_buffer=enable_act_double_buffer,
enable_weights_double_buffer=True,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
conv_config=conv_config,
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
Expand All @@ -373,6 +439,47 @@ def __call__(

# conv3 is 1x1 conv
logger.debug(f"Running conv3")
conv_config = ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
)
conv_kwargs = {
"input_layout": x.get_layout(),
"in_channels": self.conv3_input_channels,
"out_channels": self.conv3_output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (1, 1),
"stride": (1, 1),
"padding": (0, 0),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor):
print("Preparing conv3 weights")
self.conv3_weight_tensor = ttnn.prepare_conv_weights(
weight_tensor=self.conv3_weight_tensor,
weights_format="OIHW",
input_memory_config=x.memory_config(),
**conv_kwargs,
)
self.conv3_bias_tensor = ttnn.prepare_conv_bias(
bias_tensor=self.conv3_bias_tensor,
input_memory_config=x.memory_config(),
**conv_kwargs if self.conv2_bias_tensor is not None else None,
)
self.conv3_weight_tensor = ttnn.to_device(self.conv3_weight_tensor, device)
self.conv3_bias_tensor = ttnn.to_device(self.conv3_bias_tensor, device) if self.conv3_bias_tensor else None

out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
Expand All @@ -386,15 +493,7 @@ def __call__(
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
),
conv_config=conv_config,
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
Expand Down Expand Up @@ -742,6 +841,38 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
logger.debug(f"==== first conv")

# first conv
conv_kwargs = {
"input_layout": fold_output_tensor.get_layout(),
"in_channels": self.conv1_input_channels,
"out_channels": self.conv1_output_channels,
"batch_size": self.batch_size,
"input_height": self.conv1_input_height,
"input_width": self.conv1_input_width,
"kernel_size": self.conv1_kernel_size,
"stride": self.conv1_stride,
"padding": self.conv1_padding,
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": self.conv1_config,
}

if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor):
self.conv1_weight_tensor = ttnn.prepare_conv_weights(
weight_tensor=self.conv1_weight_tensor,
weights_format="OIHW",
input_memory_config=fold_output_tensor.memory_config(),
**conv_kwargs,
)

self.conv1_bias_tensor = ttnn.prepare_conv_bias(
bias_tensor=self.conv1_bias_tensor,
input_memory_config=fold_output_tensor.memory_config(),
**conv_kwargs if self.conv1_bias_tensor is not None else None,
)
self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device)
self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) if self.conv1_bias_tensor else None

x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=fold_output_tensor,
weight_tensor=self.conv1_weight_tensor,
Expand Down

0 comments on commit de63abd

Please sign in to comment.