Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmap inside jit #5681

Closed
cgarciae opened this issue Feb 8, 2021 · 3 comments
Closed

pmap inside jit #5681

cgarciae opened this issue Feb 8, 2021 · 3 comments
Labels
question Questions for the JAX team

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 8, 2021

Here w is being captured by g so it works in terms of shapes but its unclear if / how w being distributed to each device.

import jax

@jax.jit
def f(w, x):
    @jax.pmap
    def g(x):
        return w * x

    return g(x)

What is actually happening here?

Edit
Assume:

x.shape == (device, batch, d)
w.shape == (batch, d)
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 9, 2021

You can see how this is passed to XLA using jax.make_jaxpr:

x = jnp.ones((1, 2, 3))
w = jnp.ones((2, 3))
jax.make_jaxpr(f)(w, x)
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = xla_pmap[ axis_name=<axis 0x7f4c49bbad08>
                                                   axis_size=1
                                                   backend=None
                                                   call_jaxpr={ lambda  ; a b.
                                                                let c = mul a b
                                                                in (c,) }
                                                   devices=None
                                                   donated_invars=(False, False)
                                                   global_arg_shapes=(None,)
                                                   global_axis_size=None
                                                   in_axes=(None, 0)
                                                   name=g
                                                   out_axes=(0,) ] a b
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=f ] a b
  in (c,) }

In particular, you can see that both a (which represents w) and b (which represents x) are passed as arguments to pmap, with in_axes=(None, 0). I believe this means that the values of w are replicated on each device in order to perform the computation, similarly to if you had defined the function like this:

@jax.jit
def f(w, x):
  @jax.partial(jax.pmap, in_axes=(None, 0))
  def g(w, x):
    return w * x
  return g(w, x)

which generates a nearly identical jaxpr.

@jakevdp jakevdp added the question Questions for the JAX team label Feb 9, 2021
@skye
Copy link
Member

skye commented Feb 10, 2021

Also note that calling pmap inside of jit is not usually what you want! pmap already compiles your function the same way jit does, and furthermore, adding the extra jit can often causes performance issues (see #2926 for more info on why).

@cgarciae
Copy link
Collaborator Author

cgarciae commented Feb 10, 2021

Thanks @jakevdp, your answer was really helpful in understand a bit more about jax tracing in general!

@skye thanks for the tip! Luckily I got a warning message about this when actually running some test code on a TPU that pointed towards this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants