Skip to content

Commit

Permalink
Create scan.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Jan 17, 2024
1 parent f9618ea commit bfd7d84
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions torch_xla/csrc/ops/scan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "torch_xla/csrc/ops/scan.h"

#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::Value& input) {
xla::Shape input_shape = GetXlaShape(input);
return input_shape;
}

} // namespace

Scan::Scan(const Callable f, const at::Tensor& init, const at::Tensor& xs)
: XlaNode(torch::lazy::OpKind(at::aten::scan), {f, init, xs},
[&]() { return NodeOutputShape(init); }, 2,) {}

torch::lazy::NodePtr Scan::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<Scan>(operands.at(0), operands.at(1), operands.at(2));
}

XlaOpVector Map::Lower(LoweringContext* loctx) const {
xla::XlaOp f = loctx->GetOutputOp(operand(0));
xla::XlaOp init = loctx->GetOutputOp(operand(1));
xla::XlaOp xs = loctx->GetOutputOp(operand(2));
return ReturnOps(BuildMap(f, init, xs), loctx);
}

} // namespace torch_xla

0 comments on commit bfd7d84

Please sign in to comment.