Skip to content

Commit

Permalink
#12253: Implement Batch norm operation for inference mode (#16432)
Browse files Browse the repository at this point in the history
### List of Github issues

- #12253 
- #15449

### Problem description
To implement batch normalization as device op

### What's changed
Work done : 

- Provided support for Inference mode

<img width="1254" alt="Screenshot 2025-01-10 at 11 59 49 AM"
src="https://github.com/user-attachments/assets/061ccaca-8b7e-44c9-a0f3-e1c87c908403"
/>

### Checklist
- [x] [All post-commit
testsI](https://github.com/tenstorrent/tt-metal/actions/runs/12729863177)
- [x] [Blackhole Post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/12729863432)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
- [Link to
test](https://github.com/tenstorrent/tt-metal/actions/runs/12732676848)
- [x] [(Single-card) Demo
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12729864298/job/35486219235)
- same as main
- [x] [(Single-card) Device perf
regressions](https://github.com/tenstorrent/tt-metal/actions/runs/12729865036)
- same as main
- [x] [Single-card Model perf
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12729865358)
- same as main
  • Loading branch information
VirdhatchaniKN authored Jan 12, 2025
1 parent 880f700 commit 35c7145
Show file tree
Hide file tree
Showing 15 changed files with 1,128 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ Normalization
ttnn.group_norm
ttnn.layer_norm
ttnn.rms_norm
ttnn.batch_norm


Moreh Operations
Expand Down
41 changes: 41 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
)


def data_gen_with_range_batch_norm(
input_shapes,
low,
high,
device,
is_input=False,
required_grad=False,
):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
channels = input_shapes[1]
size = input_shapes if is_input else channels
pt_tensor = torch.rand(size, requires_grad=required_grad).bfloat16() * (high - low) + low
reshaped_tensor = pt_tensor
if not is_input:
reshaped_tensor = pt_tensor.view(1, channels, 1, 1)
tt_tensor = ttnn.from_torch(
reshaped_tensor,
device=device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
return pt_tensor, tt_tensor


def data_gen_pt_tt(input_shapes, device, required_grad=False):
torch.manual_seed(213919)
pt_tensor = torch.randn(input_shapes, requires_grad=required_grad).bfloat16()
Expand Down Expand Up @@ -107,6 +133,21 @@ def compare_results(tt_tensor, golden_tensor, pcc=0.99):
return status


def compare_results_batch_norm(tt_tensor, golden_tensor, pcc=0.99):
status = True
for i in range(len(tt_tensor)):
tt_out_tensor = tt_tensor[i]
pt_out_tensor = golden_tensor[i]
comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=pcc)
comp_all, comp_out_res = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=4, rtol=1e-1)
logger.debug(comp_pass)
logger.debug(comp_all)
logger.debug(comp_out)
logger.debug(comp_out_res)
status = status & comp_pass & comp_all
return status


def compare_pcc(tt_tensor, golden_tensor, pcc=0.99):
status = True
for i in range(len(tt_tensor)):
Expand Down
98 changes: 98 additions & 0 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range_batch_norm,
compare_results_batch_norm,
)
from itertools import product


@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 32, 32]),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])),
torch.Size([4, 4, 23, 23]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [False])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm(input_shapes, training, weight, bias, eps, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
running_var=var_tensor,
training=training,
eps=eps,
weight=weight_tensor,
bias=bias_tensor,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
# ttnn.set_printoptions(profile="full")
# print("TT result : ", tt_output, tt_output.shape)
# torch.set_printoptions(precision=5, sci_mode=False)
torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
running_var=var_data,
weight=weight_data,
bias=bias_data,
training=training,
eps=eps,
)
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
[
torch.Size([3, 2, 32, 32]),
],
)
@pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.TensorMemoryLayout.HEIGHT_SHARDED])
def test_batch_norm_program_cache_and_default(input_shapes, mem_layout, device):
N, H, W, C = input_shapes
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device)

grid_size = ttnn.CoreGrid(y=1, x=8)
grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_shape = N * H * W // grid_size.x, C // grid_size.y
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False)
sharded_mem_config = ttnn.MemoryConfig(mem_layout, ttnn.types.BufferType.L1, shard_spec)

if mem_layout is not ttnn.TensorMemoryLayout.INTERLEAVED:
pytest.xfail("Input tensors to batch norm must be interleaved")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor, running_mean=mean_tensor, running_var=var_tensor, memory_config=sharded_mem_config
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
assert comp_pass
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp
Expand Down
31 changes: 31 additions & 0 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "batch_norm.hpp"

#include "device/batch_norm_device_operation.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::normalization {

Tensor BatchNorm::invoke(
const Tensor& input,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
const bool training,
const float eps,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
// TODO: Implementation for training mode is in progress
TT_FATAL((!training), "Support currently provided for inference mode only");
TT_FATAL(
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config);
}
} // namespace ttnn::operations::normalization
27 changes: 27 additions & 0 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "ttnn/decorators.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"

namespace ttnn {
namespace operations::normalization {
struct BatchNorm {
static Tensor invoke(
const Tensor& input,
std::optional<Tensor> running_mean = std::nullopt,
std::optional<Tensor> running_var = std::nullopt,
const bool training = false,
const float eps = 1e-05,
std::optional<Tensor> weight = std::nullopt,
std::optional<Tensor> bias = std::nullopt,
std::optional<Tensor> output = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};
} // namespace operations::normalization

constexpr auto batch_norm =
ttnn::register_operation_with_auto_launch_op<"ttnn::batch_norm", ttnn::operations::normalization::BatchNorm>();
} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "batch_norm_pybind.hpp"

#include "batch_norm.hpp"

#include "pybind11/decorators.hpp"
namespace py = pybind11;
namespace ttnn::operations::normalization::detail {
void bind_batch_norm_operation(pybind11::module& module) {
ttnn::bind_registered_operation(
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Args:
input_tensor (ttnn.Tensor): the input tensor of shape `[N, C, H, W]`.
Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result of shape `[N, C, H, W]`. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
)doc",
ttnn::pybind_arguments_t{
py::arg("input"),
py::kw_only(),
py::arg("running_mean") = std::nullopt,
py::arg("running_var") = std::nullopt,
py::arg("training") = false,
py::arg("eps") = 1e-05,
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("output") = std::nullopt,
py::arg("memory_config") = std::nullopt

});
}
} // namespace ttnn::operations::normalization::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "pybind11/pybind_fwd.hpp"

namespace py = pybind11;

namespace ttnn::operations::normalization::detail {
void bind_batch_norm_operation(pybind11::module& module);
}
Loading

0 comments on commit 35c7145

Please sign in to comment.