diff --git a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py index 89b4bcf3c7f67..12707287a3f64 100644 --- a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py @@ -46,9 +46,11 @@ def _conv_transpose( output_padding=0, groups=1, dilation=1, + filter_format="channel_first", ): dims = len(input.shape) - 2 - weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) + if filter_format == "channel_first": + weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) for i in range(dims): weight = ivy.flip(weight, axis=i) padding, output_padding, stride, dilation = map( @@ -185,6 +187,7 @@ def conv_transpose1d( output_padding=output_padding, groups=groups, dilation=dilation, + filter_format="channel_first", )