Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support to use customized precompiled binaries #41

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 88 additions & 27 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ defmodule EMLX.MixProject do

defp libmlx_config() do
version = System.get_env("LIBMLX_VERSION", @mlx_version)

features = %{
jit?: to_boolean(System.get_env("LIBMLX_ENABLE_JIT")),
debug?: to_boolean(System.get_env("LIBMLX_ENABLE_DEBUG"))
}

variant = to_variant(features)

%{
Expand All @@ -79,10 +81,10 @@ defmodule EMLX.MixProject do
end

defp to_variant(features) do
[(if features.debug?, do: "debug", else: nil), (if features.jit?, do: "jit", else: nil)]
|> Enum.filter(& &1 != nil)
[if(features.debug?, do: "debug", else: nil), if(features.jit?, do: "jit", else: nil)]
|> Enum.filter(&(&1 != nil))
|> Enum.sort()
|> Enum.map(& "-#{&1}")
|> Enum.map(&"-#{&1}")
|> Enum.join("")
end

Expand Down Expand Up @@ -110,38 +112,36 @@ defmodule EMLX.MixProject do

defp download_and_unarchive(cache_dir, libmlx_config) do
File.mkdir_p!(cache_dir)
libmlx_archive = Path.join(cache_dir, "libmlx-#{libmlx_config.version}#{libmlx_config.variant}.tar.gz")

unless File.exists?(libmlx_archive) do
# Download libmlx

if {:unix, :darwin} != :os.type() do
Mix.raise("No MLX support on non Apple Silicon machines")
end
libmlx_archive =
Path.join(cache_dir, "libmlx-#{libmlx_config.version}#{libmlx_config.variant}.tar.gz")

url =
"https://github.com/cocoa-xu/mlx-build/releases/download/v#{libmlx_config.version}/mlx-arm64-apple-darwin#{libmlx_config.variant}.tar.gz"
libmlx_archive = System.get_env("MLX_ARCHIVE_PATH", libmlx_archive)

sha256_url = "#{url}.sha256"

download!(url, libmlx_archive)
url =
"https://github.com/cocoa-xu/mlx-build/releases/download/v#{libmlx_config.version}/mlx-arm64-apple-darwin#{libmlx_config.variant}.tar.gz"

libmlx_archive_checksum = checksum!(libmlx_archive)
sha256_url = "#{url}.sha256"

data = download!(sha256_url)
checksum = String.split(data, " ", parts: 2, trim: true)
verify_integrity = "sha256=url:#{sha256_url}"

if length(checksum) != 2 do
Mix.raise("Invalid checksum file: #{sha256_url}")
{url, verify_integrity} =
if customized_url = System.get_env("MLX_ARCHIVE_URL") do
verify_integrity = System.get_env("MLX_ARCHIVE_INTEGRITY")
{customized_url, verify_integrity}
else
{url, verify_integrity}
end

expected_checksum = hd(checksum)
unless File.exists?(libmlx_archive) do
# Download libmlx

if expected_checksum != libmlx_archive_checksum do
Mix.raise(
"Checksum mismatch for #{libmlx_archive}. Expected: #{expected_checksum}, got: #{libmlx_archive_checksum}"
)
if {:unix, :darwin} != :os.type() do
Mix.raise("No MLX support on non Apple Silicon machines")
end

download!(url, libmlx_archive)
:ok = maybe_verify_integrity!(verify_integrity, libmlx_archive)
end

# Unpack libmlx and move to the target cache dir
Expand All @@ -157,6 +157,67 @@ defmodule EMLX.MixProject do
:ok
end

defp maybe_verify_integrity!(nil, _libmlx_archive), do: :ok

defp maybe_verify_integrity!(verify_integrity, libmlx_archive) do
{checksum_algo, expected_checksum} = get_checksum_info!(verify_integrity)
libmlx_archive_checksum = checksum!(libmlx_archive, checksum_algo)

if expected_checksum != libmlx_archive_checksum do
Mix.raise(
"Checksum (#{checksum_algo}) mismatch for #{libmlx_archive}. Expected: #{expected_checksum}, got: #{libmlx_archive_checksum}"
)
else
:ok
end
end

@known_checksum_algos [
"sha",
"sha224",
"sha256",
"sha384",
"sha512",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"blake2b",
"blake2s",
"ripemd160",
"md4",
"md5"
]

defp get_checksum_info!(verify_integrity) do
case String.split(verify_integrity, "=", parts: 2, trim: true) do
[algo, checksum] when algo in @known_checksum_algos ->
{String.to_existing_atom(algo), get_checksum_value!(checksum)}

_ ->
Mix.raise("Invalid checksum: #{verify_integrity}")
end
end

defp get_checksum_value!("url:" <> url) do
checksum_from_url!(url)
end

defp get_checksum_value!(checksum) do
checksum
end

defp checksum_from_url!(url) do
data = download!(url)
checksum = String.split(data, " ", parts: 2, trim: true)

if length(checksum) == 0 do
Mix.raise("Invalid checksum file: #{url}")
end

hd(checksum)
end

def download!(url, save_as \\ nil) do
url_charlist = String.to_charlist(url)

Expand Down Expand Up @@ -245,10 +306,10 @@ defmodule EMLX.MixProject do
""")
end

defp checksum!(file_path) do
defp checksum!(file_path, algo) do
case File.read(file_path) do
{:ok, content} ->
Base.encode16(:crypto.hash(:sha256, content), case: :lower)
Base.encode16(:crypto.hash(algo, content), case: :lower)

{:error, reason} ->
Mix.raise(
Expand Down
Loading