Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove dead code (previous non new ffi lowering) #21

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
remove dead code
  • Loading branch information
rdyro committed Nov 29, 2024
commit 6e0ab8e942b1b172d83260a6e32f087cb761ceb9
108 changes: 3 additions & 105 deletions torch2jax/api.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from __future__ import annotations

from typing import Callable, Any
from functools import partial
from inspect import signature

import torch
from torch import Tensor, Size
from torch import Tensor

import jax
from jax import numpy as jnp
from jax import ShapeDtypeStruct
from jax.interpreters import mlir, xla, batching
from jax import core
from jax.extend import ffi

# from jax.abstract_arrays import ShapedArray
from jax.tree_util import PyTreeDef

from .compile import compile_and_import_module
from .lowering_rule import _torch_call_lowering
from .utils import find_unique_id, dtype_t2j, normalize_shapes, warn_once
from .utils import find_unique_id, dtype_t2j, normalize_shapes


def torch2jax_flat(
@@ -45,7 +40,7 @@ def torch2jax_flat(
"""
# assert example_args is not None or output_shapes is not None or output_shapes_fn is not None
assert example_args is not None or output_shapes is not None
cpp_module = compile_and_import_module()
_ = compile_and_import_module()
id = find_unique_id()

def torch_call_fn_(args: list[torch.Tensor]):
@@ -72,103 +67,6 @@ def wrapped_flat_fn(*args_flat):

return wrapped_flat_fn

torch_prim = core.Primitive(f"torch_call_{id}")
torch_prim.multiple_results = True
torch_prim.def_impl(partial(xla.apply_primitive, torch_prim))

# inferring shapes #############################################################################
if output_shapes is not None:
# call the pytorch function to infer shapes
assert isinstance(output_shapes, (tuple, list)) and all(
isinstance(shape, (tuple, list, ShapeDtypeStruct, Size)) or shape is None or hasattr(shape, "shape")
for shape in output_shapes
)

def _torch_call_abstract(*args):
output_shapes_ = [Size(shape) if isinstance(shape, (list, tuple)) else shape for shape in output_shapes]
return normalize_shapes(output_shapes_, args)

else:
with torch.no_grad():
out = fn(*example_args)
assert isinstance(out, (tuple, list, Tensor))
out = (out,) if isinstance(out, Tensor) else tuple(out)

def _torch_call_abstract(*args):
return jax.tree.map(lambda x: ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)), out)

torch_prim.def_abstract_eval(_torch_call_abstract)
# inferring shapes #############################################################################

# lowering ####################################################################################
for platform in ["cpu", "gpu"]:
mlir.register_lowering(
torch_prim,
partial(_torch_call_lowering, cpp_module=cpp_module, platform=platform, id=id),
platform=platform,
)
# lowering ####################################################################################

def torch_call_fn_():
args = getattr(torch, f"_torch2jax_args_{id:d}")
out = fn(*args)
return (out,) if isinstance(out, Tensor) else tuple(out)

setattr(torch, f"_torch2jax_fn_{id:d}", torch_call_fn_)

def wrapped_fn(*args):
return torch_prim.bind(*args)

def torch_call_batching(args, axes):
if use_torch_vmap:

def torch_fn_vmap(*args):
return torch.vmap(fn, in_dims=axes)(*args)

assert any(ax is not None for ax in axes)
batch_size = [arg.shape[ax] for arg, ax in zip(args, axes) if ax is not None][0]
assert output_shapes is not None
output_shapes_ = _torch_call_abstract(*args)
output_shapes_vmap = [
ShapeDtypeStruct((batch_size,) + tuple(shape.shape), shape.dtype) for shape in output_shapes_
]
outaxes = (0 for _ in output_shapes_vmap)
return (
torch2jax_flat(torch_fn_vmap, args, output_shapes=output_shapes_vmap)(*args),
outaxes,
)
else:
warn_once(
"You are NOT using PyTorch's functional vmap. " + "This is highly experimental and may be slower."
)
assert all(axis is None or axis == 0 for axis in axes)
if all(axis is None for axis in axes):
return wrapped_fn(*args)
n = 0
for i, axis in enumerate(axes):
if axis is not None:
n = args[i].shape[axis]
break
output_lists, output_struct = None, None
for i in range(n):
args_ = [arg if axis is None else arg[i] for arg, axis in zip(args, axes)]
outputs = wrapped_fn(*args_)
output_flat, output_struct = jax.tree.flatten(outputs)
if output_lists is None:
output_lists = [[] for _ in output_flat]
for output_list, output in zip(output_lists, output_flat):
output_list.append(output)
outputs = tuple([jnp.stack(output_list, 0) for output_list in output_lists])
outputs = jax.tree.unflatten(output_struct, outputs)
return outputs, jax.tree.unflatten(output_struct, (0 for _ in outputs))

batching.primitive_batchers[torch_prim] = torch_call_batching

return wrapped_fn


####################################################################################################


def torch2jax(
fn: Callable,
73 changes: 0 additions & 73 deletions torch2jax/lowering_rule.py

This file was deleted.