Skip to content

Commit

Permalink
TTNN Runtime workaround to support NOP program (input returned as out…
Browse files Browse the repository at this point in the history
…put) #426 (#488)

- Discussed a bit offline, other more complicated solutions could
   exist including new ops, but this is simplest for now, and simplifies
   some future runtime refactoring that might occur.

 - Added original testcase generated from XLA that exposed this
   and added CHECK attributes inline
  • Loading branch information
kmabeeTT authored Aug 26, 2024
1 parent bd3e430 commit 2a4287d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
22 changes: 21 additions & 1 deletion runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,25 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
}
}

// Nop is single input, output tensor where input is returned as output.
bool handleNopProgram(::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs) {

bool is_nop = program->inputs()->size() == 1 &&
program->outputs()->size() == 1 &&
program->inputs()->Get(0)->global_id() ==
program->outputs()->Get(0)->global_id();

if (is_nop) {
void *src = ::tt::tt_metal::get_raw_host_data_ptr(*inputs.at(0));
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(*outputs.at(0));
std::uint32_t size = outputs[0]->volume() * outputs[0]->element_size();
std::memcpy(dst, src, size);
}
return is_nop;
}

void runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
Expand All @@ -434,6 +453,7 @@ void runProgram(::ttnn::Device &device,
int inputIndex = 0;
assert(program->inputs()->size() == inputs.size() &&
"Mismatch between program inputs and input tensors");
bool is_nop = handleNopProgram(program, inputs, outputs);
for (::tt::target::TensorRef const *input : *program->inputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]);
Expand All @@ -446,7 +466,7 @@ void runProgram(::ttnn::Device &device,
for (::tt::target::TensorRef const *output : *program->outputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]);
assert(inserted && "Duplicate output tensor");
assert(is_nop || inserted && "Duplicate output tensor");
}

for (::tt::target::ttnn::Operation const *op : *program->operations()) {
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_nop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttir-implicit-device --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

module @jit_convert_element_type attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x2xf32> {mhlo.layout_mode = "default"}) -> (tensor<2x2xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: "ttnn.close_device"[[C:.*]]
return %arg0 : tensor<2x2xf32>
}
}

0 comments on commit 2a4287d

Please sign in to comment.