From 105f37e9014dbaadfd63126e5bc526c9b4c11fa2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 13 Dec 2024 14:09:55 +0900 Subject: [PATCH 01/12] Invoke MSVC on Windows in mx.compile --- mlx/backend/common/compiled_cpu.cpp | 140 +++++++++++++++++- mlx/backend/common/make_compiled_preamble.ps1 | 2 +- 2 files changed, 135 insertions(+), 7 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index eb08d070d..4b7e832b0 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -1,11 +1,16 @@ // Copyright © 2023-2024 Apple Inc. -#include +#include #include #include #include #include +#include #include +#include + +#include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" @@ -14,6 +19,110 @@ namespace mlx::core { +#ifdef _MSC_VER + +namespace { + +// Remove trailing whitespaces. +std::string rtrim(const std::string& s) { + return std::string(s.begin(), std::find_if(s.rbegin(), s.rend(), [](auto ch) { + return !std::isspace(ch); + }).base()); +} + +// Split string into array. +std::vector str_split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +// Return a new vector by transforming its values. +template +std::vector vec_map(const std::vector& v, F&& transform) { + std::vector ret(v.size()); + std::transform(v.begin(), v.end(), ret.begin(), std::forward(transform)); + return ret; +} + +// Join a vector into a string. +template +std::string vec_join(const std::vector& v, const std::string& delimiter) { + if (v.empty()) + return ""; + return std::accumulate( + v.begin() + 1, + v.end(), + v[0], + [&](const std::string& a, const std::string& b) { + return a + delimiter + b; + }); +} + +// Run a command and get its output. +std::string exec(std::string cmd) { + std::unique_ptr pipe( + _popen(cmd.c_str(), "r"), _pclose); + if (!pipe) { + throw std::runtime_error("popen() failed."); + } + char buffer[128]; + std::string result; + while (fgets(buffer, sizeof(buffer), pipe.get())) { + result += buffer; + } + return rtrim(result); +} + +// Get path information about MSVC. +struct VisualStudioInfo { + VisualStudioInfo() { +#ifdef _M_ARM64 + arch = "arm64"; +#else + arch = "x64"; +#endif + // Get path of Visual Studio. + std::string vs_path = exec(fmt::format( + "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + " -property installationPath", + std::getenv("ProgramFiles(x86)"))); + if (vs_path.empty()) { + throw std::runtime_error("Can not find Visual Studio."); + } + // Read the envs from vcvarsall. + std::string envs = exec(fmt::format( + "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", + vs_path, + arch)); + for (const std::string& line : str_split(envs, '\n')) { + auto pair = str_split(line, '='); + assert(pair.size() == 2); + if (pair[0] == "LIB") { + libpaths = str_split(pair[1], ';'); + } else if (pair[0] == "VCToolsInstallDir") { + cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", pair[1], arch); + } + } + } + std::string arch; + std::string cl_exe; + std::vector libpaths; +}; + +const VisualStudioInfo& GetVisualStudioInfo() { + static VisualStudioInfo info; + return info; +} + +} // namespace + +#endif // _MSC_VER + struct CompilerCache { struct DLib { DLib(const std::string& libname) { @@ -44,6 +153,7 @@ namespace detail { bool compile_available_for_device(const Device& device) { return true; } + } // namespace detail std::string get_temp_file(const std::string& name) { @@ -107,11 +217,29 @@ void* compile( source_file << source_code; source_file.close(); - std::ostringstream build_command; - build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '" - << source_file_path << "' -o '" << shared_lib_path << "'"; - std::string build_command_str = build_command.str(); - auto return_code = system(build_command_str.c_str()); +#ifdef _MSC_VER + const VisualStudioInfo& info = GetVisualStudioInfo(); + std::string build_command = fmt::format( + "\"" + "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\" {3}" + "\"", + info.cl_exe, + source_file_path, + shared_lib_path, + vec_join( + vec_map( + info.libpaths, + [](const auto& lib) { + return fmt::format("/libpath:\"{0}\"", lib); + }), + " ")); +#else + std::string build_command = fmt::format( + "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", + source_file_path, + shared_lib_path); +#endif + auto return_code = system(build_command.c_str()); if (return_code) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name diff --git a/mlx/backend/common/make_compiled_preamble.ps1 b/mlx/backend/common/make_compiled_preamble.ps1 index 0b2248b67..18d057453 100644 --- a/mlx/backend/common/make_compiled_preamble.ps1 +++ b/mlx/backend/common/make_compiled_preamble.ps1 @@ -13,7 +13,7 @@ $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/comp # Otherwise there will be too much empty lines making the result unreadable. $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } # Concatenate to string. -$CONTENT = $CONTENT -join '`n' +$CONTENT = $CONTENT -join "`n" # Append extra content. $CONTENT = @" From 5b7b42c868dbe132d87fcb6f745178d8e80c2711 Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Fri, 13 Dec 2024 05:19:14 +0000 Subject: [PATCH 02/12] Export kernel symbol on MSVC --- mlx/backend/common/compiled_cpu.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 4b7e832b0..6a1bc2560 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -284,6 +284,11 @@ inline void build_kernel( NodeNamer namer; +#ifdef _MSC_VER + // Export the symbol + os << "__declspec(dllexport) "; +#endif + // Start the kernel os << "void " << kernel_name << "(void** args) {" << std::endl; From dc36c8699adf3207cdad865636a1a6af79451f8b Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 13 Dec 2024 14:37:36 +0900 Subject: [PATCH 03/12] Remove unused template --- mlx/backend/common/compiled_cpu.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 6a1bc2560..a540c71b9 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -50,15 +50,13 @@ std::vector vec_map(const std::vector& v, F&& transform) { } // Join a vector into a string. -template -std::string vec_join(const std::vector& v, const std::string& delimiter) { +std::string vec_join( + const std::vector& v, + const std::string& delimiter) { if (v.empty()) return ""; return std::accumulate( - v.begin() + 1, - v.end(), - v[0], - [&](const std::string& a, const std::string& b) { + v.begin() + 1, v.end(), v[0], [&](const auto& a, const auto& b) { return a + delimiter + b; }); } From c54b5f84c9160f9ff7c3d1c591b369f251d41be3 Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 14 Dec 2024 02:10:03 +0000 Subject: [PATCH 04/12] Parse env pairs in a robust way --- mlx/backend/common/compiled_cpu.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index a540c71b9..451fb6be9 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -98,12 +98,16 @@ struct VisualStudioInfo { vs_path, arch)); for (const std::string& line : str_split(envs, '\n')) { - auto pair = str_split(line, '='); - assert(pair.size() == 2); - if (pair[0] == "LIB") { - libpaths = str_split(pair[1], ';'); - } else if (pair[0] == "VCToolsInstallDir") { - cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", pair[1], arch); + // Each line is in the format "ENV_NAME=values". + auto pos = line.find_first_of('='); + if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) + continue; + std::string name = line.substr(0, pos); + std::string value = line.substr(pos + 1); + if (name == "LIB") { + libpaths = str_split(value, ';'); + } else if (name == "VCToolsInstallDir") { + cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); } } } From 6c1b37e90fe4b88ce787ba4cbfa40eecf90665c5 Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 14 Dec 2024 02:11:47 +0000 Subject: [PATCH 05/12] No need of cassert --- mlx/backend/common/compiled_cpu.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 451fb6be9..8b0bfe15a 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -1,6 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include #include #include #include From 1906cbba55f0fb4ef8c3d02a1573900310e80f2e Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 14 Dec 2024 02:31:12 +0000 Subject: [PATCH 06/12] Remove unnecessary helpers --- mlx/backend/common/compiled_cpu.cpp | 50 +++++++---------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 8b0bfe15a..74045ec3b 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -22,13 +21,6 @@ namespace mlx::core { namespace { -// Remove trailing whitespaces. -std::string rtrim(const std::string& s) { - return std::string(s.begin(), std::find_if(s.rbegin(), s.rend(), [](auto ch) { - return !std::isspace(ch); - }).base()); -} - // Split string into array. std::vector str_split(const std::string& str, char delimiter) { std::vector tokens; @@ -40,26 +32,6 @@ std::vector str_split(const std::string& str, char delimiter) { return tokens; } -// Return a new vector by transforming its values. -template -std::vector vec_map(const std::vector& v, F&& transform) { - std::vector ret(v.size()); - std::transform(v.begin(), v.end(), ret.begin(), std::forward(transform)); - return ret; -} - -// Join a vector into a string. -std::string vec_join( - const std::vector& v, - const std::string& delimiter) { - if (v.empty()) - return ""; - return std::accumulate( - v.begin() + 1, v.end(), v[0], [&](const auto& a, const auto& b) { - return a + delimiter + b; - }); -} - // Run a command and get its output. std::string exec(std::string cmd) { std::unique_ptr pipe( @@ -68,11 +40,13 @@ std::string exec(std::string cmd) { throw std::runtime_error("popen() failed."); } char buffer[128]; - std::string result; + std::string ret; while (fgets(buffer, sizeof(buffer), pipe.get())) { - result += buffer; + ret += buffer; } - return rtrim(result); + // Trim trailing spaces. + ret.erase(std::remove_if(ret.begin(), ret.end(), isspace), ret.end()); + return ret; } // Get path information about MSVC. @@ -220,20 +194,18 @@ void* compile( #ifdef _MSC_VER const VisualStudioInfo& info = GetVisualStudioInfo(); + std::string libpaths; + for (const std::string& lib : info.libpaths) { + libpaths += fmt::format(" /libpath:\"{0}\"", lib); + } std::string build_command = fmt::format( "\"" - "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\" {3}" + "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" "\"", info.cl_exe, source_file_path, shared_lib_path, - vec_join( - vec_map( - info.libpaths, - [](const auto& lib) { - return fmt::format("/libpath:\"{0}\"", lib); - }), - " ")); + libpaths); #else std::string build_command = fmt::format( "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", From 63e7267267dd1e2fb3e2304fb0def0a703b914f9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 14 Dec 2024 11:42:36 +0900 Subject: [PATCH 07/12] Fix right trim --- mlx/backend/common/compiled_cpu.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 74045ec3b..6c3d52e2d 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -45,7 +45,13 @@ std::string exec(std::string cmd) { ret += buffer; } // Trim trailing spaces. - ret.erase(std::remove_if(ret.begin(), ret.end(), isspace), ret.end()); + ret.erase( + std::find_if( + ret.rbegin(), + ret.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + ret.end()); return ret; } From 8f2822cf1830a5fb67045e44956ce23b0dfa9f0a Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 21 Dec 2024 10:37:52 +0900 Subject: [PATCH 08/12] Move command building to a separate file --- mlx/backend/common/CMakeLists.txt | 3 +- mlx/backend/common/compiled_cpu.cpp | 117 ++------------------------ mlx/backend/common/jit_compiler.cpp | 124 ++++++++++++++++++++++++++++ mlx/backend/common/jit_compiler.h | 17 ++++ 4 files changed, 148 insertions(+), 113 deletions(-) create mode 100644 mlx/backend/common/jit_compiler.cpp create mode 100644 mlx/backend/common/jit_compiler.h diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index a412cbe7b..e32123f43 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -66,5 +66,6 @@ target_sources( if(IOS) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) endif() diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 6c3d52e2d..f0d68287e 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -1,109 +1,20 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include #include #include #include -#include - -#include -#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" namespace mlx::core { -#ifdef _MSC_VER - -namespace { - -// Split string into array. -std::vector str_split(const std::string& str, char delimiter) { - std::vector tokens; - std::string token; - std::istringstream tokenStream(str); - while (std::getline(tokenStream, token, delimiter)) { - tokens.push_back(token); - } - return tokens; -} - -// Run a command and get its output. -std::string exec(std::string cmd) { - std::unique_ptr pipe( - _popen(cmd.c_str(), "r"), _pclose); - if (!pipe) { - throw std::runtime_error("popen() failed."); - } - char buffer[128]; - std::string ret; - while (fgets(buffer, sizeof(buffer), pipe.get())) { - ret += buffer; - } - // Trim trailing spaces. - ret.erase( - std::find_if( - ret.rbegin(), - ret.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - ret.end()); - return ret; -} - -// Get path information about MSVC. -struct VisualStudioInfo { - VisualStudioInfo() { -#ifdef _M_ARM64 - arch = "arm64"; -#else - arch = "x64"; -#endif - // Get path of Visual Studio. - std::string vs_path = exec(fmt::format( - "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" - " -property installationPath", - std::getenv("ProgramFiles(x86)"))); - if (vs_path.empty()) { - throw std::runtime_error("Can not find Visual Studio."); - } - // Read the envs from vcvarsall. - std::string envs = exec(fmt::format( - "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", - vs_path, - arch)); - for (const std::string& line : str_split(envs, '\n')) { - // Each line is in the format "ENV_NAME=values". - auto pos = line.find_first_of('='); - if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) - continue; - std::string name = line.substr(0, pos); - std::string value = line.substr(pos + 1); - if (name == "LIB") { - libpaths = str_split(value, ';'); - } else if (name == "VCToolsInstallDir") { - cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); - } - } - } - std::string arch; - std::string cl_exe; - std::vector libpaths; -}; - -const VisualStudioInfo& GetVisualStudioInfo() { - static VisualStudioInfo info; - return info; -} - -} // namespace - -#endif // _MSC_VER - struct CompilerCache { struct DLib { DLib(const std::string& libname) { @@ -198,27 +109,9 @@ void* compile( source_file << source_code; source_file.close(); -#ifdef _MSC_VER - const VisualStudioInfo& info = GetVisualStudioInfo(); - std::string libpaths; - for (const std::string& lib : info.libpaths) { - libpaths += fmt::format(" /libpath:\"{0}\"", lib); - } - std::string build_command = fmt::format( - "\"" - "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" - "\"", - info.cl_exe, - source_file_path, - shared_lib_path, - libpaths); -#else - std::string build_command = fmt::format( - "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", - source_file_path, - shared_lib_path); -#endif - auto return_code = system(build_command.c_str()); + std::string command = + JitCompiler::build_command(source_file_path, shared_lib_path); + auto return_code = system(command.c_str()); if (return_code) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp new file mode 100644 index 000000000..9eeb19f54 --- /dev/null +++ b/mlx/backend/common/jit_compiler.cpp @@ -0,0 +1,124 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/jit_compiler.h" + +#include + +#include + +namespace mlx::core { + +#ifdef _MSC_VER + +namespace { + +// Split string into array. +std::vector str_split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +// Run a command and get its output. +std::string exec(const std::string& cmd) { + std::unique_ptr pipe( + _popen(cmd.c_str(), "r"), _pclose); + if (!pipe) { + throw std::runtime_error("popen() failed."); + } + char buffer[128]; + std::string ret; + while (fgets(buffer, sizeof(buffer), pipe.get())) { + ret += buffer; + } + // Trim trailing spaces. + ret.erase( + std::find_if( + ret.rbegin(), + ret.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + ret.end()); + return ret; +} + +// Get path information about MSVC. +struct VisualStudioInfo { + VisualStudioInfo() { +#ifdef _M_ARM64 + arch = "arm64"; +#else + arch = "x64"; +#endif + // Get path of Visual Studio. + std::string vs_path = exec(fmt::format( + "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + " -property installationPath", + std::getenv("ProgramFiles(x86)"))); + if (vs_path.empty()) { + throw std::runtime_error("Can not find Visual Studio."); + } + // Read the envs from vcvarsall. + std::string envs = exec(fmt::format( + "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", + vs_path, + arch)); + for (const std::string& line : str_split(envs, '\n')) { + // Each line is in the format "ENV_NAME=values". + auto pos = line.find_first_of('='); + if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) + continue; + std::string name = line.substr(0, pos); + std::string value = line.substr(pos + 1); + if (name == "LIB") { + libpaths = str_split(value, ';'); + } else if (name == "VCToolsInstallDir") { + cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); + } + } + } + std::string arch; + std::string cl_exe; + std::vector libpaths; +}; + +const VisualStudioInfo& GetVisualStudioInfo() { + static VisualStudioInfo info; + return info; +} + +} // namespace + +#endif // _MSC_VER + +std::string JitCompiler::build_command( + const std::string& source_file_path, + const std::string& shared_lib_path) { +#ifdef _MSC_VER + const VisualStudioInfo& info = GetVisualStudioInfo(); + std::string libpaths; + for (const std::string& lib : info.libpaths) { + libpaths += fmt::format(" /libpath:\"{0}\"", lib); + } + std::string command = fmt::format( + "\"" + "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" + "\"", + info.cl_exe, + source_file_path, + shared_lib_path, + libpaths); +#else + std::string command = fmt::format( + "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", + source_file_path, + shared_lib_path); +#endif + return command; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/common/jit_compiler.h new file mode 100644 index 000000000..be93bf0e3 --- /dev/null +++ b/mlx/backend/common/jit_compiler.h @@ -0,0 +1,17 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include + +namespace mlx::core { + +class JitCompiler { + public: + // Build a shell command that compiles |source_file_path| to a shared library + // at |shared_lib_path|. + static std::string build_command( + const std::string& source_file_path, + const std::string& shared_lib_path); +}; + +} // namespace mlx::core From 6372c2b79eb667fd661868de6d3f29730d07a915 Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 21 Dec 2024 01:46:56 +0000 Subject: [PATCH 09/12] Missing header --- mlx/backend/common/jit_compiler.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp index 9eeb19f54..91985b4d9 100644 --- a/mlx/backend/common/jit_compiler.cpp +++ b/mlx/backend/common/jit_compiler.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/common/jit_compiler.h" +#include #include #include @@ -104,7 +105,7 @@ std::string JitCompiler::build_command( for (const std::string& lib : info.libpaths) { libpaths += fmt::format(" /libpath:\"{0}\"", lib); } - std::string command = fmt::format( + return fmt::format( "\"" "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" "\"", @@ -113,12 +114,11 @@ std::string JitCompiler::build_command( shared_lib_path, libpaths); #else - std::string command = fmt::format( + return fmt::format( "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", source_file_path, shared_lib_path); #endif - return command; } } // namespace mlx::core From 13aeec251d0e681559ec8b24a12b21a8fc3c70ce Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 21 Dec 2024 02:31:27 +0000 Subject: [PATCH 10/12] Do not pollute cwd with cl.exe --- mlx/backend/common/compiled_cpu.cpp | 22 ++++++++++------------ mlx/backend/common/jit_compiler.cpp | 20 ++++++++++++-------- mlx/backend/common/jit_compiler.h | 10 +++++----- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index f0d68287e..bf4361241 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -48,10 +48,6 @@ bool compile_available_for_device(const Device& device) { } // namespace detail -std::string get_temp_file(const std::string& name) { - return std::filesystem::temp_directory_path().append(name).string(); -} - // Return a pointer to a compiled function void* compile( const std::string& kernel_name, @@ -90,9 +86,10 @@ void* compile( kernel_file_name = kernel_name; } - std::ostringstream shared_lib_name; - shared_lib_name << "lib" << kernel_file_name << ".so"; - auto shared_lib_path = get_temp_file(shared_lib_name.str()); + auto output_dir = std::filesystem::temp_directory_path(); + + std::string shared_lib_name = std::string("lib") + kernel_file_name + ".so"; + auto shared_lib_path = (output_dir / shared_lib_name).string(); bool lib_exists = false; { std::ifstream f(shared_lib_path.c_str()); @@ -101,16 +98,17 @@ void* compile( if (!lib_exists) { // Open source file and write source code to it - std::ostringstream source_file_name; - source_file_name << kernel_file_name << ".cpp"; - auto source_file_path = get_temp_file(source_file_name.str()); + std::string source_file_name = kernel_file_name + ".cpp"; + auto source_file_path = (output_dir / source_file_name).string(); std::ofstream source_file(source_file_path); source_file << source_code; source_file.close(); - std::string command = - JitCompiler::build_command(source_file_path, shared_lib_path); + std::string command = JitCompiler::build_command( + std::filesystem::temp_directory_path(), + source_file_name, + shared_lib_name); auto return_code = system(command.c_str()); if (return_code) { std::ostringstream msg; diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp index 91985b4d9..27fb9e723 100644 --- a/mlx/backend/common/jit_compiler.cpp +++ b/mlx/backend/common/jit_compiler.cpp @@ -2,8 +2,8 @@ #include "mlx/backend/common/jit_compiler.h" -#include #include +#include #include @@ -97,8 +97,9 @@ const VisualStudioInfo& GetVisualStudioInfo() { #endif // _MSC_VER std::string JitCompiler::build_command( - const std::string& source_file_path, - const std::string& shared_lib_path) { + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name) { #ifdef _MSC_VER const VisualStudioInfo& info = GetVisualStudioInfo(); std::string libpaths; @@ -107,17 +108,20 @@ std::string JitCompiler::build_command( } return fmt::format( "\"" - "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" + "cd /D \"{0}\" && " + "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " + "/link /out:\"{3}\" {4} >nul" "\"", + dir.string(), info.cl_exe, - source_file_path, - shared_lib_path, + source_file_name, + shared_lib_name, libpaths); #else return fmt::format( "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", - source_file_path, - shared_lib_path); + (dir / source_file_name).string(), + (dir / shared_lib_name).string()); #endif } diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/common/jit_compiler.h index be93bf0e3..b0bf8c0de 100644 --- a/mlx/backend/common/jit_compiler.h +++ b/mlx/backend/common/jit_compiler.h @@ -1,17 +1,17 @@ // Copyright © 2024 Apple Inc. #pragma once -#include +#include namespace mlx::core { class JitCompiler { public: - // Build a shell command that compiles |source_file_path| to a shared library - // at |shared_lib_path|. + // Build a shell command that compiles a source code file to a shared library. static std::string build_command( - const std::string& source_file_path, - const std::string& shared_lib_path); + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name); }; } // namespace mlx::core From 1855e7e9db65855d28770df16a1f8e9a971255d6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 21 Dec 2024 11:33:56 +0900 Subject: [PATCH 11/12] Simplify str concat --- mlx/backend/common/compiled_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index bf4361241..5aeb4c07d 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -88,7 +88,7 @@ void* compile( auto output_dir = std::filesystem::temp_directory_path(); - std::string shared_lib_name = std::string("lib") + kernel_file_name + ".so"; + std::string shared_lib_name = "lib" + kernel_file_name + ".so"; auto shared_lib_path = (output_dir / shared_lib_name).string(); bool lib_exists = false; { From 781990b528a55ba36aefba48d18d80e5700c242b Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Sat, 21 Dec 2024 02:35:50 +0000 Subject: [PATCH 12/12] Pass output dir --- mlx/backend/common/compiled_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 5aeb4c07d..3bcb82d3f 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -106,7 +106,7 @@ void* compile( source_file.close(); std::string command = JitCompiler::build_command( - std::filesystem::temp_directory_path(), + output_dir, source_file_name, shared_lib_name); auto return_code = system(command.c_str());