From 7a148902cfb2c8ac979635d3e4b6b95d9e795b26 Mon Sep 17 00:00:00 2001 From: Steffen Deusch Date: Thu, 12 Dec 2024 18:29:09 +0100 Subject: [PATCH] use serving name as pg group name The Nx documentation states that it is important to use the same Nx version when using distributed servings. This leads to a problem for people using blue green deployments and trying to upgrade Nx. Without some special handling, you'll often run into situations where both the old and new nodes run in the cluster at the same time. Also, if you previously tried to run different servings on different machines, this wasn't really possible, as Nx would choose a random machine from pg, so each machine in the cluster running Nx would always have to run all servings. By separating the pg groups by serving name, we can run different servings on different nodes while keeping the cluster connected. This also allows us to upgrade Nx in a blue green deployment by encoding the Nx version into the serving name. --- nx/lib/nx/serving.ex | 4 ++-- nx/test/nx/serving_test.exs | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index acb0523448..ab0a524bd7 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1117,7 +1117,7 @@ defmodule Nx.Serving do end defp distributed_batched_run_with_retries!(name, input, retries) do - case :pg.get_members(Nx.Serving.PG, __MODULE__) do + case :pg.get_members(Nx.Serving.PG, name) do [] -> exit({:noproc, {__MODULE__, :distributed_batched_run, [name, input, [retries: retries]]}}) @@ -1332,7 +1332,7 @@ defmodule Nx.Serving do ) serving_weight = max(1, weight * partitions_count) - :pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight)) + :pg.join(Nx.Serving.PG, name, List.duplicate(self(), serving_weight)) for batch_key <- batch_keys do stack_init(batch_key) diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index f6aa54368d..49d72573f0 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -1288,7 +1288,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1327,14 +1328,16 @@ defmodule Nx.ServingTest do opts2 = Keyword.put(opts, :distribution_weight, 4) Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 1 + assert name == config.test Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts2]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 4 + assert name == config.test - members = :pg.get_members(Nx.Serving.PG, Nx.Serving) + members = :pg.get_members(Nx.Serving.PG, config.test) assert length(members) == 5 end @@ -1356,7 +1359,8 @@ defmodule Nx.ServingTest do args = [parent, opts] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :add_five_round_about, args) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1412,7 +1416,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])])