diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5f07f09086d..e0acf954d88 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1957,10 +1957,10 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai Args: input: Input data (N, C, ...) + weight: Scaling factor (gamma) (C,), can be None + bias: Shift factor (beta) (C,), can be None running_mean: Running mean of input (C,) running_var: Running variance of input (C,) - weight: Optional scaling factor (gamma) (C,) - bias: Optional shift factor (beta) (C,) training: Whether to perform training-time or inference-time batch norm momentum: Momentum factor for updating running mean and variance eps: Small constant added to the variance to avoid division by zero