Skip to content

Commit

Permalink
#391: use insert_or_assign due to EmptyOp occupying output slot in li…
Browse files Browse the repository at this point in the history
…ve tensors. (#432)
  • Loading branch information
jnie-TT authored Aug 17, 2024
1 parent 298afb5 commit b03e161
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
op->out()->desc()->layout()->memory_desc()->memory_space();

switch (targetMemorySpace) {
// This case should only be used when gathering outputs at the end of the
// program
case ::tt::target::MemorySpace::System:
case ::tt::target::MemorySpace::SystemMMIO: {
::ttnn::Tensor result;
Expand Down Expand Up @@ -136,14 +138,14 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize,
false);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, false, false);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}
break;
}
Expand All @@ -163,14 +165,14 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize,
false);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, false, false);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}
break;
}
Expand All @@ -181,7 +183,6 @@ static void
run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {

::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType(
op->out()->desc()->layout()->memory_desc()->data_type());
// TODO: determine layout, hardcoding tile_layout for now
Expand All @@ -190,6 +191,7 @@ run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device,
utils::toShapeFromFBShape(*op->out()->desc()->shape())));
tensorPool.push_back(
::ttnn::empty(shape, targetDataTypeTTNN, desiredLayout, device));
// use try emplace here so the program output tensor doesn't get overwritten
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
}

Expand All @@ -204,39 +206,39 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::add(lhs, rhs));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Multiply: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::multiply(lhs, rhs));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Subtract: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::subtract(lhs, rhs));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::GreaterEqual: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
::ttnn::Tensor &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
::ttnn::Tensor &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::ge(lhs, rhs));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
/* Eltwise Unary */
case ::tt::target::ttnn::EltwiseOpType::Relu: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::relu(in));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
}
Expand All @@ -258,7 +260,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,

tensorPool.push_back(::ttnn::sum(in, dim_arg, op->keep_dim()));

liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::ReductionOpType::Mean: {
Expand All @@ -272,7 +274,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,

tensorPool.push_back(::ttnn::mean(in, dim_arg, op->keep_dim()));

liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
}
Expand All @@ -286,7 +288,7 @@ run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device,
int32_t dimension = op->dimension();

tensorPool.push_back(::ttnn::softmax(in, dimension));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}

static void
Expand Down Expand Up @@ -317,7 +319,7 @@ run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device,
// work at the moment, we use this temporary solution.
auto unsqueezed_input = ::ttnn::unsqueeze_to_4D(in);
tensorPool.push_back(::ttnn::permute(unsqueezed_input, dimensionOrder));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}

// ANCHOR: adding_an_op_matmul_runtime
Expand All @@ -329,7 +331,7 @@ run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
auto &rhs = *liveTensors.at(op->in1()->global_id());
tensorPool.push_back(::ttnn::operations::matmul::matmul(
lhs, rhs, std::nullopt, ::tt::operations::primary::Matmul{}));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
}
// ANCHOR_END: adding_an_op_matmul_runtime

Expand Down

0 comments on commit b03e161

Please sign in to comment.