diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 2a6ff9c64c3..2ac2fda5fb6 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -17,6 +17,7 @@ #include "tt_metal/impl/device/device.hpp" #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" +#include "tt_metal/common/core_coord.h" namespace tt { diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp index 22efdeccca6..955dbe3e88a 100644 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ b/ttnn/cpp/pybind11/operations/normalization.hpp @@ -72,6 +72,27 @@ void py_module(py::module& module) { py::kw_only(), py::arg("epsilon") = 1e-12, py::arg("memory_config") = std::nullopt}); + + ttnn::bind_registered_operation( + module, + ttnn::group_norm, + R"doc(group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor + Compute group_norm over :attr:`input_tensor`. + )doc", + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::kw_only(), + py::arg("num_groups"), + py::arg("epsilon") = 1e-12, + py::arg("input_mask") = std::nullopt, + py::arg("weight") = std::nullopt, + py::arg("bias") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt, + py::arg("core_grid") = std::nullopt, + py::arg("inplace") = true + } + ); } } // namespace normalization diff --git a/ttnn/cpp/ttnn/operations/normalization.hpp b/ttnn/cpp/ttnn/operations/normalization.hpp index aa9cce6bf39..01c4735ad75 100644 --- a/ttnn/cpp/ttnn/operations/normalization.hpp +++ b/ttnn/cpp/ttnn/operations/normalization.hpp @@ -7,6 +7,7 @@ #include "tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" #include "tt_dnn/op_library/softmax/softmax_op.hpp" #include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp" namespace ttnn { namespace operations { @@ -161,10 +162,87 @@ struct RMSNorm : tt::operations::primary::LayerNorm { } }; + +struct GroupNorm { + template + static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { + return std::make_tuple(input_tensor); + } + + static inline const std::array input_tensor_schemas() { + return { + ttnn::TensorSchema{ + 2, + 4, + {ttnn::bfloat16}, + {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, + true, + false, + false, + false} + }; + } + + static inline ttnn::Tensor execute( + const ttnn::Tensor& input_tensor, + const int num_groups, + const float epsilon, + const std::optional & input_mask = std::nullopt, + const std::optional & weight = std::nullopt, + const std::optional & bias = std::nullopt, + const std::optional & memory_config = std::nullopt, + const std::optional dtype = std::nullopt, + std::optional core_grid = std::nullopt, + std::optional inplace = std::nullopt + ) { + + TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported"); + + TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported"); + + TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Input tensor cannot be width sharded"); + + TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4"); + + TT_FATAL(input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups"); + + const auto& ts = input_tensor.get_shape(); + TT_FATAL((ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, "Input tensor dim NHW must be divisible by tile size"); + + const auto output_dtype = dtype.value_or(input_tensor.get_dtype()); + + const std::optional & gamma = weight.has_value() ? std::optional(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt; + const std::optional & beta = bias.has_value() ? std::optional(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt; + + const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM}; + const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config); + + const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = { + .compute_with_storage_grid_size = core_grid.value().to_CoreCoord(), + .math_fidelity = MathFidelity::HiFi4, + .im_data_format = DataType::BFLOAT16, + .out_data_format = DataType::BFLOAT16, + .inplace = inplace.value_or(false) + }; + + return tt::operations::primary::groupnorm( + input_tensor, + num_groups, + epsilon, + gamma, + beta, + input_mask, + output_mem_config, + program_config + ); + } +}; + } // namespace normalization } // namespace operations constexpr auto softmax = ttnn::register_operation>("ttnn::softmax"); constexpr auto layer_norm = ttnn::register_operation("ttnn::layer_norm"); constexpr auto rms_norm = ttnn::register_operation("ttnn::rms_norm"); +constexpr auto group_norm = ttnn::register_operation("ttnn::group_norm"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/types.hpp b/ttnn/cpp/ttnn/types.hpp index 2710ac0048e..af802e53853 100644 --- a/ttnn/cpp/ttnn/types.hpp +++ b/ttnn/cpp/ttnn/types.hpp @@ -47,6 +47,9 @@ struct CoreGrid { std::size_t y; CoreGrid(std::size_t x, std::size_t y) : x(x), y(y) {} + CoreCoord to_CoreCoord(){ + return CoreCoord(int(x), int(y)); + } }; // This buffer class is compatible with multithreaded runtime (which lives in tt_eager) diff --git a/ttnn/ttnn/operations/normalization.py b/ttnn/ttnn/operations/normalization.py index e0ff7a8ade0..d9077450956 100644 --- a/ttnn/ttnn/operations/normalization.py +++ b/ttnn/ttnn/operations/normalization.py @@ -262,74 +262,6 @@ def _group_norm_validate_input_tensors(operation_name, input_tensor, *args, weig ) -@ttnn.register_operation( - name="ttnn.group_norm", - validate_input_tensors=_group_norm_validate_input_tensors, - golden_function=_golden_function, - postprocess_golden_function_outputs=_postprocess_golden_function_outputs, -) -def group_norm( - input_tensor: ttnn.Tensor, - *, - num_groups: int, - epsilon: float = 1e-12, - input_mask: Optional[ttnn.Tensor] = None, - weight: Optional[ttnn.Tensor] = None, - bias: Optional[ttnn.Tensor] = None, - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, - dtype: Optional[ttnn.DataType] = None, - core_grid: Optional[Union[ttnn.CoreGrid, ttnn.CoreRange]] = None, - inplace: Optional[bool] = True, -) -> ttnn.Tensor: - r""" - group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor - - Compute group_norm over :attr:`input_tensor`. - - """ - - if core_grid is not None and not isinstance(core_grid, ttnn.CoreGrid): - raise RuntimeError("core_grid must be a valid CoreGrid object") - - if ttnn.is_sharded(input_tensor): - if input_tensor.shape.rank != 4: - raise TypeError("The input tensor rank must equal to 4") - - if input_tensor.shape[-1] % num_groups != 0: - raise TypeError("number of channels must be divisible by number of groups") - - if ttnn.get_memory_config(input_tensor).memory_layout == ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED: - raise TypeError("Cannot be width sharded") - - if (input_tensor.shape[0] * input_tensor.shape[1] * input_tensor.shape[2]) % ttnn.TILE_SIZE != 0: - raise TypeError("input tensor dim NHW must be divisible by tile size") - - output_dtype = input_tensor.dtype if dtype is None else dtype - - if weight is not None: - weight = ttnn.unsqueeze_to_4D(weight) - - if bias is not None: - bias = ttnn.unsqueeze_to_4D(bias) - - output_tensor = ttnn.experimental.operations.primary.groupnorm( - input_tensor, - num_groups, - epsilon, - weight, - bias, - input_mask, - output_mem_config=memory_config, - program_config=ttl.operations.primary.GroupNormShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=(core_grid.x, core_grid.y), - out_data_format=output_dtype, - inplace=inplace, - ), - ) - return output_tensor - - else: - raise NotImplementedError - +group_norm = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.normalization.group_norm) __all__ = []