diff --git a/nx/guides/advanced/aggregation.livemd b/nx/guides/advanced/aggregation.livemd index 587db22baca..dd1b568104b 100644 --- a/nx/guides/advanced/aggregation.livemd +++ b/nx/guides/advanced/aggregation.livemd @@ -93,7 +93,7 @@ m = ~MAT[ > ``` -First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. +First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. ```elixir w = ~MAT[ diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 99209dc3add..cdd5b36f7b4 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -12919,6 +12919,11 @@ defmodule Nx do of summing the element-wise products in the window across each input channel. + > #### Kernel Reflection {: .info} + > + > See the note at the end of this section for more details + > on the convention for kernel reflection and conjugation. + The ranks of both `input` and `kernel` must match. By default, both `input` and `kernel` are expected to have shapes of the following form: @@ -13000,6 +13005,28 @@ defmodule Nx do in the same way as with `:feature_group_size`, however, the input tensor will be split into groups along the batch dimension. + > #### Convolution vs Correlation {: .tip} + > + > `conv/3` does not perform reversion nor conjugation of the kernel. + > This means that if you come from a Signal Processing background, + > you might call this operation "correlation" instead of convolution. + > + > If you need the proper Signal Processing convolution, you can use + > `reverse/2` and `conjugate/1`, like in the example: + > + > ```elixir + > axes = Nx.axes(kernel) |> Enum.drop(2) + > + > kernel = + > if Nx.Type.complex?(Nx.type(kernel)) do + > Nx.conjugate(Nx.reverse(kernel, axes: axes)) + > else + > Nx.reverse(kernel, axes: axes) + > end + > + > Nx.conv(img, kernel) + > ``` + ## Examples iex> left = Nx.iota({1, 1, 3, 3}) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index 9bfb478237d..ab913ab61fb 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do ## Named hooks It is possible to give names to the hooks. This allows them - to be defined or overridden by calling `Nx.Defn.jit/2` or - `Nx.Defn.stream/2`. Let's see an example: + to be defined or overridden by calling `Nx.Defn.jit/2`. + Let's see an example: defmodule Hooks do import Nx.Defn @@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do {add, mult} end - If a hook with the same name is given to `Nx.Defn.jit/2` - or `Nx.Defn.stream/2`, then it will override the default - callback. + If a hook with the same name is given to `Nx.Defn.jit/2`, + then it will override the default callback. ## Hooks and tokens