Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move op dispatching logic into an Environment class; and use Mode to capture dispatcher instead of tensor. #7009

Merged
merged 3 commits into from
May 8, 2024

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Apr 30, 2024

This work is motivated by supporting random numbers.

In Jax, we need to explicitly keep track of the PRNG state because Jax is functional. Therefore, we need a global to keep track of this state, and ops that needs it to have access to it. With this there are few global states that we are managing:

  • PRNG
  • Ops registry
  • (potential) configs (currently some DEBUG boolean globals that are sprinkled around that controls debug logging), for example: default device, default dtype etc.

So the thought is why not put them into an environment object and avoid globals.

So the current design is:
op dispatching logic are concentrated in Environement.dispatch method. The 2 modes are only responsible in capturing the op and forward to env.

Most of changes are moves:

  • _ops.py -> jaten.py
  • functions.py -> jtorch.py
  • environments.py -> tensor.py (Tensor and Environment depends on each other).
  • extra.py -> interop.py

@qihqi qihqi force-pushed the hanq_torchxla2 branch 3 times, most recently from 625e106 to f5ffd71 Compare May 1, 2024 17:02
@qihqi qihqi changed the title ops refactor Move op dispatching logic into an Environment class; and use Mode to capture dispatcher instead of tensor. May 1, 2024
@qihqi qihqi force-pushed the hanq_torchxla2 branch from f5ffd71 to 07198d6 Compare May 1, 2024 21:15
@qihqi qihqi requested review from will-cromar and lsy323 May 1, 2024 21:18
@qihqi qihqi marked this pull request as ready for review May 1, 2024 21:19
@qihqi qihqi force-pushed the hanq_torchxla2 branch from 07198d6 to 39d5bd3 Compare May 1, 2024 22:03
@qihqi qihqi requested review from alanwaketan and JackCaoG May 1, 2024 22:04
@qihqi qihqi force-pushed the hanq_torchxla2 branch from 39d5bd3 to 0a50922 Compare May 1, 2024 22:39
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
TorchCallable: TypeAlias = Callable[P, TorchValue]

JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'TorchCallable', Any]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoud it be 'JaxCallable' here?


P = ParamSpec('P')

TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these type aliases are repeated from interop.py. Do you think it would be useful to have a shared types module to hold these in one place?

return func.__name__


class Environment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the environment inherit TorchDispatchMode and TorchFunctionMode and implement __torch_function__ and __torch_dispatch__ directly? That would change the semantic to with torch_xla2.default_env():

Or do we ever expect to use XLAFunctionMode and XLADispatchMode independently of Environment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making Environment a context manager itself makes sense. However, multiple inheritance won't work as only one mode will activate. on calling of __enter__ with with, Python's method resolution will find the first parent with that method and skip the second parent. I added __enter__ and __exit__ to Environment class instead.

Also helper functions to manipulate those.
"""

_prng_key: jax.random.PRNGKey
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a general design question about how to handle randomness. I had a half-finished addition for a few random ops. In there, I ignored the Environment's PRNG key entirely.

The most important thing for me personally is to preserve as much of torch's semantics as possible. So for a function like randn, we should support the generator argument as the way for a user to pass in a key. We can call Generator.seed and then convert that to a JAX key.

Since we would have to support generator anyway in that case, it occurred to me that we don't actually need an extra way to store the global/default PRNG state -- torch already gives us one through torch.random, and it's already in use in much existing torch code (e.g. everyone is already calling torch.manual_seed). Do we also want to add an additional random state on top of that?

@qihqi qihqi merged commit 825ba0d into master May 8, 2024
2 of 3 checks passed
@qihqi qihqi deleted the hanq_torchxla2 branch May 8, 2024 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants