From b2fdb9a83ebdbbcfe18a916d3dfd1a39eb8e257f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:10:58 -0300 Subject: [PATCH] docs: improve Nx.conv docs on convolution vs correlation --- nx/lib/nx.ex | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index cdd5b36f7b..3162efd10c 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -13007,24 +13007,41 @@ defmodule Nx do > #### Convolution vs Correlation {: .tip} > - > `conv/3` does not perform reversion nor conjugation of the kernel. + > `conv/3` does not perform reversion of the kernel. > This means that if you come from a Signal Processing background, - > you might call this operation "correlation" instead of convolution. + > you might treat it as a cross-correlation operation instead of a convolution. > - > If you need the proper Signal Processing convolution, you can use - > `reverse/2` and `conjugate/1`, like in the example: + > This function is not exactly a cross-correlation function, as it does not + > perform conjugation of the kernel, as is done in [scipy.signal.correlate](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html). + > This can be remedied via `Nx.conjugate/1`, as seen below: > > ```elixir - > axes = Nx.axes(kernel) |> Enum.drop(2) - > > kernel = > if Nx.Type.complex?(Nx.type(kernel)) do - > Nx.conjugate(Nx.reverse(kernel, axes: axes)) + > Nx.conjugate(kernel) > else - > Nx.reverse(kernel, axes: axes) + > kernel + > end + > + > Nx.conv(tensor, kernel) + > ``` + > + > If you need the proper Signal Processing convolution, such as the one in + > [scipy.signal.convolve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html), + > you can use `reverse/2`, like in the example: + > + > ```elixir + > reversal_axes = + > case Nx.rank(kernel) do + > 0 -> [] + > 1 -> [1] + > 2 -> [0, 1] + > _ -> Enum.drop(Nx.axes(kernel), 2) > end > - > Nx.conv(img, kernel) + > kernel = Nx.reverse(kernel, axes: reversal_axes) + > + > Nx.conv(tensor, kernel) > ``` ## Examples