Skip to content

Commit

Permalink
#7781: Port ttnn.group_norm to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
xanderchin committed May 17, 2024
1 parent c560a26 commit 6ec36cb
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 69 deletions.
1 change: 1 addition & 0 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/tt_stl/concepts.hpp"
#include "tt_metal/tt_stl/reflection.hpp"
#include "tt_metal/common/core_coord.h"

namespace tt {

Expand Down
21 changes: 21 additions & 0 deletions ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ void py_module(py::module& module) {
py::kw_only(),
py::arg("epsilon") = 1e-12,
py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::group_norm,
R"doc(group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor
Compute group_norm over :attr:`input_tensor`.
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::kw_only(),
py::arg("num_groups"),
py::arg("epsilon") = 1e-12,
py::arg("input_mask") = std::nullopt,
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("dtype") = std::nullopt,
py::arg("core_grid") = std::nullopt,
py::arg("inplace") = true
}
);
}

} // namespace normalization
Expand Down
78 changes: 78 additions & 0 deletions ttnn/cpp/ttnn/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp"
#include "tt_dnn/op_library/softmax/softmax_op.hpp"
#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp"
#include "tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp"

namespace ttnn {
namespace operations {
Expand Down Expand Up @@ -161,10 +162,87 @@ struct RMSNorm : tt::operations::primary::LayerNorm {
}
};


struct GroupNorm {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() {
return {
ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16},
{ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT},
true,
false,
false,
false}
};
}

static inline ttnn::Tensor execute(
const ttnn::Tensor& input_tensor,
const int num_groups,
const float epsilon,
const std::optional<ttnn::Tensor> & input_mask = std::nullopt,
const std::optional<ttnn::Tensor> & weight = std::nullopt,
const std::optional<ttnn::Tensor> & bias = std::nullopt,
const std::optional<MemoryConfig> & memory_config = std::nullopt,
const std::optional<ttnn::DataType> dtype = std::nullopt,
std::optional<CoreGrid> core_grid = std::nullopt,
std::optional<bool> inplace = std::nullopt
) {

TT_FATAL(core_grid.has_value(), "Automatic determination of grid size not supported");

TT_FATAL(input_tensor.is_sharded(), "Only sharded input tensors supported");

TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Input tensor cannot be width sharded");

TT_FATAL(input_tensor.get_shape().rank() == 4, "Input tensor must be rank 4");

TT_FATAL(input_tensor.get_shape()[-1] % num_groups == 0, "Number of channels must be divisible by number of groups");

const auto& ts = input_tensor.get_shape();
TT_FATAL((ts[0] * ts[1] * ts[2]) % ttnn::types::TILE_SIZE == 0, "Input tensor dim NHW must be divisible by tile size");

const auto output_dtype = dtype.value_or(input_tensor.get_dtype());

const std::optional<ttnn::Tensor> & gamma = weight.has_value() ? std::optional<ttnn::Tensor>(ttnn::unsqueeze_to_4D(weight.value())) : std::nullopt;
const std::optional<ttnn::Tensor> & beta = bias.has_value() ? std::optional<ttnn::Tensor>(ttnn::unsqueeze_to_4D(bias.value())) : std::nullopt;

const MemoryConfig& dram_memory_config = tt::tt_metal::MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED,.buffer_type=tt::tt_metal::BufferType::DRAM};
const MemoryConfig& output_mem_config = memory_config.value_or(dram_memory_config);

const tt::operations::primary::GroupNormShardedMultiCoreProgramConfig& program_config = {
.compute_with_storage_grid_size = core_grid.value().to_CoreCoord(),
.math_fidelity = MathFidelity::HiFi4,
.im_data_format = DataType::BFLOAT16,
.out_data_format = DataType::BFLOAT16,
.inplace = inplace.value_or(false)
};

return tt::operations::primary::groupnorm(
input_tensor,
num_groups,
epsilon,
gamma,
beta,
input_mask,
output_mem_config,
program_config
);
}
};

} // namespace normalization
} // namespace operations

constexpr auto softmax = ttnn::register_operation<ttnn::operations::normalization::Softmax<false>>("ttnn::softmax");
constexpr auto layer_norm = ttnn::register_operation<ttnn::operations::normalization::LayerNorm>("ttnn::layer_norm");
constexpr auto rms_norm = ttnn::register_operation<ttnn::operations::normalization::RMSNorm>("ttnn::rms_norm");
constexpr auto group_norm = ttnn::register_operation<ttnn::operations::normalization::GroupNorm>("ttnn::group_norm");
} // namespace ttnn
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ struct CoreGrid {
std::size_t y;

CoreGrid(std::size_t x, std::size_t y) : x(x), y(y) {}
CoreCoord to_CoreCoord(){
return CoreCoord(int(x), int(y));
}
};

// This buffer class is compatible with multithreaded runtime (which lives in tt_eager)
Expand Down
70 changes: 1 addition & 69 deletions ttnn/ttnn/operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,74 +262,6 @@ def _group_norm_validate_input_tensors(operation_name, input_tensor, *args, weig
)


@ttnn.register_operation(
name="ttnn.group_norm",
validate_input_tensors=_group_norm_validate_input_tensors,
golden_function=_golden_function,
postprocess_golden_function_outputs=_postprocess_golden_function_outputs,
)
def group_norm(
input_tensor: ttnn.Tensor,
*,
num_groups: int,
epsilon: float = 1e-12,
input_mask: Optional[ttnn.Tensor] = None,
weight: Optional[ttnn.Tensor] = None,
bias: Optional[ttnn.Tensor] = None,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
dtype: Optional[ttnn.DataType] = None,
core_grid: Optional[Union[ttnn.CoreGrid, ttnn.CoreRange]] = None,
inplace: Optional[bool] = True,
) -> ttnn.Tensor:
r"""
group_norm(input_tensor: ttnn.Tensor, *, num_groups: int, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None) -> ttnn.Tensor
Compute group_norm over :attr:`input_tensor`.
"""

if core_grid is not None and not isinstance(core_grid, ttnn.CoreGrid):
raise RuntimeError("core_grid must be a valid CoreGrid object")

if ttnn.is_sharded(input_tensor):
if input_tensor.shape.rank != 4:
raise TypeError("The input tensor rank must equal to 4")

if input_tensor.shape[-1] % num_groups != 0:
raise TypeError("number of channels must be divisible by number of groups")

if ttnn.get_memory_config(input_tensor).memory_layout == ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED:
raise TypeError("Cannot be width sharded")

if (input_tensor.shape[0] * input_tensor.shape[1] * input_tensor.shape[2]) % ttnn.TILE_SIZE != 0:
raise TypeError("input tensor dim NHW must be divisible by tile size")

output_dtype = input_tensor.dtype if dtype is None else dtype

if weight is not None:
weight = ttnn.unsqueeze_to_4D(weight)

if bias is not None:
bias = ttnn.unsqueeze_to_4D(bias)

output_tensor = ttnn.experimental.operations.primary.groupnorm(
input_tensor,
num_groups,
epsilon,
weight,
bias,
input_mask,
output_mem_config=memory_config,
program_config=ttl.operations.primary.GroupNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid.x, core_grid.y),
out_data_format=output_dtype,
inplace=inplace,
),
)
return output_tensor

else:
raise NotImplementedError

group_norm = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.normalization.group_norm)

__all__ = []

0 comments on commit 6ec36cb

Please sign in to comment.