Skip to content

Commit

Permalink
Remove nop check (#1622)
Browse files Browse the repository at this point in the history
The new submit API does not need to check for nops as it generically
returns tensors based on global id, regardless of if it's an input
tensor or not.
  • Loading branch information
jnie-TT authored Dec 18, 2024
1 parent cf80a1a commit 68a26fb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 24 deletions.
24 changes: 0 additions & 24 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,6 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) {
}
}

// Nop is single input, output tensor where input is returned as output.
static bool isNopProgram(const ::tt::target::ttnn::Program *program) {
return program->inputs()->size() == 1 && program->outputs()->size() == 1 &&
program->inputs()->Get(0)->global_id() ==
program->outputs()->Get(0)->global_id();
}

static ::ttnn::Tensor
handleNopProgram(::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs) {
const ::ttnn::Tensor &input = *inputs[0];
::ttnn::Tensor output =
::ttnn::zeros(input.get_shape(), input.get_dtype(), input.get_layout());
const void *src = ::tt::tt_metal::get_raw_host_data_ptr(input);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(output);
std::memcpy(dst, src, input.volume() * input.element_size());
return output;
}

namespace legacy {

static bool handleNopProgram(::tt::target::ttnn::Program const *program,
Expand Down Expand Up @@ -331,11 +312,6 @@ std::vector<Tensor> runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle);
::tt::target::ttnn::Program const *program =
fbb.programs()->Get(programIndex);
if (isNopProgram(program)) {
Tensor out =
utils::createRuntimeTensorFromTTNN(handleNopProgram(program, inputs));
return {out};
}
std::unordered_map<uint32_t, ::ttnn::Tensor *> liveTensors;
std::vector<uint32_t> programInputs;
int inputIndex = 0;
Expand Down
3 changes: 3 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ void memcpy(Tensor dst, Tensor src) {

void deallocateTensor(Tensor &tensor, bool force) {
::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
if (ttnnTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
return;
}
::ttnn::deallocate(ttnnTensor, force);
}

Expand Down

0 comments on commit 68a26fb

Please sign in to comment.