Skip to content

Commit

Permalink
#5389: added ttnn.cast
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 17, 2024
1 parent 129181f commit ec3a8dd
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 44 deletions.
39 changes: 39 additions & 0 deletions tests/ttnn/unit_tests/test_to_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize("height", [32])
@pytest.mark.parametrize("width", [32])
@pytest.mark.parametrize("from_dtype", [ttnn.float32, ttnn.bfloat16])
@pytest.mark.parametrize("to_dtype", [ttnn.bfloat16, ttnn.float32, ttnn.bfloat8_b])
def test_to_dtype(height, width, from_dtype, to_dtype):
torch_input_tensor = torch.rand((height, width), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)
assert input_tensor.layout == ttnn.ROW_MAJOR_LAYOUT

input_tensor = ttnn.to_dtype(input_tensor, from_dtype)
assert input_tensor.dtype == from_dtype
assert input_tensor.layout == ttnn.ROW_MAJOR_LAYOUT
assert tuple(input_tensor.shape) == (height, width)

output_tensor = ttnn.to_dtype(input_tensor, to_dtype)
assert output_tensor.dtype == to_dtype
if to_dtype == ttnn.bfloat8_b:
assert output_tensor.layout == ttnn.TILE_LAYOUT
else:
assert output_tensor.layout == ttnn.ROW_MAJOR_LAYOUT
assert tuple(output_tensor.shape) == (height, width)

output_tensor = ttnn.to_torch(output_tensor).to(torch_input_tensor.dtype)

assert_with_pcc(torch_input_tensor, output_tensor)
8 changes: 4 additions & 4 deletions tt_eager/tensor/host_buffer/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ borrowed_buffer::Buffer<T> get_as(Tensor& tensor) {
[](auto&& storage) -> borrowed_buffer::Buffer<T> {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return get_as<T>(storage.buffer);
return host_buffer::get_as<T>(storage.buffer);
} else if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return get_as<T>(storage.buffer);
return host_buffer::get_as<T>(storage.buffer);
} else {
TT_THROW("Tensor must have OwnedStorage or BorrowedStorage");
}
Expand All @@ -198,9 +198,9 @@ const borrowed_buffer::Buffer<T> get_as(const Tensor& tensor) {
[](auto&& storage) -> const borrowed_buffer::Buffer<T> {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return get_as<T>(storage.buffer);
return host_buffer::get_as<T>(storage.buffer);
} else if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return get_as<T>(storage.buffer);
return host_buffer::get_as<T>(storage.buffer);
} else {
TT_THROW("Tensor must have OwnedStorage or BorrowedStorage");
}
Expand Down
51 changes: 30 additions & 21 deletions ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ void py_module(py::module& module) {

module.def("deallocate", &ttnn::operations::core::deallocate, py::arg("tensor"), py::arg("force") = true);

module.def(
"reallocate",
[](ttnn::Tensor& input_tensor, const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt)
-> ttnn::Tensor { return reallocate(input_tensor, memory_config); },
py::arg("tensor"),
py::arg("memory_config") = std::nullopt,
R"doc(
Deallocates device tensor and returns a reallocated tensor
Args:
* :attr:`input_tensor`: Input Tensor
)doc");

bind_registered_operation(
module,
ttnn::to_memory_config,
Expand All @@ -106,29 +119,25 @@ void py_module(py::module& module) {
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
>>> tensor = ttnn.to_memory_config(tensor, memory_config)
)doc",
ttnn::pybind_overload_t{
[](const std::decay_t<decltype(ttnn::to_memory_config)> self,
const ttnn::Tensor& tensor,
const ttnn::MemoryConfig& memory_config,
const std::optional<ttnn::DataType>& dtype) -> ttnn::Tensor {
return self(tensor, memory_config, dtype);
},
py::arg("tensor"),
py::arg("memory_config"),
py::arg("dtype") = std::nullopt});
ttnn::pybind_arguments_t{py::arg("tensor"), py::arg("memory_config"), py::arg("dtype") = std::nullopt});

module.def(
"reallocate",
[](ttnn::Tensor& input_tensor, const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt)
-> ttnn::Tensor { return reallocate(input_tensor, memory_config); },
py::arg("tensor"),
py::arg("memory_config") = std::nullopt,
R"doc(
Deallocates device tensor and returns a reallocated tensor
bind_registered_operation(
module,
ttnn::to_dtype,
R"doc(to_dtype(tensor: ttnn.Tensor, dtype: DataType = None) -> ttnn.Tensor
Args:
* :attr:`input_tensor`: Input Tensor
)doc");
Converts a tensor to the desired dtype
Args:
* :attr:`tensor`: the ttnn.Tensor
* :attr:`dtype`: `ttnn` data type.
Example::
>>> tensor = ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16))
>>> tensor = ttnn.to_dtype(tensor, dtype=ttnn.uint16)
)doc",
ttnn::pybind_arguments_t{py::arg("tensor"), py::arg("dtype")});

bind_registered_operation(
module,
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct operation_without_validation_t {
template <typename... args_t>
static auto input_tensors_to_validate(args_t&&... args) {
return std::make_tuple();
};
}

template <typename... args_t>
static auto execute(args_t&&... args) {
Expand Down
111 changes: 105 additions & 6 deletions ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace core {

namespace detail {

inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) {
inline Tensor convert_to_cpp_supported_dtype(const Tensor& input_tensor) {
auto input_dtype = input_tensor.get_dtype();

auto buffer = std::visit(
Expand Down Expand Up @@ -55,8 +55,6 @@ inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) {
unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false);
buffer = owned_buffer::create<float>(std::move(float_unpacked_data));
input_dtype = DataType::FLOAT32;
} else {
TT_THROW("Unsupported input data type");
}

return std::visit(
Expand All @@ -77,6 +75,107 @@ inline Tensor convert_to_cpp_dtypes(const Tensor& input_tensor) {
buffer);
}

template <typename NewT, typename OldT>
inline std::vector<NewT> cast(const borrowed_buffer::Buffer<OldT>& input_buffer) {
std::vector<NewT> output_vector(input_buffer.size());
for (auto index = 0; index < input_buffer.size(); ++index) {
auto convert_value = [](auto&& value) {
if constexpr (std::is_same_v<OldT, ::bfloat16>) {
return value.to_float();
} else if constexpr (std::is_same_v<NewT, ::bfloat16>) {
return static_cast<float>(value);
} else {
return value;
}
};
auto value = input_buffer[index];
output_vector[index] = static_cast<NewT>(convert_value(value));
}
return output_vector;
}

template <typename T>
Tensor create_owned_tensor(std::vector<T>&& data, const Shape& shape, DataType data_type, Layout layout) {
auto buffer = owned_buffer::create(std::move(data));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor(std::move(storage), shape, data_type, layout);
}

template <typename T>
inline Tensor create_tensor_from_buffer(
const borrowed_buffer::Buffer<T>& input_buffer, const Shape& shape, const DataType& dtype) {
switch (dtype) {
case DataType::UINT16: {
auto data = cast<uint16_t, T>(input_buffer);
return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR);
}
case DataType::INT32: {
auto data = cast<int32_t, T>(input_buffer);
return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR);
}
case DataType::UINT32: {
auto data = cast<uint32_t, T>(input_buffer);
return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR);
}
case DataType::FLOAT32: {
auto data = cast<float, T>(input_buffer);
return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR);
}
case DataType::BFLOAT16: {
auto data = cast<::bfloat16, T>(input_buffer);
return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR);
}
case DataType::BFLOAT8_B: {
auto data = cast<float, T>(input_buffer);
auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false);
auto buffer = owned_buffer::create<uint32_t>(std::move(uint32_vector));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR).to(ttnn::TILE_LAYOUT);
}
case DataType::BFLOAT4_B: {
auto data = cast<float, T>(input_buffer);
auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false);
auto buffer = owned_buffer::create<uint32_t>(std::move(uint32_vector));
auto storage = OwnedStorage{std::move(buffer)};
return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR).to(ttnn::TILE_LAYOUT);
}
default: {
TT_THROW(fmt::format("Unsupported DataType: {}", dtype));
break;
}
}
}

inline Tensor convert_to_dtype(const Tensor& input_tensor, const DataType& dtype) {
auto input_dtype = input_tensor.get_dtype();

switch (input_dtype) {
case DataType::UINT16: {
auto buffer = host_buffer::get_as<uint16_t>(input_tensor);
return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype);
}
case DataType::INT32: {
auto buffer = host_buffer::get_as<int32_t>(input_tensor);
return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype);
}
case DataType::UINT32: {
auto buffer = host_buffer::get_as<uint32_t>(input_tensor);
return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype);
}
case DataType::FLOAT32: {
auto buffer = host_buffer::get_as<float>(input_tensor);
return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype);
}
case DataType::BFLOAT16: {
auto buffer = host_buffer::get_as<::bfloat16>(input_tensor);
return create_tensor_from_buffer(buffer, input_tensor.get_shape(), dtype);
}
default: TT_THROW(fmt::format("Unsupported DataType: {}", input_dtype)); break;
}

return input_tensor;
}

} // namespace detail

struct ToDtype {
Expand All @@ -95,7 +194,7 @@ struct ToDtype {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) {
return std::make_tuple(tensor_arg);
};
}

// TODO: Move to cpp once we merge with tt_eager
static Tensor execute(const ttnn::Tensor& input_tensor, const ttnn::DataType& dtype) {
Expand All @@ -110,8 +209,8 @@ struct ToDtype {
TT_THROW("Only ROW_MAJOR_LAYOUT is supported");
}

auto output_tensor = input_tensor;
return output_tensor;
auto intermediate_tensor = detail::convert_to_cpp_supported_dtype(input_tensor);
return detail::convert_to_dtype(intermediate_tensor, dtype);
};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct ToMemoryConfig {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) {
return std::make_tuple(tensor_arg);
};
}

// TODO: Move to cpp once we merge with tt_eager
static Tensor execute(
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct Softmax {
}
};

struct LayerNorm : tt::operations::primary::LayerNorm {
struct LayerNorm {
static inline const std::array<ttnn::TensorSchema, 4> input_tensor_schemas() {
return {
ttnn::TensorSchema{
Expand Down Expand Up @@ -123,7 +123,7 @@ struct LayerNorm : tt::operations::primary::LayerNorm {
}
};

struct RMSNorm : tt::operations::primary::LayerNorm {
struct RMSNorm {
static inline const std::array<ttnn::TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{
Expand Down
14 changes: 8 additions & 6 deletions ttnn/cpp/ttnn/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ struct ConcatenateHeads : public tt::tt_metal::NlpConcatHeads {
};

template <bool in_place>
struct AttentionSoftmax : public tt::operations::primary::Softmax {
struct ExecuteAttentionSoftmax {
static inline const std::array<TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false},
Expand Down Expand Up @@ -274,7 +274,7 @@ struct AttentionSoftmax : public tt::operations::primary::Softmax {
auto kernel_config_val = init_device_compute_kernel_config(
input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false);
auto output_tensor = operation::run(
AttentionSoftmax{
tt::operations::primary::Softmax{
head_size,
in_place,
memory_config.value_or(input_tensor.memory_config()),
Expand All @@ -301,10 +301,12 @@ constexpr auto split_query_key_value_and_split_heads =
constexpr auto concatenate_heads =
ttnn::register_operation<ttnn::operations::transformer::ConcatenateHeads>("ttnn::transfomer::concatenate_heads");

constexpr auto attention_softmax = ttnn::register_operation<ttnn::operations::transformer::AttentionSoftmax<false>>(
"ttnn::transfomer::attention_softmax");
constexpr auto attention_softmax_ = ttnn::register_operation<ttnn::operations::transformer::AttentionSoftmax<true>>(
"ttnn::transfomer::attention_softmax_");
constexpr auto attention_softmax =
ttnn::register_operation<ttnn::operations::transformer::ExecuteAttentionSoftmax<false>>(
"ttnn::transfomer::attention_softmax");
constexpr auto attention_softmax_ =
ttnn::register_operation<ttnn::operations::transformer::ExecuteAttentionSoftmax<true>>(
"ttnn::transfomer::attention_softmax_");
} // namespace transformer

} // namespace ttnn
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() {
template <typename... Args>
inline auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
};
}

inline Tensor execute(
const Tensor& input_tensor,
Expand All @@ -53,7 +53,7 @@ inline Tensor execute(
} // namespace detail

template <UnaryOpType unary_op_type>
struct Unary : public EltwiseUnary {
struct Unary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
Expand All @@ -68,7 +68,7 @@ struct Unary : public EltwiseUnary {
};

template <UnaryOpType unary_op_type>
struct UnaryWithFastAndApproximateMode : public EltwiseUnary {
struct UnaryWithFastAndApproximateMode {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
Expand Down
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def manage_config(name, value):
to_device,
from_device,
to_layout,
to_dtype,
reshape,
to_memory_config,
deallocate,
Expand Down
7 changes: 7 additions & 0 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,13 @@ def _golden_function(tensor, *args, **kwargs):
to_layout = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.core.to_layout)


def _golden_function(tensor, *args, **kwargs):
return tensor


to_dtype = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.core.to_dtype)


def _clone_validate_input_tensors(operation_name, input_tensor, *args, **kwargs):
ttnn.validate_input_tensor(
operation_name,
Expand Down

0 comments on commit ec3a8dd

Please sign in to comment.