-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
### List of Github issues - #12253 - #15449 ### Problem description To implement batch normalization as device op ### What's changed Work done : - Provided support for Inference mode <img width="1254" alt="Screenshot 2025-01-10 at 11 59 49 AM" src="https://github.com/user-attachments/assets/061ccaca-8b7e-44c9-a0f3-e1c87c908403" /> ### Checklist - [x] [All post-commit testsI](https://github.com/tenstorrent/tt-metal/actions/runs/12729863177) - [x] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/runs/12729863432) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) - [Link to test](https://github.com/tenstorrent/tt-metal/actions/runs/12732676848) - [x] [(Single-card) Demo tests](https://github.com/tenstorrent/tt-metal/actions/runs/12729864298/job/35486219235) - same as main - [x] [(Single-card) Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/runs/12729865036) - same as main - [x] [Single-card Model perf tests](https://github.com/tenstorrent/tt-metal/actions/runs/12729865358) - same as main
- Loading branch information
1 parent
880f700
commit 35c7145
Showing
15 changed files
with
1,128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -444,6 +444,7 @@ Normalization | |
ttnn.group_norm | ||
ttnn.layer_norm | ||
ttnn.rms_norm | ||
ttnn.batch_norm | ||
|
||
|
||
Moreh Operations | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
import ttnn | ||
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( | ||
data_gen_with_range_batch_norm, | ||
compare_results_batch_norm, | ||
) | ||
from itertools import product | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[ | ||
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), | ||
torch.Size([4, 4, 32, 32]), | ||
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), | ||
torch.Size([4, 4, 23, 23]), | ||
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), | ||
torch.Size([3, 1, 64, 120]), | ||
torch.Size([3, 2, 64, 120]), | ||
], | ||
) | ||
@pytest.mark.parametrize("training", [False]) | ||
@pytest.mark.parametrize("weight", [True, False]) | ||
@pytest.mark.parametrize("bias", [True, False]) | ||
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) | ||
def test_batch_norm(input_shapes, training, weight, bias, eps, device): | ||
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) | ||
mean_data, mean_tensor = ( | ||
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None) | ||
) | ||
var_data, var_tensor = ( | ||
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None) | ||
) | ||
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None) | ||
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None) | ||
|
||
tt_output_tensor_on_device = ttnn.batch_norm( | ||
input_tensor, | ||
running_mean=mean_tensor, | ||
running_var=var_tensor, | ||
training=training, | ||
eps=eps, | ||
weight=weight_tensor, | ||
bias=bias_tensor, | ||
) | ||
tt_output = ttnn.to_torch(tt_output_tensor_on_device) | ||
# ttnn.set_printoptions(profile="full") | ||
# print("TT result : ", tt_output, tt_output.shape) | ||
# torch.set_printoptions(precision=5, sci_mode=False) | ||
torch_result = torch.nn.functional.batch_norm( | ||
input=in_data, | ||
running_mean=mean_data, | ||
running_var=var_data, | ||
weight=weight_data, | ||
bias=bias_data, | ||
training=training, | ||
eps=eps, | ||
) | ||
# print("Torch result : ",torch_result) | ||
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) | ||
assert comp_pass | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[ | ||
torch.Size([3, 2, 32, 32]), | ||
], | ||
) | ||
@pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.TensorMemoryLayout.HEIGHT_SHARDED]) | ||
def test_batch_norm_program_cache_and_default(input_shapes, mem_layout, device): | ||
N, H, W, C = input_shapes | ||
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) | ||
mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) | ||
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) | ||
|
||
grid_size = ttnn.CoreGrid(y=1, x=8) | ||
grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1) | ||
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) | ||
shard_shape = N * H * W // grid_size.x, C // grid_size.y | ||
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False) | ||
sharded_mem_config = ttnn.MemoryConfig(mem_layout, ttnn.types.BufferType.L1, shard_spec) | ||
|
||
if mem_layout is not ttnn.TensorMemoryLayout.INTERLEAVED: | ||
pytest.xfail("Input tensors to batch norm must be interleaved") | ||
|
||
tt_output_tensor_on_device = ttnn.batch_norm( | ||
input_tensor, running_mean=mean_tensor, running_var=var_tensor, memory_config=sharded_mem_config | ||
) | ||
tt_output = ttnn.to_torch(tt_output_tensor_on_device) | ||
torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data) | ||
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) | ||
assert comp_pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "batch_norm.hpp" | ||
|
||
#include "device/batch_norm_device_operation.hpp" | ||
|
||
using namespace tt::tt_metal; | ||
|
||
namespace ttnn::operations::normalization { | ||
|
||
Tensor BatchNorm::invoke( | ||
const Tensor& input, | ||
std::optional<Tensor> running_mean, | ||
std::optional<Tensor> running_var, | ||
const bool training, | ||
const float eps, | ||
std::optional<Tensor> weight, | ||
std::optional<Tensor> bias, | ||
std::optional<Tensor> output, | ||
const std::optional<MemoryConfig>& memory_config) { | ||
// TODO: Implementation for training mode is in progress | ||
TT_FATAL((!training), "Support currently provided for inference mode only"); | ||
TT_FATAL( | ||
(running_mean.has_value() && running_var.has_value()), | ||
"running_mean and running_var must be defined in evaluation mode"); | ||
return ttnn::prim::batch_norm( | ||
input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config); | ||
} | ||
} // namespace ttnn::operations::normalization |
27 changes: 27 additions & 0 deletions
27
ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "ttnn/decorators.hpp" | ||
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" | ||
|
||
namespace ttnn { | ||
namespace operations::normalization { | ||
struct BatchNorm { | ||
static Tensor invoke( | ||
const Tensor& input, | ||
std::optional<Tensor> running_mean = std::nullopt, | ||
std::optional<Tensor> running_var = std::nullopt, | ||
const bool training = false, | ||
const float eps = 1e-05, | ||
std::optional<Tensor> weight = std::nullopt, | ||
std::optional<Tensor> bias = std::nullopt, | ||
std::optional<Tensor> output = std::nullopt, | ||
const std::optional<MemoryConfig>& memory_config = std::nullopt); | ||
}; | ||
} // namespace operations::normalization | ||
|
||
constexpr auto batch_norm = | ||
ttnn::register_operation_with_auto_launch_op<"ttnn::batch_norm", ttnn::operations::normalization::BatchNorm>(); | ||
} // namespace ttnn |
54 changes: 54 additions & 0 deletions
54
ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "batch_norm_pybind.hpp" | ||
|
||
#include "batch_norm.hpp" | ||
|
||
#include "pybind11/decorators.hpp" | ||
namespace py = pybind11; | ||
namespace ttnn::operations::normalization::detail { | ||
void bind_batch_norm_operation(pybind11::module& module) { | ||
ttnn::bind_registered_operation( | ||
module, | ||
ttnn::batch_norm, | ||
R"doc( | ||
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only. | ||
Args: | ||
input_tensor (ttnn.Tensor): the input tensor of shape `[N, C, H, W]`. | ||
Keyword args: | ||
eps (float, optional): Epsilon value. Defaults to `1e-05`. | ||
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. | ||
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. | ||
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`. | ||
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`. | ||
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode). | ||
output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result of shape `[N, C, H, W]`. Defaults to `None`. | ||
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. | ||
Returns: | ||
ttnn.Tensor: the output tensor. | ||
)doc", | ||
ttnn::pybind_arguments_t{ | ||
py::arg("input"), | ||
py::kw_only(), | ||
py::arg("running_mean") = std::nullopt, | ||
py::arg("running_var") = std::nullopt, | ||
py::arg("training") = false, | ||
py::arg("eps") = 1e-05, | ||
py::arg("weight") = std::nullopt, | ||
py::arg("bias") = std::nullopt, | ||
py::arg("output") = std::nullopt, | ||
py::arg("memory_config") = std::nullopt | ||
|
||
}); | ||
} | ||
} // namespace ttnn::operations::normalization::detail |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace ttnn::operations::normalization::detail { | ||
void bind_batch_norm_operation(pybind11::module& module); | ||
} |
Oops, something went wrong.