Skip to content

Commit

Permalink
feat: add type transcoding (#55)
Browse files Browse the repository at this point in the history
* refactor: remove same-device enforcement

* feat: add type transcoding
  • Loading branch information
polvalente authored Nov 27, 2024
1 parent ce5a975 commit 0be8412
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 19 deletions.
116 changes: 110 additions & 6 deletions lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ defmodule EMLX.Backend do
@impl true
def to_binary(tensor, limit) do
EMLX.to_blob(from_nx(tensor), limit)
|> maybe_modify_binary(to_nx_type(to_mlx_type(tensor.type)), tensor.type)
end

@impl true
def from_binary(%T{type: type, shape: shape} = out, binary, backend_options) do
binary
|> maybe_modify_binary(type, to_nx_type(to_mlx_type(type)))
|> EMLX.from_blob(
shape,
to_mlx_type(type),
Expand All @@ -86,6 +88,104 @@ defmodule EMLX.Backend do
|> to_nx(out)
end

defp maybe_modify_binary(binary, type, type), do: binary

defp maybe_modify_binary(binary, {:f, 8}, {:f, 16}) do
for <<byte::8 <- binary>>, into: <<>> do
case read_f8(<<byte::8>>) do
:nan ->
Nx.Type.nan_binary({:f, 16})

:infinity ->
Nx.Type.infinity_binary({:f, 16})

:neg_infinity ->
Nx.Type.neg_infinity_binary({:f, 16})

number ->
<<number::float-native-size(16)>>
end
end
end

defp maybe_modify_binary(binary, {:f, 16}, {:f, 8}) do
for <<float::16 <- binary>>, into: <<>> do
case <<float::16-native>> do
<<float::float-native-16>> -> write_finite_f8(float)
<<0xFC00::16-native>> -> write_non_finite(:neg_infinity, 8)
<<0x7C00::16-native>> -> write_non_finite(:infinity, 8)
_ -> write_non_finite(:nan, 8)
end
end
end

defp maybe_modify_binary(binary, {:f, 64}, {:f, 32}) do
for <<float::64 <- binary>>, into: <<>> do
case <<float::64>> do
<<float::float-native-64>> -> <<float::float-native-size(32)>>
<<0xFFF0000000000000::64-native>> -> write_non_finite(:neg_infinity, 32)
<<0x7FF0000000000000::64-native>> -> write_non_finite(:infinity, 32)
_ -> write_non_finite(:nan, 32)
end
end
end

defp maybe_modify_binary(binary, {:f, 32}, {:f, 64}) do
for <<float::32 <- binary>>, into: <<>> do
case <<float::32>> do
<<float::float-native-32>> -> <<float::float-native-size(64)>>
<<0xFF800000::32-native>> -> write_non_finite(:neg_infinity, 64)
<<0x7F800000::32-native>> -> write_non_finite(:infinity, 64)
_ -> write_non_finite(:nan, 64)
end
end
end

defp maybe_modify_binary(binary, {:u, size}, {:u, 8}) when size in [2, 4] do
for <<bits::integer-native-size(size) <- binary>>, into: <<>> do
<<bits::integer-native-size(8)>>
end
end

defp maybe_modify_binary(binary, {:u, 8}, {:u, size}) when size in [2, 4] do
for <<bits::integer-native-size(8) <- binary>>, into: <<>> do
<<bits::integer-native-size(size)>>
end
end

defp read_f8(<<0xFC::8-native>>), do: :neg_infinity
defp read_f8(<<0x7C::8-native>>), do: :infinity
defp read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan

defp read_f8(<<sign::1, exp::5, mantissa::2>>) do
float = :math.pow(2, exp - 15) * (1 + mantissa / 4)

case sign do
0 -> float
_ -> -float
end
end

if System.endianness() == :little do
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 0, 1)
end
else
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 1, 1)
end
end

for size <- [8, 16, 32, 64] do
def write_non_finite(data, unquote(size)) do
case data do
:infinity -> unquote(Nx.Type.infinity_binary({:f, size}))
:neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:f, size}))
:nan -> unquote(Nx.Type.nan_binary({:f, size}))
end
end
end

@impl true
def slice(
out,
Expand Down Expand Up @@ -202,18 +302,25 @@ defmodule EMLX.Backend do
defp needs_type_conversion?({:u, 8}, :bool), do: true
defp needs_type_conversion?(_, _), do: false

defp to_mlx_type({:u, 2}), do: :uint8
defp to_mlx_type({:u, 4}), do: :uint8
defp to_mlx_type({:u, 8}), do: :uint8
defp to_mlx_type({:u, 16}), do: :uint16
defp to_mlx_type({:u, 32}), do: :uint32
defp to_mlx_type({:u, 64}), do: :uint64
defp to_mlx_type({:s, 2}), do: :int8
defp to_mlx_type({:s, 4}), do: :int8
defp to_mlx_type({:s, 8}), do: :int8
defp to_mlx_type({:s, 16}), do: :int16
defp to_mlx_type({:s, 32}), do: :int32
defp to_mlx_type({:s, 64}), do: :int64
defp to_mlx_type({:f, 8}), do: :float16
defp to_mlx_type({:f, 16}), do: :float16
defp to_mlx_type({:f, 32}), do: :float32
defp to_mlx_type({:f, 64}), do: :float32
defp to_mlx_type({:bf, 16}), do: :bfloat16
defp to_mlx_type({:c, 64}), do: :complex64
defp to_mlx_type({:c, 128}), do: :complex64
defp to_mlx_type(:bool), do: :bool

defp to_nx_type(:uint8), do: {:u, 8}
Expand Down Expand Up @@ -252,16 +359,13 @@ defmodule EMLX.Backend do
{{:u, 8}, {:u, qint}} when qint in [2, 4] ->
:ok

{{:s, 32}, {:u, 16}} ->
:ok

{{:s, 64}, {:u, 32}} ->
{{:f, 16}, {:f, 8}} ->
:ok

{{:s, 64}, {:u, 64}} ->
{{:f, 32}, {:f, 64}} ->
:ok

{{:u, 8}, {:u, 32}} ->
{{:c, 64}, {:c, 128}} ->
:ok

_ when actual_type != expected_type ->
Expand Down
18 changes: 5 additions & 13 deletions test/emlx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ defmodule EMLX.Nx.DoctestTest do
cosh: 1,
log10: 1,
acos: 1,
covariance: 3
covariance: 3,
# These fail because we're using different representation types
atan2: 2,
as_type: 2,
from_binary: 3
]

@to_be_fixed [
Expand All @@ -51,18 +55,6 @@ defmodule EMLX.Nx.DoctestTest do
]

@not_supported [
# Does not support f8 (yet?)
tensor: 2,
# Does not support u2 (yet?)
bit_size: 1,
# f64 not supported
from_binary: 3,
# f64 not supported
iota: 2,
# f64 not supported
atan2: 2,
# f64 not supported
as_type: 2,
reduce: 4,
window_reduce: 5,
population_count: 1,
Expand Down

0 comments on commit 0be8412

Please sign in to comment.