Skip to content

Commit

Permalink
Update ttcnn.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mywoodstock committed Dec 13, 2024
1 parent 6063ec9 commit 49adf12
Showing 1 changed file with 165 additions and 0 deletions.
165 changes: 165 additions & 0 deletions tech_reports/CNNs/ttcnn.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,171 @@ real-world applications including video analysis, medical imaging
analysis, and natural language processing, where they can analyze data
to extract meaningful features and patterns.

## Convolution Operations in TTNN

### `conv2d`

Applies a 2D convolution over `input_tensor`, a 4D tensor with dimensions ordered as `(batch_size, input_height, input_width, in_channels)` using provided weights, with dimensions `(out_channels, in_channels, kernel_height, kernel_width)`, and optional `bias`, with dimensions `(1, 1, 1, out_channels)`, and generates `output_tensor` with dimensions ordered as `(batch_size, output_height, output_width, out_channels)`.

#### Python API

```
output_tensor = ttnn.conv2d(
input_tensor,
weight_tensor,
in_channels,
out_channels,
device,
bias_tensor,
kernel_size,
stride,
padding,
dilation,
batch_size,
input_height,
input_width,
conv_config,
compute_config,
groups,
memory_config,
return_weights_and_bias=False,
return_output_dim=False,
)
```

Arguments:

* `input_tensor`
* `weight_tensor`
* `bias_tensor`
* `in_channels` the number of input channels as an `int`.
* `out_channels` the number of output channels as an `int`.
* `device` device pointer as `ttnn.Device`.
* `kernel_size` tuple of two ints: `(kernel_height, kernel_width)`.
* `stride` tuple of two ints: `(stride_height, stride_width)`.
* `padding` tuple of two ints: `(padding_height, padding_width)`.
* `dilation` tuple of two ints: `(dilation_height, dilation_width)`.
* `batch_size` an `int`.
* `input_height` an `int`.
* `input_width` an `int`.
* `conv_config` _optional_ structure of configuration parameters of type `Conv2DConfig`. This is described in detail below.
* `compute_config` _optional_ structure of compute configuration parameters of type `DeviceConfiguration`. This is described in detail below.
* `groups` _optional_ `int` to control the connections between inputs and outputs. Both `in_channels` and `out_channels` should be divisible by `groups`.
* `memory_config` _optional_ output tensor memory configuration. This is described below.
* `return_weights_and_bias` _optional_ `bool` indicating whether to return pre-processed weights and bias tensors on device.
* `return_output_dim` _optional_ `bool` indicating whether to return the outout tensor height and width.

#### `Conv2dConfig`

Following are the conv2d operation configuration parameters:

* `dtype = ttnn.bfloat16` input activations data type.
* `weights_dtype = ttnn.bfloat16` weights and bias data type.
* `activation = ""` _optional_ `string`. Any activation function to apply. Options are `"relu"`.
* `input_channels_alignment = 32` _optional_ `uint32_t`. Alignment value for channels dimension in the input tensor. This is applicable when `in_channels < 32` and should take a value of `16` in those cases, `32` otherwise.
* `deallocate_activation = false` _optional_ bool indicating whether the input activation tensor memory should be deallocated.
* `reallocate_halo_output = false` _optional_ bool indicating if the intermediate tensor generated within the op should be reallocated to reduce memory fragmentation.
* `act_block_h_override = 0` _optional_ `uint32_t` to override the `act_block_h` parameter, which determines the size of blocks used in computations -- smaller values require less memory, larger values require more memory but are more performant. This argument is ignored when `shard_layout = WIDTH_SHARDED`.
* `act_block_w_div = 1` _optional_ `uint32_t`, value by which the maximum possible `act_block_w` parameter is divided. This arguments is ignored when `shard_layout = HEIGHT_SHARDED` or `BLOCK_SHARDED`.
* `reshard_if_not_optimal = false` _optional_ bool indicating whether the operation can re-shard the input tensor to make it more optimal for performance. If true, override_sharding_config should not be set to true.
* `override_sharding_config = false` _optional_ bool indicating if input sharding config should be overridden with provided shard_layout. If true, reshard_if_not_optimal should not be set to true.
* `shard_layout = None` _optional_ `ttnn.TensorMemoryLayout` to specify type of sharding to use.
* `core_grid = None` _optional_ `ttnn.CoreRangeSet` specifies the core grid to use. Applicable only when `override_sharding_config = True`,
* `transpose_shards = true` _optional_ `bool` whether the shards be distributed in `ROW_MAJOR` order. This is applicable only when not using height sharding.
* `output_layout = ttnn.TILE_LAYOUT` _optional_ `ttnn.Layout` to specify whether the output tensor be in `TILE` or `ROW_MAJOR` layout.
* `enable_act_double_buffer = false` _optional_ bool to enable activation double buffering.
* `enable_weights_double_buffer = false` _optional_ bool to enable weights double buffering when using block sharding.
* `enable_split_reader = false` _optional_ bool to two concurrent reader kernels instead of one.

#### Compute Config

Architecture specific device compute kernel configuration, `DeviceComputeKernelConfig` with the following parameters:

* `math_fidelity = MathFidelity.HiFi4`
* `math_approx_mode = True`
* `dst_full_sync_en = False`

Wormhole and Blackhole specific parameters:

* `fp32_dest_acc_en = false` enable accumulations in fp32.
* `packer_l1_acc = false` enable packer accumulation directly in L1.

#### Preparing input tensors

`conv2d` takes 4D `input_tensor` with dimensions ordered as `(N, H, W, C)` (channels last), and `weight_tensor` as `(C_in, C_out // groups, kernel_h, kernel_w)` 4D tensor. The input activation, weight and bias tensors can be on host or on device. If weight and bias are on device, they need to be already pre-processed by the conv2d op.

Example to prepare the input tensors:

```python
import ttnn
import torch

## activation tensor

input_shape_nchw = [batch_size, in_channels, input_height, input_width]
torch_input_tensor_nchw = torch.randn(input_shape_nchw, dtype=torch.bfloat16)
torch_input_tensor_nhwc = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))

ttnn_input_tensor = ttnn.from_torch(torch_input_tensor_nhwc, ttnn.bfloat16)

## weight tensor

weight_shape = [out_channels, in_channels // groups, kernel_height, kernel_width]
torch_weight_tensor = torch.randn(weight_shape, dtype=torch.bfloat16)

ttnn_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)

## bias tensor

bias_shape = [1, 1, 1, out_channels]
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16)

ttnn_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)
```

#### Calling the operation

```python

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16
)

ttnn_output_tensor_on_device = ttnn.conv2d(
input_tensor=ttnn_input_tensor,
weight_tensor=ttnn_weight_tensor,
in_channels=in_channels,
out_channels=out_channels,
device=device,
bias_tensor=ttnn_bias_tensor,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation_h, dilation_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=conv_config,
)
```

#### Output post-processing

```python
ttnn_output_tensor = ttnn.from_device(ttnn_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(ttnn_output_tensor)
torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))
```


### `maxpool2d`
_Coming soon._



Convolution as Matrix Multiplication
------------------------------------

Expand Down

0 comments on commit 49adf12

Please sign in to comment.