-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move Ops to their own dedicated directories/files/namespaces. * Add program context that manages the state of the program (borrowed tensors/devices, intermed tensors generated by ops) * Add program executor class that executes the program
- Loading branch information
Showing
46 changed files
with
1,501 additions
and
960 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
add_subdirectory(operations) | ||
add_library(TTRuntimeTTNN | ||
STATIC | ||
runtime.cpp | ||
program.cpp | ||
) | ||
# We have to set the C++ standard to 20 because tt-metal requires it | ||
set_property(TARGET TTRuntimeTTNN PROPERTY CXX_STANDARD 20) | ||
target_compile_options(TTRuntimeTTNN PRIVATE -mavx -mavx2) | ||
target_include_directories(TTRuntimeTTNN PUBLIC | ||
${PROJECT_SOURCE_DIR}/runtime/include | ||
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common | ||
) | ||
target_include_directories(TTRuntimeTTNN PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>") | ||
target_link_libraries(TTRuntimeTTNN PUBLIC TTNN_LIBRARY) | ||
add_dependencies(TTRuntimeTTNN TTNN_LIBRARY tt-metal FBS_GENERATION) | ||
target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps) | ||
add_dependencies(TTRuntimeTTNN TTRuntimeTTNNOps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_TYPES_H | ||
#define TTNN_RUNTIME_TYPES_H | ||
|
||
#include "tt/runtime/detail/ttnn.h" | ||
|
||
namespace tt::runtime::ttnn { | ||
|
||
using DeviceMap = std::unordered_map<uint32_t, ::ttnn::Device *>; | ||
using TensorMap = std::unordered_map<uint32_t, ::ttnn::Tensor *>; | ||
struct ProgramTensorPool { | ||
ProgramTensorPool(const TensorMap &liveTensors) : liveTensors(liveTensors) {} | ||
|
||
auto try_emplace(std::uint32_t global_id, const ::ttnn::Tensor &tensor) { | ||
auto it = liveTensors.find(global_id); | ||
if (it != liveTensors.end()) { | ||
return std::make_pair(it, false); | ||
} | ||
assert(!intermedTensors.contains(global_id)); | ||
intermedTensors.try_emplace(global_id, tensor); | ||
return liveTensors.try_emplace(global_id, &intermedTensors.at(global_id)); | ||
} | ||
|
||
auto insert_or_assign(std::uint32_t global_id, const ::ttnn::Tensor &tensor) { | ||
intermedTensors.insert_or_assign(global_id, tensor); | ||
return liveTensors.insert_or_assign(global_id, | ||
&intermedTensors.at(global_id)); | ||
} | ||
|
||
::ttnn::Tensor &at(std::uint32_t global_id) { | ||
assert(liveTensors.contains(global_id)); | ||
return *liveTensors.at(global_id); | ||
} | ||
|
||
size_t erase(std::uint32_t global_id) { | ||
assert(liveTensors.contains(global_id) && | ||
intermedTensors.contains(global_id)); | ||
intermedTensors.erase(global_id); | ||
return liveTensors.erase(global_id); | ||
} | ||
|
||
bool contains(std::uint32_t global_id) const { | ||
return liveTensors.contains(global_id); | ||
} | ||
|
||
private: | ||
// A superset of intermedTensors, containing pointers to all tensors created | ||
// by the program and the input/output tensors passed in by the user | ||
TensorMap liveTensors; | ||
|
||
// A subset of liveTensors, containing values of any intermediate tensors | ||
// created by the program | ||
std::unordered_map<std::uint32_t, ::ttnn::Tensor> intermedTensors; | ||
}; | ||
|
||
struct ProgramContext { | ||
ProgramTensorPool tensorPool; | ||
DeviceMap allDevices; | ||
DeviceMap devicePool; | ||
|
||
ProgramContext(const TensorMap &liveTensors, const DeviceMap &allDevices) | ||
: tensorPool(ProgramTensorPool(liveTensors)), allDevices(allDevices) {} | ||
}; | ||
} // namespace tt::runtime::ttnn | ||
|
||
#endif |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
set(TTNN_OPS_SRCS | ||
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/reshape.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_layout.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_memory_config.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/matmul/matmul.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/normalization/softmax.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/pool/maxpool2d.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/reduction/reduction.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/context/get_device.cpp | ||
) | ||
|
||
add_library(TTRuntimeTTNNOps | ||
STATIC | ||
${TTNN_OPS_SRCS} | ||
) | ||
|
||
set_property(TARGET TTRuntimeTTNNOps PROPERTY CXX_STANDARD 20) | ||
target_compile_options(TTRuntimeTTNNOps PUBLIC -mavx -mavx2) | ||
target_include_directories(TTRuntimeTTNNOps PUBLIC | ||
${PROJECT_SOURCE_DIR}/runtime/include | ||
${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include | ||
${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include | ||
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common | ||
) | ||
target_include_directories(TTRuntimeTTNNOps PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>") | ||
target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY) | ||
add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "get_device.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::context { | ||
void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) { | ||
DeviceMap &devicePool = context.devicePool; | ||
DeviceMap &allDevices = context.allDevices; | ||
const flatbuffers::Vector<uint32_t> *chipIds = op->chip_ids(); | ||
assert(chipIds->size() == 1 && "Expected 1 chip id"); | ||
for (const uint32_t chipId : *chipIds) { | ||
assert(allDevices.contains(chipId) && "Device not found"); | ||
auto [iter, inserted] = | ||
devicePool.try_emplace(chipId, allDevices.at(chipId)); | ||
assert(inserted && "Duplicate device"); | ||
} | ||
} | ||
} // namespace tt::runtime::ttnn::operations::context |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_GET_DEVICE_H | ||
#define TTNN_RUNTIME_GET_DEVICE_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::context { | ||
void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::context | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "conv2d.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::conv { | ||
void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { | ||
ProgramTensorPool &tensorPool = context.tensorPool; | ||
DeviceMap &devicePool = context.devicePool; | ||
const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); | ||
const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); | ||
std::optional<::ttnn::Tensor> bias = | ||
op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) | ||
: std::nullopt; | ||
auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig(); | ||
config.dtype = input.dtype(); | ||
config.weights_dtype = weight.dtype(); | ||
::ttnn::Device &device = utils::getDevice(op->device(), devicePool); | ||
::ttnn::Tensor out = | ||
std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>( | ||
input, weight, &device, op->in_channels(), op->out_channels(), | ||
op->batch_size(), op->input_height(), op->input_width(), | ||
{op->kernel_height(), op->kernel_width()}, | ||
{op->stride_height(), op->stride_width()}, | ||
{op->padding_height(), op->padding_width()}, | ||
{op->dilation_height(), op->dilation_width()}, op->groups(), bias, | ||
config)); | ||
|
||
tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
} // namespace tt::runtime::ttnn::operations::conv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_CONV2D_H | ||
#define TTNN_RUNTIME_CONV2D_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::conv { | ||
void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::conv | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "empty.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
#include "tt/runtime/ttnn/utils.h" | ||
|
||
namespace tt::runtime::ttnn::operations::creation { | ||
void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) { | ||
ProgramTensorPool &tensorPool = context.tensorPool; | ||
DeviceMap &devicePool = context.devicePool; | ||
::ttnn::DataType targetDataTypeTTNN = utils::getDataType(op->out()); | ||
// TODO(bug #582): ttnn::empty doesn't work properly with tile layout, | ||
// using ROW_MAJOR until we fix it | ||
auto desiredLayout = ::ttnn::Layout::ROW_MAJOR; | ||
auto shape = ::ttnn::Shape( | ||
::tt::tt_metal::Shape(::tt::runtime::ttnn::utils::toShapeFromFBShape( | ||
*op->out()->desc()->shape()))); | ||
|
||
::ttnn::Device &device = utils::getDevice(op->device(), devicePool); | ||
::ttnn::Tensor out = | ||
::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device); | ||
// use try emplace here so the program output tensor doesn't get overwritten | ||
tensorPool.try_emplace(op->out()->global_id(), out); | ||
} | ||
} // namespace tt::runtime::ttnn::operations::creation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_EMPTY_H | ||
#define TTNN_RUNTIME_EMPTY_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::creation { | ||
|
||
void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::creation | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "full.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
#include "tt/runtime/ttnn/utils.h" | ||
|
||
namespace tt::runtime::ttnn::operations::creation { | ||
void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) { | ||
ProgramTensorPool &tensorPool = context.tensorPool; | ||
DeviceMap devicePool = context.devicePool; | ||
::ttnn::Device &device = utils::getDevice(op->device(), devicePool); | ||
::ttnn::DataType outputDataType = utils::getDataType(op->out()); | ||
auto shape = ::ttnn::Shape( | ||
::tt::tt_metal::Shape(::tt::runtime::ttnn::utils::toShapeFromFBShape( | ||
*op->out()->desc()->shape()))); | ||
float fillValue = op->fill_value(); | ||
// TODO(bug #272), determine correct layout by tile shape in the future | ||
::ttnn::Layout outputLayout = ::ttnn::Layout::ROW_MAJOR; | ||
std::optional<std::reference_wrapper<::ttnn::Device>> outputDevice = | ||
std::make_optional(std::ref(device)); | ||
std::optional<::tt::tt_metal::MemoryConfig> outputMemoryConfig = | ||
std::make_optional(utils::createMemoryConfig(op->out())); | ||
|
||
::ttnn::Tensor out = | ||
::ttnn::full(shape, fillValue, outputDataType, outputLayout, outputDevice, | ||
outputMemoryConfig); | ||
|
||
tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
} // namespace tt::runtime::ttnn::operations::creation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_FULL_H | ||
#define TTNN_RUNTIME_FULL_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::creation { | ||
|
||
void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::creation | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "concat.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
|
||
namespace tt::runtime::ttnn::operations::data_movement { | ||
void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) { | ||
std::vector<::ttnn::Tensor> inputs; | ||
for (const auto &input : *op->inputs()) { | ||
inputs.push_back(context.tensorPool.at(input->global_id())); | ||
} | ||
int32_t dim = op->dim(); | ||
::ttnn::Tensor out = ::ttnn::concat(inputs, dim); | ||
context.tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
} // namespace tt::runtime::ttnn::operations::data_movement |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_CONCAT_H | ||
#define TTNN_RUNTIME_CONCAT_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::data_movement { | ||
void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context); | ||
} // namespace tt::runtime::ttnn::operations::data_movement | ||
|
||
#endif |
Oops, something went wrong.