Skip to content

Commit

Permalink
Move maxpool to the original place
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Feb 2, 2024
1 parent bb075e3 commit cef020f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
13 changes: 13 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}

} // namespace aten_autograd_ops
} // namespace torch_xla
14 changes: 0 additions & 14 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ namespace torch_xla {
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY_FRAGMENT(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
Expand Down

0 comments on commit cef020f

Please sign in to comment.