Skip to content

Commit

Permalink
Add memcpy(void *, Tensor) API (#1519)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Dec 6, 2024
1 parent 04d4d51 commit f4dd5d9
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 15 deletions.
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Tensor toLayout(Tensor tensor, Device device, Layout layout);
Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
std::uint32_t inputIndex);

void memcpy(void *dst, Tensor src);

void memcpy(Tensor dst, Tensor src);

void deallocateTensor(Tensor &tensor, bool force = false);
Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Tensor toLayout(Tensor tensor, Device device, Layout layout);
Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
std::uint32_t inputIndex);

void memcpy(void *dst, Tensor src);

void memcpy(Tensor dst, Tensor src);

void deallocateTensor(Tensor &tensor, bool force = false);
Expand Down
15 changes: 15 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
LOG_FATAL("runtime is not enabled");
}

void memcpy(void *dst, Tensor src) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::memcpy(dst, src);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
LOG_FATAL("not implemented");
}
#endif
LOG_FATAL("runtime is not enabled");
}

void memcpy(Tensor dst, Tensor src) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
Expand Down
14 changes: 12 additions & 2 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@ Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
DeviceRuntime::TTNN);
}

void memcpy(void *dst, Tensor src) {
const ::ttnn::Tensor &srcTensor = src.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
if (utils::isOnHost(srcTensor.storage_type())) {
const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor);
size_t size = srcTensor.volume() * srcTensor.element_size();
std::memcpy(dst, srcPtr, size);
} else {
::tt::tt_metal::memcpy(dst, srcTensor);
}
}

void memcpy(Tensor dst, Tensor src) {
::ttnn::Tensor &dstTensor = dst.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
const ::ttnn::Tensor &srcTensor = src.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
Expand All @@ -304,11 +315,10 @@ void memcpy(Tensor dst, Tensor src) {
"Input output tensor size mismatch in memcpy: ",
srcTensor.volume(), " * ", srcTensor.element_size(),
" != ", dstTensor.volume(), " * ", dstTensor.element_size());

if (utils::isOnHost(srcTensor.storage_type()) and
utils::isOnHost(dstTensor.storage_type())) {
void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dstTensor);
void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor);
const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor);
size_t size = srcTensor.volume() * srcTensor.element_size();
std::memcpy(dstPtr, srcPtr, size);
} else {
Expand Down
62 changes: 49 additions & 13 deletions runtime/test/python/ttnn/test_runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,50 @@ def test_to_layout(helper: Helper, shape, dtype, request):
ttrt.runtime.memcpy(runtime_output_tensor, host_tensor)
ttrt.runtime.deallocate_tensor(host_tensor, force=True)

lambda: assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.999)
assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99)
helper.teardown()


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_memcpy_to_pointer(helper: Helper, shape, dtype, request):
helper.initialize(request.node.name)
helper.check_constraints()
runtime_dtype = Binary.Program.to_data_type(dtype)
torch_result_tensor = torch.zeros(shape, dtype=dtype)

# Device to host
torch_input_tensor = torch.randn(shape, dtype=dtype)
runtime_input_tensor = ttrt.runtime.create_tensor(
torch_input_tensor.data_ptr(),
list(torch_input_tensor.shape),
list(torch_input_tensor.stride()),
torch_input_tensor.element_size(),
runtime_dtype,
)
device_layout = ttrt.runtime.testing.get_dram_interleaved_row_major_layout(
runtime_dtype
)
with DeviceContext([helper.query.device_ids[0]]) as device:
device_tensor = ttrt.runtime.to_layout(
runtime_input_tensor, device, device_layout
)
ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), device_tensor)
ttrt.runtime.deallocate_tensor(device_tensor, force=True)

assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99)

# Host to host
torch_input_tensor2 = torch.randn(shape, dtype=dtype)
host_tensor = ttrt.runtime.create_tensor(
torch_input_tensor2.data_ptr(),
list(torch_input_tensor2.shape),
list(torch_input_tensor2.stride()),
torch_input_tensor2.element_size(),
runtime_dtype,
)
ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), host_tensor)
assert_pcc(torch_input_tensor2, torch_result_tensor, threshold=0.99)
helper.teardown()


Expand Down Expand Up @@ -80,12 +123,12 @@ def test_create_tensor_memcpy(helper: Helper, shape, dtype, request):
list(torch_input_tensor.stride()),
torch_input_tensor.element_size(),
)
# Copy from host to device container
ttrt.runtime.memcpy(device_tensor, runtime_input_tensor)
host_tensor = ttrt.runtime.to_host(device_tensor, untilize=True)
# Copy from device to host
ttrt.runtime.memcpy(runtime_output_tensor, device_tensor)
ttrt.runtime.deallocate_tensor(device_tensor, force=True)
ttrt.runtime.memcpy(runtime_output_tensor, host_tensor)
ttrt.runtime.deallocate_tensor(host_tensor, force=True)
lambda: assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.999)
assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99)
helper.teardown()


Expand Down Expand Up @@ -145,14 +188,7 @@ def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request):
]
),
)
runtime_result_tensor = ttrt.runtime.create_tensor(
torch_result_tensor.data_ptr(),
list(torch_result_tensor.shape),
list(torch_result_tensor.stride()),
torch_result_tensor.element_size(),
Binary.Program.to_data_type(torch_result_tensor.dtype),
)
ttrt.runtime.memcpy(runtime_result_tensor, activations)
ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), activations)
golden = (
(inputs_torch[0] + inputs_torch[1]).mul(inputs_torch[1]).sub(inputs_torch[1])
)
Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ PYBIND11_MODULE(_C, m) {
"Get the debug string of the op");
m.def("get_op_loc_info", &tt::runtime::getOpLocInfo,
"Get the location info of the op");
m.def(
"memcpy",
[](std::uintptr_t dst, ::tt::runtime::Tensor src) {
void *dstPtr = reinterpret_cast<void *>(dst);
::tt::runtime::memcpy(dstPtr, src);
},
py::arg("dst"), py::arg("src"),
"Copy the data from src tensor to dst pointer");
m.def(
"memcpy",
[](::tt::runtime::Tensor dst, ::tt::runtime::Tensor src) {
Expand Down

0 comments on commit f4dd5d9

Please sign in to comment.