Skip to content

Commit

Permalink
#5233: added ability to fold batch_norm2d into conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Feb 12, 2024
1 parent c7ca8c0 commit de3de8c
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(
def __call__(self, x):
identity = x

# Relu and bn1 are fused with conv1
out = self.conv1(x)
# out = self.bn1(out)

# Relu and bn2 are fused with conv1
out = self.conv2(out)
# out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)
Expand Down
35 changes: 28 additions & 7 deletions tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,27 @@

import torch

from ttnn.model_preprocessing import preprocess_model
from ttnn.model_preprocessing import preprocess_model, preprocess_conv2d, fold_batch_norm2d_into_conv2d

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0

from models.experimental.functional_resnet.tt import ttnn_functional_resnet


def custom_preprocessor(model, name, ttnn_module_args):
parameters = {}
if isinstance(model, BasicBlock):
ttnn_module_args.conv1["activation"] = "relu" # Fuse relu with conv1

conv1_weight, conv1_bias = fold_batch_norm2d_into_conv2d(model.conv1, model.bn1)
conv2_weight, conv2_bias = fold_batch_norm2d_into_conv2d(model.conv2, model.bn2)

parameters["conv1"] = preprocess_conv2d(conv1_weight, conv1_bias, ttnn_module_args.conv1)
parameters["conv2"] = preprocess_conv2d(conv2_weight, conv2_bias, ttnn_module_args.conv2)
return parameters


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> torch.nn.Conv2d:
"""3x3 convolution with padding"""
return torch.nn.Conv2d(
Expand Down Expand Up @@ -56,22 +69,22 @@ def __init__(
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
# self.bn1 = norm_layer(planes)
self.bn1 = norm_layer(planes)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
# self.bn2 = norm_layer(planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x

out = self.conv1(x)
# out = self.bn1(out)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
# out = self.bn2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)
Expand All @@ -86,7 +99,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def test_basic_block(device):
torch.manual_seed(0)

torch_model = BasicBlock(inplanes=64, planes=64, stride=1)
torch_model = BasicBlock(inplanes=64, planes=64, stride=1).eval()

new_state_dict = {}
for name, parameter in torch_model.state_dict().items():
if isinstance(parameter, torch.FloatTensor):
new_state_dict[name] = torch.rand_like(parameter)
torch_model.load_state_dict(new_state_dict)

torch_input_tensor = torch.rand((8, 64, 56, 56), dtype=torch.float32)
torch_output_tensor = torch_model(torch_input_tensor)
Expand All @@ -95,6 +114,7 @@ def test_basic_block(device):
parameters = preprocess_model(
initialize_model=lambda: torch_model,
run_model=lambda model: model(torch_input_tensor),
custom_preprocessor=custom_preprocessor,
reader_patterns_cache=reader_patterns_cache,
device=device,
)
Expand All @@ -111,7 +131,7 @@ def test_basic_block(device):
def test_basic_block_with_downsample(device):
torch.manual_seed(0)

torch_model = BasicBlock(inplanes=64, planes=64, stride=1, downsample=conv1x1(64, 64, 1))
torch_model = BasicBlock(inplanes=64, planes=64, stride=1, downsample=conv1x1(64, 64, 1)).eval()

torch_input_tensor = torch.rand((8, 64, 56, 56), dtype=torch.float32)
torch_output_tensor = torch_model(torch_input_tensor)
Expand All @@ -120,6 +140,7 @@ def test_basic_block_with_downsample(device):
parameters = preprocess_model(
initialize_model=lambda: torch_model,
run_model=lambda model: model(torch_input_tensor),
custom_preprocessor=custom_preprocessor,
reader_patterns_cache=reader_patterns_cache,
device=device,
)
Expand Down
99 changes: 98 additions & 1 deletion tests/ttnn/unit_tests/test_model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import torchvision

import ttnn
from ttnn.model_preprocessing import preprocess_model, preprocess_model_parameters
from ttnn.model_preprocessing import (
preprocess_model,
preprocess_model_parameters,
preprocess_conv2d,
fold_batch_norm2d_into_conv2d,
)

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0
Expand Down Expand Up @@ -360,6 +365,98 @@ def functional_ttnn(input_tensor, parameters):
assert_with_pcc(torch_output_tensor, output_tensor)


@skip_for_wormhole_b0()
@pytest.mark.parametrize("use_conv_bias", [True, False])
def test_conv2d_with_batch_norm2d(device, use_conv_bias):
torch.manual_seed(0)

class TorchModule(torch.nn.Module):
def __init__(
self,
in_planes: int,
out_planes: int,
use_conv_bias: bool,
stride: int = 1,
groups: int = 1,
dilation: int = 1,
) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=use_conv_bias,
dilation=dilation,
)
self.bn1 = torch.nn.BatchNorm2d(out_planes)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
output_tensor = self.conv1(input_tensor)
output_tensor = self.bn1(output_tensor)
return output_tensor

def custom_preprocessor(model, name, ttnn_module_args):
parameters = {}
if isinstance(model, TorchModule):
conv1_weight, conv1_bias = fold_batch_norm2d_into_conv2d(model.conv1, model.bn1)
parameters["conv1"] = preprocess_conv2d(conv1_weight, conv1_bias, ttnn_module_args.conv1)
return parameters

torch_model = TorchModule(in_planes=64, out_planes=64, use_conv_bias=use_conv_bias).eval()

new_state_dict = {}
for name, parameter in torch_model.state_dict().items():
if isinstance(parameter, torch.FloatTensor):
new_state_dict[name] = torch.rand_like(parameter)
torch_model.load_state_dict(new_state_dict)

torch_input_tensor = torch.rand((8, 64, 56, 56), dtype=torch.float32)
torch_output_tensor = torch_model(torch_input_tensor)

reader_patterns_cache = {}
parameters = preprocess_model(
initialize_model=lambda: torch_model,
run_model=lambda model: model(torch_input_tensor),
custom_preprocessor=custom_preprocessor,
reader_patterns_cache=reader_patterns_cache,
device=device,
)

class TTNNBasicBlock:
def __init__(
self,
parameters,
) -> None:
self.conv1 = parameters.conv1

def __call__(self, input_tensor):
output_tensor = self.conv1(input_tensor)
return output_tensor

def torch_call(self, torch_input_tensor):
input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1))
input_tensor = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16)

input_tensor = self.conv1.copy_input_to_device(input_tensor)
output_tensor = self(input_tensor)
output_tensor = self.conv1.copy_output_from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
output_tensor = torch.reshape(output_tensor, torch_input_tensor.shape)
output_tensor = output_tensor.to(torch_input_tensor.dtype)
return output_tensor

ttnn_model = TTNNBasicBlock(parameters)

output_tensor = ttnn_model.torch_call(torch_input_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@skip_for_wormhole_b0()
def test_resnet():
torch.manual_seed(0)
Expand Down
6 changes: 4 additions & 2 deletions ttnn/ttnn/dot_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ class DotAccessDict(dict):
__delattr__ = dict.__delitem__


def make_dot_access_dict(dictionary: Union[dict, DotAccessDict]) -> DotAccessDict:
def make_dot_access_dict(dictionary: Union[dict, DotAccessDict], *, ignore_types=None) -> DotAccessDict:
if isinstance(dictionary, DotAccessDict):
return dictionary
elif ignore_types is not None and isinstance(dictionary, ignore_types):
return dictionary
preprocessed_dictionary = {}
for key, value in dictionary.items():
if isinstance(value, dict):
value = make_dot_access_dict(value)
value = make_dot_access_dict(value, ignore_types=ignore_types)
preprocessed_dictionary[key] = value
return DotAccessDict(preprocessed_dictionary)
Loading

0 comments on commit de3de8c

Please sign in to comment.