Skip to content

Commit

Permalink
Add jax compatible api (#1207)
Browse files Browse the repository at this point in the history
This PR adds a JAX compatible API, refer issue #1027
  • Loading branch information
sky-2002 authored Nov 27, 2024
1 parent 568f252 commit 5608dd8
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
23 changes: 23 additions & 0 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
except ImportError:
pass

try:
import jax
import jax.numpy as jnp
except ImportError:
pass


def is_mlx_lm_allowed():
try:
Expand All @@ -18,6 +24,14 @@ def is_mlx_lm_allowed():
return mx.metal.is_available()


def is_jax_allowed():
try:
import jax # noqa: F401
except ImportError:
return False
return True


def get_mock_processor_inputs(array_library, num_tokens=30000):
"""
logits: (4, 30,000 ) dtype=float
Expand All @@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000):
input_ids = mx.random.randint(
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
)
elif array_library == "jax":
logits = jnp.random.uniform(
key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32
)
input_ids = jnp.random.randint(
key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048)
)
else:
raise ValueError

Expand All @@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark:
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]
if is_jax_allowed():
params += ["jax"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()
Expand Down
21 changes: 21 additions & 0 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def is_mlx_array_type(array_type):
return issubclass(array_type, mx.array)


def is_jax_array_type(array_type):
try:
import jaxlib
except ImportError:
return False
return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance(
array_type, jaxlib.xla_extension.ArrayImpl
)


class OutlinesLogitsProcessor(Protocol):
"""
Base class for logits processors which normalizes types of logits:
Expand Down Expand Up @@ -101,6 +111,12 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
return torch.from_dlpack(tensor_like)

elif is_jax_array_type(type(tensor_like)):
import jax

torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
return torch_tensor

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray, "
Expand Down Expand Up @@ -129,6 +145,11 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
return mx.array(tensor.float().numpy())

elif is_jax_array_type(target_type):
import jax

return jax.dlpack.from_dlpack(tensor)

else:
raise TypeError(
f"Failed to convert torch tensors to target_type `{target_type}`"
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ enable_incomplete_feature = ["Unpack"]
[[tool.mypy.overrides]]
module = [
"exllamav2.*",
"jax",
"jaxlib",
"jax.numpy",
"jinja2",
"jsonschema.*",
"openai.*",
Expand Down
74 changes: 74 additions & 0 deletions tests/processors/test_base_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List

import jax.numpy as jnp
import numpy as np
import pytest
import torch

from outlines.processors.base_logits_processor import OutlinesLogitsProcessor

arrays = {
"list": [[1.0, 2.0], [3.0, 4.0]],
"np": np.array([[1, 2], [3, 4]], dtype=np.float32),
"jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32),
"torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
}

try:
import mlx.core as mx

arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
except ImportError:
pass

try:
import jax.numpy as jnp

arrays["jax"] = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
except ImportError:
pass


# Mock implementation of the abstract class for testing
class MockLogitsProcessor(OutlinesLogitsProcessor):
def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
# For testing purposes, let's just return logits multiplied by 2
return logits * 2


@pytest.fixture
def processor():
"""Fixture for creating an instance of the MockLogitsProcessor."""
return MockLogitsProcessor()


@pytest.mark.parametrize("array_type", arrays.keys())
def test_to_torch(array_type, processor):
data = arrays[array_type]
torch_tensor = processor._to_torch(data)
assert isinstance(torch_tensor, torch.Tensor)
assert torch.allclose(
torch_tensor.cpu(), torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
)


@pytest.mark.parametrize("array_type", arrays.keys())
def test_from_torch(array_type, processor):
torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
data = processor._from_torch(torch_tensor, type(arrays[array_type]))
assert isinstance(data, type(arrays[array_type]))
assert np.allclose(data, arrays[array_type])


@pytest.mark.parametrize("array_type", arrays.keys())
def test_call(array_type, processor):
input_ids = arrays[array_type]
logits = arrays[array_type]
processed_logits = processor(input_ids, logits)

assert isinstance(processed_logits, type(arrays[array_type]))
assert np.allclose(
np.array(processed_logits), np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32)
)

0 comments on commit 5608dd8

Please sign in to comment.