Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime refactor to support runtime stitching #448

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {

tt::target::DataType getTensorDataType(Tensor tensor);

void deallocateTensor(Tensor tensor, bool force);

Tensor toCpu(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});

Expand All @@ -90,13 +94,28 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

std::vector<Tensor> submit(Device device, Binary executable,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

void wait(Event event);

void runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);

std::vector<Tensor> runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs);

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input);

Tensor updateProgramTensorLayout(Device device,
::tt::target::ttnn::Program const *program,
std::uint32_t inputIndex, Tensor const &input);

} // namespace tt::runtime::ttnn

#endif
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,16 @@ void closeDevice(Device device);
Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

std::vector<Tensor> submit(Device device, Binary executable,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

void wait(Event event);

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input);

} // namespace tt::runtime

#endif
27 changes: 25 additions & 2 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,37 @@ struct Device : public detail::RuntimeCheckedObjectImpl {
};

struct Event : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
Event(std::shared_ptr<void> handle, DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime) {}

bool isTTNNEvent() const {
return this->matchesRuntime(DeviceRuntime::TTNN) and this->handle.get();
}

bool isTTMetalEvent() const {
return this->matchesRuntime(DeviceRuntime::TTMetal) and this->handle.get();
}
};

struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;
Event event;

Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(Event(nullptr, runtime)) {}

Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime, Event event)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(event) {}

// Users need to manually deallocate tensors returned from submit
// As the storage is now owned instead of borrowed
void deallocate(bool force = false);

Tensor cpu() const;
};

} // namespace tt::runtime
Expand Down
8 changes: 8 additions & 0 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

add_library(TTRuntimeTypes STATIC types.cpp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this static library (TTRuntimeTypes) to be the part of the shared lib used by tt-forge (in lib/SharedLib/CMakeLists.txt).

target_include_directories(TTRuntimeTypes
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
Expand Down
36 changes: 36 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ Event submit(Device deviceHandle, Binary executableHandle,
throw std::runtime_error("runtime is not enabled");
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
throw std::runtime_error("Currently not supported after refactor");
}
#endif

throw std::runtime_error("runtime is not enabled");
}

void wait(Event event) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
Expand All @@ -206,4 +225,21 @@ void wait(Event event) {
throw std::runtime_error("runtime is not enabled");
}

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::toLayout(device, executable, programIndex,
inputIndex, input);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
throw std::runtime_error("Not implemented");
}

#endif
throw std::runtime_error("runtime is not enabled");
}
} // namespace tt::runtime
Loading
Loading