diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 2b24bc6ff..d22ce6d2b 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -104,6 +104,8 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device, op->out()->desc()->layout()->memory_desc()->memory_space(); switch (targetMemorySpace) { + // This case should only be used when gathering outputs at the end of the + // program case ::tt::target::MemorySpace::System: case ::tt::target::MemorySpace::SystemMMIO: { ::ttnn::Tensor result; @@ -136,14 +138,14 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device, result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize, false); tensorPool.push_back(result); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } else if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE) { ::ttnn::Tensor result = updateLayoutAndDataType( inputTensor, targetDataTypeTTNN, false, false); result = ::ttnn::to_memory_config(result, memConfig, std::nullopt); tensorPool.push_back(result); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } break; } @@ -163,14 +165,14 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device, result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize, false); tensorPool.push_back(result); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } else if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE) { ::ttnn::Tensor result = updateLayoutAndDataType( inputTensor, targetDataTypeTTNN, false, false); result = ::ttnn::to_memory_config(result, memConfig, std::nullopt); tensorPool.push_back(result); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } break; } @@ -181,7 +183,6 @@ static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device, std::unordered_map &liveTensors, std::list<::ttnn::Tensor> &tensorPool) { - ::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType( op->out()->desc()->layout()->memory_desc()->data_type()); // TODO: determine layout, hardcoding tile_layout for now @@ -190,6 +191,7 @@ run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device, utils::toShapeFromFBShape(*op->out()->desc()->shape()))); tensorPool.push_back( ::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device)); + // use try emplace here so the program output tensor doesn't get overwritten liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); } @@ -204,7 +206,7 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id()); auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id()); tensorPool.push_back(::ttnn::add(lhs, rhs)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } case ::tt::target::ttnn::EltwiseOpType::Multiply: { @@ -212,7 +214,7 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id()); auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id()); tensorPool.push_back(::ttnn::multiply(lhs, rhs)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } case ::tt::target::ttnn::EltwiseOpType::Subtract: { @@ -220,7 +222,7 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id()); auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id()); tensorPool.push_back(::ttnn::subtract(lhs, rhs)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } case ::tt::target::ttnn::EltwiseOpType::GreaterEqual: { @@ -228,7 +230,7 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, ::ttnn::Tensor &lhs = *liveTensors.at(op->ins()->Get(0)->global_id()); ::ttnn::Tensor &rhs = *liveTensors.at(op->ins()->Get(1)->global_id()); tensorPool.push_back(::ttnn::ge(lhs, rhs)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } /* Eltwise Unary */ @@ -236,7 +238,7 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, assert(op->ins()->size() == 1 && "Unsupported number of inputs"); ::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id()); tensorPool.push_back(::ttnn::relu(in)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } } @@ -258,7 +260,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device, tensorPool.push_back(::ttnn::sum(in, dim_arg, op->keep_dim())); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } case ::tt::target::ttnn::ReductionOpType::Mean: { @@ -272,7 +274,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device, tensorPool.push_back(::ttnn::mean(in, dim_arg, op->keep_dim())); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } } @@ -286,7 +288,7 @@ run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device, int32_t dimension = op->dimension(); tensorPool.push_back(::ttnn::softmax(in, dimension)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } static void @@ -317,7 +319,7 @@ run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device, // work at the moment, we use this temporary solution. auto unsqueezed_input = ::ttnn::unsqueeze_to_4D(in); tensorPool.push_back(::ttnn::permute(unsqueezed_input, dimensionOrder)); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } // ANCHOR: adding_an_op_matmul_runtime @@ -329,7 +331,7 @@ run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device, auto &rhs = *liveTensors.at(op->in1()->global_id()); tensorPool.push_back(::ttnn::operations::matmul::matmul( lhs, rhs, std::nullopt, ::tt::operations::primary::Matmul{})); - liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); } // ANCHOR_END: adding_an_op_matmul_runtime