From 2d73a5fd1900ed14cda015f08e77a9e29f205cef Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Fri, 18 Oct 2024 09:27:27 -0700 Subject: [PATCH] Fix support for `torch.var_mean`. (#8275) Co-authored-by: mrguenther --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 16 +++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index c8881304bb1..aabd2a11352 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -92,7 +92,6 @@ "unfold_copy", "unfold", "unravel_index", - "var_mean", "nanmean", "nn.functional.upsample_bilinear", "randint", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index dbf04b872ff..659901b74dc 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -3027,11 +3027,17 @@ def _aten_to_dtype_layout( # Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False @op(torch.ops.aten.var_mean.correction) -def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): - return ( - jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), - jnp.mean(self, dim, keepdims=keepdim), - ) +def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): + # The internal API technically has a default `correction` argument of `None`, + # but the public API has a default argument of 1. Therefore, we simply set our + # default argument to 1. However, since the argument is officially supposed to + # be nullable, we still need to check for `None` per the API contract. + if correction is None: + correction = 1 + mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) + # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. + var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) + return var, mean @op(torch.ops.aten.scalar_tensor)