Skip to content

Commit

Permalink
#10860: Split Conv2dConfig into Compute config (#15164)
Browse files Browse the repository at this point in the history
### Ticket
#10860 

### Problem description
Conv2dConfig contains Compute kernel specific arguments which should be
passed using DeviceComputeKernelConfig

### What's changed
Removed compute kernel arguments from Conv2dConfig. `conv2d` takes an
additional argument called compute_config. Used pybind to expose helper
functions to create DeviceComputeKernelConfig from the model code.

### Checklist
- [x] Post commit CI
[passes](https://github.com/tenstorrent/tt-metal/actions/runs/11893281772)
- [x] Model regression CI testing
[passes](https://github.com/tenstorrent/tt-metal/actions/runs/11948195753)
- [x] Demo tests
[passes](https://github.com/tenstorrent/tt-metal/actions/runs/11935742809)
- [x] Device performance regression CI testing
[passes](https://github.com/tenstorrent/tt-metal/actions/runs/11936504399)
- [x] Nightly CI run. [No new regressions.
](https://github.com/tenstorrent/tt-metal/actions/runs/12005424271)
  • Loading branch information
sankarmanoj-tt authored Dec 9, 2024
1 parent 7768e89 commit 612744d
Show file tree
Hide file tree
Showing 42 changed files with 441 additions and 245 deletions.
13 changes: 8 additions & 5 deletions models/demos/convnet_mnist/tt/convnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ def convnet_mnist(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=False,
packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=True,
deallocate_activation=True,
reallocate_halo_output=True,
)

compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
[x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=x,
Expand All @@ -47,6 +49,7 @@ def convnet_mnist(
input_height=input_tensor.shape[1],
input_width=input_tensor.shape[2],
conv_config=conv_config,
compute_config=compute_config,
conv_op_cache={},
debug=True,
groups=1,
Expand Down
3 changes: 2 additions & 1 deletion models/demos/llama3/tt/multimodal/llama_conv2d_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)

self.compute_kernel_config = ttnn.WormholeComputeKernelConfig(
self.compute_kernel_config = ttnn.init_device_compute_kernel_config(
mesh_device.arch(),
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=True,
Expand Down
12 changes: 8 additions & 4 deletions models/demos/segformer/tt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ def __call__(self, device, input_tensor):
conv_config = ttnn.Conv2dConfig(
dtype=self.dtype,
weights_dtype=ttnn.bfloat16,
math_fidelity=ttnn.MathFidelity.LoFi,
activation=self.activation,
shard_layout=self.shard_layout,
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=False,
packer_l1_accum_enabled=False,
input_channels_alignment=16 if input_tensor.shape[3] < 16 else 32,
transpose_shards=False,
reshard_if_not_optimal=self.reshard,
Expand All @@ -54,6 +50,13 @@ def __call__(self, device, input_tensor):
enable_act_double_buffer=True,
enable_split_reader=False,
)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h

Expand All @@ -71,6 +74,7 @@ def __call__(self, device, input_tensor):
input_height=input_tensor.shape[1],
input_width=input_tensor.shape[2],
conv_config=conv_config,
compute_config=compute_config,
groups=self.groups,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=shard_layout,
deallocate_activation=True,
reallocate_halo_output=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)
ttnn.deallocate(x)
Expand Down Expand Up @@ -230,13 +232,15 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)

Expand Down Expand Up @@ -293,7 +297,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
Expand All @@ -303,6 +306,9 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)

Expand All @@ -324,12 +330,14 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)

Expand Down Expand Up @@ -562,12 +570,14 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)
# Relu is fused with conv1
Expand Down Expand Up @@ -873,12 +883,14 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
),
conv_op_cache=conv_op_cache,
)
# Relu is fused with conv1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,13 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
deallocate_activation=True,
reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16),
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_accum_enabled,
enable_act_double_buffer=enable_act_double_buffer
if height_sharding
else True
Expand All @@ -194,6 +192,11 @@ def run_downsample_if_req(
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
packer_l1_acc=packer_l1_accum_enabled,
),
conv_op_cache=conv_op_cache,
)
ttnn.deallocate(x)
Expand Down Expand Up @@ -242,14 +245,17 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
packer_l1_acc=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
)
Expand Down Expand Up @@ -323,7 +329,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
Expand All @@ -333,12 +338,16 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
enable_act_double_buffer=enable_act_double_buffer,
enable_weights_double_buffer=True,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
packer_l1_acc=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
)

Expand Down Expand Up @@ -374,13 +383,16 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
),
compute_config=ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
packer_l1_acc=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
)
Expand Down Expand Up @@ -569,19 +581,22 @@ def __init__(
self.conv1_config = ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=dealloc_input,
input_channels_alignment=input_channels_alignment,
act_block_h_override=act_block_h_override,
transpose_shards=self.transpose_shards,
packer_l1_accum_enabled=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_split_reader=True if whb0_and_b16 or not is_wormhole_b0() else False,
enable_subblock_padding=False,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
reshard_if_not_optimal=False,
)
self.conv1_compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=self.model_config["MATH_FIDELITY"],
packer_l1_acc=True if whb0_and_b16 else False,
)
if whb0_and_b16:
# Issue #13145: Temp workaround for Galaxy to avoid hangs
if type(device) == ttnn.MeshDevice and device.get_num_devices() > 8:
Expand Down Expand Up @@ -733,6 +748,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
input_height=self.conv1_input_height,
input_width=self.conv1_input_width,
conv_config=self.conv1_config,
compute_config=self.conv1_compute_config,
conv_op_cache=conv_op_cache,
)
# Relu is fused with conv1
Expand Down
Loading

0 comments on commit 612744d

Please sign in to comment.