diff --git a/lib/axon/metrics.ex b/lib/axon/metrics.ex index 31962d51..f18502db 100644 --- a/lib/axon/metrics.ex +++ b/lib/axon/metrics.ex @@ -166,7 +166,7 @@ defmodule Axon.Metrics do iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) iex> Axon.Metrics.true_positives(y_true, y_pred) #Nx.Tensor< - u64 + u32 1 > """ @@ -198,7 +198,7 @@ defmodule Axon.Metrics do iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) iex> Axon.Metrics.false_negatives(y_true, y_pred) #Nx.Tensor< - u64 + u32 3 > """ @@ -230,7 +230,7 @@ defmodule Axon.Metrics do iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) iex> Axon.Metrics.true_negatives(y_true, y_pred) #Nx.Tensor< - u64 + u32 1 > """ @@ -262,7 +262,7 @@ defmodule Axon.Metrics do iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2]) iex> Axon.Metrics.false_positives(y_true, y_pred) #Nx.Tensor< - u64 + u32 2 > """ diff --git a/mix.exs b/mix.exs index 6a72dfb7..80e68544 100644 --- a/mix.exs +++ b/mix.exs @@ -2,7 +2,7 @@ defmodule Axon.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/axon" - @version "0.6.1" + @version "0.7.0" def project do [ @@ -35,9 +35,9 @@ defmodule Axon.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:exla, "~> 0.7.0", [only: :test] ++ exla_opts()}, - {:torchx, "~> 0.7.0", [only: :test] ++ torchx_opts()}, - {:nx, "~> 0.6.0 or ~> 0.7.0", nx_opts()}, + {:nx, "~> 0.9", nx_opts()}, + {:exla, "~> 0.9", [only: :test] ++ exla_opts()}, + {:torchx, "~> 0.9", [only: :test] ++ torchx_opts()}, {:ex_doc, "~> 0.23", only: :docs}, {:table_rex, "~> 3.1.1", optional: true}, {:kino, "~> 0.7", optional: true}, diff --git a/mix.lock b/mix.lock index a7ea4911..a83ceb41 100644 --- a/mix.lock +++ b/mix.lock @@ -1,22 +1,23 @@ %{ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, - "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, - "exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, + "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, + "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "exla": {:hex, :exla, "0.9.0", "e048c7a3d33917c214774a7ea1a0c626eb9de01e3fb2423cf9e2b89ef6dada3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.8.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "cbd30b54992d0da01a5aaee361a3160fc29de05a9f6c3dbcbd1fa04b4aa72302"}, "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, - "kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"}, - "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, - "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.4", "29563475afa9b8a2add1b7a9c8fb68d06ca7737648f28398e04461f008b69521", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f4ed47ecda66de70dd817698a703f8816daa91272e7e45812469498614ae8b29"}, + "kino": {:hex, :kino, "0.14.1", "c499afb1cd0be462feaf0a75c0631aa65aacc545b1c10f431b439b74f104be22", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "090aea1aaa267e42e5ac24ee6bc5ed515aecc0a9edb8619aa4ee839201e704aa"}, + "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.13", "03c00405987a2202e4b8014ee55eb7f5727691b3f13d76a3764f6eeccef45322", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "00c72bc270e7b9d3c339f726cdab0012fd3f2fc75e36c7548e0f250fe420fa10"}, + "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, + "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, - "nx": {:hex, :nx, "0.7.0", "cec684cada356e9d268af01daa758882f7372aa952716dbe0369c657abb9e762", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "68edaa48a5841495ecab0dd4cf7b11b2fc0ad809754ae7f82d9c4090b91acf55"}, + "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, + "nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, - "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "torchx": {:hex, :torchx, "0.7.0", "c71fd603b0133ed8709450d82aa3434cbcf485a37c9a68e9ebcce86f5e4fb7f0", [:mix], [{:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "a324079c56bb67750b1da16f859d994982bb467020a8c2cba324639552f3adb8"}, - "vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"}, - "xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"}, + "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "torchx": {:hex, :torchx, "0.9.0", "936cbd32233f89d73700c39b7ef56f94b3f3541db03c90f8ddf6b3fe73260e28", [:mix], [{:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "4e057d6b93fc91191957230b2c61c408861b888abdf6a900baf0db4125405505"}, + "vega_lite": {:hex, :vega_lite, "0.1.9", "d7a288665f916181b68d0a3617f1b3611d16a4dcd5fafb51b847b71db1159d4c", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "c6a056e763162198e73ae6dfb46c09753bb0298474410fd085074e1cdcee7418"}, + "xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"}, } diff --git a/test/axon/integration_test.exs b/test/axon/integration_test.exs index 841899ae..fd4cda58 100644 --- a/test/axon/integration_test.exs +++ b/test/axon/integration_test.exs @@ -241,55 +241,6 @@ defmodule Axon.IntegrationTest do end) end - test "gradient accumulation test" do - {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) - - train = - train - |> Stream.map(fn {xs, ys} -> - {xs, one_hot(ys, num_classes: 2)} - end) - |> Enum.to_list() - - [{x_test, _}] = Enum.take(train, 1) - - model = - Axon.input("input") - |> Axon.dense(16) - |> Axon.dropout(rate: 0.1) - |> Axon.dense(2, activation: :softmax) - - ExUnit.CaptureIO.capture_io(fn -> - results = - model - |> Axon.Loop.trainer( - :categorical_cross_entropy, - Polaris.Optimizers.adam(learning_rate: 5.0e-3), - gradient_accumulation_steps: 3 - ) - # TODO: Fix default output transform - |> Map.update(:output_transform, nil, fn _ -> & &1 end) - |> Axon.Loop.metric(:accuracy) - |> Axon.Loop.validate(model, train) - |> Axon.Loop.run(train, Axon.ModelState.empty(), epochs: 10) - - assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = - results - - eval_results = - model - |> Axon.Loop.evaluator() - |> Axon.Loop.metric(:accuracy) - |> Axon.Loop.run(train, model_state) - - assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results - - assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) - assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) - assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} - end) - end - test "deterministic training test" do {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) @@ -525,8 +476,8 @@ defmodule Axon.IntegrationTest do assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} - assert Nx.type(model_state.data["dense_0"]["kernel"]) == - unquote(Macro.escape(policy)).params + params_policy = unquote(Macro.escape(policy)).params || {:f, 32} + assert Nx.type(model_state.data["dense_0"]["kernel"]) == params_policy end) end @@ -578,8 +529,8 @@ defmodule Axon.IntegrationTest do assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} - assert Nx.type(model_state.data["dense_0"]["kernel"]) == - unquote(Macro.escape(policy)).params + params_policy = unquote(Macro.escape(policy)).params || {:f, 32} + assert Nx.type(model_state.data["dense_0"]["kernel"]) == params_policy end) end end diff --git a/test/axon/loss_scale_test.exs b/test/axon/loss_scale_test.exs index 93fb70c0..29e01637 100644 --- a/test/axon/loss_scale_test.exs +++ b/test/axon/loss_scale_test.exs @@ -244,15 +244,26 @@ defmodule Axon.LossScaleTest do non_finite = Nx.tensor([:infinity, :infinity, :infinity]) - # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 - # is fixed - for i <- 0..62, reduce: state do + for i <- 0..99, reduce: state do new_state -> {_, %{loss_scale: loss_scale, counter: counter} = new_state} = adjust_fn.(non_finite, new_state) - expected_new_scale = Nx.max(1, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + # We want to check if init_scale / factor ** (i + 1) is greater than 1. + # If we rely on `i` directly, we run into integer overflow issues. + # Instead, we accumulate the divisor on the reduce. + + scale_divisor = 2 ** (i + 1) + + expected_new_scale = + if scale_divisor >= 2 ** 32 do + Nx.tensor(1) + else + Nx.max(1, Nx.divide(init_scale, scale_divisor)) + end + assert_equal(counter, Nx.tensor(0)) + assert_all_close(loss_scale, expected_new_scale) new_state @@ -277,15 +288,19 @@ defmodule Axon.LossScaleTest do non_finite = Nx.tensor([:infinity, :infinity, :infinity]) - # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 - # is fixed - for i <- 0..62, reduce: state do + for i <- 0..99, reduce: state do new_state -> {_, %{loss_scale: loss_scale, counter: counter} = new_state} = adjust_fn.(non_finite, new_state) + scale_divisor = 2 ** (i + 1) + expected_new_scale = - Nx.max(min_loss_scale, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + if scale_divisor >= 2 ** 32 do + Nx.tensor(min_loss_scale) + else + Nx.max(min_loss_scale, Nx.divide(init_scale, scale_divisor)) + end assert_equal(counter, Nx.tensor(0)) assert_all_close(loss_scale, expected_new_scale)