Skip to content

Commit

Permalink
TTNN Runtime refactor (#713)
Browse files Browse the repository at this point in the history
* Move Ops to their own dedicated directories/files/namespaces.
* Add program context that manages the state of the program (borrowed tensors/devices, intermed tensors generated by ops)
* Add program executor class that executes the program
  • Loading branch information
jnie-TT authored Sep 16, 2024
1 parent 5e48f3a commit 880154e
Show file tree
Hide file tree
Showing 46 changed files with 1,501 additions and 960 deletions.
4 changes: 4 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@
#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/pool/maxpool/max_pool2d.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"

#pragma clang diagnostic pop

#include "tt/runtime/types.h"
Expand Down
6 changes: 3 additions & 3 deletions runtime/lib/ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
add_subdirectory(operations)
add_library(TTRuntimeTTNN
STATIC
runtime.cpp
program.cpp
)
# We have to set the C++ standard to 20 because tt-metal requires it
set_property(TARGET TTRuntimeTTNN PROPERTY CXX_STANDARD 20)
target_compile_options(TTRuntimeTTNN PRIVATE -mavx -mavx2)
target_include_directories(TTRuntimeTTNN PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTNN PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
target_link_libraries(TTRuntimeTTNN PUBLIC TTNN_LIBRARY)
add_dependencies(TTRuntimeTTNN TTNN_LIBRARY tt-metal FBS_GENERATION)
target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps)
add_dependencies(TTRuntimeTTNN TTRuntimeTTNNOps)
69 changes: 69 additions & 0 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_TYPES_H
#define TTNN_RUNTIME_TYPES_H

#include "tt/runtime/detail/ttnn.h"

namespace tt::runtime::ttnn {

using DeviceMap = std::unordered_map<uint32_t, ::ttnn::Device *>;
using TensorMap = std::unordered_map<uint32_t, ::ttnn::Tensor *>;
struct ProgramTensorPool {
ProgramTensorPool(const TensorMap &liveTensors) : liveTensors(liveTensors) {}

auto try_emplace(std::uint32_t global_id, const ::ttnn::Tensor &tensor) {
auto it = liveTensors.find(global_id);
if (it != liveTensors.end()) {
return std::make_pair(it, false);
}
assert(!intermedTensors.contains(global_id));
intermedTensors.try_emplace(global_id, tensor);
return liveTensors.try_emplace(global_id, &intermedTensors.at(global_id));
}

auto insert_or_assign(std::uint32_t global_id, const ::ttnn::Tensor &tensor) {
intermedTensors.insert_or_assign(global_id, tensor);
return liveTensors.insert_or_assign(global_id,
&intermedTensors.at(global_id));
}

::ttnn::Tensor &at(std::uint32_t global_id) {
assert(liveTensors.contains(global_id));
return *liveTensors.at(global_id);
}

size_t erase(std::uint32_t global_id) {
assert(liveTensors.contains(global_id) &&
intermedTensors.contains(global_id));
intermedTensors.erase(global_id);
return liveTensors.erase(global_id);
}

bool contains(std::uint32_t global_id) const {
return liveTensors.contains(global_id);
}

private:
// A superset of intermedTensors, containing pointers to all tensors created
// by the program and the input/output tensors passed in by the user
TensorMap liveTensors;

// A subset of liveTensors, containing values of any intermediate tensors
// created by the program
std::unordered_map<std::uint32_t, ::ttnn::Tensor> intermedTensors;
};

struct ProgramContext {
ProgramTensorPool tensorPool;
DeviceMap allDevices;
DeviceMap devicePool;

ProgramContext(const TensorMap &liveTensors, const DeviceMap &allDevices)
: tensorPool(ProgramTensorPool(liveTensors)), allDevices(allDevices) {}
};
} // namespace tt::runtime::ttnn

#endif
File renamed without changes.
38 changes: 38 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_memory_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pool/maxpool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduction/reduction.cpp
${CMAKE_CURRENT_SOURCE_DIR}/context/get_device.cpp
)

add_library(TTRuntimeTTNNOps
STATIC
${TTNN_OPS_SRCS}
)

set_property(TARGET TTRuntimeTTNNOps PROPERTY CXX_STANDARD 20)
target_compile_options(TTRuntimeTTNNOps PUBLIC -mavx -mavx2)
target_include_directories(TTRuntimeTTNNOps PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include
${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTNNOps PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY)
add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION)
23 changes: 23 additions & 0 deletions runtime/lib/ttnn/operations/context/get_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "get_device.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::context {
void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) {
DeviceMap &devicePool = context.devicePool;
DeviceMap &allDevices = context.allDevices;
const flatbuffers::Vector<uint32_t> *chipIds = op->chip_ids();
assert(chipIds->size() == 1 && "Expected 1 chip id");
for (const uint32_t chipId : *chipIds) {
assert(allDevices.contains(chipId) && "Device not found");
auto [iter, inserted] =
devicePool.try_emplace(chipId, allDevices.at(chipId));
assert(inserted && "Duplicate device");
}
}
} // namespace tt::runtime::ttnn::operations::context
16 changes: 16 additions & 0 deletions runtime/lib/ttnn/operations/context/get_device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_GET_DEVICE_H
#define TTNN_RUNTIME_GET_DEVICE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::context {
void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::context

#endif
35 changes: 35 additions & 0 deletions runtime/lib/ttnn/operations/conv/conv2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "conv2d.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::conv {
void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap &devicePool = context.devicePool;
const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id());
const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id());
std::optional<::ttnn::Tensor> bias =
op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id()))
: std::nullopt;
auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig();
config.dtype = input.dtype();
config.weights_dtype = weight.dtype();
::ttnn::Device &device = utils::getDevice(op->device(), devicePool);
::ttnn::Tensor out =
std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>(
input, weight, &device, op->in_channels(), op->out_channels(),
op->batch_size(), op->input_height(), op->input_width(),
{op->kernel_height(), op->kernel_width()},
{op->stride_height(), op->stride_width()},
{op->padding_height(), op->padding_width()},
{op->dilation_height(), op->dilation_width()}, op->groups(), bias,
config));

tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::conv
16 changes: 16 additions & 0 deletions runtime/lib/ttnn/operations/conv/conv2d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_CONV2D_H
#define TTNN_RUNTIME_CONV2D_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::conv {
void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::conv

#endif
28 changes: 28 additions & 0 deletions runtime/lib/ttnn/operations/creation/empty.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "empty.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::operations::creation {
void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap &devicePool = context.devicePool;
::ttnn::DataType targetDataTypeTTNN = utils::getDataType(op->out());
// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
auto desiredLayout = ::ttnn::Layout::ROW_MAJOR;
auto shape = ::ttnn::Shape(
::tt::tt_metal::Shape(::tt::runtime::ttnn::utils::toShapeFromFBShape(
*op->out()->desc()->shape())));

::ttnn::Device &device = utils::getDevice(op->device(), devicePool);
::ttnn::Tensor out =
::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device);
// use try emplace here so the program output tensor doesn't get overwritten
tensorPool.try_emplace(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::creation
17 changes: 17 additions & 0 deletions runtime/lib/ttnn/operations/creation/empty.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_EMPTY_H
#define TTNN_RUNTIME_EMPTY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::creation {

void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::creation

#endif
33 changes: 33 additions & 0 deletions runtime/lib/ttnn/operations/creation/full.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "full.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::operations::creation {
void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap devicePool = context.devicePool;
::ttnn::Device &device = utils::getDevice(op->device(), devicePool);
::ttnn::DataType outputDataType = utils::getDataType(op->out());
auto shape = ::ttnn::Shape(
::tt::tt_metal::Shape(::tt::runtime::ttnn::utils::toShapeFromFBShape(
*op->out()->desc()->shape())));
float fillValue = op->fill_value();
// TODO(bug #272), determine correct layout by tile shape in the future
::ttnn::Layout outputLayout = ::ttnn::Layout::ROW_MAJOR;
std::optional<std::reference_wrapper<::ttnn::Device>> outputDevice =
std::make_optional(std::ref(device));
std::optional<::tt::tt_metal::MemoryConfig> outputMemoryConfig =
std::make_optional(utils::createMemoryConfig(op->out()));

::ttnn::Tensor out =
::ttnn::full(shape, fillValue, outputDataType, outputLayout, outputDevice,
outputMemoryConfig);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::creation
17 changes: 17 additions & 0 deletions runtime/lib/ttnn/operations/creation/full.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_FULL_H
#define TTNN_RUNTIME_FULL_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::creation {

void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::creation

#endif
18 changes: 18 additions & 0 deletions runtime/lib/ttnn/operations/data_movement/concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "concat.h"
#include "tt/runtime/detail/ttnn.h"

namespace tt::runtime::ttnn::operations::data_movement {
void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) {
std::vector<::ttnn::Tensor> inputs;
for (const auto &input : *op->inputs()) {
inputs.push_back(context.tensorPool.at(input->global_id()));
}
int32_t dim = op->dim();
::ttnn::Tensor out = ::ttnn::concat(inputs, dim);
context.tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::data_movement
15 changes: 15 additions & 0 deletions runtime/lib/ttnn/operations/data_movement/concat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_CONCAT_H
#define TTNN_RUNTIME_CONCAT_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::data_movement {
void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::data_movement

#endif
Loading

0 comments on commit 880154e

Please sign in to comment.