diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 81cfdfb4f42..40b9790a121 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,5 +253,18 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } +TORCH_LIBRARY_FRAGMENT(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); +} + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e07cf45e224..c0481e4bca2 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -37,20 +37,6 @@ namespace torch_xla { // to define a library of operators in the namespace. Used to define a new set // of custom operators that do not already exist in PyTorch. TORCH_LIBRARY_FRAGMENT(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); m.def( "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " "tile_assignment, int[][] group_assignment, int[][] replication_groups, "