-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#7051: resblock ttnn submodule for yolov4
- Loading branch information
1 parent
d91aa25
commit 54abc54
Showing
3 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
models/experimental/functional_yolov4/reference/resblock.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class ResBlock(nn.Module): | ||
def __init__(self, ch, nblocks=1, shortcut=True): | ||
super().__init__() | ||
self.shortcut = shortcut | ||
self.module_list = nn.ModuleList() | ||
for i in range(nblocks): | ||
conv1 = nn.Conv2d(ch, ch, 1, 1, 0, bias=False) | ||
bn1 = nn.BatchNorm2d(ch) | ||
relu = nn.ReLU(inplace=True) | ||
conv2 = nn.Conv2d(ch, ch, 3, 1, 1, bias=False) | ||
bn2 = nn.BatchNorm2d(ch) | ||
resblock_one = nn.ModuleList([conv1, bn1, relu, conv2, bn2, relu]) | ||
self.module_list.append(resblock_one) | ||
|
||
def forward(self, x): | ||
for module in self.module_list: | ||
h = x | ||
for res in module: | ||
h = res(h) | ||
x = x + h if self.shortcut else h | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import ttnn | ||
|
||
|
||
class TtResBlock: | ||
def __init__(self, parameters, nblocks, shortcut) -> None: | ||
self.shortcut = shortcut | ||
self.nblocks = nblocks | ||
self.module_list = [] | ||
for i in range(nblocks): | ||
conv1 = parameters[f"resblock_{i}_conv1"] | ||
conv2 = parameters[f"resblock_{i}_conv2"] | ||
resblock_one = [conv1, conv2] | ||
self.module_list.append(resblock_one) | ||
|
||
def __call__(self, device, input_tensor): | ||
input_tensor = input_tensor.to(device, self.module_list[0][0].conv.input_sharded_memory_config) | ||
for i in range(self.nblocks): | ||
output_tensor_h = input_tensor | ||
output_tensor_1 = self.module_list[i][0](output_tensor_h) | ||
output_tensor_h = self.module_list[i][1](output_tensor_1) | ||
input_tensor = ttnn.add(input_tensor, output_tensor_h) if self.shortcut else output_tensor_h | ||
return ttnn.from_device(input_tensor) |
124 changes: 124 additions & 0 deletions
124
tests/ttnn/integration_tests/yolov4/test_ttnn_resblock.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
|
||
from ttnn.model_preprocessing import preprocess_model, preprocess_conv2d, fold_batch_norm2d_into_conv2d | ||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
from models.utility_functions import skip_for_wormhole_b0 | ||
from models.experimental.functional_yolov4.reference.resblock import ResBlock | ||
from models.experimental.functional_yolov4.tt.ttnn_resblock import TtResBlock | ||
import ttnn | ||
|
||
|
||
def update_ttnn_module_args(ttnn_module_args): | ||
ttnn_module_args["use_1d_systolic_array"] = ttnn_module_args.in_channels < 256 | ||
|
||
|
||
def create_custom_preprocessor(device): | ||
def custom_preprocessor(model, name, ttnn_module_args): | ||
parameters = {} | ||
if isinstance(model, ResBlock): | ||
for i, block in enumerate(model.module_list): | ||
conv1 = block[0] | ||
bn1 = block[1] | ||
conv2 = block[3] | ||
bn2 = block[4] | ||
ttnn_module_args[f"resblock_{i}_conv1"] = ttnn_module_args["0"] | ||
ttnn_module_args[f"resblock_{i}_conv1"]["math_fidelity"] = ttnn.MathFidelity.LoFi | ||
ttnn_module_args[f"resblock_{i}_conv1"]["dtype"] = ttnn.bfloat8_b | ||
ttnn_module_args[f"resblock_{i}_conv1"]["weights_dtype"] = ttnn.bfloat8_b | ||
ttnn_module_args[f"resblock_{i}_conv1"]["activation"] = "relu" # Fuse relu with conv1 | ||
ttnn_module_args[f"resblock_{i}_conv1"]["deallocate_activation"] = True | ||
ttnn_module_args[f"resblock_{i}_conv1"]["conv_blocking_and_parallelization_config_override"] = None | ||
|
||
weight1, bias1 = fold_batch_norm2d_into_conv2d(conv1, bn1) | ||
update_ttnn_module_args(ttnn_module_args[f"resblock_{i}_conv1"]) | ||
parameters[f"resblock_{i}_conv1"], _ = preprocess_conv2d( | ||
weight1, bias1, ttnn_module_args[f"resblock_{i}_conv1"], return_parallel_config=True | ||
) | ||
|
||
ttnn_module_args[f"resblock_{i}_conv2"] = ttnn_module_args["3"] | ||
ttnn_module_args[f"resblock_{i}_conv2"]["math_fidelity"] = ttnn.MathFidelity.LoFi | ||
ttnn_module_args[f"resblock_{i}_conv2"]["dtype"] = ttnn.bfloat8_b | ||
ttnn_module_args[f"resblock_{i}_conv2"]["weights_dtype"] = ttnn.bfloat8_b | ||
ttnn_module_args[f"resblock_{i}_conv2"]["activation"] = "relu" # Fuse relu with conv1 | ||
ttnn_module_args[f"resblock_{i}_conv2"]["deallocate_activation"] = True | ||
ttnn_module_args[f"resblock_{i}_conv2"]["conv_blocking_and_parallelization_config_override"] = None | ||
# Preprocess the convolutional layer | ||
weight2, bias2 = fold_batch_norm2d_into_conv2d(conv2, bn2) | ||
update_ttnn_module_args(ttnn_module_args[f"resblock_{i}_conv2"]) | ||
parameters[f"resblock_{i}_conv2"], _ = preprocess_conv2d( | ||
weight2, bias2, ttnn_module_args[f"resblock_{i}_conv2"], return_parallel_config=True | ||
) | ||
|
||
return parameters | ||
|
||
return custom_preprocessor | ||
|
||
|
||
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True) | ||
@skip_for_wormhole_b0() | ||
def test_resblock(device, reset_seeds, model_location_generator): | ||
model_path = model_location_generator("models", model_subdir="Yolo") | ||
weights_pth = str(model_path / "yolov4.pth") | ||
state_dict = torch.load(weights_pth) | ||
resblock_state_dict = {k: v for k, v in state_dict.items() if (k.startswith("down2.resblock."))} | ||
|
||
if not resblock_state_dict: | ||
raise ValueError("No parameters found in resblock_state_dict") | ||
|
||
torch_model = ResBlock(ch=64, nblocks=2, shortcut=True) | ||
for layer in torch_model.children(): | ||
print(layer) | ||
|
||
new_state_dict = {} | ||
keys = [name for name, parameter in torch_model.state_dict().items()] | ||
values = [parameter for name, parameter in resblock_state_dict.items()] | ||
|
||
for i in range(len(keys)): | ||
new_state_dict[keys[i]] = values[i] | ||
|
||
torch_model.load_state_dict(new_state_dict) | ||
torch_model.eval() | ||
|
||
# Input tensor for testing | ||
torch_input_tensor = torch.randn(1, 64, 32, 32) # Sample input tensor | ||
torch_output_tensor = torch_model(torch_input_tensor) | ||
|
||
# Preprocess the model for TTNN | ||
reader_patterns_cache = {} | ||
parameters = preprocess_model( | ||
initialize_model=lambda: torch_model, | ||
run_model=lambda model: model(torch_input_tensor), | ||
custom_preprocessor=create_custom_preprocessor(device), | ||
reader_patterns_cache=reader_patterns_cache, | ||
device=device, | ||
) | ||
|
||
# Convert the model to TTNN | ||
ttnn_model = TtResBlock(parameters, nblocks=2, shortcut=True) | ||
|
||
# Convert input tensor to TTNN format | ||
input_shape = torch_input_tensor.shape | ||
input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) | ||
|
||
input_tensor = input_tensor.reshape( | ||
input_tensor.shape[0], 1, input_tensor.shape[1] * input_tensor.shape[2], input_tensor.shape[3] | ||
) | ||
|
||
input_tensor = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) | ||
|
||
# Apply TTNN model | ||
output_tensor = ttnn_model(device, input_tensor) | ||
|
||
# Convert output tensor back to Torch format | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
output_tensor = output_tensor.reshape(1, 32, 32, 64) | ||
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2)) | ||
output_tensor = output_tensor.to(torch_input_tensor.dtype) | ||
|
||
# Assertion | ||
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99) |