Skip to content

Commit

Permalink
runtime support for get device op (#658)
Browse files Browse the repository at this point in the history
* #627: runtime support for get device op, add device parameter for ops that explicitly use device
  • Loading branch information
jnie-TT authored Sep 10, 2024
1 parent fbe55df commit 745cec2
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 40 deletions.
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
let description = [{
}];

let arguments = (ins AnyRankedTensor:$input);
let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand Down Expand Up @@ -324,6 +325,7 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
I32Attr:$batch_size,
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ table GetDeviceOp {

table ToMemoryConfigOp {
in0: tt.target.TensorRef;
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
}

Expand Down Expand Up @@ -103,6 +104,7 @@ table Conv2dOp {
weight: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
in_channels: uint32;
out_channels: uint32;
batch_size: uint32;
Expand Down
18 changes: 13 additions & 5 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,15 @@ class ToLayoutOpConversionPattern
//
ValueRange nonDPSOperands = adaptor.getOperands().drop_back();

if (nonDPSOperands.size() != 1) {
return op->emitOpError(
"Expected exactly one non-DPS operand for toMemoryConfig op");
}
Value nonDPSOperand = nonDPSOperands.front();
auto device = getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::ToMemoryConfigOp>(
op, this->getTypeConverter()->convertType(op.getType()),
nonDPSOperands);
op, this->getTypeConverter()->convertType(op.getType()), nonDPSOperand,
device);
return success();
}
};
Expand Down Expand Up @@ -311,6 +317,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = getOrInsertDevice(rewriter, op);
auto kernel_ty =
mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
llvm::ArrayRef<std::int64_t> kernel_shape = kernel_ty.getShape();
Expand Down Expand Up @@ -361,9 +368,10 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
rewriter.replaceOpWithNewOp<ttnn::Conv2dOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
adaptor.getOutput(), in_channels, out_channels, batch_size, input_width,
input_height, kernel_height, kernel_width, stride_height, stride_width,
padding_height, padding_width, dilation_height, dilation_width, groups);
adaptor.getOutput(), device, in_channels, out_channels, batch_size,
input_width, input_height, kernel_height, kernel_width, stride_height,
stride_width, padding_height, padding_width, dilation_height,
dilation_width, groups);
return success();
}
};
Expand Down
14 changes: 11 additions & 3 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ ::tt::target::Dim2dRange toFlatbuffer(CoreRangeAttr coreRange) {

::flatbuffers::Offset<::tt::target::DeviceRef>
createDeviceRef(FlatbufferObjectCache &cache, Value device) {
return ::tt::target::CreateDeviceRef(*cache.fbb, cache.nextGlobalId());
auto deviceType = mlir::cast<DeviceType>(device.getType());
auto chipIds = deviceType.getDesc().getChipIds();
assert(chipIds.size() == 1 && "expected single chip");
return ::tt::target::CreateDeviceRef(*cache.fbb, chipIds[0]);
}

template <typename OpT>
Expand Down Expand Up @@ -73,9 +76,11 @@ createOp(FlatbufferObjectCache &cache, ToMemoryConfigOp op) {
constexpr uint64_t kHostAllocatedSize = 0;
auto input =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto device = getOperandThroughDPSOps(op.getDevice());
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateToMemoryConfigOp(*cache.fbb, input, output);
return ::tt::target::ttnn::CreateToMemoryConfigOp(
*cache.fbb, input, cache.at<::tt::target::DeviceRef>(device), output);
}

::flatbuffers::Offset<::tt::target::ttnn::EmptyOp>
Expand Down Expand Up @@ -128,8 +133,11 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) {
getOperandThroughDPSOps(op.getBias()));
auto output = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));

auto device = getOperandThroughDPSOps(op.getDevice());
return ::tt::target::ttnn::CreateConv2dOp(
*cache.fbb, in0, in1, in2, output, op.getInChannels(),
*cache.fbb, in0, in1, in2, output,
cache.at<::tt::target::DeviceRef>(device), op.getInChannels(),
op.getOutChannels(), op.getBatchSize(), op.getInputHeight(),
op.getInputWidth(), op.getKernelHeight(), op.getKernelWidth(),
op.getStrideHeight(), op.getStrideWidth(), op.getPaddingHeight(),
Expand Down
105 changes: 74 additions & 31 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ static ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) {
tensorRef->desc()->layout()->memory_desc()->data_type());
}

static ::ttnn::Device &
getDevice(const ::tt::target::DeviceRef *deviceRef,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool) {
uint32_t deviceId = deviceRef->global_id();
assert(devicePool.contains(deviceId) && "Device not found in device pool");
return *devicePool.at(deviceId);
}

static CoreRangeSet toCoreRangeSet(
const ::flatbuffers::Vector<const tt::target::Dim2dRange *> *coreRangeSet) {
std::set<CoreRange> coreRanges;
Expand Down Expand Up @@ -325,7 +333,8 @@ handleToL1MemoryConfigOp(::ttnn::Device &device,
// TODO(bug #272): right now hardcoding tilize/untilize, should determine with
// tile shape blocked by issue #272
static void run(::tt::target::ttnn::ToMemoryConfigOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {

const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id());
assert(isOnHost(inputTensor) or
Expand All @@ -347,17 +356,20 @@ static void run(::tt::target::ttnn::ToMemoryConfigOp const *op,
break;
}
case ::tt::target::MemorySpace::DeviceDRAM: {
::ttnn::Device &device = getDevice(op->device(), devicePool);
handleToDramMemoryConfigOp(device, inputTensor, op->out(), tensorPool);
break;
}
case ::tt::target::MemorySpace::DeviceL1: {
::ttnn::Device &device = getDevice(op->device(), devicePool);
handleToL1MemoryConfigOp(device, inputTensor, op->out(), tensorPool);
break;
}
}
}

static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::EmptyOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
::ttnn::DataType targetDataTypeTTNN = getDataType(op->out());
// TODO(bug #582): ttnn::empty doesn't work properly with tile layout,
Expand All @@ -366,6 +378,7 @@ static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device,
auto shape = ::ttnn::Shape(::tt::tt_metal::Shape(
utils::toShapeFromFBShape(*op->out()->desc()->shape())));

::ttnn::Device &device = getDevice(op->device(), devicePool);
::ttnn::Tensor out =
::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device);
// use try emplace here so the program output tensor doesn't get overwritten
Expand Down Expand Up @@ -468,7 +481,8 @@ static void runEltwiseUnaryWithFastAndApproximateModeOP(
tensorPool.insert_or_assign(op->out()->global_id(), std::move(out));
}

static void run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::EltwiseOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
switch (op->type()) {
/* Eltwise Binary */
Expand Down Expand Up @@ -547,7 +561,8 @@ static void runReductionOp(
}

static void run(::tt::target::ttnn::ReductionOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
switch (op->type()) {
case ::tt::target::ttnn::ReductionOpType::Sum: {
runReductionOp(op, tensorPool, ::ttnn::sum);
Expand Down Expand Up @@ -581,7 +596,8 @@ static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor,
return ::ttnn::reshape(tensor, vectorToArray<Rank>(shape));
}

static void run(::tt::target::ttnn::ReshapeOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::ReshapeOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
const auto *fbShape = op->shape();
Expand Down Expand Up @@ -617,7 +633,8 @@ static void run(::tt::target::ttnn::ReshapeOp const *op, ::ttnn::Device &device,
}

static void run(::tt::target::ttnn::EmbeddingOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id());
const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id());
// default params for embedding op
Expand All @@ -632,7 +649,8 @@ static void run(::tt::target::ttnn::EmbeddingOp const *op,
tensorPool.insert_or_assign(op->output()->global_id(), std::move(out));
}

static void run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::SoftmaxOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
int32_t dimension = op->dimension();
Expand All @@ -643,7 +661,8 @@ static void run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device,
}

static void run(::tt::target::ttnn::TransposeOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
int32_t dim0 = op->dim0();
int32_t dim1 = op->dim1();
Expand Down Expand Up @@ -673,7 +692,8 @@ static void run(::tt::target::ttnn::TransposeOp const *op,
tensorPool.insert_or_assign(op->out()->global_id(), std::move(out));
}

static void run(::tt::target::ttnn::ConcatOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::ConcatOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
std::vector<::ttnn::Tensor> inputs;
for (const auto &input : *op->inputs()) {
Expand All @@ -685,7 +705,8 @@ static void run(::tt::target::ttnn::ConcatOp const *op, ::ttnn::Device &device,
}

// ANCHOR: adding_an_op_matmul_runtime
static void run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::MatmulOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id());
Expand All @@ -701,7 +722,8 @@ static void run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
}
// ANCHOR_END: adding_an_op_matmul_runtime

static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::Conv2dOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id());
const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id());
Expand All @@ -711,7 +733,7 @@ static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device,
auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig();
config.dtype = input.dtype();
config.weights_dtype = weight.dtype();

::ttnn::Device &device = getDevice(op->device(), devicePool);
::ttnn::Tensor out =
std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>(
input, weight, &device, op->in_channels(), op->out_channels(),
Expand All @@ -726,59 +748,78 @@ static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device,
return;
}

static void run(::tt::target::ttnn::DeallocOp const *op, ::ttnn::Device &device,
static void run(::tt::target::ttnn::DeallocOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
bool force = true;
::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id());
tensor.deallocate(force);
tensorPool.erase(op->in()->global_id());
}

static void run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
ProgramTensorPool &tensorPool) {
static void
run(::tt::target::ttnn::GetDeviceOp const *op,
const std::unordered_map<uint32_t, ::ttnn::Device *> &allDevices,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const flatbuffers::Vector<uint32_t> *chipIds = op->chip_ids();
assert(chipIds->size() == 1 && "Expected 1 chip id");
for (const uint32_t chipId : *chipIds) {
assert(allDevices.contains(chipId) && "Device not found");
auto [iter, inserted] =
devicePool.try_emplace(chipId, allDevices.at(chipId));
assert(inserted && "Duplicate device");
}
}

static void
run(::tt::target::ttnn::Operation const *op,
const std::unordered_map<uint32_t, ::ttnn::Device *> &allDevices,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
switch (op->type_type()) {
case ::tt::target::ttnn::OpType::GetDeviceOp: {
// TODO(bug #627)
return run(op->type_as_GetDeviceOp(), allDevices, devicePool, tensorPool);
break;
}
case ::tt::target::ttnn::OpType::ToMemoryConfigOp: {
return run(op->type_as_ToMemoryConfigOp(), device, tensorPool);
return run(op->type_as_ToMemoryConfigOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::EmptyOp: {
return run(op->type_as_EmptyOp(), device, tensorPool);
return run(op->type_as_EmptyOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::FullOp: {
// TODO(bug #626)
break;
}
case ::tt::target::ttnn::OpType::EltwiseOp: {
return run(op->type_as_EltwiseOp(), device, tensorPool);
return run(op->type_as_EltwiseOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::MatmulOp: {
return run(op->type_as_MatmulOp(), device, tensorPool);
return run(op->type_as_MatmulOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::ReductionOp: {
return run(op->type_as_ReductionOp(), device, tensorPool);
return run(op->type_as_ReductionOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::EmbeddingOp: {
return run(op->type_as_EmbeddingOp(), device, tensorPool);
return run(op->type_as_EmbeddingOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::SoftmaxOp: {
return run(op->type_as_SoftmaxOp(), device, tensorPool);
return run(op->type_as_SoftmaxOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::TransposeOp: {
return run(op->type_as_TransposeOp(), device, tensorPool);
return run(op->type_as_TransposeOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::Conv2dOp: {
return run(op->type_as_Conv2dOp(), device, tensorPool);
return run(op->type_as_Conv2dOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::ConcatOp: {
return run(op->type_as_ConcatOp(), device, tensorPool);
return run(op->type_as_ConcatOp(), devicePool, tensorPool);
case ::tt::target::ttnn::OpType::ReshapeOp: {
return run(op->type_as_ReshapeOp(), device, tensorPool);
return run(op->type_as_ReshapeOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::DeallocOp: {
return run(op->type_as_DeallocOp(), device, tensorPool);
return run(op->type_as_DeallocOp(), devicePool, tensorPool);
}
default:
throw std::runtime_error("Unsupported operation type");
Expand Down Expand Up @@ -810,10 +851,13 @@ void runProgram(::ttnn::Device &device,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs) {
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> liveTensors;

std::unordered_map<std::uint32_t, ::ttnn::Device *> allDevices;
std::unordered_map<std::uint32_t, ::ttnn::Device *> devicePool;
int inputIndex = 0;
assert(program->inputs()->size() == inputs.size());
bool isNop = handleNopProgram(program, inputs, outputs);
// Assuming single device for now until we support multichip
allDevices.try_emplace(device.id(), &device);
for (::tt::target::TensorRef const *input : *program->inputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]);
Expand All @@ -827,10 +871,9 @@ void runProgram(::ttnn::Device &device,
liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]);
assert((isNop || inserted) && "Duplicate output tensor");
}

ProgramTensorPool tensorPool(std::move(liveTensors));
for (::tt::target::ttnn::Operation const *op : *program->operations()) {
run(op, device, tensorPool);
run(op, allDevices, devicePool, tensorPool);
}
}
} // namespace tt::runtime::ttnn

0 comments on commit 745cec2

Please sign in to comment.