Skip to content

Commit

Permalink
Minor refactor of pytensor and tensor implementation files.
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 18, 2024
1 parent 5d0170e commit 5171d6c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
17 changes: 7 additions & 10 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <chrono>
#include <memory>

#include "tensor.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "tt_metal/graph/graph_tracking.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tt_stl/overloaded.hpp"
Expand Down Expand Up @@ -150,15 +150,12 @@ Tensor create_tt_tensor_from_py_data(
std::size_t py_data_ptr,
const TensorSpec& tensor_spec,
Device* device,
bool override_enable_borrow,
bool force_disable_borrow,
const std::function<void()>& on_creation_callback,
const std::function<void()>& on_destruction_callback) {
auto layout = tensor_spec.layout();

bool enable_borrow = true;
if (layout != Layout::ROW_MAJOR or override_enable_borrow) {
enable_borrow = false;
}
const bool enable_borrow = layout == Layout::ROW_MAJOR and not force_disable_borrow;

auto data_type = tensor_spec.data_type();
std::size_t num_elements = tensor_spec.logical_shape().volume();
Expand Down Expand Up @@ -256,7 +253,7 @@ Tensor convert_python_tensor_to_tt_tensor(
const std::optional<Tile>& optional_tile,
const MemoryConfig& memory_config,
Device* device,
bool override_enable_borrow = false) {
bool force_disable_borrow = false) {
GraphTracker::instance().track_function_start(
"tt::tt_metal::detail::convert_python_tensor_to_tt_tensor",
py_tensor,
Expand All @@ -265,7 +262,7 @@ Tensor convert_python_tensor_to_tt_tensor(
optional_tile,
memory_config,
device,
override_enable_borrow);
force_disable_borrow);
py::object torch = py::module_::import("torch");
py::object np = py::module_::import("numpy");

Expand Down Expand Up @@ -342,7 +339,7 @@ Tensor convert_python_tensor_to_tt_tensor(
num_elements = py::cast<std::size_t>(contiguous_py_tensor.attr("numel")());
py_data_ptr = py::cast<std::size_t>(contiguous_py_tensor.attr("data_ptr")());
} else if (py::isinstance(py_tensor, np.attr("ndarray"))) {
TT_FATAL(!override_enable_borrow, "Disabling borrowed buffers for numpy tensors is untested!");
TT_FATAL(!force_disable_borrow, "Disabling borrowed buffers for numpy tensors is untested!");

contiguous_py_tensor = np.attr("ascontiguousarray")(py_tensor);

Expand Down Expand Up @@ -429,7 +426,7 @@ Tensor convert_python_tensor_to_tt_tensor(
auto on_creation_callback = [tensor = contiguous_py_tensor] { tensor.inc_ref(); };
auto on_destruction_callback = [tensor = contiguous_py_tensor] { tensor.dec_ref(); };
auto output = create_tt_tensor_from_py_data(
py_data_ptr, tensor_spec, device, override_enable_borrow, on_creation_callback, on_destruction_callback);
py_data_ptr, tensor_spec, device, force_disable_borrow, on_creation_callback, on_destruction_callback);

if (device) {
output = output.to(device, memory_config);
Expand Down
43 changes: 17 additions & 26 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <memory>
#include <utility>

#include "common/bfloat16.hpp"
#include "tt_metal/common/bfloat16.hpp"
#include "impl/buffers/buffer_constants.hpp"
#include "tt_metal/tt_stl/overloaded.hpp"
#include "tensor_ops.hpp"
#include "ttnn/tensor/tensor_ops.hpp"
#include "ttnn/tensor/tensor_impl.hpp"
#include "ttnn/tensor/tensor_impl_wrapper.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
Expand All @@ -28,26 +28,9 @@
#include "ttnn/tensor/layout/tensor_layout.hpp"
#include "ttnn/distributed/api.hpp"

using namespace tt::constants;

namespace tt::tt_metal {
namespace {

MemoryConfig extract_memory_config(const Storage& storage) {
return std::visit(
[](const auto& storage) -> MemoryConfig {
using T = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<T, DeviceStorage>) {
return storage.memory_config();
} else if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
return storage.memory_config();
} else {
return MemoryConfig{};
}
},
storage);
}

template <typename T>
Tensor create_owned_tensor_from_span(tt::stl::Span<const T> data, const TensorSpec& spec) {
// TODO: support tilized layouts.
Expand Down Expand Up @@ -154,14 +137,22 @@ void Tensor::TensorAttributes::update_main_thread_ref_count(Device* worker, uint

Tensor::Tensor(
Storage storage, const ttnn::Shape& shape, DataType dtype, Layout layout, const std::optional<Tile>& tile) {
if (tile.has_value()) {
if (tile->get_tile_shape()[0] != TILE_WIDTH or tile->get_tile_shape()[1] != TILE_HEIGHT) {
tt::log_warning(
"only matmul op and ccl all-gather currently supports the customized tile shape: {}",
tile->get_tile_shape());
}
using namespace tt::constants;

if (tile.has_value() and //
(tile->get_tile_shape()[0] != TILE_WIDTH or tile->get_tile_shape()[1] != TILE_HEIGHT)) {
tt::log_warning(
"only matmul op and ccl all-gather currently supports the customized tile shape: {}",
tile->get_tile_shape());
}
auto memory_config = extract_memory_config(storage);

const auto memory_config = std::visit(
tt::stl::overloaded{
[](const DeviceStorage& s) { return s.memory_config(); },
[](const MultiDeviceStorage& s) { return s.memory_config(); },
[](auto&&) { return MemoryConfig{}; }},
storage);

init(
std::move(storage),
TensorSpec(
Expand Down

0 comments on commit 5171d6c

Please sign in to comment.