Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#5389: updated attention_softmax to be a C++ operation #8333

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ttnn/cpp/pybind11/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,39 @@ void py_module(py::module& module) {

Args:
* :attr:`input_tensor`: Input Tensor

Keyword Args:
* :attr:`memory_config`: Memory Config of the output tensor, if None then it gets set to input_tensor.memory_config()
)doc",
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::transformer::attention_softmax,
R"doc(attention_softmax(tensor: ttnn.Tensor, *, head_size: Optional[int] = None, attention_mask: Optional[ttnn.Tensor] = None, program_config: Optional[SoftmaxProgramConfig] = SoftmaxDefaultProgramConfig(), causal_mask: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Divides :attr:`tensor` by the square root of :attr:`head_size`, adds :attr:`attention_mask` (optionally) and computes softmax.

Args:
* :attr:`tensor`: Input Tensor

Keyword Args:
* :attr:`head_size`: Number of heads
* :attr:`attention_mask`: Attention Mask
* :attr:`program_config`: Program Config of the output tensor
* :attr:`causal_mask`: the attention mask is causal
* :attr:`memory_config`: Memory Config of the output tensor, defaults to input_tensor.memory_config()
)doc",
ttnn::pybind_arguments_t{
py::arg("tensor"),
py::kw_only(),
py::arg("head_size") = std::nullopt,
py::arg("attention_mask") = std::nullopt,
py::arg("program_config").noconvert() =
tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
py::arg("causal_mask") = false,
py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::transformer::attention_softmax_,
Expand All @@ -39,9 +68,12 @@ void py_module(py::module& module) {

Args:
* :attr:`tensor`: Input Tensor

Keyword Args:
* :attr:`head_size`: Number of heads
* :attr:`attention_mask`: Attention Mask
* :attr:`program_config`: Program Config of the output tensor
* :attr:`causal_mask`: the attention mask is causal
* :attr:`memory_config`: Memory Config of the output tensor, defaults to input_tensor.memory_config()
)doc",
ttnn::pybind_arguments_t{
Expand Down
12 changes: 10 additions & 2 deletions ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,15 @@ void bind_unary(py::module& module, const unary_operation_t& operation) {

template <typename unary_operation_t>
void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module, const unary_operation_t& operation) {
std::string parameter_description;
if (operation.name() == "exp") {
parameter_description = "Use fast and approximate mode";
} else {
TT_THROW("Unknown name!");
}

auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, parameter: float, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor: ttnn.Tensor, *, parameter: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` element-wise.

Expand All @@ -64,7 +71,7 @@ void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module,
* :attr:`input_tensor`

Keyword Args:
* :attr:`parameter` (float): Parameter for the operation.
* :attr:`parameter` (bool): {2}.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.

Example::
Expand All @@ -73,6 +80,7 @@ void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module,
>>> output = {1}(tensor, parameter=true)
)doc",
operation.name(),
parameter_description,
operation.python_fully_qualified_name());

bind_registered_operation(
Expand Down
30 changes: 23 additions & 7 deletions ttnn/cpp/ttnn/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ inline std::tuple<Tensor, Tensor, Tensor> reshape_outputs_of_split_query_key_val
} // namespace detail

inline std::tuple<Tensor, Tensor, Tensor> split_query_key_value_and_split_heads(
const Tensor &input_tensor,
const std::optional<Tensor> &input_tensor_kv,
const Tensor& input_tensor,
const std::optional<Tensor>& input_tensor_kv,
const uint32_t num_heads,
const std::optional<uint32_t> num_kv_heads,
const bool transpose_key,
const std::optional<MemoryConfig> &memory_config) {
const std::optional<MemoryConfig>& memory_config) {
const auto input_shape = input_tensor.get_shape();
TT_FATAL(input_shape.rank() == 3, "Input Tensor must have strictly 3 dimensions!");
TT_FATAL(input_tensor.get_layout() == tt::tt_metal::Layout::TILE,"Input Tensor must be in a TILE_LAYOUT!");
Expand All @@ -81,7 +81,13 @@ inline std::tuple<Tensor, Tensor, Tensor> split_query_key_value_and_split_heads(
auto head_size = qkv_heads_times_head_dim / (num_heads + (num_kv_heads.value() * 2));
auto padded_head_size = qkv_heads_times_head_dim_padded / (num_heads + (num_kv_heads.value() * 2));

TT_FATAL(head_size % TILE_SIZE == 0, fmt::format("Head size {} must be a multiple of tile size {}! Update the preceding matmul to have the padding in the weights!", head_size, TILE_WIDTH));
TT_FATAL(
head_size % TILE_SIZE == 0,
fmt::format(
"Head size {} must be a multiple of tile size {}! Update the preceding matmul to have the padding in "
"the weights!",
head_size,
TILE_WIDTH));
TT_FATAL(padded_head_size == head_size, fmt::format("Head size {} cannot have tile padding", head_size));

const auto input_4d = input_tensor.reshape(
Expand Down Expand Up @@ -257,20 +263,28 @@ struct AttentionSoftmax : public tt::operations::primary::Softmax {

static ttnn::Tensor execute(
const ttnn::Tensor& input_tensor,
const std::optional<int>& head_size = std::nullopt,
const std::optional<int>& head_size_arg = std::nullopt,
const std::optional<const ttnn::Tensor>& attention_mask = std::nullopt,
const tt::operations::primary::transformers::SoftmaxProgramConfig& program_config =
tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
const std::optional<bool> causal_mask = false,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt) {
TT_FATAL(attention_mask.has_value(), "Cannot apply divide by sqrt(head_size) using in-place version!");
auto head_size = head_size_arg.has_value() ? 1.0 / sqrt(head_size_arg.value()) : 1.0;
if constexpr (in_place) {
TT_FATAL(attention_mask.has_value(), "Cannot apply divide by sqrt(head_size) using in-place version!");
} else {
if (not attention_mask.has_value()) {
auto output_tensor = ttnn::multiply(input_tensor, head_size);
return tt::tt_metal::softmax(output_tensor, memory_config.value_or(input_tensor.memory_config()));
}
}

std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt;
auto kernel_config_val = init_device_compute_kernel_config(
input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false);
auto output_tensor = operation::run(
AttentionSoftmax{
head_size.has_value() ? 1.0 / sqrt(head_size.value()) : 1.0,
head_size,
in_place,
memory_config.value_or(input_tensor.memory_config()),
program_config,
Expand All @@ -291,6 +305,8 @@ namespace transformer {
constexpr auto concatenate_heads =
ttnn::register_operation<ttnn::operations::transformer::ConcatenateHeads>("ttnn::transfomer::concatenate_heads");

constexpr auto attention_softmax = ttnn::register_operation<ttnn::operations::transformer::AttentionSoftmax<false>>(
"ttnn::transfomer::attention_softmax");
constexpr auto attention_softmax_ = ttnn::register_operation<ttnn::operations::transformer::AttentionSoftmax<true>>(
"ttnn::transfomer::attention_softmax_");
} // namespace transformer
Expand Down
66 changes: 2 additions & 64 deletions ttnn/ttnn/operations/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,71 +125,9 @@ def _golden_function(input_tensor: ttnn.Tensor, *, head_size: int, attention_mas
return torch.softmax(input_tensor, -1)


def _attention_softmax_validate_input_tensors(operation_name, input_tensor, *args, attention_mask, **kwargs):
ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(4,),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b),
layouts=(ttnn.TILE_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
)
ttnn.validate_input_tensor(
operation_name,
attention_mask,
ranks=(4,),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b),
layouts=(ttnn.TILE_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
is_optional=True,
)


@ttnn.register_operation(
name="ttnn.transformer.attention_softmax",
validate_input_tensors=_attention_softmax_validate_input_tensors,
attention_softmax = ttnn.register_operation(
golden_function=_golden_function,
)
def attention_softmax(
input_tensor: ttnn.Tensor,
*,
head_size: Optional[int],
attention_mask: Optional[ttnn.Tensor],
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
program_config: Optional[
ttl.operations.primary.transformers.SoftmaxProgramConfig
] = ttl.operations.primary.transformers.SoftmaxDefaultProgramConfig(),
) -> ttnn.Tensor:
"""
attention_softmax(input_tensor: ttnn.Tensor, *, head_size: int, attention_mask: Optional[ttnn.Tensor], memory_config: MemoryConfig = DRAM_MEMORY_CONFIG) -> ttnn.Tensor

Divides :attr:`input_tensor` by the square root of :attr:`head_size`, adds :attr:`attention_mask` (optionally) and computes softmax

Args:
* :attr:`input_tensor`: Input Tensor
* :attr:`head_size`: Number of heads
* :attr:`attention_mask`: Attention Mask
* :attr:`memory_config`: Memory Config of the output tensor

"""
if head_size is not None:
scaler = 1 / (head_size**0.5)
else:
scaler = 1.0

if attention_mask is not None:
output_tensor = ttl.tensor.scale_mask_softmax(
input_tensor,
scaler,
attention_mask,
output_mem_config=memory_config,
)
else:
scaled_input_tensor = input_tensor * scaler
output_tensor = ttl.tensor.softmax(scaled_input_tensor, output_mem_config=memory_config)
return output_tensor
)(ttnn._ttnn.operations.transformer.attention_softmax)


attention_softmax_ = ttnn.register_operation(
Expand Down
Loading