Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add type transcoding #55

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 110 additions & 6 deletions lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ defmodule EMLX.Backend do
@impl true
def to_binary(tensor, limit) do
EMLX.to_blob(from_nx(tensor), limit)
|> maybe_modify_binary(to_nx_type(to_mlx_type(tensor.type)), tensor.type)
end

@impl true
def from_binary(%T{type: type, shape: shape} = out, binary, backend_options) do
binary
|> maybe_modify_binary(type, to_nx_type(to_mlx_type(type)))
|> EMLX.from_blob(
shape,
to_mlx_type(type),
Expand All @@ -86,6 +88,104 @@ defmodule EMLX.Backend do
|> to_nx(out)
end

defp maybe_modify_binary(binary, type, type), do: binary

defp maybe_modify_binary(binary, {:f, 8}, {:f, 16}) do
for <<byte::8 <- binary>>, into: <<>> do
case read_f8(<<byte::8>>) do
:nan ->
Nx.Type.nan_binary({:f, 16})

:infinity ->
Nx.Type.infinity_binary({:f, 16})

:neg_infinity ->
Nx.Type.neg_infinity_binary({:f, 16})

number ->
<<number::float-native-size(16)>>
end
end
end

defp maybe_modify_binary(binary, {:f, 16}, {:f, 8}) do
for <<float::16 <- binary>>, into: <<>> do
case <<float::16-native>> do
<<float::float-native-16>> -> write_finite_f8(float)
<<0xFC00::16-native>> -> write_non_finite(:neg_infinity, 8)
<<0x7C00::16-native>> -> write_non_finite(:infinity, 8)
_ -> write_non_finite(:nan, 8)
end
end
end

defp maybe_modify_binary(binary, {:f, 64}, {:f, 32}) do
for <<float::64 <- binary>>, into: <<>> do
case <<float::64>> do
<<float::float-native-64>> -> <<float::float-native-size(32)>>
<<0xFFF0000000000000::64-native>> -> write_non_finite(:neg_infinity, 32)
<<0x7FF0000000000000::64-native>> -> write_non_finite(:infinity, 32)
_ -> write_non_finite(:nan, 32)
end
end
end

defp maybe_modify_binary(binary, {:f, 32}, {:f, 64}) do
for <<float::32 <- binary>>, into: <<>> do
case <<float::32>> do
<<float::float-native-32>> -> <<float::float-native-size(64)>>
<<0xFF800000::32-native>> -> write_non_finite(:neg_infinity, 64)
<<0x7F800000::32-native>> -> write_non_finite(:infinity, 64)
_ -> write_non_finite(:nan, 64)
end
end
end

defp maybe_modify_binary(binary, {:u, size}, {:u, 8}) when size in [2, 4] do
for <<bits::integer-native-size(size) <- binary>>, into: <<>> do
<<bits::integer-native-size(8)>>
end
end

defp maybe_modify_binary(binary, {:u, 8}, {:u, size}) when size in [2, 4] do
for <<bits::integer-native-size(8) <- binary>>, into: <<>> do
<<bits::integer-native-size(size)>>
end
end

defp read_f8(<<0xFC::8-native>>), do: :neg_infinity
defp read_f8(<<0x7C::8-native>>), do: :infinity
defp read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan

defp read_f8(<<sign::1, exp::5, mantissa::2>>) do
float = :math.pow(2, exp - 15) * (1 + mantissa / 4)

case sign do
0 -> float
_ -> -float
end
end

if System.endianness() == :little do
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 0, 1)
end
else
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 1, 1)
end
end

for size <- [8, 16, 32, 64] do
def write_non_finite(data, unquote(size)) do
case data do
:infinity -> unquote(Nx.Type.infinity_binary({:f, size}))
:neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:f, size}))
:nan -> unquote(Nx.Type.nan_binary({:f, size}))
end
end
end

@impl true
def slice(
out,
Expand Down Expand Up @@ -202,18 +302,25 @@ defmodule EMLX.Backend do
defp needs_type_conversion?({:u, 8}, :bool), do: true
defp needs_type_conversion?(_, _), do: false

defp to_mlx_type({:u, 2}), do: :uint8
defp to_mlx_type({:u, 4}), do: :uint8
defp to_mlx_type({:u, 8}), do: :uint8
defp to_mlx_type({:u, 16}), do: :uint16
defp to_mlx_type({:u, 32}), do: :uint32
defp to_mlx_type({:u, 64}), do: :uint64
defp to_mlx_type({:s, 2}), do: :int8
defp to_mlx_type({:s, 4}), do: :int8
defp to_mlx_type({:s, 8}), do: :int8
defp to_mlx_type({:s, 16}), do: :int16
defp to_mlx_type({:s, 32}), do: :int32
defp to_mlx_type({:s, 64}), do: :int64
defp to_mlx_type({:f, 8}), do: :float16
defp to_mlx_type({:f, 16}), do: :float16
defp to_mlx_type({:f, 32}), do: :float32
defp to_mlx_type({:f, 64}), do: :float32
defp to_mlx_type({:bf, 16}), do: :bfloat16
defp to_mlx_type({:c, 64}), do: :complex64
defp to_mlx_type({:c, 128}), do: :complex64
defp to_mlx_type(:bool), do: :bool

defp to_nx_type(:uint8), do: {:u, 8}
Expand Down Expand Up @@ -252,16 +359,13 @@ defmodule EMLX.Backend do
{{:u, 8}, {:u, qint}} when qint in [2, 4] ->
:ok

{{:s, 32}, {:u, 16}} ->
:ok

{{:s, 64}, {:u, 32}} ->
{{:f, 16}, {:f, 8}} ->
:ok

{{:s, 64}, {:u, 64}} ->
{{:f, 32}, {:f, 64}} ->
:ok

{{:u, 8}, {:u, 32}} ->
{{:c, 64}, {:c, 128}} ->
:ok

_ when actual_type != expected_type ->
Expand Down
18 changes: 5 additions & 13 deletions test/emlx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ defmodule EMLX.Nx.DoctestTest do
cosh: 1,
log10: 1,
acos: 1,
covariance: 3
covariance: 3,
# These fail because we're using different representation types
atan2: 2,
as_type: 2,
from_binary: 3
]

@to_be_fixed [
Expand All @@ -51,18 +55,6 @@ defmodule EMLX.Nx.DoctestTest do
]

@not_supported [
# Does not support f8 (yet?)
tensor: 2,
# Does not support u2 (yet?)
bit_size: 1,
# f64 not supported
from_binary: 3,
# f64 not supported
iota: 2,
# f64 not supported
atan2: 2,
# f64 not supported
as_type: 2,
reduce: 4,
window_reduce: 5,
population_count: 1,
Expand Down