From ab8261180cd54ca95c0c34035a5380ade2805afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 2 Sep 2024 12:38:14 +0200 Subject: [PATCH] Allow jitted functions to work across nodes --- exla/lib/exla/executable.ex | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index c3efaba2ae..e3494f60bf 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -11,8 +11,19 @@ defmodule EXLA.Executable do @doc """ Runs the given executable with a list of lists as inputs and the given options. + + Works across nodes. """ - def run(%Executable{} = executable, [subinputs | _] = inputs, options \\ []) + def run(executable, inputs, options \\ []) + + def run(%Executable{ref: ref, client: client} = executable, inputs, options) + when node(ref) != node() do + client + |> load(dump(executable)) + |> run(inputs, options) + end + + def run(%Executable{} = executable, [subinputs | _] = inputs, options) when is_list(subinputs) do %{client: client, device_id: device_id, output_typespecs: output_typespecs, ref: ref} = executable @@ -25,17 +36,20 @@ defmodule EXLA.Executable do @doc """ Dumps the executable to a data structure that can be serialized with `term_to_binary`. + + Works across nodes. """ # If you change this function, you must bump the version in EXLA.Defn.Disk. def dump(%Executable{ - ref: executable, + ref: ref, output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id - }) do + }) + when node(ref) == node() do serialized_exec = - executable + ref |> EXLA.NIF.serialize_executable() |> unwrap!() |> IO.iodata_to_binary() @@ -49,6 +63,10 @@ defmodule EXLA.Executable do } end + def dump(%Executable{ref: ref} = executable) do + :erpc.call(node(ref), __MODULE__, :dump, [executable]) + end + @doc """ Loads a previously dumped executable. """