-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move op dispatching logic into an Environment
class; and use Mode to capture dispatcher instead of tensor.
#7009
Conversation
625e106
to
f5ffd71
Compare
Environment
class; and use Mode to capture dispatcher instead of tensor.
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] | ||
TorchCallable: TypeAlias = Callable[P, TorchValue] | ||
|
||
JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'TorchCallable', Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shoud it be 'JaxCallable'
here?
|
||
P = ParamSpec('P') | ||
|
||
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these type aliases are repeated from interop.py
. Do you think it would be useful to have a shared types
module to hold these in one place?
return func.__name__ | ||
|
||
|
||
class Environment: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the environment inherit TorchDispatchMode
and TorchFunctionMode
and implement __torch_function__
and __torch_dispatch__
directly? That would change the semantic to with torch_xla2.default_env():
Or do we ever expect to use XLAFunctionMode
and XLADispatchMode
independently of Environment
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making Environment a context manager itself makes sense. However, multiple inheritance won't work as only one mode will activate. on calling of __enter__
with with
, Python's method resolution will find the first parent with that method and skip the second parent. I added __enter__
and __exit__
to Environment class instead.
Also helper functions to manipulate those. | ||
""" | ||
|
||
_prng_key: jax.random.PRNGKey |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a general design question about how to handle randomness. I had a half-finished addition for a few random ops. In there, I ignored the Environment's PRNG key entirely.
The most important thing for me personally is to preserve as much of torch
's semantics as possible. So for a function like randn
, we should support the generator
argument as the way for a user to pass in a key. We can call Generator.seed
and then convert that to a JAX key.
Since we would have to support generator
anyway in that case, it occurred to me that we don't actually need an extra way to store the global/default PRNG state -- torch
already gives us one through torch.random
, and it's already in use in much existing torch
code (e.g. everyone is already calling torch.manual_seed
). Do we also want to add an additional random state on top of that?
This work is motivated by supporting random numbers.
In Jax, we need to explicitly keep track of the PRNG state because Jax is functional. Therefore, we need a global to keep track of this state, and ops that needs it to have access to it. With this there are few global states that we are managing:
So the thought is why not put them into an environment object and avoid globals.
So the current design is:
op dispatching logic are concentrated in Environement.dispatch method. The 2 modes are only responsible in capturing the op and forward to env.
Most of changes are moves: