From e19edf1c74a7d55cd8c6924d4c70298b1e7d87f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 2 Sep 2024 20:22:53 +0200 Subject: [PATCH] Discard cache on init_params computation --- lib/axon.ex | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index fa73645a..52f4ba6d 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -587,7 +587,7 @@ defmodule Axon do end @doc """ - Implements an or else (e.g. an Elixir ||) + Implements an or else (e.g. an Elixir ||) """ @doc type: :special def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do @@ -3771,7 +3771,7 @@ defmodule Axon do as input and returns a function that replaces or rewrites the given node. For example, you can define a simple rewriter which replaces the `:relu` layers with `:tanh` layers: - + tanh_rewriter = fn [%Axon{} = x], _output -> Axon.relu(x) end @@ -3926,13 +3926,16 @@ defmodule Axon do end @doc """ - Compiles the given model to `{init_fn, predict_fn}`. + Compiles the given model to `{init_params, predict_fn}`. This function will compile a model specialized to the given input shapes and types. This is useful for avoiding the overhead of long compilations at program runtime. You must provide template inputs which match the expected shapes and types of inputs at - execution time. + execution time. Depending on the Nx compiler, such as EXLA v0.9.1+, + both `init_params` the `predict_fn` can be sent across nodes, as + long the node that owns them keeps a reference to the underlying + resources. This function makes use of the built-in `Nx.Defn.compile/3`. Note that passing inputs which differ in shape or type from the templates @@ -3946,7 +3949,12 @@ defmodule Axon do def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do {init_fn, predict_fn} = build(model, opts) - init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts) + model_state = Axon.ModelState.new(init_params) + + # If there is a disk cache, we only want it to apply to the predict function + init_opts = if is_binary(opts[:cache]), do: Keyword.delete(opts, :cache), else: opts + init_params = Nx.Defn.jit_apply(init_fn, [template, model_state], init_opts) + predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts) {init_params, predict_compiled_fn} end