Skip to content

Commit

Permalink
First Iteration/Prototype: Runtime refactor to support runtime stitching
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Sep 8, 2024
1 parent 27f80f8 commit 608ed89
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 49 deletions.
15 changes: 15 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,27 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

std::vector<Tensor> submit(Device device, Binary executable,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

void wait(Event event);

void runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);

std::vector<Tensor> runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs);

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input);

Tensor updateProgramTensorLayout(Device device,
::tt::target::ttnn::Program const *program,
std::uint32_t inputIndex, Tensor const &input);

} // namespace tt::runtime::ttnn

Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,16 @@ void closeDevice(Device device);
Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

std::vector<Tensor> submit(Device device, Binary executable,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

void wait(Event event);

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input);

} // namespace tt::runtime

#endif
26 changes: 24 additions & 2 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,36 @@ struct Device : public detail::RuntimeCheckedObjectImpl {
};

struct Event : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
Event(std::shared_ptr<void> handle, DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime) {}

bool isTTNNEvent() const {
return this->matchesRuntime(DeviceRuntime::TTNN) and this->handle.get();
}

bool isTTMetalEvent() const {
return this->matchesRuntime(DeviceRuntime::TTMetal) and this->handle.get();
}
};

struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;
Event event;

Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(Event(nullptr, runtime)) {}


Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime, Event event)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(event) {}

// Users need to manually deallocate tensors returned from submit
// As the storage is now owned instead of borrowed
void deallocate();
};

} // namespace tt::runtime
Expand Down
8 changes: 8 additions & 0 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

add_library(TTRuntimeTypes STATIC types.cpp)
target_include_directories(TTRuntimeTypes
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
Expand Down
36 changes: 36 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ Event submit(Device deviceHandle, Binary executableHandle,
throw std::runtime_error("runtime is not enabled");
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
throw std::runtime_error("Currently not supported after refactor");
}
#endif

throw std::runtime_error("runtime is not enabled");
}

void wait(Event event) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
Expand All @@ -206,4 +225,21 @@ void wait(Event event) {
throw std::runtime_error("runtime is not enabled");
}

Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex,
std::uint32_t inputIndex, Tensor const &input) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::toLayout(device, executable, programIndex,
inputIndex, input);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
throw std::runtime_error("Not implemented");
}

#endif
throw std::runtime_error("runtime is not enabled");
}
} // namespace tt::runtime
165 changes: 118 additions & 47 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class ProgramTensorPool {
return liveTensors.contains(global_id);
}

size_t size() const {
return liveTensors.size();
}

private:
// A superset of intermedTensors, containing all tensors created by the
// program and the input/output tensors passed in by the user
Expand All @@ -89,6 +93,20 @@ static ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) {
tensorRef->desc()->layout()->memory_desc()->data_type());
}

static Tensor toTypeErasedTensor(const ::ttnn::Tensor &tensor) {
std::shared_ptr<::ttnn::Tensor> tensorHandle = std::make_shared<::ttnn::Tensor>(tensor);
void *dataPtr = isOnHost(*tensorHandle) ? ::tt::tt_metal::get_raw_host_data_ptr(*tensorHandle) : nullptr;
return Tensor(tensorHandle, ::tt::runtime::utils::unsafe_borrow_shared(dataPtr), DeviceRuntime::TTNN);
}

static void tensorMemcpy(::ttnn::Tensor &dst, ::ttnn::Tensor &src) {
assert(isOnHost(src) and dst.storage_type() == ::tt::tt_metal::StorageType::BORROWED);
void *srcDataPtr = ::tt::tt_metal::get_raw_host_data_ptr(src);
void *dstDataPtr = ::tt::tt_metal::get_raw_host_data_ptr(dst);
std::uint32_t size = src.volume() * src.element_size();
std::memcpy(dstDataPtr, srcDataPtr, size);

}
static CoreRangeSet toCoreRangeSet(
const ::flatbuffers::Vector<const tt::target::Dim2dRange *> *coreRangeSet) {
std::set<CoreRange> coreRanges;
Expand Down Expand Up @@ -214,10 +232,9 @@ updateLayoutAndDataType(const ::ttnn::Tensor &inputTensor,
return outputTensor;
}

static void
static ::ttnn::Tensor
handleToHostMemoryConfigOp(const ::ttnn::Tensor &inputTensor,
const ::tt::target::TensorRef *outputTensorRef,
ProgramTensorPool &tensorPool) {
const ::tt::target::TensorRef *outputTensorRef) {
::ttnn::Tensor result;
::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef);
bool shouldTilize, shouldUntilize;
Expand All @@ -232,24 +249,13 @@ handleToHostMemoryConfigOp(const ::ttnn::Tensor &inputTensor,
result = updateLayoutAndDataType(inputTensor.cpu(), targetDataTypeTTNN,
shouldTilize, shouldUntilize);
}
// copy the output to the output tensor if it exists
if (tensorPool.contains(outputTensorRef->global_id())) {
::ttnn::Tensor &outputTensor = tensorPool.at(outputTensorRef->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(result);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = result.volume() * result.element_size();
std::memcpy(dst, src, size);
} else {
tensorPool.insert_or_assign(outputTensorRef->global_id(),
std::move(result));
}
return result;
}

static void
static ::ttnn::Tensor
handleToDramMemoryConfigOp(::ttnn::Device &device,
const ::ttnn::Tensor &inputTensor,
const ::tt::target::TensorRef *outputTensorRef,
ProgramTensorPool &tensorPool) {
const ::tt::target::TensorRef *outputTensorRef) {
::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef);
::tt::tt_metal::MemoryConfig targetMemoryConfig =
createMemoryConfig(outputTensorRef);
Expand All @@ -266,24 +272,23 @@ handleToDramMemoryConfigOp(::ttnn::Device &device,
result = ::ttnn::to_device(result, &device, targetMemoryConfig);
result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize,
shouldUntilize);
tensorPool.insert_or_assign(outputTensorRef->global_id(),
std::move(result));
return result;
} else if (isOnDevice(inputTensor)) {
shouldTilize = false;
shouldUntilize = false;
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, shouldTilize, shouldUntilize);
result = ::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt);
tensorPool.insert_or_assign(outputTensorRef->global_id(),
std::move(result));
return result;
} else {
throw std::runtime_error("Unsupported input tensor storage type");
}
}

static void
static ::ttnn::Tensor
handleToL1MemoryConfigOp(::ttnn::Device &device,
const ::ttnn::Tensor &inputTensor,
const ::tt::target::TensorRef *outputTensorRef,
ProgramTensorPool &tensorPool) {
const ::tt::target::TensorRef *outputTensorRef) {
::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef);
::tt::tt_metal::MemoryConfig targetMemoryConfig =
createMemoryConfig(outputTensorRef);
Expand All @@ -309,53 +314,65 @@ handleToL1MemoryConfigOp(::ttnn::Device &device,
result =
::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt);
}
tensorPool.insert_or_assign(outputTensorRef->global_id(),
std::move(result));
return result;
} else if (isOnDevice(inputTensor)) {
shouldTilize = false;
shouldUntilize = false;
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, shouldTilize, shouldUntilize);
result = ::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt);
tensorPool.insert_or_assign(outputTensorRef->global_id(),
std::move(result));
return result;
} else {
throw std::runtime_error("Unsupported input tensor storage type");
}
}

// TODO(bug #272): right now hardcoding tilize/untilize, should determine with
// tile shape blocked by issue #272
static void run(::tt::target::ttnn::ToMemoryConfigOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {

const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id());
assert(isOnHost(inputTensor) or
isOnDevice(inputTensor) && "Unsupported storage type");

const ::tt::target::Dim2d *targetTileShape =
op->out()->desc()->layout()->memory_desc()->tile_shape();
assert(utils::isValidTileShape(targetTileShape) && "Invalid tile shape");

static ::ttnn::Tensor updateTensorMemoryConfig(::ttnn::Device &device,
const ::ttnn::Tensor &inputTensor,
const ::tt::target::TensorRef *outputTensorRef) {
const ::tt::target::MemoryDesc *targetMemoryDesc =
outputTensorRef->desc()->layout()->memory_desc();
const ::tt::target::MemorySpace targetMemorySpace =
op->out()->desc()->layout()->memory_desc()->memory_space();
targetMemoryDesc->memory_space();

switch (targetMemorySpace) {
// This case should only be used when gathering outputs at the end of the
// program
case ::tt::target::MemorySpace::System:
case ::tt::target::MemorySpace::SystemMMIO: {
handleToHostMemoryConfigOp(inputTensor, op->out(), tensorPool);
return handleToHostMemoryConfigOp(inputTensor, outputTensorRef);
break;
}
case ::tt::target::MemorySpace::DeviceDRAM: {
handleToDramMemoryConfigOp(device, inputTensor, op->out(), tensorPool);
return handleToDramMemoryConfigOp(device, inputTensor, outputTensorRef);
break;
}
case ::tt::target::MemorySpace::DeviceL1: {
handleToL1MemoryConfigOp(device, inputTensor, op->out(), tensorPool);
return handleToL1MemoryConfigOp(device, inputTensor, outputTensorRef);
break;
}
}
}
// TODO(bug #272): right now hardcoding tilize/untilize, should determine with
// tile shape blocked by issue #272
static void run(::tt::target::ttnn::ToMemoryConfigOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {

const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id());
assert(isOnHost(inputTensor) or
isOnDevice(inputTensor) && "Unsupported storage type");

const ::tt::target::Dim2d *targetTileShape =
op->out()->desc()->layout()->memory_desc()->tile_shape();
assert(utils::isValidTileShape(targetTileShape) && "Invalid tile shape");

::ttnn::Tensor result = updateTensorMemoryConfig(device, inputTensor, op->out());
// copy the output to the output tensor if it exists
if (tensorPool.contains(op->out()->global_id()) and tensorPool.at(op->out()->global_id()).storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
tensorMemcpy(tensorPool.at(op->out()->global_id()), result);
} else {
tensorPool.insert_or_assign(op->out()->global_id(),
std::move(result));
}
}

static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device,
ProgramTensorPool &tensorPool) {
Expand Down Expand Up @@ -837,4 +854,58 @@ void runProgram(::ttnn::Device &device,
run(op, device, tensorPool);
}
}

std::vector<Tensor> runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs) {

ProgramTensorPool tensorPool({});
int inputIndex = 0;

// convert inputs to the desired layout/memory config
for (::tt::target::TensorRef const *inputRef : *program->inputs()) {
const ::ttnn::Tensor *inputTensor = inputs[inputIndex++];
::ttnn::Tensor updatedInputTensor = updateTensorMemoryConfig(device, *inputTensor, inputRef);
auto [iter, inserted] = tensorPool.try_emplace(inputRef->global_id(), std::move(updatedInputTensor));
assert(inserted && "Duplicate input tensor");
}

for (::tt::target::ttnn::Operation const *op : *program->operations()) {
run(op, device, tensorPool);
}

// convert outputs to the desired layout/memory config
// then convert them to type erased tensors and return
std::vector<Tensor> outputs;
for (::tt::target::TensorRef const *outputRef : *program->outputs()) {
size_t outputId = outputRef->global_id();
assert(tensorPool.contains(outputId) &&
"Program output tensor not found in tensorPool");
const ::ttnn::Tensor &outputTensor = tensorPool.at(outputId);
::ttnn::Tensor updatedOutputTensor = updateTensorMemoryConfig(device, outputTensor, outputRef);
outputs.push_back(toTypeErasedTensor(updatedOutputTensor));
}

return outputs;
}

Tensor updateProgramTensorLayout(Device device,
::tt::target::ttnn::Program const *program,
std::uint32_t inputIndex,
Tensor const &input) {
TT_FATAL(inputIndex < program->inputs()->size(),
"Input index {} out of range {}", inputIndex,
program->inputs()->size());
const ::tt::target::TensorRef *inputRef = program->inputs()->Get(inputIndex);

::ttnn::Device &ttnnDevice = device.as<::ttnn::Device>(DeviceRuntime::TTNN);
const ::ttnn::Tensor &ttnnInput =
input.as<::ttnn::Tensor>(DeviceRuntime::TTNN);

::ttnn::Tensor result =
updateTensorMemoryConfig(ttnnDevice, ttnnInput, inputRef);

return toTypeErasedTensor(result);
}

} // namespace tt::runtime::ttnn
Loading

0 comments on commit 608ed89

Please sign in to comment.