diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 68c3cf21c..26b14654c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -28,7 +28,8 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> { let description = [{ }]; - let arguments = (ins AnyRankedTensor:$input); + let arguments = (ins AnyRankedTensor:$input, + TT_Device:$device); let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; @@ -324,6 +325,7 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> { AnyRankedTensor:$weight, Optional:$bias, AnyRankedTensor:$output, + TT_Device:$device, I32Attr:$in_channels, I32Attr:$out_channels, I32Attr:$batch_size, diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 0dc5ca96c..99443598c 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -11,6 +11,7 @@ table GetDeviceOp { table ToMemoryConfigOp { in0: tt.target.TensorRef; + device: tt.target.DeviceRef; out: tt.target.TensorRef; } @@ -103,6 +104,7 @@ table Conv2dOp { weight: tt.target.TensorRef; bias: tt.target.TensorRef; out: tt.target.TensorRef; + device: tt.target.DeviceRef; in_channels: uint32; out_channels: uint32; batch_size: uint32; diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 9a62b2280..dd2e413aa 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -75,9 +75,15 @@ class ToLayoutOpConversionPattern // ValueRange nonDPSOperands = adaptor.getOperands().drop_back(); + if (nonDPSOperands.size() != 1) { + return op->emitOpError( + "Expected exactly one non-DPS operand for toMemoryConfig op"); + } + Value nonDPSOperand = nonDPSOperands.front(); + auto device = getOrInsertDevice(rewriter, op); rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), - nonDPSOperands); + op, this->getTypeConverter()->convertType(op.getType()), nonDPSOperand, + device); return success(); } }; @@ -311,6 +317,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto device = getOrInsertDevice(rewriter, op); auto kernel_ty = mlir::cast(adaptor.getWeight().getType()); llvm::ArrayRef kernel_shape = kernel_ty.getShape(); @@ -361,9 +368,10 @@ class Conv2dOpConversionPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(), - adaptor.getOutput(), in_channels, out_channels, batch_size, input_width, - input_height, kernel_height, kernel_width, stride_height, stride_width, - padding_height, padding_width, dilation_height, dilation_width, groups); + adaptor.getOutput(), device, in_channels, out_channels, batch_size, + input_width, input_height, kernel_height, kernel_width, stride_height, + stride_width, padding_height, padding_width, dilation_height, + dilation_width, groups); return success(); } }; diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 42406d971..888065c6e 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -43,7 +43,10 @@ ::tt::target::Dim2dRange toFlatbuffer(CoreRangeAttr coreRange) { ::flatbuffers::Offset<::tt::target::DeviceRef> createDeviceRef(FlatbufferObjectCache &cache, Value device) { - return ::tt::target::CreateDeviceRef(*cache.fbb, cache.nextGlobalId()); + auto deviceType = mlir::cast(device.getType()); + auto chipIds = deviceType.getDesc().getChipIds(); + assert(chipIds.size() == 1 && "expected single chip"); + return ::tt::target::CreateDeviceRef(*cache.fbb, chipIds[0]); } template @@ -73,9 +76,11 @@ createOp(FlatbufferObjectCache &cache, ToMemoryConfigOp op) { constexpr uint64_t kHostAllocatedSize = 0; auto input = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto device = getOperandThroughDPSOps(op.getDevice()); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateToMemoryConfigOp(*cache.fbb, input, output); + return ::tt::target::ttnn::CreateToMemoryConfigOp( + *cache.fbb, input, cache.at<::tt::target::DeviceRef>(device), output); } ::flatbuffers::Offset<::tt::target::ttnn::EmptyOp> @@ -128,8 +133,11 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) { getOperandThroughDPSOps(op.getBias())); auto output = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getResult())); + + auto device = getOperandThroughDPSOps(op.getDevice()); return ::tt::target::ttnn::CreateConv2dOp( - *cache.fbb, in0, in1, in2, output, op.getInChannels(), + *cache.fbb, in0, in1, in2, output, + cache.at<::tt::target::DeviceRef>(device), op.getInChannels(), op.getOutChannels(), op.getBatchSize(), op.getInputHeight(), op.getInputWidth(), op.getKernelHeight(), op.getKernelWidth(), op.getStrideHeight(), op.getStrideWidth(), op.getPaddingHeight(), diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 4ece6f642..0e14c3dc0 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -89,6 +89,14 @@ static ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) { tensorRef->desc()->layout()->memory_desc()->data_type()); } +static ::ttnn::Device & +getDevice(const ::tt::target::DeviceRef *deviceRef, + std::unordered_map &devicePool) { + uint32_t deviceId = deviceRef->global_id(); + assert(devicePool.contains(deviceId) && "Device not found in device pool"); + return *devicePool.at(deviceId); +} + static CoreRangeSet toCoreRangeSet( const ::flatbuffers::Vector *coreRangeSet) { std::set coreRanges; @@ -325,7 +333,8 @@ handleToL1MemoryConfigOp(::ttnn::Device &device, // TODO(bug #272): right now hardcoding tilize/untilize, should determine with // tile shape blocked by issue #272 static void run(::tt::target::ttnn::ToMemoryConfigOp const *op, - ::ttnn::Device &device, ProgramTensorPool &tensorPool) { + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id()); assert(isOnHost(inputTensor) or @@ -347,17 +356,20 @@ static void run(::tt::target::ttnn::ToMemoryConfigOp const *op, break; } case ::tt::target::MemorySpace::DeviceDRAM: { + ::ttnn::Device &device = getDevice(op->device(), devicePool); handleToDramMemoryConfigOp(device, inputTensor, op->out(), tensorPool); break; } case ::tt::target::MemorySpace::DeviceL1: { + ::ttnn::Device &device = getDevice(op->device(), devicePool); handleToL1MemoryConfigOp(device, inputTensor, op->out(), tensorPool); break; } } } -static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::EmptyOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { ::ttnn::DataType targetDataTypeTTNN = getDataType(op->out()); // TODO(bug #582): ttnn::empty doesn't work properly with tile layout, @@ -366,6 +378,7 @@ static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device, auto shape = ::ttnn::Shape(::tt::tt_metal::Shape( utils::toShapeFromFBShape(*op->out()->desc()->shape()))); + ::ttnn::Device &device = getDevice(op->device(), devicePool); ::ttnn::Tensor out = ::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device); // use try emplace here so the program output tensor doesn't get overwritten @@ -468,7 +481,8 @@ static void runEltwiseUnaryWithFastAndApproximateModeOP( tensorPool.insert_or_assign(op->out()->global_id(), std::move(out)); } -static void run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::EltwiseOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { switch (op->type()) { /* Eltwise Binary */ @@ -547,7 +561,8 @@ static void runReductionOp( } static void run(::tt::target::ttnn::ReductionOp const *op, - ::ttnn::Device &device, ProgramTensorPool &tensorPool) { + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { switch (op->type()) { case ::tt::target::ttnn::ReductionOpType::Sum: { runReductionOp(op, tensorPool, ::ttnn::sum); @@ -581,7 +596,8 @@ static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor, return ::ttnn::reshape(tensor, vectorToArray(shape)); } -static void run(::tt::target::ttnn::ReshapeOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::ReshapeOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); const auto *fbShape = op->shape(); @@ -617,7 +633,8 @@ static void run(::tt::target::ttnn::ReshapeOp const *op, ::ttnn::Device &device, } static void run(::tt::target::ttnn::EmbeddingOp const *op, - ::ttnn::Device &device, ProgramTensorPool &tensorPool) { + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); // default params for embedding op @@ -632,7 +649,8 @@ static void run(::tt::target::ttnn::EmbeddingOp const *op, tensorPool.insert_or_assign(op->output()->global_id(), std::move(out)); } -static void run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::SoftmaxOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); int32_t dimension = op->dimension(); @@ -643,7 +661,8 @@ static void run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device, } static void run(::tt::target::ttnn::TransposeOp const *op, - ::ttnn::Device &device, ProgramTensorPool &tensorPool) { + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); int32_t dim0 = op->dim0(); int32_t dim1 = op->dim1(); @@ -673,7 +692,8 @@ static void run(::tt::target::ttnn::TransposeOp const *op, tensorPool.insert_or_assign(op->out()->global_id(), std::move(out)); } -static void run(::tt::target::ttnn::ConcatOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::ConcatOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { std::vector<::ttnn::Tensor> inputs; for (const auto &input : *op->inputs()) { @@ -685,7 +705,8 @@ static void run(::tt::target::ttnn::ConcatOp const *op, ::ttnn::Device &device, } // ANCHOR: adding_an_op_matmul_runtime -static void run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::MatmulOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); @@ -701,7 +722,8 @@ static void run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device, } // ANCHOR_END: adding_an_op_matmul_runtime -static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::Conv2dOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); @@ -711,7 +733,7 @@ static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device, auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig(); config.dtype = input.dtype(); config.weights_dtype = weight.dtype(); - + ::ttnn::Device &device = getDevice(op->device(), devicePool); ::ttnn::Tensor out = std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>( input, weight, &device, op->in_channels(), op->out_channels(), @@ -726,7 +748,8 @@ static void run(::tt::target::ttnn::Conv2dOp const *op, ::ttnn::Device &device, return; } -static void run(::tt::target::ttnn::DeallocOp const *op, ::ttnn::Device &device, +static void run(::tt::target::ttnn::DeallocOp const *op, + std::unordered_map &devicePool, ProgramTensorPool &tensorPool) { bool force = true; ::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id()); @@ -734,51 +757,69 @@ static void run(::tt::target::ttnn::DeallocOp const *op, ::ttnn::Device &device, tensorPool.erase(op->in()->global_id()); } -static void run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device, - ProgramTensorPool &tensorPool) { +static void +run(::tt::target::ttnn::GetDeviceOp const *op, + const std::unordered_map &allDevices, + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { + const flatbuffers::Vector *chipIds = op->chip_ids(); + assert(chipIds->size() == 1 && "Expected 1 chip id"); + for (const uint32_t chipId : *chipIds) { + assert(allDevices.contains(chipId) && "Device not found"); + auto [iter, inserted] = + devicePool.try_emplace(chipId, allDevices.at(chipId)); + assert(inserted && "Duplicate device"); + } +} + +static void +run(::tt::target::ttnn::Operation const *op, + const std::unordered_map &allDevices, + std::unordered_map &devicePool, + ProgramTensorPool &tensorPool) { switch (op->type_type()) { case ::tt::target::ttnn::OpType::GetDeviceOp: { - // TODO(bug #627) + return run(op->type_as_GetDeviceOp(), allDevices, devicePool, tensorPool); break; } case ::tt::target::ttnn::OpType::ToMemoryConfigOp: { - return run(op->type_as_ToMemoryConfigOp(), device, tensorPool); + return run(op->type_as_ToMemoryConfigOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::EmptyOp: { - return run(op->type_as_EmptyOp(), device, tensorPool); + return run(op->type_as_EmptyOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::FullOp: { // TODO(bug #626) break; } case ::tt::target::ttnn::OpType::EltwiseOp: { - return run(op->type_as_EltwiseOp(), device, tensorPool); + return run(op->type_as_EltwiseOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::MatmulOp: { - return run(op->type_as_MatmulOp(), device, tensorPool); + return run(op->type_as_MatmulOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::ReductionOp: { - return run(op->type_as_ReductionOp(), device, tensorPool); + return run(op->type_as_ReductionOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::EmbeddingOp: { - return run(op->type_as_EmbeddingOp(), device, tensorPool); + return run(op->type_as_EmbeddingOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::SoftmaxOp: { - return run(op->type_as_SoftmaxOp(), device, tensorPool); + return run(op->type_as_SoftmaxOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::TransposeOp: { - return run(op->type_as_TransposeOp(), device, tensorPool); + return run(op->type_as_TransposeOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::Conv2dOp: { - return run(op->type_as_Conv2dOp(), device, tensorPool); + return run(op->type_as_Conv2dOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::ConcatOp: { - return run(op->type_as_ConcatOp(), device, tensorPool); + return run(op->type_as_ConcatOp(), devicePool, tensorPool); case ::tt::target::ttnn::OpType::ReshapeOp: { - return run(op->type_as_ReshapeOp(), device, tensorPool); + return run(op->type_as_ReshapeOp(), devicePool, tensorPool); } case ::tt::target::ttnn::OpType::DeallocOp: { - return run(op->type_as_DeallocOp(), device, tensorPool); + return run(op->type_as_DeallocOp(), devicePool, tensorPool); } default: throw std::runtime_error("Unsupported operation type"); @@ -810,10 +851,13 @@ void runProgram(::ttnn::Device &device, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs) { std::unordered_map liveTensors; - + std::unordered_map allDevices; + std::unordered_map devicePool; int inputIndex = 0; assert(program->inputs()->size() == inputs.size()); bool isNop = handleNopProgram(program, inputs, outputs); + // Assuming single device for now until we support multichip + allDevices.try_emplace(device.id(), &device); for (::tt::target::TensorRef const *input : *program->inputs()) { auto [iter, inserted] = liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); @@ -827,10 +871,9 @@ void runProgram(::ttnn::Device &device, liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]); assert((isNop || inserted) && "Duplicate output tensor"); } - ProgramTensorPool tensorPool(std::move(liveTensors)); for (::tt::target::ttnn::Operation const *op : *program->operations()) { - run(op, device, tensorPool); + run(op, allDevices, devicePool, tensorPool); } } } // namespace tt::runtime::ttnn