Skip to content

Commit

Permalink
Deprecate ~V and ~M in favor of ~VEC and ~MAT
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed May 13, 2024
1 parent 7c36e06 commit 510e689
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 297 deletions.
4 changes: 2 additions & 2 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule EXLA.BackendTest do
use EXLA.Case, async: true

import Nx, only: [sigil_V: 2]
import Nx, only: [sigil_VEC: 2]

setup do
Nx.default_backend(EXLA.Backend)
Expand Down Expand Up @@ -192,7 +192,7 @@ defmodule EXLA.BackendTest do
end

test "conjugate" do
assert inspect(Nx.conjugate(~V[1 2-0i 3+0i 0-i 0-2i])) =~
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
end
end
36 changes: 18 additions & 18 deletions exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -815,17 +815,17 @@ defmodule EXLA.Defn.ExprTest do
test "fft" do
assert_all_close(
fft(Nx.tensor([1, 1, 0, 0]), length: 5),
~V[2.0+0.0i 1.3090-0.9511i 0.1909-0.5877i 0.1909+0.5877i 1.3090+0.9510i]
~VEC[2.0+0.0i 1.3090-0.9511i 0.1909-0.5877i 0.1909+0.5877i 1.3090+0.9510i]
)

assert_all_close(
fft(Nx.tensor([1, 1, 0, 0, 2, 3]), length: 4),
~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
)

assert_all_close(
fft(Nx.tensor([1, 1, 0]), length: :power_of_two),
~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i]
)
end

Expand All @@ -847,12 +847,12 @@ defmodule EXLA.Defn.ExprTest do
length: :power_of_two
),
Nx.stack([
~M[
~MAT[
2 1.0-1.0i 0 1.0+1.0i
1 1 1 1
1 -1i -1 1i
],
~M[
~MAT[
1 -1i -1 1i
1 1 1 1
2 1.0-1.0i 0 1.0+1.0i
Expand All @@ -877,12 +877,12 @@ defmodule EXLA.Defn.ExprTest do
length: 4
),
Nx.stack([
~M[
~MAT[
2 1.0-1.0i 0 1.0+1.0i
1 1 1 1
1 -1i -1 1i
],
~M[
~MAT[
1 -1i -1 1i
1 1 1 1
2 1.0-1.0i 0 1.0+1.0i
Expand All @@ -907,12 +907,12 @@ defmodule EXLA.Defn.ExprTest do
length: 4
),
Nx.stack([
~M[
~MAT[
2 1.0-1.0i 0 1.0+1.0i
1 1 1 1
1 -1i -1 1i
],
~M[
~MAT[
1 -1i -1 1i
1 1 1 1
2 1.0-1.0i 0 1.0+1.0i
Expand All @@ -923,19 +923,19 @@ defmodule EXLA.Defn.ExprTest do

test "ifft" do
assert_all_close(
ifft(~V[5 5 5 5 5],
ifft(~VEC[5 5 5 5 5],
length: 5
),
Nx.tensor([5, 0, 0, 0, 0])
)

assert_all_close(
ifft(~V[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i 5 6], length: 4),
ifft(~VEC[2.0+0.0i 1.0-1.0i 0.0+0.0i 1.0+1.0i 5 6], length: 4),
Nx.tensor([1, 1, 0, 0])
)

assert_all_close(
ifft(~V[2 0 0], length: :power_of_two),
ifft(~VEC[2 0 0], length: :power_of_two),
Nx.tensor([0.5, 0.5, 0.5, 0.5])
)
end
Expand All @@ -944,12 +944,12 @@ defmodule EXLA.Defn.ExprTest do
assert_all_close(
ifft(
Nx.stack([
~M[
~MAT[
2 1.0-1.0i 0 1.0+1.0i
1 1 1 1
1 -1i -1 1i
],
~M[
~MAT[
1 -1i -1 1i
1 1 1 1
2 1.0-1.0i 0 1.0+1.0i
Expand Down Expand Up @@ -988,12 +988,12 @@ defmodule EXLA.Defn.ExprTest do
length: 4
),
Nx.stack([
~M[
~MAT[
2 1.0+1.0i 0 1.0-1.0i
1 1 1 1
1 1i -1 -1i
],
~M[
~MAT[
1 1i -1 -1i
1 1 1 1
2 1.0+1.0i 0 1.0-1.0i
Expand All @@ -1018,12 +1018,12 @@ defmodule EXLA.Defn.ExprTest do
length: 4
),
Nx.stack([
~M[
~MAT[
2 1.0+1.0i 0 1.0-1.0i
1 1 1 1
1 1i -1 -1i
],
~M[
~MAT[
1 1i -1 -1i
1 1 1 1
2 1.0+1.0i 0 1.0-1.0i
Expand Down
86 changes: 43 additions & 43 deletions exla/test/exla/defn/vectorize_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ defmodule EXLA.Defn.VectorizeTest do

test "simple if" do
# this tests the case where we have a single vectorized predicate
pred = Nx.vectorize(~V[0 1 0], :pred)
pred = Nx.vectorize(~VEC[0 1 0], :pred)

assert_equal(vectorized_if(pred, 1, 2, pid: self()), Nx.vectorize(~V[2 1 2], :pred))
assert_equal(vectorized_if(pred, 1, 2, pid: self()), Nx.vectorize(~VEC[2 1 2], :pred))

assert_received {:vectorization_test, t, clause: "if"}
assert_equal(t, Nx.tensor(1))
Expand All @@ -195,12 +195,12 @@ defmodule EXLA.Defn.VectorizeTest do

test "simple cond" do
# this tests the case where we have a two vectorized predicates
pred1 = Nx.vectorize(~V[1 0 0], :pred)
pred2 = Nx.vectorize(~V[0 0 0], :pred)
pred1 = Nx.vectorize(~VEC[1 0 0], :pred)
pred2 = Nx.vectorize(~VEC[0 0 0], :pred)

assert_equal(
vectorized_cond(pred1, 1, pred2, 2, 3, pid: self()),
Nx.vectorize(~V[1 3 3], :pred)
Nx.vectorize(~VEC[1 3 3], :pred)
)

assert_received {:vectorization_test, t, clause: "clause_1"}
Expand All @@ -211,20 +211,20 @@ defmodule EXLA.Defn.VectorizeTest do
end

test "if with container result" do
pred1 = Nx.vectorize(~V[2 0 0], :pred)
pred1 = Nx.vectorize(~VEC[2 0 0], :pred)

result =
vectorized_if(
pred1,
{1, 2, 3},
{7, 8, Nx.vectorize(~V[9 10 11], :x)},
{7, 8, Nx.vectorize(~VEC[9 10 11], :x)},
pid: self()
)

assert_equal(result, {
Nx.vectorize(~V[1 7 7], :pred),
Nx.vectorize(~V[2 8 8], :pred),
Nx.vectorize(~M[
Nx.vectorize(~VEC[1 7 7], :pred),
Nx.vectorize(~VEC[2 8 8], :pred),
Nx.vectorize(~MAT[
3 3 3
9 10 11
9 10 11
Expand All @@ -248,8 +248,8 @@ defmodule EXLA.Defn.VectorizeTest do
end

test "only executes selected branches" do
t = Nx.vectorize(~V[1], :pred)
f = Nx.vectorize(~V[0], :pred)
t = Nx.vectorize(~VEC[1], :pred)
f = Nx.vectorize(~VEC[0], :pred)

assert = fn res, val, clause ->
t = Nx.tensor(val)
Expand All @@ -267,74 +267,74 @@ defmodule EXLA.Defn.VectorizeTest do

test "1 vectorized pred in the beginning" do
assert_equal(
cond4(Nx.vectorize(~V[0 1], :pred), 10, 0, 20, 0, 30, 40),
Nx.vectorize(~V[40 10], :pred)
cond4(Nx.vectorize(~VEC[0 1], :pred), 10, 0, 20, 0, 30, 40),
Nx.vectorize(~VEC[40 10], :pred)
)

assert_equal(
cond4(Nx.vectorize(~V[0 0], :pred), 10, 1, 20, 0, 30, 40),
Nx.vectorize(~V[20 20], :pred)
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 1, 20, 0, 30, 40),
Nx.vectorize(~VEC[20 20], :pred)
)

assert_equal(
cond4(Nx.vectorize(~V[0 0], :pred), 10, 0, 20, 1, 30, 40),
Nx.vectorize(~V[30 30], :pred)
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 0, 20, 1, 30, 40),
Nx.vectorize(~VEC[30 30], :pred)
)

assert_equal(
cond4(Nx.vectorize(~V[0 0], :pred), 10, 0, 20, 0, 30, 40),
Nx.vectorize(~V[40 40], :pred)
cond4(Nx.vectorize(~VEC[0 0], :pred), 10, 0, 20, 0, 30, 40),
Nx.vectorize(~VEC[40 40], :pred)
)
end

test "1 vectorized pred in the second but not last position" do
assert_equal(
cond4(0, 10, Nx.vectorize(~V[0 1], :pred), 20, 0, 30, 40),
Nx.vectorize(~V[40 20], :pred)
cond4(0, 10, Nx.vectorize(~VEC[0 1], :pred), 20, 0, 30, 40),
Nx.vectorize(~VEC[40 20], :pred)
)

assert_equal(
cond4(1, 10, Nx.vectorize(~V[0 1], :pred), 20, 0, 30, 40),
Nx.vectorize(~V[10 10], :pred)
cond4(1, 10, Nx.vectorize(~VEC[0 1], :pred), 20, 0, 30, 40),
Nx.vectorize(~VEC[10 10], :pred)
)

assert_equal(
cond4(0, 10, Nx.vectorize(~V[0 0], :pred), 20, 1, 30, 40),
Nx.vectorize(~V[30 30], :pred)
cond4(0, 10, Nx.vectorize(~VEC[0 0], :pred), 20, 1, 30, 40),
Nx.vectorize(~VEC[30 30], :pred)
)

assert_equal(
cond4(0, 10, Nx.vectorize(~V[0 0], :pred), 20, 0, 30, 40),
Nx.vectorize(~V[40 40], :pred)
cond4(0, 10, Nx.vectorize(~VEC[0 0], :pred), 20, 0, 30, 40),
Nx.vectorize(~VEC[40 40], :pred)
)
end

test "1 vectorized pred in the last position" do
assert_equal(
cond4(0, 10, 0, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
Nx.vectorize(~V[40 30], :pred)
cond4(0, 10, 0, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
Nx.vectorize(~VEC[40 30], :pred)
)

assert_equal(
cond4(1, 10, 0, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
Nx.vectorize(~V[10 10], :pred)
cond4(1, 10, 0, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
Nx.vectorize(~VEC[10 10], :pred)
)

assert_equal(
cond4(0, 10, 1, 20, Nx.vectorize(~V[0 1], :pred), 30, 40),
Nx.vectorize(~V[20 20], :pred)
cond4(0, 10, 1, 20, Nx.vectorize(~VEC[0 1], :pred), 30, 40),
Nx.vectorize(~VEC[20 20], :pred)
)

assert_equal(
cond4(0, 10, 0, 20, Nx.vectorize(~V[0 0], :pred), 30, 40),
Nx.vectorize(~V[40 40], :pred)
cond4(0, 10, 0, 20, Nx.vectorize(~VEC[0 0], :pred), 30, 40),
Nx.vectorize(~VEC[40 40], :pred)
)
end

test "2 vectorized preds with different axes" do
assert_equal(
cond4(Nx.vectorize(~V[0 1 0], :pred1), 10, Nx.vectorize(~V[1 0], :pred2), 20, 0, 30, 40),
Nx.vectorize(~M[
cond4(Nx.vectorize(~VEC[0 1 0], :pred1), 10, Nx.vectorize(~VEC[1 0], :pred2), 20, 0, 30, 40),
Nx.vectorize(~MAT[
20 40
10 10
20 40
Expand All @@ -345,15 +345,15 @@ defmodule EXLA.Defn.VectorizeTest do
test "2 vectorized preds with different axes + clauses that match either" do
assert_equal(
cond4(
Nx.vectorize(~V[0 1 0], :pred1),
Nx.vectorize(~V[10 100], :pred2),
Nx.vectorize(~V[1 0], :pred2),
Nx.vectorize(~V[20 200 2000], :pred1),
Nx.vectorize(~VEC[0 1 0], :pred1),
Nx.vectorize(~VEC[10 100], :pred2),
Nx.vectorize(~VEC[1 0], :pred2),
Nx.vectorize(~VEC[20 200 2000], :pred1),
0,
30,
40
),
Nx.vectorize(~M[
Nx.vectorize(~MAT[
20 40
10 100
2000 40
Expand Down
10 changes: 5 additions & 5 deletions nx/guides/advanced/aggregation.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ max_y = Nx.reduce_max(m, axes: [:y])
Let's consider another example with [Nx.weighted_mean](https://hexdocs.pm/nx/Nx.html#weighted_mean/3). It supports full-tensor and per axis operations. We display how to compute the _weighted mean aggregate_ of a matrix with the example below of a 2D tensor of shape `{2,2}` labeled `m`:

```elixir
m = ~M[
m = ~MAT[
1 2
3 4
]
Expand All @@ -96,7 +96,7 @@ m = ~M[
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.

```elixir
w = ~M[
w = ~MAT[
10 20
30 40
]
Expand All @@ -121,7 +121,7 @@ man_w_avg = (1 * 10 + 2 * 20 + 3 * 30 + 4 * 40) / (10 + 20 + 30 + 40)
The weighted mean can be computed _per axis_. Let's compute it along the _first_ axis (`axes: [0]`): you calculate "by column", so you aggregate/reduce along the first axis:

```elixir
w = ~M[
w = ~MAT[
10 20
30 40
]
Expand All @@ -148,7 +148,7 @@ man_w_avg_x = [(1 * 10 + 3 * 30) / (10 + 30), (2 * 20 + 4 * 40) / (20 + 40)]
We calculate weighted mean of a square matrix along the _second_ axis (`axes: [1]`): you calculate per row, so you aggregate/reduce along the second axis.

```elixir
w = ~M[
w = ~MAT[
10 20
30 40
]
Expand Down Expand Up @@ -816,7 +816,7 @@ Nx.argmin(t, axis: 3)
You have the `:tie_break` option to decide how to operate with you have several occurences of the result. It defaults to `tie_break: :low`.

```elixir
t4 = ~V[2 0 0 0 1]
t4 = ~VEC[2 0 0 0 1]

%{
argmin_with_default: Nx.argmin(t4) |> Nx.to_number(),
Expand Down
Loading

0 comments on commit 510e689

Please sign in to comment.