Skip to content

Commit

Permalink
#7781: Port ttnn.group_norm to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
xanderchin committed May 8, 2024
1 parent 3b13c10 commit 834f4de
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 68 deletions.
1 change: 1 addition & 0 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/tt_stl/concepts.hpp"
#include "tt_metal/tt_stl/reflection.hpp"
#include "tt_metal/common/core_coord.h"

namespace tt {

Expand Down
16 changes: 16 additions & 0 deletions ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ Compute layer_norm over :attr:`input_tensor`.
R"doc(
Compute rms_norm over :attr:`input_tensor`.
)doc");

module.def("group_norm", &group_norm,
py::arg("input_tensor"),
py::kw_only(),
py::arg("num_groups"),
py::arg("epsilon") = 1e-12,
py::arg("input_mask") = std::nullopt,
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("dtype") = std::nullopt,
py::arg("core_grid") = std::nullopt,
py::arg("inplace") = true,
R"doc(
Compute group_norm over :attr:`input_tensor`.
)doc");
}

} // namespace normalization
Expand Down
55 changes: 55 additions & 0 deletions ttnn/cpp/ttnn/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp"
#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp"

namespace ttnn {
namespace operations {
Expand Down Expand Up @@ -40,6 +41,60 @@ inline ttnn::Tensor rms_norm(
return tt::operations::primary::rmsnorm(input_tensor, epsilon, std::optional<const ttnn::Tensor>(weight), std::nullopt, dram_memory_config);
}

inline ttnn::Tensor group_norm(
const ttnn::Tensor& input_tensor,
const int num_groups,
const float epsilon,
const std::optional<ttnn::Tensor> & input_mask = std::nullopt,
const std::optional<ttnn::Tensor> & weight = std::nullopt,
const std::optional<ttnn::Tensor> & bias = std::nullopt,
const std::optional<MemoryConfig> & memory_config = std::nullopt,
const std::optional<ttnn::DataType> dtype = std::nullopt,
std::optional<CoreGrid> core_grid = std::nullopt,
std::optional<bool> inplace = std::nullopt
) {

TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported");

TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported");

TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Input tensor cannot be width sharded");

TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4");

TT_FATAL(input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups");

const auto& ts = input_tensor.get_shape();
TT_FATAL((ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, "Input tensor dim NHW must be divisible by tile size");

const auto output_dtype = dtype.value_or(input_tensor.get_dtype());

const std::optional<ttnn::Tensor> & gamma = weight.has_value() ? std::optional<ttnn::Tensor>(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt;
const std::optional<ttnn::Tensor> & beta = bias.has_value() ? std::optional<ttnn::Tensor>(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt;

const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM};
const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config);

const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = {
.compute_with_storage_grid_size = core_grid.value().to_CoreCoord(),
.math_fidelity = MathFidelity::HiFi4,
.im_data_format = DataType::BFLOAT16,
.out_data_format = DataType::BFLOAT16,
.inplace = inplace.value_or(false)
};

return tt::operations::primary::groupnorm(
input_tensor,
num_groups,
epsilon,
gamma,
beta,
input_mask,
output_mem_config,
program_config
);
}

} // namespace normalization
} // namespace operations
} // namespace ttnn
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct CoreGrid {
std::size_t y;

CoreGrid(std::size_t x, std::size_t y) : x(x), y(y) {}
CoreCoord to_CoreCoord(){
return CoreCoord(int(x), int(y));
}
};

static std::ostream &operator<<(std::ostream &os, const CoreGrid &core_grid) {
Expand Down
70 changes: 2 additions & 68 deletions ttnn/ttnn/operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,74 +332,8 @@ def _group_norm_validate_input_tensors(operation_name, input_tensor, *args, weig
)


@ttnn.register_operation(
name="ttnn.group_norm",
validate_input_tensors=_group_norm_validate_input_tensors,
golden_function=_golden_function,
postprocess_golden_function_outputs=_postprocess_golden_function_outputs,
group_norm = ttnn.register_operation(name="ttnn.group_norm", is_cpp_function=True, golden_function=_golden_function)(
ttnn._ttnn.operations.normalization.group_norm
)
def group_norm(
input_tensor: ttnn.Tensor,
*,
num_groups: int,
epsilon: float = 1e-12,
input_mask: Optional[ttnn.Tensor] = None,
weight: Optional[ttnn.Tensor] = None,
bias: Optional[ttnn.Tensor] = None,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
dtype: Optional[ttnn.DataType] = None,
core_grid: Optional[Union[ttnn.CoreGrid, ttnn.CoreRange]] = None,
inplace: Optional[bool] = True,
) -> ttnn.Tensor:
r"""
group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor
Compute group_norm over :attr:`input_tensor`.
"""

if core_grid is not None and not isinstance(core_grid, ttnn.CoreGrid):
raise RuntimeError("core_grid must be a valid CoreGrid object")

if ttnn.is_sharded(input_tensor):
if input_tensor.shape.rank != 4:
raise TypeError("The input tensor rank must equal to 4")

if input_tensor.shape[-1] % num_groups != 0:
raise TypeError("number of channels must be divisible by number of groups")

if ttnn.get_memory_config(input_tensor).memory_layout == ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED:
raise TypeError("Cannot be width sharded")

if (input_tensor.shape[0] * input_tensor.shape[1] * input_tensor.shape[2]) % ttnn.TILE_SIZE != 0:
raise TypeError("input tensor dim NHW must be divisible by tile size")

output_dtype = input_tensor.dtype if dtype is None else dtype

if weight is not None:
weight = ttnn.unsqueeze_to_4D(weight)

if bias is not None:
bias = ttnn.unsqueeze_to_4D(bias)

output_tensor = ttnn.experimental.operations.primary.groupnorm(
input_tensor,
num_groups,
epsilon,
weight,
bias,
input_mask,
output_mem_config=memory_config,
program_config=ttl.operations.primary.GroupNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid.x, core_grid.y),
out_data_format=output_dtype,
inplace=inplace,
),
)
return output_tensor

else:
raise NotImplementedError


__all__ = []

0 comments on commit 834f4de

Please sign in to comment.