Skip to content

Commit

Permalink
docs: improve docs for Nx.conv/3 (#1564)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
polvalente and josevalim authored Dec 7, 2024
1 parent ed7a3b1 commit cb7fed4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nx/guides/advanced/aggregation.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ m = ~MAT[
>
```

First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.

```elixir
w = ~MAT[
Expand Down
27 changes: 27 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12919,6 +12919,11 @@ defmodule Nx do
of summing the element-wise products in the window across
each input channel.
> #### Kernel Reflection {: .info}
>
> See the note at the end of this section for more details
> on the convention for kernel reflection and conjugation.
The ranks of both `input` and `kernel` must match. By
default, both `input` and `kernel` are expected to have shapes
of the following form:
Expand Down Expand Up @@ -13000,6 +13005,28 @@ defmodule Nx do
in the same way as with `:feature_group_size`, however, the input
tensor will be split into groups along the batch dimension.
> #### Convolution vs Correlation {: .tip}
>
> `conv/3` does not perform reversion nor conjugation of the kernel.
> This means that if you come from a Signal Processing background,
> you might call this operation "correlation" instead of convolution.
>
> If you need the proper Signal Processing convolution, you can use
> `reverse/2` and `conjugate/1`, like in the example:
>
> ```elixir
> axes = Nx.axes(kernel) |> Enum.drop(2)
>
> kernel =
> if Nx.Type.complex?(Nx.type(kernel)) do
> Nx.conjugate(Nx.reverse(kernel, axes: axes))
> else
> Nx.reverse(kernel, axes: axes)
> end
>
> Nx.conv(img, kernel)
> ```
## Examples
iex> left = Nx.iota({1, 1, 3, 3})
Expand Down
9 changes: 4 additions & 5 deletions nx/lib/nx/defn/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do
## Named hooks
It is possible to give names to the hooks. This allows them
to be defined or overridden by calling `Nx.Defn.jit/2` or
`Nx.Defn.stream/2`. Let's see an example:
to be defined or overridden by calling `Nx.Defn.jit/2`.
Let's see an example:
defmodule Hooks do
import Nx.Defn
Expand Down Expand Up @@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do
{add, mult}
end
If a hook with the same name is given to `Nx.Defn.jit/2`
or `Nx.Defn.stream/2`, then it will override the default
callback.
If a hook with the same name is given to `Nx.Defn.jit/2`,
then it will override the default callback.
## Hooks and tokens
Expand Down

0 comments on commit cb7fed4

Please sign in to comment.