diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index e53f63bbece..00c8473bfa6 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -17,6 +17,7 @@ #include "tt_metal/impl/device/device.hpp" #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" +#include "tt_metal/common/core_coord.h" namespace tt { diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp index 4b774d1a36e..19c7ad61847 100644 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ b/ttnn/cpp/pybind11/operations/normalization.hpp @@ -40,6 +40,22 @@ Compute layer_norm over :attr:`input_tensor`. R"doc( Compute rms_norm over :attr:`input_tensor`. )doc"); + + module.def("group_norm", &group_norm, + py::arg("input_tensor"), + py::kw_only(), + py::arg("num_groups"), + py::arg("epsilon") = 1e-12, + py::arg("input_mask") = std::nullopt, + py::arg("weight") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt, + py::arg("core_grid") = std::nullopt, + py::arg("inplace") = true, + R"doc( +Compute group_norm over :attr:`input_tensor`. + )doc"); } } // namespace normalization diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp index 7764d02d207..95d8f95e5ca 100644 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ b/ttnn/cpp/ttnn/operations/normalization.hpp @@ -5,6 +5,7 @@ #pragma once #include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp" namespace ttnn { namespace operations { @@ -40,6 +41,60 @@ inline ttnn::Tensor rms_norm( return tt::operations::primary::rmsnorm(input_tensor, epsilon, std::optional(weight), std::nullopt, dram_memory_config); } +inline ttnn::Tensor group_norm( + const ttnn::Tensor& input_tensor, + const int num_groups, + const float epsilon, + const std::optional & input_mask = std::nullopt, + const std::optional & weight = std::nullopt, + const std::optional & bias = std::nullopt, + const std::optional & memory_config = std::nullopt, + const std::optional dtype = std::nullopt, + std::optional core_grid = std::nullopt, + std::optional inplace = std::nullopt +) { + + TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported"); + + TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported"); + + TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Input tensor cannot be width sharded"); + + TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4"); + + TT_FATAL(input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups"); + + const auto& ts = input_tensor.get_shape(); + TT_FATAL((ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, "Input tensor dim NHW must be divisible by tile size"); + + const auto output_dtype = dtype.value_or(input_tensor.get_dtype()); + + const std::optional & gamma = weight.has_value() ? std::optional(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt; + const std::optional & beta = bias.has_value() ? std::optional(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt; + + const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM}; + const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config); + + const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = { + .compute_with_storage_grid_size = core_grid.value().to_CoreCoord(), + .math_fidelity = MathFidelity::HiFi4, + .im_data_format = DataType::BFLOAT16, + .out_data_format = DataType::BFLOAT16, + .inplace = inplace.value_or(false) + }; + + return tt::operations::primary::groupnorm( + input_tensor, + num_groups, + epsilon, + gamma, + beta, + input_mask, + output_mem_config, + program_config + ); +} + } // namespace normalization } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/types.hpp b/ttnn/cpp/ttnn/types.hpp index b5c579db3b3..5f4ef17f713 100644 --- a/ttnn/cpp/ttnn/types.hpp +++ b/ttnn/cpp/ttnn/types.hpp @@ -46,6 +46,9 @@ struct CoreGrid { std::size_t y; CoreGrid(std::size_t x, std::size_t y) : x(x), y(y) {} + CoreCoord to_CoreCoord(){ + return CoreCoord(int(x), int(y)); + } }; static std::ostream &operator<<(std::ostream &os, const CoreGrid &core_grid) { diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index 4f67c00a082..658d8f41a9d 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -332,74 +332,8 @@ def _group_norm_validate_input_tensors(operation_name, input_tensor, *args, weig ) -@ttnn.register_operation( - name="ttnn.group_norm", - validate_input_tensors=_group_norm_validate_input_tensors, - golden_function=_golden_function, - postprocess_golden_function_outputs=_postprocess_golden_function_outputs, +group_norm = ttnn.register_operation(name="ttnn.group_norm", is_cpp_function=True, golden_function=_golden_function)( + ttnn._ttnn.operations.normalization.group_norm ) -def group_norm( - input_tensor: ttnn.Tensor, - *, - num_groups: int, - epsilon: float = 1e-12, - input_mask: Optional[ttnn.Tensor] = None, - weight: Optional[ttnn.Tensor] = None, - bias: Optional[ttnn.Tensor] = None, - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, - dtype: Optional[ttnn.DataType] = None, - core_grid: Optional[Union[ttnn.CoreGrid, ttnn.CoreRange]] = None, - inplace: Optional[bool] = True, -) -> ttnn.Tensor: - r""" - group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor - - Compute group_norm over :attr:`input_tensor`. - - """ - - if core_grid is not None and not isinstance(core_grid, ttnn.CoreGrid): - raise RuntimeError("core_grid must be a valid CoreGrid object") - - if ttnn.is_sharded(input_tensor): - if input_tensor.shape.rank != 4: - raise TypeError("The input tensor rank must equal to 4") - - if input_tensor.shape[-1] % num_groups != 0: - raise TypeError("number of channels must be divisible by number of groups") - - if ttnn.get_memory_config(input_tensor).memory_layout == ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED: - raise TypeError("Cannot be width sharded") - - if (input_tensor.shape[0] * input_tensor.shape[1] * input_tensor.shape[2]) % ttnn.TILE_SIZE != 0: - raise TypeError("input tensor dim NHW must be divisible by tile size") - - output_dtype = input_tensor.dtype if dtype is None else dtype - - if weight is not None: - weight = ttnn.unsqueeze_to_4D(weight) - - if bias is not None: - bias = ttnn.unsqueeze_to_4D(bias) - - output_tensor = ttnn.experimental.operations.primary.groupnorm( - input_tensor, - num_groups, - epsilon, - weight, - bias, - input_mask, - output_mem_config=memory_config, - program_config=ttl.operations.primary.GroupNormShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=(core_grid.x, core_grid.y), - out_data_format=output_dtype, - inplace=inplace, - ), - ) - return output_tensor - - else: - raise NotImplementedError - __all__ = []