From ca1808320f544a84b988ee6657995a49f367243f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Wed, 4 Sep 2024 12:59:22 +0200 Subject: [PATCH] Quantized integers to Torchx --- torchx/lib/torchx/backend.ex | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 1f29d02c17..6e04ec877f 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -223,6 +223,14 @@ defmodule Torchx.Backend do for <>, into: <<>>, do: <> end + defp maybe_pad_binary(bin, {:u, size}) when size in [2, 4] do + for <>, into: <<>>, do: <> + end + + defp maybe_pad_binary(bin, {:s, size}) when size in [2, 4] do + for <>, into: <<>>, do: <> + end + defp maybe_pad_binary(bin, _), do: bin ## Shape @@ -1746,6 +1754,12 @@ defmodule Torchx.Backend do current_type = Torchx.scalar_type(device_ref) |> from_torch_type() case {current_type, type} do + {{:s, 8}, {:s, qint}} when qint in [2, 4] -> + :ok + + {{:u, 8}, {:u, qint}} when qint in [2, 4] -> + :ok + {{:s, 32}, {:u, 16}} -> :ok