Skip to content

Commit

Permalink
#695: Runtime multi-device bootstrap - Replace all uses of Device wi…
Browse files Browse the repository at this point in the history
…th MeshDevice (#726)
  • Loading branch information
jnie-TT authored Sep 24, 2024
1 parent 4e8cf3c commit ae94cb1
Show file tree
Hide file tree
Showing 33 changed files with 251 additions and 133 deletions.
5 changes: 3 additions & 2 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<uint8_t> const &numHWCQs = {});
size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);

void closeDevice(Device device);

Expand Down
11 changes: 7 additions & 4 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
#pragma clang diagnostic ignored "-Wc99-extensions"

#define FMT_HEADER_ONLY
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/mesh_device.hpp"
#include "ttnn/device.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/copy.hpp"
Expand All @@ -59,7 +62,6 @@
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"

#pragma clang diagnostic pop

#include "tt/runtime/types.h"
Expand All @@ -84,8 +86,9 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});
size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);

void closeDevice(Device device);

Expand All @@ -97,7 +100,7 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,

void wait(Event event);

void runProgram(::ttnn::Device &device,
void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);
Expand Down
5 changes: 3 additions & 2 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});
size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);

void closeDevice(Device device);

Expand Down
4 changes: 0 additions & 4 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ struct Binary : public Flatbuffer {

struct Device : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;

template <typename T> static Device borrow(T &object, DeviceRuntime runtime) {
return Device(utils::unsafe_borrow_shared(&object), runtime);
}
};

struct Event : public detail::RuntimeCheckedObjectImpl {
Expand Down
18 changes: 16 additions & 2 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,22 @@ tt::target::DataType getTensorDataType(Tensor tensor) {
throw std::runtime_error("runtime is not enabled");
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
size_t getNumAvailableDevices() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getNumAvailableDevices();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getNumAvailableDevices();
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
Expand Down
52 changes: 30 additions & 22 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace tt::runtime::ttmetal {
using ::tt::runtime::DeviceRuntime;
constexpr inline std::size_t kHostBufferCommandQueueId = 0;
using Events = std::vector<std::shared_ptr<::tt::tt_metal::Event>>;
using DeviceMesh = std::vector<::tt::tt_metal::Device *>;
using DeviceList = std::vector<::tt::tt_metal::Device *>;
using MetalTensor =
std::variant<TensorDesc, std::shared_ptr<::tt::tt_metal::Buffer>>;

Expand Down Expand Up @@ -58,29 +58,33 @@ tt::target::DataType getTensorDataType(Tensor tensor) {
return ::tt::target::DataType::Float32;
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
assert(numHWCQs.empty() || numHWCQs.size() == deviceIds.size());
std::shared_ptr<DeviceMesh> deviceMesh = std::make_shared<DeviceMesh>();
int i = 0;
for (int deviceId : deviceIds) {
uint8_t num_hw_cqs = numHWCQs.empty() ? 1 : numHWCQs[i];
deviceMesh->push_back(CreateDevice(deviceId, num_hw_cqs));
++i;
}
return Device(static_pointer_cast<void>(deviceMesh), DeviceRuntime::TTMetal);
size_t getNumAvailableDevices() {
return ::tt::tt_metal::GetNumAvailableDevices();
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) {
assert(deviceIds.size() && "No devices specified");
::tt::tt_metal::MeshShape grid = std::make_pair(1, deviceIds.size());
std::shared_ptr<::tt::tt_metal::MeshDevice> meshDevice =
::tt::tt_metal::MeshDevice::create(
grid, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
::tt::tt_metal::DispatchCoreType::WORKER);

return Device(std::static_pointer_cast<void>(meshDevice),
DeviceRuntime::TTMetal);
}

void closeDevice(Device device) {
DeviceMesh &deviceMesh = device.as<DeviceMesh>(DeviceRuntime::TTMetal);
for (::tt::tt_metal::Device *device : deviceMesh) {
::tt::tt_metal::CloseDevice(device);
}
::tt::tt_metal::MeshDevice &ttmetalMeshDevice =
device.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal);
ttmetalMeshDevice.close_devices();
}

void deallocateBuffers(Device deviceHandle) {
DeviceMesh &deviceMesh = deviceHandle.as<DeviceMesh>(DeviceRuntime::TTMetal);
for (::tt::tt_metal::Device *device : deviceMesh) {
::tt::tt_metal::MeshDevice &meshDevice =
deviceHandle.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal);

for (::tt::tt_metal::Device *device : meshDevice.get_devices()) {
device->deallocate_buffers();
}
}
Expand Down Expand Up @@ -166,13 +170,17 @@ Event submit(Device deviceHandle, Binary executableHandle,
::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle);
::tt::target::metal::Program const *program =
fbb.programs()->Get(programIndex);
DeviceMesh &deviceMesh = deviceHandle.as<DeviceMesh>(DeviceRuntime::TTMetal);
assert(deviceMesh.size() == 1 && "Only one device is supported for now");
::tt::tt_metal::MeshDevice &meshDevice =
deviceHandle.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal);
DeviceList allDevices = meshDevice.get_devices();
assert(allDevices.size() > 0 && "Unexpected empty device mesh");
DeviceList deviceList = {allDevices[0]};
assert(deviceList.size() == 1 && "Only one device is supported for now");
std::shared_ptr<Events> events = std::make_shared<Events>();
assert(program->device_programs()->size() == deviceMesh.size() &&
assert(program->device_programs()->size() == deviceList.size() &&
"Device programs size mismatch");
for (std::size_t i = 0; i < program->device_programs()->size(); ++i) {
::tt::tt_metal::Device *device = deviceMesh[i];
::tt::tt_metal::Device *device = deviceList[i];

ZoneScoped;
std::string zoneName = "submit_" + std::string(program->name()->c_str()) +
Expand Down
48 changes: 42 additions & 6 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

namespace tt::runtime::ttnn {

using DeviceMap = std::unordered_map<uint32_t, ::ttnn::Device *>;
using TensorMap = std::unordered_map<uint32_t, ::ttnn::Tensor *>;
struct ProgramTensorPool {
ProgramTensorPool(const TensorMap &liveTensors) : liveTensors(liveTensors) {}
Expand Down Expand Up @@ -56,13 +55,50 @@ struct ProgramTensorPool {
std::unordered_map<std::uint32_t, ::ttnn::Tensor> intermedTensors;
};

struct ProgramContext {
class ProgramContext {
public:
ProgramContext(const TensorMap &liveTensors, ::ttnn::MeshDevice *meshDevice)
: tensorPool(ProgramTensorPool(liveTensors)), meshDevice(meshDevice) {}

const ::ttnn::MeshDevice &getMeshDevice() const {
assert(meshDevice && "Mesh device not initialized");
return *meshDevice;
}

::ttnn::MeshDeviceView &getMeshView(uint32_t globalId) {
assert(meshViews.contains(globalId) &&
"Mesh view with global id not initialized");
return *(meshViews.at(globalId));
}

ProgramTensorPool &getTensorPool() { return tensorPool; }

void addMeshView(uint32_t globalId,
std::unique_ptr<::ttnn::MeshDeviceView> view) {
assert(not meshViews.contains(globalId) &&
"Mesh view with globalId already set");
meshViews.try_emplace(globalId, std::move(view));
}

::ttnn::Device &getDeviceFromView(uint32_t globalId, int deviceId) {
assert(meshViews.contains(globalId) && "Mesh view not initialized");
::tt::tt_metal::Coordinate deviceCoord =
meshViews.at(globalId)->find_device(deviceId);
return *(
meshViews.at(globalId)->get_device(deviceCoord.row, deviceCoord.col));
}

private:
ProgramTensorPool tensorPool;
DeviceMap allDevices;
DeviceMap devicePool;

ProgramContext(const TensorMap &liveTensors, const DeviceMap &allDevices)
: tensorPool(ProgramTensorPool(liveTensors)), allDevices(allDevices) {}
// Contains all devices borrowed from the user that are available to the
// program
::ttnn::MeshDevice *meshDevice = nullptr;

// Contains various views of meshDevice that is used by the program
// Will be populated by get_device ops
std::unordered_map<uint32_t, std::unique_ptr<::ttnn::MeshDeviceView>>
meshViews;
};
} // namespace tt::runtime::ttnn

Expand Down
56 changes: 46 additions & 10 deletions runtime/lib/ttnn/operations/context/get_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,52 @@
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::context {
void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) {
DeviceMap &devicePool = context.devicePool;
DeviceMap &allDevices = context.allDevices;
const flatbuffers::Vector<uint32_t> *chipIds = op->chip_ids();
assert(chipIds->size() == 1 && "Expected 1 chip id");
for (const uint32_t chipId : *chipIds) {
assert(allDevices.contains(chipId) && "Device not found");
auto [iter, inserted] =
devicePool.try_emplace(chipId, allDevices.at(chipId));
assert(inserted && "Duplicate device");

static std::pair<::tt::tt_metal::Coordinate, ::tt::tt_metal::Coordinate>
deriveMeshViewCoordinates(const ::ttnn::MeshDevice &meshDevice,
const std::unordered_set<uint32_t> &desiredDeviceIds,
const ::tt::target::Dim2d *meshViewShape) {
::tt::tt_metal::Coordinate topLeft, bottomRight;
for (int row = 0; row < meshDevice.num_rows(); row++) {
for (int col = 0; col < meshDevice.num_cols(); col++) {
const ::ttnn::Device *currDevice = meshDevice.get_device(row, col);
if (desiredDeviceIds.contains(currDevice->id())) {
topLeft.row = row;
topLeft.col = col;
// coords are inclusive when constructing mesh view
bottomRight.row = topLeft.row + meshViewShape->y() - 1;
bottomRight.col = topLeft.col + meshViewShape->x() - 1;
return std::make_pair(topLeft, bottomRight);
}
}
}
throw std::runtime_error("Device not found in mesh for get device op");
}

static std::unique_ptr<::ttnn::MeshDeviceView>
constructMeshView(const ::ttnn::MeshDevice &meshDevice,
const std::unordered_set<uint32_t> &desiredDeviceIds,
const ::tt::target::Dim2d *meshViewShape) {
// Carve out a mesh view from MeshDevice
auto [topLeft, bottomRight] =
deriveMeshViewCoordinates(meshDevice, desiredDeviceIds, meshViewShape);

return std::make_unique<::ttnn::MeshDeviceView>(meshDevice, topLeft,
bottomRight);
}

void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) {
const ::ttnn::MeshDevice &meshDevice = context.getMeshDevice();
const ::tt::target::Dim2d *meshViewShape = op->mesh();
assert(meshViewShape->y() == 1 &&
"Expected 1xN mesh shape for get device op");
const ::flatbuffers::Vector<uint32_t> *deviceIds = op->chip_ids();
std::unordered_set<uint32_t> desiredDeviceIds(deviceIds->begin(),
deviceIds->end());
assert(desiredDeviceIds.size() == deviceIds->size() &&
"Duplicate device ids in get device op");
std::unique_ptr<::ttnn::MeshDeviceView> meshView =
constructMeshView(meshDevice, desiredDeviceIds, meshViewShape);
context.addMeshView(op->out()->global_id(), std::move(meshView));
}
} // namespace tt::runtime::ttnn::operations::context
8 changes: 5 additions & 3 deletions runtime/lib/ttnn/operations/conv/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

namespace tt::runtime::ttnn::operations::conv {
void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap &devicePool = context.devicePool;
ProgramTensorPool &tensorPool = context.getTensorPool();
// TODO (jnie): Update this once we support multi device tensors
// Investigate how to handle multi device in conv2d
::ttnn::Device &device =
context.getDeviceFromView(op->device()->global_id(), 0);
const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id());
const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id());
std::optional<::ttnn::Tensor> bias =
Expand All @@ -19,7 +22,6 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) {
auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig();
config.dtype = utils::getDataType(op->input());
config.weights_dtype = utils::getDataType(op->weight());
::ttnn::Device &device = utils::getDevice(op->device(), devicePool);
::ttnn::Tensor out =
std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>(
input, weight, &device, op->in_channels(), op->out_channels(),
Expand Down
16 changes: 8 additions & 8 deletions runtime/lib/ttnn/operations/creation/empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

namespace tt::runtime::ttnn::operations::creation {
void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap &devicePool = context.devicePool;
::ttnn::DataType dtype =
::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype());
ProgramTensorPool &tensorPool = context.getTensorPool();
// TODO (jnie): Update this once we support multi device tensors
::ttnn::Device &device =
context.getDeviceFromView(op->device()->global_id(), 0);
::ttnn::DataType dtype = utils::getDataType(op->out());
// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
// using ROW_MAJOR until we fix it
::ttnn::Layout layout __attribute__((unused)) =
Expand All @@ -22,13 +23,12 @@ void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) {
::tt::runtime::ttnn::utils::toShapeFromFBShape(
*op->out()->desc()->shape())));

const tt::target::DeviceRef *device = op->device();
const tt::target::DeviceRef *deviceRef = op->device();
::ttnn::Tensor out;
if (device) {
if (deviceRef) {
::ttnn::MemoryConfig memoryConfig =
utils::createMemoryConfig(op->memcfg(), op->out());
out = ::ttnn::empty(shape, dtype, layout,
&utils::getDevice(device, devicePool), memoryConfig);
out = ::ttnn::empty(shape, dtype, layout, &device, memoryConfig);
} else {
out = ::ttnn::zeros(shape, dtype, layout);
}
Expand Down
7 changes: 4 additions & 3 deletions runtime/lib/ttnn/operations/creation/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

namespace tt::runtime::ttnn::operations::creation {
void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.tensorPool;
DeviceMap devicePool = context.devicePool;
::ttnn::Device &device = utils::getDevice(op->device(), devicePool);
ProgramTensorPool &tensorPool = context.getTensorPool();
// TODO (jnie): Update this once we support multi device tensors
::ttnn::Device &device =
context.getDeviceFromView(op->device()->global_id(), 0);
::ttnn::DataType outputDataType = utils::getDataType(op->out());
auto shape = ::ttnn::Shape(::tt::tt_metal::LegacyShape(
::tt::runtime::ttnn::utils::toShapeFromFBShape(
Expand Down
5 changes: 3 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

namespace tt::runtime::ttnn::operations::data_movement {
void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
std::vector<::ttnn::Tensor> inputs;
for (const auto &input : *op->inputs()) {
inputs.push_back(context.tensorPool.at(input->global_id()));
inputs.push_back(tensorPool.at(input->global_id()));
}
int32_t dim = op->dim();
::ttnn::Tensor out = ::ttnn::concat(inputs, dim);
context.tensorPool.insert_or_assign(op->out()->global_id(), out);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::data_movement
Loading

0 comments on commit ae94cb1

Please sign in to comment.