Skip to content

Commit

Permalink
Additional fixes to metal runtime bootstrap #290 (#291)
Browse files Browse the repository at this point in the history
- Make inputs per-device (multi-device inputs still need to be worked
out)
- Create common functions for creating metal Buffers and CoreRange from
flatbuffer types
- Implement wait API + events, update ttrt to use it
- Bug fix logical_grid_size -> compute_with_storage_grid_size, the
latter accurately gives post-harvested shape
- Automatically move inputs / outputs during submit
  • Loading branch information
nsmithtt authored Aug 7, 2024
1 parent 47a7453 commit 2e687dc
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 137 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Target/TTMetal/binary.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ include "command.fbs";
namespace tt.target.metal;

table DeviceProgram {
inputs: [TensorRef];
outputs: [TensorRef];
command_queues: [CommandQueue];
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ class TTMetalSerializeToBinary

std::vector<::flatbuffers::Offset<::tt::target::metal::DeviceProgram>>
devicePrograms = {
::tt::target::metal::CreateDeviceProgramDirect(fbb, &commandQueues),
::tt::target::metal::CreateDeviceProgramDirect(
fbb, &cqBuilder.inputs, &cqBuilder.outputs, &commandQueues),
};

std::vector<::flatbuffers::Offset<::tt::target::metal::Program>> programs =
Expand Down
100 changes: 91 additions & 9 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,97 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex,

void wait(Event event);

std::shared_ptr<::tt::tt_metal::Event> executeCommandQueue(
::tt::tt_metal::Device *device, ::tt::target::metal::CommandQueue const *cq,
std::size_t cq_id,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&inputs,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&outputs);
using InputBuffer =
std::tuple<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>,
std::shared_ptr<::tt::tt_metal::Event>>;

using OutputBuffer =
std::tuple<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>;

std::shared_ptr<::tt::tt_metal::Event>
executeCommandQueue(::tt::tt_metal::Device *device,
::tt::target::metal::CommandQueue const *cq,
std::size_t cq_id, std::vector<InputBuffer> const &inputs,
std::vector<OutputBuffer> const &outputs);

// Utils

inline CoreRangeSet toCoreRangeSet(
::flatbuffers::Vector<tt::target::Dim2dRange const *> const *coreRangeSet) {
std::set<CoreRange> coreRanges;
for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) {
CoreCoord start(coreRange->loc().x(), coreRange->loc().y());
// End is inclusive
CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1,
coreRange->loc().y() + coreRange->size().y() - 1);
coreRanges.emplace(start, end);
}
return CoreRangeSet(coreRanges);
}

#pragma clang diagnostic push
// Needed to construct ShardedBufferConfig
#pragma clang diagnostic ignored "-Wc++20-designator"

inline std::shared_ptr<::tt::tt_metal::Buffer>
createBufferFromTensorRef(::tt::tt_metal::Device *device,
::tt::target::TensorRef const *tensorRef) {
::tt::target::TensorDesc const *tensorDesc = tensorRef->desc();
::tt::target::LayoutDesc const *layout = tensorDesc->layout();
CoreRangeSet coreRangeSet = toCoreRangeSet(layout->core_range_set());
auto shardRank = layout->memory_desc()->shape()->size();
::tt::target::Dim2d const *tile_shape = layout->memory_desc()->tile_shape();
std::array<uint32_t, 2> shardShape;
shardShape[1] =
layout->memory_desc()->shape()->Get(shardRank - 1) * tile_shape->x();
shardShape[0] = tile_shape->y();
for (unsigned i = 0; i < shardRank - 1; ++i) {
shardShape[0] *= layout->memory_desc()->shape()->Get(i);
}
ShardSpec shardSpec(coreRangeSet, shardShape);
std::array<uint32_t, 2> pageShape = {static_cast<uint32_t>(tile_shape->y()),
shardShape[1]};

auto tensorRank = layout->stride()->size();
auto innerDim = layout->stride()->Get(tensorRank - 2);
assert(layout->stride()->size() >= 2);
assert((layout->stride()->Get(0) * tensorDesc->shape()->Get(0)) %
(pageShape[0] * innerDim) ==
0);
assert(innerDim % pageShape[1] == 0);
std::array<uint32_t, 2> tensorShape = {
(layout->stride()->Get(0) * tensorDesc->shape()->Get(0)) /
(pageShape[0] * innerDim),
innerDim / pageShape[1],
};

ShardSpecBuffer shardSpecBuffer(shardSpec, pageShape, tensorShape);
assert(layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceDRAM ||
layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceL1);
BufferType bufferType = layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceDRAM
? BufferType::DRAM
: BufferType::L1;
uint64_t pageSize =
pageShape[0] * pageShape[1] * 4; // FIXME: Hardcoded data type size
uint64_t size = tensorShape[0] * tensorShape[1] * pageSize;
auto shardedBufferConfig = ShardedBufferConfig{
.device = device,
.size = size,
.page_size = pageSize,
.buffer_type = bufferType,
.buffer_layout = TensorMemoryLayout::BLOCK_SHARDED,
.shard_parameters = shardSpecBuffer,
};
std::shared_ptr<::tt::tt_metal::Buffer> buffer =
::tt::tt_metal::CreateBuffer(shardedBufferConfig);
assert(tensorRef->address());
buffer->set_address(tensorRef->address());
return buffer;
}
#pragma clang diagnostic pop

} // namespace tt::runtime::ttmetal

Expand Down
7 changes: 5 additions & 2 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ std::vector<TensorDesc> getProgramInputs(Flatbuffer binary,
std::uint32_t programIndex) {
std::vector<TensorDesc> inputs;
auto const *program = getBinary(binary)->programs()->Get(programIndex);
for (auto const *input : *program->inputs()) {
assert(program->device_programs()->size() == 1 &&
"Currently only one device is supported");
for (auto const *input : *program->device_programs()->Get(0)->inputs()) {
TensorDesc desc;
desc.shape = {input->desc()->shape()->begin(),
input->desc()->shape()->end()};
Expand All @@ -156,7 +158,8 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
std::uint32_t programIndex) {
std::vector<TensorDesc> outputs;
auto const *program = getBinary(binary)->programs()->Get(programIndex);
for (auto const *output : *program->outputs()) {
assert(program->device_programs()->size() == 1);
for (auto const *output : *program->device_programs()->Get(0)->outputs()) {
TensorDesc desc;
desc.shape = {output->desc()->shape()->begin(),
output->desc()->shape()->end()};
Expand Down
10 changes: 9 additions & 1 deletion runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ Event submit(Device deviceHandle, Binary executableHandle,
#endif
}

void wait(Event) { throw std::runtime_error("Not implemented"); }
void wait(Event event) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::wait(event);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::wait(event);
#else
throw std::runtime_error("runtime is not enabled");
#endif
}

} // namespace tt::runtime
116 changes: 26 additions & 90 deletions runtime/lib/ttmetal/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,20 @@
#include "ttmlir/Target/TTMetal/Target.h"
#include "ttmlir/Version.h"

// Needed to construct ShardedBufferConfig
#pragma clang diagnostic ignored "-Wc++20-designator"

namespace tt::runtime::ttmetal {

struct CQExecutor {
::tt::tt_metal::Device *device;
std::vector<std::shared_ptr<::tt::tt_metal::Event>> initEvents;
std::unordered_map<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>
buffers;
std::unordered_map<std::uint32_t, std::shared_ptr<::tt::tt_metal::Event>>
events;
::tt::tt_metal::CommandQueue *cq;

CQExecutor(
::tt::tt_metal::Device *device, std::size_t cq_id,
std::vector<std::pair<std::uint32_t,
std::shared_ptr<::tt::tt_metal::Buffer>>> const
&inputs,
std::vector<std::pair<std::uint32_t,
std::shared_ptr<::tt::tt_metal::Buffer>>> const
&outputs);
CQExecutor(::tt::tt_metal::Device *device, std::size_t cq_id,
std::vector<InputBuffer> const &inputs,
std::vector<OutputBuffer> const &outputs);

std::shared_ptr<::tt::tt_metal::Event>
execute(::tt::target::metal::CommandQueue const *commandQueue);
Expand All @@ -49,28 +42,33 @@ struct CQExecutor {
void execute(::tt::target::metal::FinishCommand const *command);
};

CQExecutor::CQExecutor(
::tt::tt_metal::Device *device, std::size_t cq_id,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&inputs,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&outputs)
CQExecutor::CQExecutor(::tt::tt_metal::Device *device, std::size_t cq_id,
std::vector<InputBuffer> const &inputs,
std::vector<OutputBuffer> const &outputs)
: device(device) {
for (std::size_t i = 0; i < inputs.size(); ++i) {
buffers[inputs[i].first] = inputs[i].second;
auto [global_id, buffer, event] = inputs[i];
buffers[global_id] = buffer;
if (event) {
initEvents.push_back(event);
}
}

for (std::size_t i = 0; i < outputs.size(); ++i) {
buffers[outputs[i].first] = outputs[i].second;
auto [global_id, buffer] = outputs[i];
buffers[global_id] = buffer;
}

cq = &device->command_queue(cq_id);
}

std::shared_ptr<::tt::tt_metal::Event>
CQExecutor::execute(::tt::target::metal::CommandQueue const *commandQueue) {
for (auto const &event : initEvents) {
::tt::tt_metal::EnqueueWaitForEvent(*cq, event);
}
initEvents.clear();

for (::tt::target::metal::Command const *command :
*commandQueue->commands()) {
execute(command);
Expand Down Expand Up @@ -134,18 +132,6 @@ void CQExecutor::execute(::tt::target::metal::Command const *command) {
}
}

static CoreRangeSet toCoreRangeSet(
::flatbuffers::Vector<tt::target::Dim2dRange const *> const *coreRangeSet) {
std::set<CoreRange> coreRanges;
for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) {
CoreCoord start(coreRange->loc().x(), coreRange->loc().y());
CoreCoord end(coreRange->loc().x() + coreRange->size().x(),
coreRange->loc().y() + coreRange->size().y());
coreRanges.emplace(start, end);
}
return CoreRangeSet(coreRanges);
}

static void writeFile(std::string const &fileName, char const *data,
std::size_t size) {
std::ofstream file(fileName);
Expand Down Expand Up @@ -256,54 +242,8 @@ void CQExecutor::execute(

void CQExecutor::execute(
::tt::target::metal::CreateBufferCommand const *command) {
::tt::target::LayoutDesc const *layout = command->ref()->desc()->layout();
CoreRangeSet coreRangeSet = toCoreRangeSet(layout->core_range_set());
auto shardRank = layout->memory_desc()->shape()->size();
std::array<uint32_t, 2> shardShape;
shardShape[1] = layout->memory_desc()->shape()->Get(shardRank - 1) *
layout->memory_desc()->tile_shape()->x();
shardShape[0] = layout->memory_desc()->tile_shape()->y();
for (unsigned i = 0; i < shardRank - 1; ++i) {
shardShape[0] *= layout->memory_desc()->shape()->Get(i);
}
ShardSpec shardSpec(coreRangeSet, shardShape);

auto tensorRank = layout->stride()->size();
std::array<uint32_t, 2> tensorShape;
assert(layout->stride()->size() >= 2);
tensorShape[1] = layout->stride()->Get(tensorRank - 2);
tensorShape[0] =
layout->stride()->Get(0) * command->ref()->desc()->shape()->Get(0);

auto pageShape = shardShape;
ShardSpecBuffer shardSpecBuffer(shardSpec, pageShape, tensorShape);

uint64_t gridVolume = 1;
for (auto dim2dRange : *layout->core_range_set()) {
gridVolume *= dim2dRange->size().x() * dim2dRange->size().y();
}

assert(layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceDRAM ||
layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceL1);
BufferType bufferType = layout->memory_desc()->memory_space() ==
::tt::target::MemorySpace::DeviceDRAM
? BufferType::DRAM
: BufferType::L1;
uint64_t size = gridVolume * layout->memory_desc()->size();
auto shardedBufferConfig = ShardedBufferConfig{
.device = device,
.size = size,
.page_size = size,
.buffer_type = bufferType,
.buffer_layout = TensorMemoryLayout::HEIGHT_SHARDED,
.shard_parameters = shardSpecBuffer,
};
std::shared_ptr<::tt::tt_metal::Buffer> buffer =
::tt::tt_metal::CreateBuffer(shardedBufferConfig);
buffer->set_address(command->ref()->address());
buffers[command->ref()->global_id()] = buffer;
buffers[command->ref()->global_id()] =
createBufferFromTensorRef(device, command->ref());
}

void CQExecutor::execute(
Expand Down Expand Up @@ -352,15 +292,11 @@ void CQExecutor::execute(::tt::target::metal::FinishCommand const *) {
::tt::tt_metal::Finish(*cq);
}

std::shared_ptr<::tt::tt_metal::Event> executeCommandQueue(
::tt::tt_metal::Device *device,
::tt::target::metal::CommandQueue const *commandQueue, std::size_t cq_id,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&inputs,
std::vector<
std::pair<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>>> const
&outputs) {
std::shared_ptr<::tt::tt_metal::Event>
executeCommandQueue(::tt::tt_metal::Device *device,
::tt::target::metal::CommandQueue const *commandQueue,
std::size_t cq_id, std::vector<InputBuffer> const &inputs,
std::vector<OutputBuffer> const &outputs) {
CQExecutor executor(device, cq_id, inputs, outputs);
return executor.execute(commandQueue);
}
Expand Down
Loading

0 comments on commit 2e687dc

Please sign in to comment.