Skip to content

Commit

Permalink
Move creation functions out of numpy/functions
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 3, 2024
1 parent 23b8d7f commit 864df86
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 111 deletions.
59 changes: 57 additions & 2 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,50 @@ Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device
}
}

template <typename T>
static Tensor full(
uint8_t queue_id,
const tt::tt_metal::LegacyShape& shape,
T value,
const Layout layout,
const std::vector<Device*>& devices,
const MemoryConfig& output_mem_config,
std::optional<Tensor> optional_output_tensor) {
constexpr DataType data_type = tt::tt_metal::convert_to_data_type<T>();
TensorSpec tensor_spec(
shape.logical_shape(),
TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape));
auto owned_buffer = tt::tt_metal::owned_buffer::create<T>(tensor_spec.padded_shape().volume());
// TODO: 15061 - Generalize the header to support generic vector / view types.
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);

if (!optional_output_tensor.has_value()) {
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout);
if (!devices.empty()) {
output = output.to(devices, output_mem_config);
}
return output;
} else {
const auto buffers = optional_output_tensor->buffers();
const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr);

for (auto* buffer : buffers) {
if (using_fast_dispatch) {
auto& cmd_queue = buffer->device()->command_queue(queue_id);
if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false);
} else {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false);
}
} else {
tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get());
}
}

return *optional_output_tensor;
}
}

template <typename T>
inline ttnn::Tensor full_impl(
uint8_t queue_id,
Expand All @@ -122,8 +166,19 @@ inline ttnn::Tensor full_impl(
MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config()
: memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);

return numpy::full_impl(
queue_id, shape_value, fill_value, dtype_value, layout_value, workers, mem_cfg, optional_output_tensor);
auto concrete_full = [&]<typename FillValueType>(FillValueType concrete_fill_value) {
return full<FillValueType>(
queue_id, shape_value, concrete_fill_value, layout_value, workers, mem_cfg, optional_output_tensor);
};

switch (dtype_value) {
case DataType::UINT8: return concrete_full(static_cast<uint8_t>(fill_value));
case DataType::UINT16: return concrete_full(static_cast<uint16_t>(fill_value));
case DataType::UINT32: return concrete_full(static_cast<uint32_t>(fill_value));
case DataType::FLOAT32: return concrete_full(static_cast<float>(fill_value));
case DataType::BFLOAT16: return concrete_full(static_cast<::bfloat16>(static_cast<float>(fill_value)));
default: TT_THROW("Unsupported DataType!");
}
}

template <typename T>
Expand Down
111 changes: 2 additions & 109 deletions ttnn/cpp/ttnn/operations/numpy/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,113 +26,6 @@ using tt::tt_metal::MemoryConfig;
using tt::tt_metal::OwnedStorage;
using tt::tt_metal::StorageType;
using tt::tt_metal::Tensor;
namespace detail {

template <typename T>
constexpr static DataType get_data_type() {
if constexpr (std::is_same_v<T, uint8_t>) {
return DataType::UINT8;
} else if constexpr (std::is_same_v<T, uint16_t>) {
return DataType::UINT16;
} else if constexpr (std::is_same_v<T, int32_t>) {
return DataType::INT32;
} else if constexpr (std::is_same_v<T, uint32_t>) {
return DataType::UINT32;
} else if constexpr (std::is_same_v<T, float>) {
return DataType::FLOAT32;
} else if constexpr (std::is_same_v<T, ::bfloat16>) {
return DataType::BFLOAT16;
} else {
TT_THROW("Unsupported DataType!");
}
}

template <typename T>
static Tensor full(
uint8_t queue_id,
const tt::tt_metal::LegacyShape& shape,
T value,
const Layout layout,
const std::vector<Device*>& devices,
const MemoryConfig& output_mem_config,
std::optional<Tensor> optional_output_tensor) {
constexpr DataType data_type = detail::get_data_type<T>();
TensorSpec tensor_spec(
shape.logical_shape(),
TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape));
auto owned_buffer = tt::tt_metal::owned_buffer::create<T>(tensor_spec.padded_shape().volume());
// TODO: 15061 - Generalize the header to support generic vector / view types.
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);

if (!optional_output_tensor.has_value()) {
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout);
if (!devices.empty()) {
output = output.to(devices, output_mem_config);
}
return output;
} else {
const auto buffers = optional_output_tensor->buffers();
const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr);

for (auto* buffer : buffers) {
if (using_fast_dispatch) {
auto& cmd_queue = buffer->device()->command_queue(queue_id);
if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false);
} else {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false);
}
} else {
tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get());
}
}

return *optional_output_tensor;
}
}

} // namespace detail

template <typename T>
static Tensor full_impl(
uint8_t queue_id,
const tt::tt_metal::LegacyShape& shape,
const T value,
const DataType data_type,
const Layout layout,
const std::vector<Device*>& devices,
const MemoryConfig& output_mem_config,
std::optional<Tensor> optional_output_tensor) {
switch (data_type) {
case DataType::UINT8: {
return detail::full<uint8_t>(
queue_id, shape, uint8_t(value), layout, devices, output_mem_config, optional_output_tensor);
}
case DataType::UINT16: {
return detail::full<uint16_t>(
queue_id, shape, uint16_t(value), layout, devices, output_mem_config, optional_output_tensor);
}
case DataType::UINT32: {
return detail::full<uint32_t>(
queue_id, shape, uint32_t(value), layout, devices, output_mem_config, optional_output_tensor);
}
case DataType::FLOAT32: {
return detail::full<float>(
queue_id, shape, float(value), layout, devices, output_mem_config, optional_output_tensor);
}
case DataType::BFLOAT16: {
return detail::full<::bfloat16>(
queue_id,
shape,
::bfloat16(static_cast<float>(value)),
layout,
devices,
output_mem_config,
optional_output_tensor);
}
default: TT_THROW("Unsupported DataType!");
}
}

template <typename T>
static Tensor arange(
Expand All @@ -143,7 +36,7 @@ static Tensor arange(
Device* device = nullptr,
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
constexpr DataType data_type = detail::get_data_type<T>();
constexpr DataType data_type = tt::tt_metal::convert_to_data_type<T>();
// Current implementation restrictions
TT_ASSERT(step > 0, "Step must be greater than 0");
TT_ASSERT(start < stop, "Start must be less than step");
Expand Down Expand Up @@ -628,7 +521,7 @@ static void seed(std::size_t seed) { RANDOM_GENERATOR = std::mt19937(seed); }

template <typename T>
static Tensor uniform(T low, T high, const tt::tt_metal::LegacyShape& shape, const Layout layout = Layout::ROW_MAJOR) {
constexpr DataType data_type = detail::get_data_type<T>();
constexpr DataType data_type = tt::tt_metal::convert_to_data_type<T>();

auto owned_buffer = tt::tt_metal::owned_buffer::create<T>(tt::tt_metal::compute_volume(shape));

Expand Down
19 changes: 19 additions & 0 deletions ttnn/cpp/ttnn/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ enum class DataType {
INVALID = 8,
};

template <typename T>
consteval inline DataType convert_to_data_type() {
if constexpr (std::is_same_v<T, uint8_t>) {
return DataType::UINT8;
} else if constexpr (std::is_same_v<T, uint16_t>) {
return DataType::UINT16;
} else if constexpr (std::is_same_v<T, int32_t>) {
return DataType::INT32;
} else if constexpr (std::is_same_v<T, uint32_t>) {
return DataType::UINT32;
} else if constexpr (std::is_same_v<T, float>) {
return DataType::FLOAT32;
} else if constexpr (std::is_same_v<T, ::bfloat16>) {
return DataType::BFLOAT16;
} else {
static_assert(false, "Unsupported DataType!");
}
}

inline bool is_floating_point(DataType dtype) {
switch (dtype) {
case DataType::BFLOAT16:
Expand Down

0 comments on commit 864df86

Please sign in to comment.