Skip to content

Commit

Permalink
Add weight option to Nx.Serving for static load balancing (#1348)
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas Kunz <[email protected]>
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2023
1 parent a2f58b6 commit 94037ab
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
36 changes: 30 additions & 6 deletions nx/lib/nx/serving.ex
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,16 @@ defmodule Nx.Serving do
the same code and applications. It is only required that they run the
same `Nx` version.
The load balancing between servings is done randomly, however, the number
of partitions are considered if the `partitions: true` option is also given.
The load balancing between servings is done randomly by default, however,
the number of partitions are considered if the `partitions: true` option is also given.
For example, if you have a node with 2 GPUs and another with 4, the latter
will receive the double of requests compared to the former.
Furthermore, the load balancing allows for assigning weights to servings.
Similarly to the number of partitions, when running a serving with `distribution_weight: 1`
and another one with `distribution_weight: 2`, the latter will receive double the requests
compared to the former.
`batched_run/3` receives an optional `distributed_preprocessing` callback as
third argument for preprocessing the input for distributed requests. When
using libraries like EXLA or Torchx, the tensor is often allocated in memory
Expand Down Expand Up @@ -388,7 +393,8 @@ defmodule Nx.Serving do
:partitions,
:shutdown,
:hibernate_after,
:spawn_opt
:spawn_opt,
:distribution_weight
]

@doc """
Expand Down Expand Up @@ -873,6 +879,13 @@ defmodule Nx.Serving do
The number of partitions will be determined according to your compiler
and for which host it is compiling. See the module docs for more information
* `:distribution_weight` - weight used for load balancing when running
a distributed serving. Defaults to `1`.
If it is set to a higher number `w`, the serving process will receive,
on average, `w` times the number of requests compared to the
default. Note that the weight is multiplied with the number of
partitions, if partitioning is enabled.
* `:shutdown` - the maximum time for the serving to shutdown. This will
block until the existing computation finishes (defaults to `30_000`ms)
Expand Down Expand Up @@ -901,11 +914,18 @@ defmodule Nx.Serving do
partitions = Keyword.get(opts, :partitions, false)
batch_keys = Keyword.get(opts, :batch_keys, [:default])
batch_timeout = Keyword.get(opts, :batch_timeout, 100)
weight = Keyword.get(opts, :distribution_weight, 1)
process_options = Keyword.take(opts, [:name, :hibernate_after, :spawn_opt])

unless is_integer(weight) do
raise ArgumentError, ":distribution_weight must be an integer"
end

supervisor = Module.concat(name, "Supervisor")
task_supervisor = Module.concat(name, "TaskSupervisor")
arg = {name, serving, partitions, batch_keys, batch_size, batch_timeout, task_supervisor}

arg =
{name, serving, partitions, batch_keys, batch_size, batch_timeout, task_supervisor, weight}

children = [
{Task.Supervisor, name: task_supervisor},
Expand Down Expand Up @@ -1280,7 +1300,10 @@ defmodule Nx.Serving do
@timeout_message {__MODULE__, :timeout}

@impl true
def init({name, serving, partitions?, batch_keys, batch_size, batch_timeout, task_supervisor}) do
def init(
{name, serving, partitions?, batch_keys, batch_size, batch_timeout, task_supervisor,
weight}
) do
Process.flag(:trap_exit, true)
partitions_opts = serving_partitions(serving, partitions?)
partitions_count = length(partitions_opts)
Expand All @@ -1300,7 +1323,8 @@ defmodule Nx.Serving do
}
)

:pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), partitions_count))
serving_weight = max(1, weight * partitions_count)
:pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight))

for batch_key <- batch_keys do
stack_init(batch_key)
Expand Down
25 changes: 25 additions & 0 deletions nx/test/nx/serving_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,31 @@ defmodule Nx.ServingTest do
assert tensor == Nx.tensor([6])
end

@tag :distributed
@tag :capture_log
test "spawns distributed tasks over the network with different weights", config do
parent = self()

opts = [
name: config.test,
batch_size: 2,
shutdown: 1000,
distribution_weight: 1
]
opts2 = Keyword.put(opts, :distribution_weight, 4)

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts])
assert_receive {_, :join, Nx.Serving, pids}
assert length(pids) == 1

Node.spawn_link(:"[email protected]", DistributedServings, :multiply, [parent, opts2])
assert_receive {_, :join, Nx.Serving, pids}
assert length(pids) == 4

members = :pg.get_members(Nx.Serving.PG, Nx.Serving)
assert length(members) == 5
end

@tag :distributed
@tag :capture_log
test "spawns distributed tasks over the network with streaming", config do
Expand Down

0 comments on commit 94037ab

Please sign in to comment.