From afa9ab031d6af8d59ffef8898c728964bc215e69 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 6 Dec 2024 22:35:57 -0800 Subject: [PATCH 1/2] docs: improve docs for Nx.conv/3 --- nx/guides/advanced/aggregation.livemd | 2 +- nx/lib/nx.ex | 27 +++++++++++++++++++++++++++ nx/lib/nx/defn/kernel.ex | 9 ++++----- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/nx/guides/advanced/aggregation.livemd b/nx/guides/advanced/aggregation.livemd index 587db22baca..dd1b568104b 100644 --- a/nx/guides/advanced/aggregation.livemd +++ b/nx/guides/advanced/aggregation.livemd @@ -93,7 +93,7 @@ m = ~MAT[ > ``` -First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. +First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. ```elixir w = ~MAT[ diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 99209dc3add..819ade9c67b 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -12919,6 +12919,10 @@ defmodule Nx do of summing the element-wise products in the window across each input channel. + > #### Kernel Reflection {: .info} + > See the note at the end of this section for more details + > on the convention for kernel reflection and conjugation. + The ranks of both `input` and `kernel` must match. By default, both `input` and `kernel` are expected to have shapes of the following form: @@ -13000,6 +13004,29 @@ defmodule Nx do in the same way as with `:feature_group_size`, however, the input tensor will be split into groups along the batch dimension. + + > #### Convolution vs Correlation {: .tip} + > + > `conv/3` does not perform reversion nor conjugation of the kernel. + > This means that if you come from a Signal Processing background, + > you might call this operation "correlation" instead of convolution. + > + > If you need the proper Signal Processing convolution, you can use + > `reverse/2` and `conjugate/1`, like in the example: + > + > ```elixir + > axes = Nx.axes(kernel) |> Enum.drop(2) + > + > kernel = + > if Nx.Type.complex?(Nx.type(kernel)) do + > Nx.conjugate(Nx.reverse(kernel, axes: axes)) + > else + > Nx.reverse(kernel, axes: axes) + > end + > + > Nx.conv(img, kernel) + > ``` + ## Examples iex> left = Nx.iota({1, 1, 3, 3}) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index 9bfb478237d..ab913ab61fb 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do ## Named hooks It is possible to give names to the hooks. This allows them - to be defined or overridden by calling `Nx.Defn.jit/2` or - `Nx.Defn.stream/2`. Let's see an example: + to be defined or overridden by calling `Nx.Defn.jit/2`. + Let's see an example: defmodule Hooks do import Nx.Defn @@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do {add, mult} end - If a hook with the same name is given to `Nx.Defn.jit/2` - or `Nx.Defn.stream/2`, then it will override the default - callback. + If a hook with the same name is given to `Nx.Defn.jit/2`, + then it will override the default callback. ## Hooks and tokens From 829fe310e4f87dcadd5da249d73d3fe2637406cc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:31:27 -0300 Subject: [PATCH 2/2] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 819ade9c67b..cdd5b36f7b4 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -12920,6 +12920,7 @@ defmodule Nx do each input channel. > #### Kernel Reflection {: .info} + > > See the note at the end of this section for more details > on the convention for kernel reflection and conjugation. @@ -13004,7 +13005,6 @@ defmodule Nx do in the same way as with `:feature_group_size`, however, the input tensor will be split into groups along the batch dimension. - > #### Convolution vs Correlation {: .tip} > > `conv/3` does not perform reversion nor conjugation of the kernel.