-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jit(pmap(f))
causes inefficient behavior
#2926
Comments
Can the performance issues here be resolved? There are many use cases where a
In this case, one would want to jit Additionally, when one package depends on another package (e.g. how numpyro depends on jaxns), the upper level package is restricted in the way they are allowed to use |
I'm not sure we can avoid the overhead in this situation, because the One thing to note about this pattern in general: all of |
I didn't know about that last fact you mention. So when you jit a function, you implicitly, or explicitly, specify devices for all it's primitives to run on, and thus no primitive subset can be restricted to run on a subset of those devices. If I understand that correctly it makes sense. I would be very curious to know what the fundamental reasons are for disallowing primitives running on different devices. Re, you question: I have an example where running pmap inside a jit is failing altogether, however I wanted to understand pmap inside a jit better before posting about it. However, perhaps this error is related the 'inefficient behaviour' mentioned in this post, since sharding is mentioned in the traceback. The traceback I'm getting is: File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/api.py", line 398, in f_jitted
return cpp_jitted_f(context, *args, **kwargs)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/api.py", line 289, in cache_miss
out_flat = xla.xla_call(
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1275, in bind
return call_bind(self, fun, *args, **params)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1266, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1278, in process
return trace.process_call(self, fun, tracers, params)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 631, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 580, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 708, in _xla_callable
out_nodes = jaxpr_subcomp(
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 452, in jaxpr_subcomp
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 331, in _while_loop_translation_rule
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 452, in jaxpr_subcomp
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 746, in _cond_translation_rule
branch_computations = [
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 747, in <listcomp>
make_computation(f'branch_{i}', jaxpr, op_shape)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 739, in make_computation
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 460, in jaxpr_subcomp
ans = rule(c, axis_env, in_nodes,
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1242, in _pmap_translation_rule
outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1242, in <listcomp>
outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1288, in _xla_unshard
xla.axis_groups(axis_env, axis_env.names[-1]))
File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 512, in axis_groups
assert not ragged
AssertionError
Process finished with exit code 1 |
Since the original issue has been addressed by #3426 and there is a warning for jit of pmap now I'll close this issue. Please feel free to reopen or file a new issue! |
Why do I get this warning when I'm not using |
@tavin Yes, that's why! The control flow combinators like |
@mattjj Ok good to know! The pmap man page is quite vocal about automatically jitting, but the fori_loop man page doesn't mention it. Thanks for the confirmation. |
Good idea - I updated some of the documentation in #10757 |
@Joshuaalbert I'm having the same issue, where I can't avoid pmapping inside a jit (specifically, inside a scan loop). Did you ever find a solution, by any chance? |
@carlosgmartin yes, it's possible to write JAX code that is efficient when it comes to A pattern that might be useful is this (note this is not runnable code and should be interpreted) # The work that will be done on each device
def inner_loop(state):
...
# An iterative algorithm where each iteration has two steps: 1) distribute work, 2) collect states to each device locally to do some work, e.g. determine stopping condition
def single_algorithm_thread(state):
done = False
while not(done):
local_product = inner_loop(state)
aggregated_product = all_gather(local_product, 'i') # collecting along broadcasted axis
done = is_done(aggregated_product)
state = make_next_state(state, local_product)
aggregated_product = all_gather(local_product, 'i') # collecting along broadcasted axis
return state, local_product
# Map the algorithm over devices with pmap
def step_of_larger_algorithm(state):
parallel_algorithm = pmap(single_algorithm_thread, 'i')
chunked_state = add_leading_dim(state)
chunked_output, chunked_product = parallel_algorithm(chunked_state)
output = remove_leading_dim(chunked_output)
product = remove_leading_dim(chunked_product)
return output, product #you may only need output, and not the product from the intermediate steps of algorithm
# JIT-compile a sequence of pmap-ed steps of large algorithm
@jit
def big_algorithm():
state = step_of_larger_algorithm(state)
# do something with state
...
# run more steps using pmap
state = another_step_of_larger_algorithm(state) What is going on is that you're composing your big algorithm, that you'd like to jit-compile, into a sequence of steps where you use pmap to distribute work. Each step can collect data from all the other devices locally so that it can do something, e.g. determine a stopping condition requiring knowledge of products on all devices. This sequence can be efficiently jit compiled. There is only one important thing you need to keep in mind, which is to make the pmap'ed components stateless. Make sure that all inputs to pmap'd functions are passed in as arguments and not caught by closure. Also, make sure you try to reduce the size of objects being collected with In summary, try to break up your algorithm into a sequence pmap'able steps, don't let arrays be caught from external scope, and focus on making inter device communication as light-weight as possible. |
We actually have a new and somewhat experimental solution to composing We are tentatively thinking we may be able to replace |
@Joshuaalbert @hawkinsp Thank you for your comments. The structure of the program I'm dealing with is described here: #15693. It performs OpenAI ES, which requires only that each device send a single scalar to all other devices, on each step. If you have any specific advice for that pattern/situation, feel free to comment there. I'd really appreciate it! As an aside, I dream of a compiler that is powerful enough to let users focus solely on the semantics of a program (what is to be computed), while the compiler figures out how to distribute the computation efficiently over a set of available resources (how it is to be computed). So no more Edit: Came across the following: Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning (repo):
|
@hawkinsp and @carlosgmartin one thing to keep in mind is that shmap doesn't seem to have good support for non-static loops. So it's fine with |
Combining
jit
withpmap
produces some undesirable and surprising behaviors.As one example, any lazy intermediate constants used by the function get instantiated and copied to every device. For instance:
This causes 2GB of data to be allocated on each device (and, right now, if this is the only computation you run, this can be verified by looking at
list(list(jax.pxla.parallel_callable.__closure__[1].cell_contents.items())[0][1].values())[-1][0].__closure__[0].cell_contents
, but that might break).Relatedly, the
jit
causes the return value to be copied back to a single host instead of staying as a ShardedDeviceArray.Ideally, adding
jit
would not make behavior worse. But having a warning when such a situation occurs would also be useful here, sincepmap
on its own does the right thing.The text was updated successfully, but these errors were encountered: