diff --git a/ivy/data_classes/array/experimental/norms.py b/ivy/data_classes/array/experimental/norms.py index 6d8be11765f95..7263a83b205cb 100644 --- a/ivy/data_classes/array/experimental/norms.py +++ b/ivy/data_classes/array/experimental/norms.py @@ -75,8 +75,8 @@ def l2_normalize( def batch_norm( self: Union[ivy.NativeArray, ivy.Array], - mean: Union[ivy.NativeArray, ivy.Array], - variance: Union[ivy.NativeArray, ivy.Array], + mean: Optional[Union[ivy.NativeArray, ivy.Array]], + variance: Optional[Union[ivy.NativeArray, ivy.Array]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None, @@ -145,8 +145,8 @@ def batch_norm( def instance_norm( self: Union[ivy.NativeArray, ivy.Array], - mean: Union[ivy.NativeArray, ivy.Array], - variance: Union[ivy.NativeArray, ivy.Array], + mean: Optional[Union[ivy.NativeArray, ivy.Array]], + variance: Optional[Union[ivy.NativeArray, ivy.Array]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None, diff --git a/ivy/data_classes/container/experimental/norms.py b/ivy/data_classes/container/experimental/norms.py index 72d8e3d7485da..452a753930a0d 100644 --- a/ivy/data_classes/container/experimental/norms.py +++ b/ivy/data_classes/container/experimental/norms.py @@ -246,8 +246,8 @@ def l2_normalize( @staticmethod def static_batch_norm( x: Union[ivy.Array, ivy.NativeArray, ivy.Container], - mean: Union[ivy.NativeArray, ivy.Array, ivy.Container], - variance: Union[ivy.NativeArray, ivy.Array, ivy.Container], + mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], + variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None, @@ -342,8 +342,8 @@ def static_batch_norm( def batch_norm( self: Union[ivy.Array, ivy.NativeArray, ivy.Container], - mean: Union[ivy.NativeArray, ivy.Array, ivy.Container], - variance: Union[ivy.NativeArray, ivy.Array, ivy.Container], + mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], + variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None, @@ -438,8 +438,8 @@ def batch_norm( @staticmethod def static_instance_norm( x: Union[ivy.Array, ivy.NativeArray, ivy.Container], - mean: Union[ivy.NativeArray, ivy.Array, ivy.Container], - variance: Union[ivy.NativeArray, ivy.Array, ivy.Container], + mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], + variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None, @@ -532,8 +532,8 @@ def static_instance_norm( def instance_norm( self: Union[ivy.Array, ivy.NativeArray, ivy.Container], - mean: Union[ivy.NativeArray, ivy.Array, ivy.Container], - variance: Union[ivy.NativeArray, ivy.Array, ivy.Container], + mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], + variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None, diff --git a/ivy/functional/backends/tensorflow/experimental/norms.py b/ivy/functional/backends/tensorflow/experimental/norms.py index ef4c5d7d5e8b6..b3d8deffdfad4 100644 --- a/ivy/functional/backends/tensorflow/experimental/norms.py +++ b/ivy/functional/backends/tensorflow/experimental/norms.py @@ -69,8 +69,8 @@ def local_response_norm( @with_unsupported_dtypes({"2.15.0 and below": ("float16", "bfloat16")}, backend_version) def batch_norm( x: Union[tf.Tensor, tf.Variable], - mean: Union[tf.Tensor, tf.Variable], - variance: Union[tf.Tensor, tf.Variable], + mean: Optional[Union[tf.Tensor, tf.Variable]], + variance: Optional[Union[tf.Tensor, tf.Variable]], /, *, scale: Optional[Union[tf.Tensor, tf.Variable]] = None, @@ -103,9 +103,15 @@ def batch_norm( dims = (0, *range(1, xdims - 1)) mean = tf.math.reduce_mean(x, axis=dims) variance = tf.math.reduce_variance(x, axis=dims) - runningmean = (1 - momentum) * runningmean + momentum * mean - runningvariance = (1 - momentum) * runningvariance + momentum * variance * n / ( - n - 1 + runningmean = ( + ((1 - momentum) * runningmean + momentum * mean) + if runningmean is not None + else runningmean + ) + runningvariance = ( + (1 - momentum) * runningvariance + momentum * variance * n / (n - 1) + if runningvariance is not None + else runningvariance ) inv = 1.0 / tf.math.sqrt(variance + eps) @@ -126,8 +132,8 @@ def batch_norm( def instance_norm( x: Union[tf.Tensor, tf.Variable], - mean: Union[tf.Tensor, tf.Variable], - variance: Union[tf.Tensor, tf.Variable], + mean: Optional[Union[tf.Tensor, tf.Variable]] = None, + variance: Optional[Union[tf.Tensor, tf.Variable]] = None, /, *, scale: Optional[Union[tf.Tensor, tf.Variable]] = None, @@ -161,8 +167,8 @@ def instance_norm( C = x.shape[-1] S = x.shape[0:-2] x = tf.reshape(x, (1, *S, N * C)) - mean = tf.tile(mean, [N]) - variance = tf.tile(variance, [N]) + mean = tf.tile(mean, [N]) if mean is not None else mean + variance = tf.tile(variance, [N]) if variance is not None else variance if scale is not None: scale = tf.tile(scale, [N]) if offset is not None: @@ -187,10 +193,21 @@ def instance_norm( xnormalized, perm=(xdims - 2, *range(0, xdims - 2), xdims - 1) ) + runningmean = ( + tf.reduce_mean(tf.reshape(runningmean, (N, C)), axis=0) + if runningmean is not None + else runningmean + ) + runningvariance = ( + tf.reduce_mean(tf.reshape(runningvariance, (N, C)), axis=0) + if runningvariance is not None + else runningvariance + ) + return ( xnormalized, - tf.reduce_mean(tf.reshape(runningmean, (N, C)), axis=0), - tf.reduce_mean(tf.reshape(runningvariance, (N, C)), axis=0), + runningmean, + runningvariance, ) diff --git a/ivy/functional/backends/torch/experimental/norms.py b/ivy/functional/backends/torch/experimental/norms.py index e17a2c42a8e62..4a1fe219b575b 100644 --- a/ivy/functional/backends/torch/experimental/norms.py +++ b/ivy/functional/backends/torch/experimental/norms.py @@ -59,8 +59,8 @@ def local_response_norm( @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, backend_version) def batch_norm( x: torch.Tensor, - mean: torch.Tensor, - variance: torch.Tensor, + mean: Optional[torch.Tensor], + variance: Optional[torch.Tensor], /, *, scale: Optional[torch.Tensor] = None, @@ -74,8 +74,8 @@ def batch_norm( xdims = x.ndim if data_format == "NSC": x = torch.permute(x, dims=(0, xdims - 1, *range(1, xdims - 1))) - runningmean = mean.detach().clone() - runningvariance = variance.detach().clone() + runningmean = mean.detach().clone() if mean is not None else mean + runningvariance = variance.detach().clone() if variance is not None else variance xnormalized = torch.nn.functional.batch_norm( x, runningmean, @@ -94,8 +94,8 @@ def batch_norm( batch_norm.partial_mixed_handler = ( lambda x, mean, variance, scale=None, offset=None, **kwargs: ( x.ndim > 1 - and mean.ndim == 1 - and variance.ndim == 1 + and (mean is None or mean.ndim == 1) + and (variance is None or variance.ndim == 1) and (scale is None or scale.ndim == 1) and (offset is None or offset.ndim == 1) ) @@ -105,8 +105,8 @@ def batch_norm( @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version) def instance_norm( x: torch.Tensor, - mean: torch.Tensor, - variance: torch.Tensor, + mean: Optional[torch.Tensor] = None, + variance: Optional[torch.Tensor] = None, /, *, scale: Optional[torch.Tensor] = None, @@ -117,8 +117,8 @@ def instance_norm( data_format: Optional[str] = "NSC", out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - runningmean = mean.clone() - runningvariance = variance.clone() + runningmean = mean.clone() if mean is not None else mean + runningvariance = variance.clone() if variance is not None else variance # reshape from N, *S, C to N, C, *S xdims = x.ndim if data_format == "NSC": @@ -140,10 +140,10 @@ def instance_norm( instance_norm.partial_mixed_handler = ( - lambda x, mean, variance, scale=None, offset=None, **kwargs: ( + lambda x, mean=None, variance=None, scale=None, offset=None, **kwargs: ( x.ndim > 1 - and mean.ndim == 1 - and variance.ndim == 1 + and (mean is None or mean.ndim == 1) + and (variance is None or variance.ndim == 1) and (scale is None or scale.ndim == 1) and (offset is None or offset.ndim == 1) ) diff --git a/ivy/functional/frontends/torch/nn/functional/norms.py b/ivy/functional/frontends/torch/nn/functional/norms.py index 7418d9f6f8f9b..6799701378154 100644 --- a/ivy/functional/frontends/torch/nn/functional/norms.py +++ b/ivy/functional/frontends/torch/nn/functional/norms.py @@ -15,8 +15,8 @@ @to_ivy_arrays_and_back def batch_norm( input, - running_mean, - running_var, + running_mean=None, + running_var=None, weight=None, bias=None, training=False, @@ -35,8 +35,10 @@ def batch_norm( data_format="NCS", ) if training: - ivy.inplace_update(running_mean, mean) - ivy.inplace_update(running_var, var) + if running_mean is not None: + ivy.inplace_update(running_mean, mean) + if running_var is not None: + ivy.inplace_update(running_var, var) return normalized @@ -68,8 +70,8 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): @to_ivy_arrays_and_back def instance_norm( input, - running_mean, - running_var, + running_mean=None, + running_var=None, weight=None, bias=None, use_input_stats=False, @@ -87,8 +89,10 @@ def instance_norm( momentum=momentum, data_format="NCS", ) - ivy.inplace_update(running_mean, mean) - ivy.inplace_update(running_var, var) + if running_mean is not None: + ivy.inplace_update(running_mean, mean) + if running_var is not None: + ivy.inplace_update(running_var, var) return normalized diff --git a/ivy/functional/ivy/experimental/norms.py b/ivy/functional/ivy/experimental/norms.py index 5418c3fb50df8..84e48f38b2090 100644 --- a/ivy/functional/ivy/experimental/norms.py +++ b/ivy/functional/ivy/experimental/norms.py @@ -193,8 +193,8 @@ def local_response_norm( @handle_array_function def batch_norm( x: Union[ivy.NativeArray, ivy.Array], - mean: Union[ivy.NativeArray, ivy.Array], - variance: Union[ivy.NativeArray, ivy.Array], + mean: Optional[Union[ivy.NativeArray, ivy.Array]], + variance: Optional[Union[ivy.NativeArray, ivy.Array]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None, @@ -270,9 +270,15 @@ def batch_norm( dims = (0, *range(1, xdims - 1)) mean = ivy.mean(x, axis=dims) variance = ivy.var(x, axis=dims) - runningmean = (1 - momentum) * runningmean + momentum * mean - runningvariance = (1 - momentum) * runningvariance + momentum * variance * n / ( - n - 1 + runningmean = ( + (1 - momentum) * runningmean + momentum * mean + if runningmean is not None + else runningmean + ) + runningvariance = ( + (1 - momentum) * runningvariance + momentum * variance * n / (n - 1) + if runningvariance is not None + else runningvariance ) inv = 1.0 / ivy.sqrt(variance + eps) offset = 0 if offset is None else offset @@ -313,8 +319,8 @@ def batch_norm( @handle_array_function def instance_norm( x: Union[ivy.NativeArray, ivy.Array], - mean: Union[ivy.NativeArray, ivy.Array], - variance: Union[ivy.NativeArray, ivy.Array], + mean: Optional[Union[ivy.NativeArray, ivy.Array]], + variance: Optional[Union[ivy.NativeArray, ivy.Array]], /, *, offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None, @@ -387,8 +393,8 @@ def instance_norm( C = x.shape[-1] S = x.shape[0:-2] x = x.reshape((1, *S, N * C)) - mean = ivy.tile(mean, N) - variance = ivy.tile(variance, N) + mean = ivy.tile(mean, N) if mean is not None else mean + variance = ivy.tile(variance, N) if variance is not None else variance if scale is not None: scale = ivy.tile(scale, N) if offset is not None: @@ -414,8 +420,16 @@ def instance_norm( xnormalized, axes=(xdims - 2, *range(0, xdims - 2), xdims - 1) ) - runningmean = runningmean.reshape((N, C)).mean(axis=0) - runningvariance = runningvariance.reshape((N, C)).mean(axis=0) + runningmean = ( + runningmean.reshape((N, C)).mean(axis=0) + if runningmean is not None + else runningmean + ) + runningvariance = ( + runningvariance.reshape((N, C)).mean(axis=0) + if runningvariance is not None + else runningvariance + ) if ivy.exists(out): xnormalized = ivy.inplace_update(out[0], xnormalized)