Skip to content

Commit

Permalink
remove dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Nov 29, 2024
1 parent 559e482 commit f198118
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 178 deletions.
108 changes: 3 additions & 105 deletions torch2jax/api.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from __future__ import annotations

from typing import Callable, Any
from functools import partial
from inspect import signature

import torch
from torch import Tensor, Size
from torch import Tensor

import jax
from jax import numpy as jnp
from jax import ShapeDtypeStruct
from jax.interpreters import mlir, xla, batching
from jax import core
from jax.extend import ffi

# from jax.abstract_arrays import ShapedArray
from jax.tree_util import PyTreeDef

from .compile import compile_and_import_module
from .lowering_rule import _torch_call_lowering
from .utils import find_unique_id, dtype_t2j, normalize_shapes, warn_once
from .utils import find_unique_id, dtype_t2j, normalize_shapes


def torch2jax_flat(
Expand All @@ -45,7 +40,7 @@ def torch2jax_flat(
"""
# assert example_args is not None or output_shapes is not None or output_shapes_fn is not None
assert example_args is not None or output_shapes is not None
cpp_module = compile_and_import_module()
_ = compile_and_import_module()
id = find_unique_id()

def torch_call_fn_(args: list[torch.Tensor]):
Expand All @@ -72,103 +67,6 @@ def wrapped_flat_fn(*args_flat):

return wrapped_flat_fn

torch_prim = core.Primitive(f"torch_call_{id}")
torch_prim.multiple_results = True
torch_prim.def_impl(partial(xla.apply_primitive, torch_prim))

# inferring shapes #############################################################################
if output_shapes is not None:
# call the pytorch function to infer shapes
assert isinstance(output_shapes, (tuple, list)) and all(
isinstance(shape, (tuple, list, ShapeDtypeStruct, Size)) or shape is None or hasattr(shape, "shape")
for shape in output_shapes
)

def _torch_call_abstract(*args):
output_shapes_ = [Size(shape) if isinstance(shape, (list, tuple)) else shape for shape in output_shapes]
return normalize_shapes(output_shapes_, args)

else:
with torch.no_grad():
out = fn(*example_args)
assert isinstance(out, (tuple, list, Tensor))
out = (out,) if isinstance(out, Tensor) else tuple(out)

def _torch_call_abstract(*args):
return jax.tree.map(lambda x: ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)), out)

torch_prim.def_abstract_eval(_torch_call_abstract)
# inferring shapes #############################################################################

# lowering ####################################################################################
for platform in ["cpu", "gpu"]:
mlir.register_lowering(
torch_prim,
partial(_torch_call_lowering, cpp_module=cpp_module, platform=platform, id=id),
platform=platform,
)
# lowering ####################################################################################

def torch_call_fn_():
args = getattr(torch, f"_torch2jax_args_{id:d}")
out = fn(*args)
return (out,) if isinstance(out, Tensor) else tuple(out)

setattr(torch, f"_torch2jax_fn_{id:d}", torch_call_fn_)

def wrapped_fn(*args):
return torch_prim.bind(*args)

def torch_call_batching(args, axes):
if use_torch_vmap:

def torch_fn_vmap(*args):
return torch.vmap(fn, in_dims=axes)(*args)

assert any(ax is not None for ax in axes)
batch_size = [arg.shape[ax] for arg, ax in zip(args, axes) if ax is not None][0]
assert output_shapes is not None
output_shapes_ = _torch_call_abstract(*args)
output_shapes_vmap = [
ShapeDtypeStruct((batch_size,) + tuple(shape.shape), shape.dtype) for shape in output_shapes_
]
outaxes = (0 for _ in output_shapes_vmap)
return (
torch2jax_flat(torch_fn_vmap, args, output_shapes=output_shapes_vmap)(*args),
outaxes,
)
else:
warn_once(
"You are NOT using PyTorch's functional vmap. " + "This is highly experimental and may be slower."
)
assert all(axis is None or axis == 0 for axis in axes)
if all(axis is None for axis in axes):
return wrapped_fn(*args)
n = 0
for i, axis in enumerate(axes):
if axis is not None:
n = args[i].shape[axis]
break
output_lists, output_struct = None, None
for i in range(n):
args_ = [arg if axis is None else arg[i] for arg, axis in zip(args, axes)]
outputs = wrapped_fn(*args_)
output_flat, output_struct = jax.tree.flatten(outputs)
if output_lists is None:
output_lists = [[] for _ in output_flat]
for output_list, output in zip(output_lists, output_flat):
output_list.append(output)
outputs = tuple([jnp.stack(output_list, 0) for output_list in output_lists])
outputs = jax.tree.unflatten(output_struct, outputs)
return outputs, jax.tree.unflatten(output_struct, (0 for _ in outputs))

batching.primitive_batchers[torch_prim] = torch_call_batching

return wrapped_fn


####################################################################################################


def torch2jax(
fn: Callable,
Expand Down
73 changes: 0 additions & 73 deletions torch2jax/lowering_rule.py

This file was deleted.

0 comments on commit f198118

Please sign in to comment.