Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: improve docs for Nx.conv/3 #1564

Merged
merged 2 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nx/guides/advanced/aggregation.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ m = ~MAT[
>
```

First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.

```elixir
w = ~MAT[
Expand Down
27 changes: 27 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12919,6 +12919,11 @@ defmodule Nx do
of summing the element-wise products in the window across
each input channel.

> #### Kernel Reflection {: .info}
polvalente marked this conversation as resolved.
Show resolved Hide resolved
>
> See the note at the end of this section for more details
> on the convention for kernel reflection and conjugation.

The ranks of both `input` and `kernel` must match. By
default, both `input` and `kernel` are expected to have shapes
of the following form:
Expand Down Expand Up @@ -13000,6 +13005,28 @@ defmodule Nx do
in the same way as with `:feature_group_size`, however, the input
tensor will be split into groups along the batch dimension.

> #### Convolution vs Correlation {: .tip}
>
> `conv/3` does not perform reversion nor conjugation of the kernel.
> This means that if you come from a Signal Processing background,
> you might call this operation "correlation" instead of convolution.
>
> If you need the proper Signal Processing convolution, you can use
> `reverse/2` and `conjugate/1`, like in the example:
>
> ```elixir
> axes = Nx.axes(kernel) |> Enum.drop(2)
>
> kernel =
> if Nx.Type.complex?(Nx.type(kernel)) do
> Nx.conjugate(Nx.reverse(kernel, axes: axes))
> else
> Nx.reverse(kernel, axes: axes)
> end
>
> Nx.conv(img, kernel)
> ```

## Examples

iex> left = Nx.iota({1, 1, 3, 3})
Expand Down
9 changes: 4 additions & 5 deletions nx/lib/nx/defn/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do
## Named hooks

It is possible to give names to the hooks. This allows them
to be defined or overridden by calling `Nx.Defn.jit/2` or
`Nx.Defn.stream/2`. Let's see an example:
to be defined or overridden by calling `Nx.Defn.jit/2`.
Let's see an example:

defmodule Hooks do
import Nx.Defn
Expand Down Expand Up @@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do
{add, mult}
end

If a hook with the same name is given to `Nx.Defn.jit/2`
or `Nx.Defn.stream/2`, then it will override the default
callback.
If a hook with the same name is given to `Nx.Defn.jit/2`,
then it will override the default callback.

## Hooks and tokens

Expand Down
Loading