From 5b04331a431e60048ce625b4449ba63c3869ccf3 Mon Sep 17 00:00:00 2001
From: Brian Liu <bliu@tenstorrent.com>
Date: Tue, 3 Dec 2024 16:16:42 +0000
Subject: [PATCH] #13127: Use tensor spec in conversion between python and tt
 tensors - Significant changes in ttnn/cpp/pybind11/pytensor.cpp:   * Use
 tensor spec in create_owned_tensor   * Add conversion between ROW_MAJOR and
 TILE layouts for ttnn.Tensor(...)/tensor.to(...) APIs     ** For
 ttnn.Tensor(python_tensor, ...), handling is now internal and not through
 .to(layout)     ** For ttnn.Tensor(float_vector, ...), use .to(layout) to
 convert to TILE if needed     ** Make tilize, tilize_to_list, and untilize
 python utility functions no-ops and mark as deprecated   * Add analogous
 create_row_major_owned_buffer from tensor buffer     ** Commonize handling of
 BFLOAT8_B/BFLOAT4_B as float tensors/buffers     ** Always use OwnedBuffer if
 conversion to/from TILE layout is required   * Automatically deduce python
 dtype from owned buffers instead of mapping based on tt dtype   * Set
 defaults for pybound init so it is more usable   * Invert meaning of
 enable_borrow (now called override_enable_borrow)     ** Make enable_borrow
 internal to create_tt_tensor_from_py_data - Update tensor init documentation
 and sample code for tile arg and creating tensors on device - Add
 memory_config() to TensorSpec - Commonize tt_dtype_to_torch_dtype and
 tt_dtype_to_np_dtype dicts across ttnn unit tests - Add test for host side
 tensor conversion in tests/ttnn/unit_tests/tensor/test_tensor_conversion.py -
 Add new tests/ttnn/unit_tests/tensor/test_tensor_creation.py tests   *
 Coverage for directly creating device tensors with ttnn.Tensor(...)   *
 Coverage for API parity between ttnn.from_device/ttnn.to_device and
 ttnn.Tensor(...)/tensor.to(...)

---
 models/utility_functions.py                   | 100 +----
 .../unit_testing/misc/test_indexed_fill.py    |  10 +-
 .../unit_testing/misc/test_non_zero.py        |  10 +-
 .../unit_testing/misc/test_sharded_tensor.py  |   7 +-
 .../tensor/test_tensor_conversion.py          |  40 +-
 .../unit_tests/tensor/test_tensor_creation.py | 122 ++++++
 .../tensor/test_tensor_serialization.py       |  10 +-
 tests/ttnn/unit_tests/test_print_tensor.py    |  10 +-
 tests/ttnn/utils_for_testing.py               |  27 ++
 ttnn/cpp/pybind11/pytensor.cpp                | 371 ++++++++++++------
 ttnn/cpp/ttnn/tensor/tensor_spec.hpp          |   1 +
 ttnn/tt_lib/fused_ops/softmax.py              |   2 +-
 ttnn/tt_lib/utils.py                          | 100 +----
 13 files changed, 444 insertions(+), 366 deletions(-)
 create mode 100644 tests/ttnn/unit_tests/tensor/test_tensor_creation.py

diff --git a/models/utility_functions.py b/models/utility_functions.py
index 2b652f81542..f13fd48d8ca 100644
--- a/models/utility_functions.py
+++ b/models/utility_functions.py
@@ -15,6 +15,8 @@
 
 from ttnn.device import Arch
 
+from typing_extensions import deprecated
+
 
 ### Math operations ###
 def _nearest_32(x):
@@ -430,108 +432,22 @@ def convert_act_2d_matrix(activation, kernel_y, kernel_x, stride_y, stride_x, pa
 
 
 ### Tilizing / Untilizing ###
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def tilize(x):
-    """
-    This function tilizes a tensor. The last two tensor dims must be divisible by 32, after which this function
-    produces row major tiles and creates faces. The output of this function is a flattened list that
-    we can send to the device.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
-    """
-    nearest_32 = _nearest_32
-
-    assert isinstance(
-        x, (torch.Tensor, np.ndarray)
-    ), "Input to this function must be an instance of torch.Tensor or np.array"
-    assert len(x.shape) == 4, "Only 4D tensors suppported"
-    assert (x.shape[-2] % 32) == 0 and (
-        x.shape[-1] % 32
-    ) == 0, "The last two dimensions of the tensor must be divisible by 32"
-
-    if isinstance(x, torch.Tensor):
-        ret = torch.zeros(np.prod(x.shape))
-    else:
-        ret = np.zeros(np.prod(x.shape))
-
-    idx = 0
-    for B in range(x.shape[0]):
-        for C in range(x.shape[1]):
-            for H in range(0, x.shape[2], 32):
-                for W in range(0, x.shape[3], 32):
-                    unfaced_tile = x[B, C, H : H + 32, W : W + 32]
-
-                    face0 = unfaced_tile[:16, :16]
-                    face1 = unfaced_tile[:16, 16:]
-                    face2 = unfaced_tile[16:, :16]
-                    face3 = unfaced_tile[16:, 16:]
-
-                    for face in (face0, face1, face2, face3):
-                        ret[idx : idx + 256] = face.reshape(-1)
-                        idx += 256
-
-    return ret.reshape(x.shape)
+    return x
 
 
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def tilize_to_list(x):
     """
-    Tilize a PyTorch and then return the values as a flat list. The last two
-    tensor dims must be divisible by 32, after which this function produces row
-    major tiles and creates faces.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
+    Returns a flattened list of the tensor
     """
-
     return tilize(x).reshape(-1).tolist()
 
 
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def untilize(x):
-    """
-    This function untilizes a tensor to row major format.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
-    """
-    nearest_32 = _nearest_32
-
-    assert isinstance(x, (torch.Tensor, np.ndarray)), "Input to this function must be an instance of torch.Tensor"
-    assert len(x.shape) == 4, "Only 4D tensors suppported"
-    assert (x.shape[-2] % 32) == 0 and (
-        x.shape[-1] % 32
-    ) == 0, "The last two dimensions of the tensor must be divisible by 32"
-
-    if isinstance(x, torch.Tensor):
-        ret = torch.zeros(x.shape, dtype=x.dtype)
-    else:
-        ret = np.zeros(x.shape, dtype=x.dtype)
-
-    for B in range(x.shape[0]):
-        for C in range(x.shape[1]):
-            x_hw = x[B, C, :].reshape(-1)
-            hw = 0
-            for h in range(0, x.shape[2], 32):
-                for w in range(0, x.shape[3], 32):
-                    f_tile = x_hw[hw : hw + 256].reshape(16, 16)
-                    ret[B, C, h : h + 16, w : w + 16] = f_tile
-
-                    f_tile = x_hw[hw + 256 : hw + 512].reshape(16, 16)
-                    ret[B, C, h : h + 16, w + 16 : w + 32] = f_tile
-
-                    f_tile = x_hw[hw + 512 : hw + 768].reshape(16, 16)
-                    ret[B, C, h + 16 : h + 32, w : w + 16] = f_tile
-
-                    f_tile = x_hw[hw + 768 : hw + 1024].reshape(16, 16)
-                    ret[B, C, h + 16 : h + 32, w + 16 : w + 32] = f_tile
-                    hw += 1024  # traverse tiles in RM-order
-
-    return ret
+    return x
 
 
 ### Measuring accuracy and other metrics ###
diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py
index 4245a35c3c2..3044f6bbb89 100644
--- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py
+++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py
@@ -9,15 +9,7 @@
 import ttnn
 import torch
 import numpy as np
-
-
-tt_dtype_to_torch_dtype = {
-    ttnn.uint16: torch.int16,
-    ttnn.uint32: torch.int32,
-    ttnn.float32: torch.float,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-}
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
 
 
 @pytest.mark.parametrize(
diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py
index e672856c3e2..b280d8e0b66 100644
--- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py
+++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py
@@ -10,15 +10,7 @@
 import torch
 import numpy as np
 import ttnn
-
-
-tt_dtype_to_torch_dtype = {
-    ttnn.uint16: torch.int16,
-    ttnn.uint32: torch.int32,
-    ttnn.float32: torch.float,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-}
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
 
 
 @pytest.mark.parametrize(
diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py
index 050099d62d6..1c19b8137e6 100644
--- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py
+++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py
@@ -11,14 +11,9 @@
 import ttnn
 
 from models.utility_functions import get_debug_tensor
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
 from enum import Enum
 
-tt_dtype_to_torch_dtype = {
-    ttnn.uint32: torch.int32,
-    ttnn.uint16: torch.int16,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-}
 TILE_WIDTH = 32
 TILE_HEIGHT = 32
 
diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py
index 63442308831..2fff322de44 100644
--- a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py
+++ b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py
@@ -11,29 +11,10 @@
 import numpy as np
 
 import ttnn
-
-tt_dtype_to_torch_dtype = {
-    ttnn.uint8: torch.uint8,
-    ttnn.uint16: torch.int16,
-    ttnn.uint32: torch.int32,
-    ttnn.int32: torch.int32,
-    ttnn.float32: torch.float,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-    ttnn.bfloat4_b: torch.float,
-}
-
-tt_dtype_to_np_dtype = {
-    ttnn.uint8: np.ubyte,
-    ttnn.uint16: np.int16,
-    ttnn.uint32: np.int32,
-    ttnn.int32: np.int32,
-    ttnn.float32: np.float32,
-    ttnn.bfloat8_b: np.float32,
-    ttnn.bfloat4_b: np.float32,
-}
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype, tt_dtype_to_np_dtype
 
 
+@pytest.mark.parametrize("convert_to_device", [True, False])
 @pytest.mark.parametrize(
     "tt_dtype",
     [
@@ -49,7 +30,7 @@
 )
 @pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
 @pytest.mark.parametrize("python_lib", [torch, np])
-def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
+def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, convert_to_device, device):
     torch.manual_seed(0)
 
     if python_lib == torch:
@@ -64,7 +45,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
 
     elif python_lib == np:
         if tt_dtype == ttnn.bfloat16:
-            pytest.skip("ttnn.bloat16 dtype is not supported yet for numpy tensors!")
+            pytest.skip("ttnn.bfloat16 dtype is not supported yet for numpy tensors!")
         dtype = tt_dtype_to_np_dtype[tt_dtype]
 
         if dtype in {np.ubyte, np.int16, np.int32}:
@@ -82,8 +63,9 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
         assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED
         assert tt_tensor.layout == ttnn.ROW_MAJOR_LAYOUT
 
-    tt_tensor = tt_tensor.to(device)
-    tt_tensor = tt_tensor.cpu()
+    if convert_to_device:
+        tt_tensor = tt_tensor.to(device)
+        tt_tensor = tt_tensor.cpu()
 
     if python_lib == torch:
         py_tensor_after_round_trip = tt_tensor.to_torch()
@@ -123,6 +105,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
 }
 
 
+@pytest.mark.parametrize("convert_to_device", [True, False])
 @pytest.mark.parametrize(
     "python_dtype_str",
     [
@@ -137,7 +120,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
 )
 @pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
 @pytest.mark.parametrize("python_lib", [torch, np])
-def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, device):
+def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, convert_to_device, device):
     torch.manual_seed(0)
 
     if python_lib == torch:
@@ -165,8 +148,9 @@ def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str
     tt_tensor = ttnn.Tensor(py_tensor)
     assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED
 
-    tt_tensor = tt_tensor.to(device)
-    tt_tensor = tt_tensor.cpu()
+    if convert_to_device:
+        tt_tensor = tt_tensor.to(device)
+        tt_tensor = tt_tensor.cpu()
 
     if python_lib == torch:
         py_tensor_after_round_trip = tt_tensor.to_torch()
diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_creation.py b/tests/ttnn/unit_tests/tensor/test_tensor_creation.py
new file mode 100644
index 00000000000..f0615abba97
--- /dev/null
+++ b/tests/ttnn/unit_tests/tensor/test_tensor_creation.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+
+import os
+import pathlib
+
+import torch
+import numpy as np
+
+import ttnn
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
+
+
+@pytest.mark.parametrize(
+    "layout",
+    [
+        ttnn.ROW_MAJOR_LAYOUT,
+        ttnn.TILE_LAYOUT,
+    ],
+)
+@pytest.mark.parametrize(
+    "tt_dtype",
+    [
+        ttnn.uint8,
+        ttnn.uint16,
+        ttnn.uint32,
+        ttnn.int32,
+        ttnn.float32,
+        ttnn.bfloat16,
+        ttnn.bfloat8_b,
+        ttnn.bfloat4_b,
+    ],
+)
+@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
+def test_tensor_creation(shape, tt_dtype, layout, device):
+    torch.manual_seed(0)
+
+    dtype = tt_dtype_to_torch_dtype[tt_dtype]
+
+    if dtype in {torch.uint8, torch.int16, torch.int32}:
+        py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype)
+    else:
+        py_tensor = torch.rand(shape, dtype=dtype)
+
+    tt_tensor = ttnn.Tensor(py_tensor, tt_dtype, device, layout)
+
+    tt_tensor = tt_tensor.cpu()
+
+    py_tensor_after_round_trip = tt_tensor.to_torch()
+
+    assert py_tensor.dtype == py_tensor_after_round_trip.dtype
+    assert py_tensor.shape == py_tensor_after_round_trip.shape
+
+    allclose_kwargs = {}
+    if tt_dtype == ttnn.bfloat8_b:
+        allclose_kwargs = dict(atol=1e-2)
+    elif tt_dtype == ttnn.bfloat4_b:
+        allclose_kwargs = dict(atol=0.2)
+
+    passing = torch.allclose(py_tensor, py_tensor_after_round_trip, **allclose_kwargs)
+    assert passing
+
+
+@pytest.mark.parametrize(
+    "layout",
+    [
+        ttnn.ROW_MAJOR_LAYOUT,
+        ttnn.TILE_LAYOUT,
+    ],
+)
+@pytest.mark.parametrize(
+    "tt_dtype",
+    [
+        ttnn.uint8,
+        ttnn.uint16,
+        ttnn.uint32,
+        ttnn.int32,
+        ttnn.float32,
+        ttnn.bfloat16,
+        ttnn.bfloat8_b,
+        ttnn.bfloat4_b,
+    ],
+)
+@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
+def test_tensor_creation_api_parity(shape, tt_dtype, layout, device):
+    torch.manual_seed(0)
+
+    if tt_dtype in (ttnn.bfloat8_b, ttnn.bfloat4_b) and layout == ttnn.ROW_MAJOR_LAYOUT:
+        pytest.skip("{} is only valid for ttnn.TILE_LAYOUT!".format(tt_dtype))
+
+    dtype = tt_dtype_to_torch_dtype[tt_dtype]
+
+    if dtype in {torch.uint8, torch.int16, torch.int32}:
+        py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype)
+    else:
+        py_tensor = torch.rand(shape, dtype=dtype)
+
+    tt_tensor_1 = ttnn.Tensor(py_tensor, tt_dtype, device, layout)
+    tt_tensor_2 = ttnn.from_torch(py_tensor, tt_dtype, device=device, layout=layout)
+
+    tt_tensor_1 = tt_tensor_1.cpu()
+    tt_tensor_2 = tt_tensor_2.cpu()
+
+    py_tensor_after_round_trip_1 = tt_tensor_1.to_torch()
+    py_tensor_after_round_trip_2 = tt_tensor_2.to_torch()
+    py_tensor_after_round_trip_3 = ttnn.to_torch(tt_tensor_1)
+    py_tensor_after_round_trip_4 = ttnn.to_torch(tt_tensor_2)
+
+    allclose_kwargs = {}
+    if tt_dtype == ttnn.bfloat8_b:
+        allclose_kwargs = dict(atol=1e-2)
+    elif tt_dtype == ttnn.bfloat4_b:
+        allclose_kwargs = dict(atol=0.2)
+
+    passing = torch.allclose(py_tensor, py_tensor_after_round_trip_1, **allclose_kwargs)
+    passing = torch.allclose(py_tensor, py_tensor_after_round_trip_2, **allclose_kwargs)
+    passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs)
+    passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs)
+    assert passing
diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py b/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py
index 1db497c0843..a56dde83d19 100644
--- a/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py
+++ b/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py
@@ -11,15 +11,7 @@
 import numpy as np
 
 import ttnn
-
-tt_dtype_to_torch_dtype = {
-    ttnn.uint16: torch.int16,
-    ttnn.uint32: torch.int32,
-    ttnn.float32: torch.float,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-    ttnn.bfloat4_b: torch.float,
-}
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
 
 
 @pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
diff --git a/tests/ttnn/unit_tests/test_print_tensor.py b/tests/ttnn/unit_tests/test_print_tensor.py
index 66254f7d363..90f1ecd5157 100644
--- a/tests/ttnn/unit_tests/test_print_tensor.py
+++ b/tests/ttnn/unit_tests/test_print_tensor.py
@@ -7,14 +7,8 @@
 import torch
 
 import ttnn
+from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
 
-ttnn_dtype_to_torch_dtype = {
-    ttnn.uint16: torch.int16,
-    ttnn.uint32: torch.int32,
-    ttnn.float32: torch.float,
-    ttnn.bfloat16: torch.bfloat16,
-    ttnn.bfloat8_b: torch.float,
-}
 
 GOLDEN_TENSOR_STRINGS = {
     (
@@ -77,7 +71,7 @@ def test_print(device, dtype, layout, profile, deallocate):
 
     ttnn.set_printoptions(profile=profile)
 
-    torch_dtype = ttnn_dtype_to_torch_dtype[dtype]
+    torch_dtype = tt_dtype_to_torch_dtype[dtype]
     shape = (2, 16, 64, 32)
 
     if torch_dtype in {torch.int16, torch.int32}:
diff --git a/tests/ttnn/utils_for_testing.py b/tests/ttnn/utils_for_testing.py
index fb083a681ff..92849b32e57 100644
--- a/tests/ttnn/utils_for_testing.py
+++ b/tests/ttnn/utils_for_testing.py
@@ -10,6 +10,33 @@
 from models.utility_functions import comp_pcc, comp_equal, divup, roundup
 from typing import Tuple
 
+import ttnn
+import torch
+import numpy as np
+
+
+# Dictionaries for converting dtypes
+tt_dtype_to_torch_dtype = {
+    ttnn.uint8: torch.uint8,
+    ttnn.uint16: torch.int16,
+    ttnn.uint32: torch.int32,
+    ttnn.int32: torch.int32,
+    ttnn.float32: torch.float,
+    ttnn.bfloat16: torch.bfloat16,
+    ttnn.bfloat8_b: torch.float,
+    ttnn.bfloat4_b: torch.float,
+}
+
+tt_dtype_to_np_dtype = {
+    ttnn.uint8: np.ubyte,
+    ttnn.uint16: np.int16,
+    ttnn.uint32: np.int32,
+    ttnn.int32: np.int32,
+    ttnn.float32: np.float32,
+    ttnn.bfloat8_b: np.float32,
+    ttnn.bfloat4_b: np.float32,
+}
+
 
 def construct_pcc_assert_message(message, expected_pytorch_result, actual_pytorch_result):
     messages = []
diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp
index 48a360fb3cb..17de2f3493e 100644
--- a/ttnn/cpp/pybind11/pytensor.cpp
+++ b/ttnn/cpp/pybind11/pytensor.cpp
@@ -66,17 +66,17 @@ void log_external_operation(
 #endif
 
 template <typename T>
-Tensor create_owned_tensor(
-    T* data_ptr,
-    size_t num_elements,
-    tt::stl::Span<const uint32_t> shape,
-    DataType data_type,
-    Layout layout,
-    const std::optional<Tile>& optional_tile = std::nullopt) {
-    auto data = std::vector(data_ptr, data_ptr + num_elements);
+Tensor create_owned_tensor(T* data_ptr, const ttnn::TensorSpec& tensor_spec) {
+    std::size_t num_elements = tensor_spec.logical_shape().volume();
+    auto data = std::vector<T>(data_ptr, data_ptr + num_elements);
     auto buffer = owned_buffer::create(std::move(data));
+
+    if (tensor_spec.layout() == Layout::TILE) {
+        data = tensor_impl::convert_layout_row_major_to_tile(tensor_spec.physical_shape(), tensor_spec.tile(), buffer);
+        buffer = owned_buffer::create(std::move(data));
+    }
     auto storage = OwnedStorage{std::move(buffer)};
-    return Tensor(std::move(storage), shape, data_type, layout, optional_tile);
+    return Tensor(std::move(storage), tensor_spec);
 }
 
 OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector<float>&& data, DataType data_type) {
@@ -138,7 +138,7 @@ Tensor convert_float_vector_to_tt_tensor(
         return tensor;
     }
     auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type);
-    auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile);
+    auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR, tile).to(layout);
     if (device) {
         return tensor.to(device, memory_config.value_or(MemoryConfig{}));
     }
@@ -146,23 +146,30 @@ Tensor convert_float_vector_to_tt_tensor(
 }
 
 Tensor create_tt_tensor_from_py_data(
-    std::size_t num_elements,
     std::size_t py_data_ptr,
-    const ttnn::SmallVector<uint32_t>& shape,
-    const DataType data_type,
-    const std::optional<Tile>& optional_tile,
-    bool enable_borrow,
-    const std::function<void()>& on_creation_callback = [] {},
-    const std::function<void()>& on_destruction_callback = [] {}) {
+    const TensorSpec& tensor_spec,
+    Device* device,
+    bool override_enable_borrow,
+    const std::function<void()>& on_creation_callback,
+    const std::function<void()>& on_destruction_callback) {
+    auto layout = tensor_spec.layout();
+
+    bool enable_borrow = true;
+    if (layout != Layout::ROW_MAJOR or override_enable_borrow) {
+        enable_borrow = false;
+    }
+
+    auto data_type = tensor_spec.data_type();
+    std::size_t num_elements = tensor_spec.logical_shape().volume();
     switch (data_type) {
         case DataType::UINT8: {
             auto data_ptr = reinterpret_cast<uint8_t*>(py_data_ptr);
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         case DataType::UINT16: {
@@ -170,9 +177,9 @@ Tensor create_tt_tensor_from_py_data(
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         case DataType::INT32: {
@@ -180,9 +187,9 @@ Tensor create_tt_tensor_from_py_data(
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         case DataType::UINT32: {
@@ -190,9 +197,9 @@ Tensor create_tt_tensor_from_py_data(
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         case DataType::FLOAT32: {
@@ -200,9 +207,9 @@ Tensor create_tt_tensor_from_py_data(
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         // TODO: This is not supported for numpy
@@ -211,27 +218,28 @@ Tensor create_tt_tensor_from_py_data(
             if (enable_borrow) {
                 auto storage = BorrowedStorage(
                     borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
-                return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return Tensor(std::move(storage), tensor_spec);
             } else {
-                return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile);
+                return create_owned_tensor(data_ptr, tensor_spec);
             }
         }
         case DataType::BFLOAT8_B:
         case DataType::BFLOAT4_B: {
             auto data_ptr = reinterpret_cast<float*>(py_data_ptr);
-            auto data = std::vector<float>(data_ptr, data_ptr + num_elements);
-            auto buffer = owned_buffer::create<float>(std::move(data));
-            auto tile = optional_tile.value_or(Tile());
-            auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile)
-                              .to(Layout::TILE);
-            auto output_float_data = owned_buffer::get_as<float>(tensor).get();
+            auto float_tensor_spec = TensorSpec(
+                tensor_spec.logical_shape(),
+                TensorLayout(DataType::FLOAT32, tensor_spec.page_config(), tensor_spec.memory_config()));
+            auto float_tensor = create_owned_tensor(data_ptr, float_tensor_spec);
+
+            auto tile = tensor_spec.tensor_layout().get_page_config().get_tile();
+            auto output_float_data = owned_buffer::get_as<float>(float_tensor).get();
             auto output_packed_data = data_type == DataType::BFLOAT8_B
                                           ? pack_fp32_vec_as_bfp8_tiles(
                                                 output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile)
                                           : pack_fp32_vec_as_bfp4_tiles(
                                                 output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile);
             auto output_buffer = owned_buffer::create<uint32_t>(std::move(output_packed_data));
-            return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile);
+            return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), tensor_spec);
         }
         default: {
             TT_THROW("Unsupported DataType: {}", data_type);
@@ -242,16 +250,26 @@ Tensor create_tt_tensor_from_py_data(
 
 Tensor convert_python_tensor_to_tt_tensor(
     const py::handle& py_tensor,
-    std::optional<DataType> optional_data_type = std::nullopt,
-    const std::optional<Tile>& optional_tile = std::nullopt,
-    bool enable_borrow = true) {
+    std::optional<DataType> optional_data_type,
+    std::optional<Layout> optional_layout,
+    const std::optional<Tile>& optional_tile,
+    const MemoryConfig& memory_config,
+    Device* device,
+    bool override_enable_borrow = false) {
     GraphTracker::instance().track_function_start(
-        "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", py_tensor, optional_data_type, enable_borrow);
+        "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor",
+        py_tensor,
+        optional_data_type,
+        optional_layout,
+        optional_tile,
+        memory_config,
+        device,
+        override_enable_borrow);
     py::object torch = py::module_::import("torch");
     py::object np = py::module_::import("numpy");
 
     auto py_dtype = py_tensor.attr("dtype");
-    auto shape = py::cast<ttnn::SmallVector<uint32_t>>(py_tensor.attr("shape"));
+    auto shape = ttnn::SimpleShape(py::cast<ttnn::SmallVector<uint32_t>>(py_tensor.attr("shape")));
 
     DataType data_type;
 
@@ -323,7 +341,7 @@ Tensor convert_python_tensor_to_tt_tensor(
         num_elements = py::cast<std::size_t>(contiguous_py_tensor.attr("numel")());
         py_data_ptr = py::cast<std::size_t>(contiguous_py_tensor.attr("data_ptr")());
     } else if (py::isinstance(py_tensor, np.attr("ndarray"))) {
-        TT_FATAL(enable_borrow, "Owned storage for numpy tensors is untested!");
+        TT_FATAL(!override_enable_borrow, "Disabling borrowed buffers for numpy tensors is untested!");
 
         contiguous_py_tensor = np.attr("ascontiguousarray")(py_tensor);
 
@@ -386,17 +404,35 @@ Tensor convert_python_tensor_to_tt_tensor(
         TT_THROW("The argument must be of type torch.Tensor or numpy.ndarray!");
     }
 
+    // TODO: Remove check of num_elements from python against volume of ttnn::SimpleShape
+    TT_FATAL(
+        num_elements == shape.volume(),
+        "Number of elements from python tensor {} must match volume of shape {}!",
+        num_elements,
+        shape.volume());
+
+    Layout layout = optional_layout.value_or(Layout::ROW_MAJOR);
+    if (data_type == DataType::BFLOAT8_B or data_type == DataType::BFLOAT4_B) {
+        if (optional_layout.has_value() and optional_layout.value() != Layout::TILE) {
+            log_warning(
+                tt::LogAlways,
+                "Tensor layout must be Layout::TILE for bfloat8_b or bfloat4_b! Tensor layout will be {} instead of "
+                "the requested {}!",
+                Layout::TILE,
+                optional_layout.value());
+        }
+        layout = Layout::TILE;
+    }
+
+    auto tensor_spec = TensorSpec(shape, TensorLayout(data_type, PageConfig(layout, optional_tile), memory_config));
     auto on_creation_callback = [tensor = contiguous_py_tensor] { tensor.inc_ref(); };
     auto on_destruction_callback = [tensor = contiguous_py_tensor] { tensor.dec_ref(); };
     auto output = create_tt_tensor_from_py_data(
-        num_elements,
-        py_data_ptr,
-        shape,
-        data_type,
-        optional_tile,
-        enable_borrow,
-        on_creation_callback,
-        on_destruction_callback);
+        py_data_ptr, tensor_spec, device, override_enable_borrow, on_creation_callback, on_destruction_callback);
+
+    if (device) {
+        output = output.to(device, memory_config);
+    }
     output = tt::tt_metal::set_tensor_id(output);
     GraphTracker::instance().track_function_end(output);
     return output;
@@ -411,7 +447,8 @@ Tensor convert_python_tensors_to_tt_tensors(
         "tt::tt_metal::detail::convert_python_tensors_to_tt_tensors", tensor_shards, data_type, strategy);
     std::vector<Tensor> tt_shards;
     for (const auto& shard : tensor_shards) {
-        tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false));
+        tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(
+            shard, data_type, Layout::ROW_MAJOR, tile, MemoryConfig{}, nullptr, true));
     }
     std::vector<OwnedBuffer> host_owned_buffers;
     std::vector<ttnn::Shape> host_owned_shapes;
@@ -432,15 +469,68 @@ Tensor convert_python_tensors_to_tt_tensors(
     return output;
 }
 
-std::pair<std::variant<OwnedBuffer, BorrowedBuffer>, DataType> get_buffer_and_dtype_from_tensor(
-    const Tensor& tt_tensor) {
+template <typename T>
+owned_buffer::Buffer<T> create_row_major_owned_buffer(
+    owned_buffer::Buffer<T> owned_buffer, const ttnn::TensorSpec& tensor_spec) {
+    if (tensor_spec.layout() == Layout::TILE) {
+        auto data = tensor_impl::convert_layout_tile_to_row_major(
+            tensor_spec.physical_shape(), tensor_spec.tile(), owned_buffer);
+        return owned_buffer::create(std::move(data));
+    }
+    return owned_buffer;
+}
+
+std::variant<OwnedBuffer, BorrowedBuffer> get_host_buffer_from_tensor(const Tensor& tt_tensor) {
     TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED);
 
-    auto buffer = std::visit(
-        [](auto&& storage) -> std::variant<OwnedBuffer, BorrowedBuffer> {
+    const auto& tensor_spec = tt_tensor.get_tensor_spec();
+    return std::visit(
+        [&tensor_spec, &tt_tensor](auto&& storage) -> std::variant<OwnedBuffer, BorrowedBuffer> {
             using T = std::decay_t<decltype(storage)>;
             if constexpr (std::is_same_v<T, OwnedStorage>) {
-                return storage.buffer;
+                auto tt_dtype = tensor_spec.data_type();
+                switch (tt_dtype) {
+                    case DataType::UINT8: {
+                        return create_row_major_owned_buffer(
+                            owned_buffer::get_as<uint8_t>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::UINT16: {
+                        return create_row_major_owned_buffer(
+                            owned_buffer::get_as<uint16_t>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::INT32: {
+                        return create_row_major_owned_buffer(
+                            owned_buffer::get_as<int32_t>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::UINT32: {
+                        return create_row_major_owned_buffer(
+                            owned_buffer::get_as<uint32_t>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::FLOAT32: {
+                        return create_row_major_owned_buffer(owned_buffer::get_as<float>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::BFLOAT16: {
+                        return create_row_major_owned_buffer(
+                            owned_buffer::get_as<::bfloat16>(storage.buffer), tensor_spec);
+                    }
+                    case DataType::BFLOAT8_B:
+                    case DataType::BFLOAT4_B: {
+                        const auto& tile = tensor_spec.tile();
+                        auto uint32_data = owned_buffer::get_as<std::uint32_t>(storage.buffer).get();
+                        auto float_unpacked_data =
+                            tt_dtype == DataType::BFLOAT8_B
+                                ? unpack_bfp8_tiles_into_float_vec(
+                                      uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile)
+                                : unpack_bfp4_tiles_into_float_vec(
+                                      uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile);
+                        auto input_float_buffer = owned_buffer::create<float>(std::move(float_unpacked_data));
+                        return create_row_major_owned_buffer(input_float_buffer, tensor_spec);
+                    }
+                    default: {
+                        TT_THROW("Unsupported DataType: {}", tt_dtype);
+                        break;
+                    }
+                }
             } else if constexpr (std::is_same_v<T, DeviceStorage>) {
                 TT_THROW("Device tensor cannot be converted to torch");
             } else if constexpr (std::is_same_v<T, BorrowedStorage>) {
@@ -456,52 +546,64 @@ std::pair<std::variant<OwnedBuffer, BorrowedBuffer>, DataType> get_buffer_and_dt
             }
         },
         tt_tensor.get_storage());
-
-    const auto tile = tt_tensor.get_tensor_spec().tile();
-    auto tt_dtype = tt_tensor.get_dtype();
-    if (tt_dtype == DataType::BFLOAT8_B || tt_dtype == DataType::BFLOAT4_B) {
-        TT_ASSERT(
-            std::holds_alternative<OwnedBuffer>(buffer),
-            "Unexpected type {}",
-            tt::stl::get_active_type_name_in_variant(buffer));
-        auto uint32_data = std::get<owned_buffer::Buffer<std::uint32_t>>(std::get<OwnedBuffer>(buffer)).get();
-        auto float_unpacked_data =
-            tt_dtype == DataType::BFLOAT8_B
-                ? unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile)
-                : unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile);
-        auto input_float_buffer = owned_buffer::create<float>(std::move(float_unpacked_data));
-        auto float_tensor = Tensor(
-                                OwnedStorage{input_float_buffer},
-                                tt_tensor.get_shape(),
-                                DataType::FLOAT32,
-                                tt_tensor.get_layout(),
-                                tile)
-                                .to(Layout::ROW_MAJOR);
-        auto output_float_data = owned_buffer::get_as<float>(float_tensor).get();
-        buffer = owned_buffer::create<float>(std::move(output_float_data));
-        tt_dtype = DataType::FLOAT32;
-    }
-
-    return {buffer, tt_dtype};
 }
 
 py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) {
     GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor);
 
-    auto [buffer, buffer_dtype] = get_buffer_and_dtype_from_tensor(tt_tensor);
+    auto buffer = get_host_buffer_from_tensor(tt_tensor);
 
     py::object torch = py::module_::import("torch");
     auto frombuffer = torch.attr("frombuffer");
 
-    const auto tt_dtype_to_torch_dtype = std::map<DataType, py::object>{
-        {DataType::UINT8, torch.attr("uint8")},
-        {DataType::UINT16, torch.attr("int16")},  // TODO(arakhmati): add DataType::INT16
-        {DataType::INT32, torch.attr("int32")},
-        {DataType::UINT32, torch.attr("int32")},  // TODO(arakhmati): add DataType::INT32
-        {DataType::FLOAT32, torch.attr("float32")},
-        {DataType::BFLOAT16, torch.attr("bfloat16")},
-    };
-    auto torch_dtype = tt_dtype_to_torch_dtype.at(buffer_dtype);
+    auto torch_dtype = [&]() {
+        if (std::holds_alternative<OwnedBuffer>(buffer)) {
+            return std::visit(
+                [&torch](auto& owned_buffer) -> py::object {
+                    using T = std::decay_t<decltype(owned_buffer)>;
+                    if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint8_t>>) {
+                        return torch.attr("uint8");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint16_t>>) {
+                        return torch.attr("int16");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<int32_t>>) {
+                        return torch.attr("int32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint32_t>>) {
+                        return torch.attr("int32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<float>>) {
+                        return torch.attr("float32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<::bfloat16>>) {
+                        return torch.attr("bfloat16");
+                    } else {
+                        static_assert(tt::stl::concepts::always_false_v<T>, "Unsupported buffer!");
+                    }
+                },
+                std::get<OwnedBuffer>(buffer));
+
+        } else if (std::holds_alternative<BorrowedBuffer>(buffer)) {
+            return std::visit(
+                [&torch](auto& borrowed_buffer) -> py::object {
+                    using T = std::decay_t<decltype(borrowed_buffer)>;
+                    if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint8_t>>) {
+                        return torch.attr("uint8");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint16_t>>) {
+                        return torch.attr("int16");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<int32_t>>) {
+                        return torch.attr("int32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint32_t>>) {
+                        return torch.attr("int32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<float>>) {
+                        return torch.attr("float32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<::bfloat16>>) {
+                        return torch.attr("bfloat16");
+                    } else {
+                        static_assert(tt::stl::concepts::always_false_v<T>, "Unsupported buffer!");
+                    }
+                },
+                std::get<BorrowedBuffer>(buffer));
+        } else {
+            TT_THROW("Only OwnedBuffer or BorrowedBuffer is supported for converting to python buffers!");
+        }
+    }();
 
     auto shape = tt_tensor.get_legacy_shape();
     auto torch_shape = std::vector<std::uint32_t>(std::begin(shape), std::end(shape));
@@ -527,19 +629,59 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) {
 py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) {
     GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_numpy_tensor", tt_tensor);
 
-    auto [buffer, buffer_dtype] = get_buffer_and_dtype_from_tensor(tt_tensor);
+    auto buffer = get_host_buffer_from_tensor(tt_tensor);
 
     py::object np = py::module_::import("numpy");
     auto frombuffer = np.attr("frombuffer");
 
-    const auto tt_dtype_to_np_dtype = std::map<DataType, py::object>{
-        {DataType::UINT8, np.attr("ubyte")},
-        {DataType::UINT16, np.attr("int16")},  // TODO(arakhmati): add DataType::INT16
-        {DataType::INT32, np.attr("int32")},
-        {DataType::UINT32, np.attr("int32")},  // TODO(arakhmati): add DataType::INT32
-        {DataType::FLOAT32, np.attr("float32")},
-    };
-    auto np_dtype = tt_dtype_to_np_dtype.at(buffer_dtype);
+    auto np_dtype = [&]() {
+        if (std::holds_alternative<OwnedBuffer>(buffer)) {
+            return std::visit(
+                [&np](auto& owned_buffer) -> py::object {
+                    using T = std::decay_t<decltype(owned_buffer)>;
+                    if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint8_t>>) {
+                        return np.attr("ubyte");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint16_t>>) {
+                        return np.attr("int16");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<int32_t>>) {
+                        return np.attr("int32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<uint32_t>>) {
+                        return np.attr("int32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<float>>) {
+                        return np.attr("float32");
+                    } else if constexpr (std::is_same_v<T, owned_buffer::Buffer<::bfloat16>>) {
+                        TT_THROW("Bfloat16 is not supported for numpy!");
+                    } else {
+                        static_assert(tt::stl::concepts::always_false_v<T>, "Unsupported buffer!");
+                    }
+                },
+                std::get<OwnedBuffer>(buffer));
+
+        } else if (std::holds_alternative<BorrowedBuffer>(buffer)) {
+            return std::visit(
+                [&np](auto& borrowed_buffer) -> py::object {
+                    using T = std::decay_t<decltype(borrowed_buffer)>;
+                    if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint8_t>>) {
+                        return np.attr("ubyte");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint16_t>>) {
+                        return np.attr("int16");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<int32_t>>) {
+                        return np.attr("int32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<uint32_t>>) {
+                        return np.attr("int32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<float>>) {
+                        return np.attr("float32");
+                    } else if constexpr (std::is_same_v<T, borrowed_buffer::Buffer<::bfloat16>>) {
+                        TT_THROW("Bfloat16 is not supported for numpy!");
+                    } else {
+                        static_assert(tt::stl::concepts::always_false_v<T>, "Unsupported buffer!");
+                    }
+                },
+                std::get<BorrowedBuffer>(buffer));
+        } else {
+            TT_THROW("Only OwnedBuffer or BorrowedBuffer is supported for converting to python buffers!");
+        }
+    }();
 
     auto shape = tt_tensor.get_legacy_shape();
     auto np_shape = std::vector<std::uint32_t>(std::begin(shape), std::end(shape));
@@ -842,7 +984,8 @@ void pytensor_module(py::module& m_tensor) {
                 if (py::isinstance<py::list>(tensor)) {
                     return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, tile, strategy);
                 }
-                return detail::convert_python_tensor_to_tt_tensor(tensor, data_type, tile);
+                return detail::convert_python_tensor_to_tt_tensor(
+                    tensor, data_type, std::nullopt, tile, MemoryConfig{}, nullptr);
             }),
             py::arg("tensor"),
             py::arg("data_type") = std::nullopt,
@@ -857,6 +1000,8 @@ void pytensor_module(py::module& m_tensor) {
                 +--------------+------------------------+
                 | data_type    | TT Tensor data type    |
                 +--------------+------------------------+
+                | tile         | TT Tile Spec           |
+                +--------------+------------------------+
 
                 Example of creating a TT Tensor that uses torch.Tensor's storage as its own storage:
 
@@ -872,16 +1017,15 @@ void pytensor_module(py::module& m_tensor) {
                           Layout layout,
                           const MemoryConfig& mem_config,
                           const std::optional<Tile>& tile) {
-                auto tensor = detail::convert_python_tensor_to_tt_tensor(python_tensor, data_type, tile);
-                auto layout_tensor = tensor.to(layout);
-                return layout_tensor.to(device, mem_config);
+                return detail::convert_python_tensor_to_tt_tensor(
+                    python_tensor, data_type, layout, tile, mem_config, device);
             }),
             py::arg("tensor"),
             py::arg("data_type") = std::nullopt,
-            py::arg("device").noconvert(),
-            py::arg("layout").noconvert(),
-            py::arg("mem_config").noconvert(),
-            py::arg("tile") = std::nullopt,
+            py::arg("device") = nullptr,
+            py::arg("layout").noconvert() = Layout::ROW_MAJOR,
+            py::arg("mem_config").noconvert() = MemoryConfig{},
+            py::arg("tile").noconvert() = std::nullopt,
             py::return_value_policy::move,
             R"doc(
                 +--------------+------------------------+
@@ -897,14 +1041,17 @@ void pytensor_module(py::module& m_tensor) {
                 +--------------+------------------------+
                 | mem_config   | TT memory_config       |
                 +--------------+------------------------+
+                | tile         | TT Tile Spec           |
+                +--------------+------------------------+
 
 
-                Example of creating a TT Tensor that uses torch.Tensor's storage as its own storage:
+                Example of creating a TT Tensor from numpy tensor:
 
                 .. code-block:: python
 
+                    device = ttnn.open_device(device_id=0)
                     py_tensor = np.zeros((1, 1, 32, 32))
-                    ttnn.Tensor(py_tensor)
+                    ttnn.Tensor(py_tensor, ttnn.bfloat16, device, ttnn.TILE_LAYOUT)
             )doc")
         .def_property_readonly("shape", [](const Tensor& self) { return self.get_shape(); })
         .def_property_readonly("dtype", [](const Tensor& self) { return self.get_dtype(); })
diff --git a/ttnn/cpp/ttnn/tensor/tensor_spec.hpp b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp
index 125b3bb719f..172e0d881f5 100644
--- a/ttnn/cpp/ttnn/tensor/tensor_spec.hpp
+++ b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp
@@ -28,6 +28,7 @@ class TensorSpec final {
     DataType data_type() const { return tensor_layout_.get_data_type(); }
     Layout layout() const { return tensor_layout_.get_layout(); }
     PageConfig page_config() const { return tensor_layout_.get_page_config(); }
+    const MemoryConfig& memory_config() const { return tensor_layout_.get_memory_config(); }
     const ttnn::SimpleShape& padded_shape() const { return cached_padded_shape_; }
     const Size& physical_shape() const { return cached_physical_shape_; }
     ttnn::Shape shape() const { return ttnn::Shape(logical_shape_.view(), cached_padded_shape_.view()); }
diff --git a/ttnn/tt_lib/fused_ops/softmax.py b/ttnn/tt_lib/fused_ops/softmax.py
index f5b2f5fceb4..904b4cea008 100644
--- a/ttnn/tt_lib/fused_ops/softmax.py
+++ b/ttnn/tt_lib/fused_ops/softmax.py
@@ -42,7 +42,7 @@ def ref_stable_softmax(x):
 
 
 if __name__ == "__main__":
-    device = ttnn.open_device(0)
+    device = ttnn.open_device(device_id=0)
 
     H, W = 64, 96
     torch.manual_seed(123)
diff --git a/ttnn/tt_lib/utils.py b/ttnn/tt_lib/utils.py
index 9883666b81f..a61f9759464 100644
--- a/ttnn/tt_lib/utils.py
+++ b/ttnn/tt_lib/utils.py
@@ -8,6 +8,8 @@
 import torch
 import numpy as np
 
+from typing_extensions import deprecated
+
 
 def _nearest_32(x):
     return math.ceil(x / 32) * 32
@@ -134,108 +136,22 @@ def convert_act_2d_matrix(activation, kernel_y, kernel_x, stride_y, stride_x, pa
     return ret.reshape(ret_shape)
 
 
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def tilize(x):
-    """
-    This function tilizes a tensor. The last two tensor dims must be divisible by 32, after which this function
-    produces row major tiles and creates faces. The output of this function is a flattened list that
-    we can send to the device.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
-    """
-    nearest_32 = _nearest_32
-
-    assert isinstance(
-        x, (torch.Tensor, np.ndarray)
-    ), "Input to this function must be an instance of torch.Tensor or np.array"
-    assert len(x.shape) == 4, "Only 4D tensors suppported"
-    assert (x.shape[-2] % 32) == 0 and (
-        x.shape[-1] % 32
-    ) == 0, "The last two dimensions of the tensor must be divisible by 32"
-
-    if isinstance(x, torch.Tensor):
-        ret = torch.zeros(np.prod(x.shape))
-    else:
-        ret = np.zeros(np.prod(x.shape))
-
-    idx = 0
-    for B in range(x.shape[0]):
-        for C in range(x.shape[1]):
-            for H in range(0, x.shape[2], 32):
-                for W in range(0, x.shape[3], 32):
-                    unfaced_tile = x[B, C, H : H + 32, W : W + 32]
-
-                    face0 = unfaced_tile[:16, :16]
-                    face1 = unfaced_tile[:16, 16:]
-                    face2 = unfaced_tile[16:, :16]
-                    face3 = unfaced_tile[16:, 16:]
-
-                    for face in (face0, face1, face2, face3):
-                        ret[idx : idx + 256] = face.reshape(-1)
-                        idx += 256
-
-    return ret.reshape(x.shape)
+    return x
 
 
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def tilize_to_list(x):
     """
-    Tilize a PyTorch and then return the values as a flat list. The last two
-    tensor dims must be divisible by 32, after which this function produces row
-    major tiles and creates faces.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
+    Returns a flattened list of the tensor
     """
-
     return tilize(x).reshape(-1).tolist()
 
 
+@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
 def untilize(x):
-    """
-    This function untilizes a tensor to row major format.
-
-    :param x: Input PyTorch Tensor
-    :type x: class:`torch.Tensor`
-
-    WARNING: This function should eventually be retired in favour of fully tilizing on device.
-    """
-    nearest_32 = _nearest_32
-
-    assert isinstance(x, (torch.Tensor, np.ndarray)), "Input to this function must be an instance of torch.Tensor"
-    assert len(x.shape) == 4, "Only 4D tensors suppported"
-    assert (x.shape[-2] % 32) == 0 and (
-        x.shape[-1] % 32
-    ) == 0, "The last two dimensions of the tensor must be divisible by 32"
-
-    if isinstance(x, torch.Tensor):
-        ret = torch.zeros(x.shape)
-    else:
-        ret = np.zeros(x.shape)
-
-    for B in range(x.shape[0]):
-        for C in range(x.shape[1]):
-            x_hw = x[B, C, :].reshape(-1)
-            hw = 0
-            for h in range(0, x.shape[2], 32):
-                for w in range(0, x.shape[3], 32):
-                    f_tile = x_hw[hw : hw + 256].reshape(16, 16)
-                    ret[B, C, h : h + 16, w : w + 16] = f_tile
-
-                    f_tile = x_hw[hw + 256 : hw + 512].reshape(16, 16)
-                    ret[B, C, h : h + 16, w + 16 : w + 32] = f_tile
-
-                    f_tile = x_hw[hw + 512 : hw + 768].reshape(16, 16)
-                    ret[B, C, h + 16 : h + 32, w : w + 16] = f_tile
-
-                    f_tile = x_hw[hw + 768 : hw + 1024].reshape(16, 16)
-                    ret[B, C, h + 16 : h + 32, w + 16 : w + 32] = f_tile
-                    hw += 1024  # traverse tiles in RM-order
-
-    return ret
+    return x
 
 
 def print_diff_argmax(a, b, annotation=""):