diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 92f6ab7da..a412cbe7b 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -5,13 +5,21 @@ else() set(COMPILER ${CMAKE_CXX_COMPILER}) endif() +if(MSVC) + set(SHELL_EXT ps1) + set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) +else() + set(SHELL_EXT sh) + set(SHELL_CMD /bin/bash) +endif() + add_custom_command( OUTPUT compiled_preamble.cpp COMMAND - /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} - DEPENDS make_compiled_preamble.sh + DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h diff --git a/mlx/backend/common/make_compiled_preamble.ps1 b/mlx/backend/common/make_compiled_preamble.ps1 new file mode 100644 index 000000000..0b2248b67 --- /dev/null +++ b/mlx/backend/common/make_compiled_preamble.ps1 @@ -0,0 +1,38 @@ +# This script generates a C++ function that provides the CPU +# code for use with kernel generation. +# +# Copyright © 2024 Apple Inc. + +$OUTPUT_FILE = $args[0] +$CL = $args[1] +$SRCDIR = $args[2] + +# Get command result as array. +$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h" +# Remove empty lines. +# Otherwise there will be too much empty lines making the result unreadable. +$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } +# Concatenate to string. +$CONTENT = $CONTENT -join '`n' + +# Append extra content. +$CONTENT = @" +$($CONTENT) +using namespace mlx::core; +using namespace mlx::core::detail; +"@ + +# Convert each char to ASCII code. +# Unlike the unix script that outputs string literal directly, the output from +# MSVC is way too large to be embedded as string and compilation will fail, so +# we store it as static array instead. +$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0' + +$OUTPUT = @" +const char* get_kernel_preamble() { + static char preamble[] = { $CHARCODES }; + return preamble; +} +"@ + +Set-Content -Path $OUTPUT_FILE -Value $OUTPUT