From 8f2822cf1830a5fb67045e44956ce23b0dfa9f0a Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 21 Dec 2024 10:37:52 +0900 Subject: [PATCH] Move command building to a separate file --- mlx/backend/common/CMakeLists.txt | 3 +- mlx/backend/common/compiled_cpu.cpp | 117 ++------------------------ mlx/backend/common/jit_compiler.cpp | 124 ++++++++++++++++++++++++++++ mlx/backend/common/jit_compiler.h | 17 ++++ 4 files changed, 148 insertions(+), 113 deletions(-) create mode 100644 mlx/backend/common/jit_compiler.cpp create mode 100644 mlx/backend/common/jit_compiler.h diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index a412cbe7b..e32123f43 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -66,5 +66,6 @@ target_sources( if(IOS) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) endif() diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 6c3d52e2d..f0d68287e 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -1,109 +1,20 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include #include #include #include -#include - -#include -#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" namespace mlx::core { -#ifdef _MSC_VER - -namespace { - -// Split string into array. -std::vector str_split(const std::string& str, char delimiter) { - std::vector tokens; - std::string token; - std::istringstream tokenStream(str); - while (std::getline(tokenStream, token, delimiter)) { - tokens.push_back(token); - } - return tokens; -} - -// Run a command and get its output. -std::string exec(std::string cmd) { - std::unique_ptr pipe( - _popen(cmd.c_str(), "r"), _pclose); - if (!pipe) { - throw std::runtime_error("popen() failed."); - } - char buffer[128]; - std::string ret; - while (fgets(buffer, sizeof(buffer), pipe.get())) { - ret += buffer; - } - // Trim trailing spaces. - ret.erase( - std::find_if( - ret.rbegin(), - ret.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - ret.end()); - return ret; -} - -// Get path information about MSVC. -struct VisualStudioInfo { - VisualStudioInfo() { -#ifdef _M_ARM64 - arch = "arm64"; -#else - arch = "x64"; -#endif - // Get path of Visual Studio. - std::string vs_path = exec(fmt::format( - "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" - " -property installationPath", - std::getenv("ProgramFiles(x86)"))); - if (vs_path.empty()) { - throw std::runtime_error("Can not find Visual Studio."); - } - // Read the envs from vcvarsall. - std::string envs = exec(fmt::format( - "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", - vs_path, - arch)); - for (const std::string& line : str_split(envs, '\n')) { - // Each line is in the format "ENV_NAME=values". - auto pos = line.find_first_of('='); - if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) - continue; - std::string name = line.substr(0, pos); - std::string value = line.substr(pos + 1); - if (name == "LIB") { - libpaths = str_split(value, ';'); - } else if (name == "VCToolsInstallDir") { - cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); - } - } - } - std::string arch; - std::string cl_exe; - std::vector libpaths; -}; - -const VisualStudioInfo& GetVisualStudioInfo() { - static VisualStudioInfo info; - return info; -} - -} // namespace - -#endif // _MSC_VER - struct CompilerCache { struct DLib { DLib(const std::string& libname) { @@ -198,27 +109,9 @@ void* compile( source_file << source_code; source_file.close(); -#ifdef _MSC_VER - const VisualStudioInfo& info = GetVisualStudioInfo(); - std::string libpaths; - for (const std::string& lib : info.libpaths) { - libpaths += fmt::format(" /libpath:\"{0}\"", lib); - } - std::string build_command = fmt::format( - "\"" - "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" - "\"", - info.cl_exe, - source_file_path, - shared_lib_path, - libpaths); -#else - std::string build_command = fmt::format( - "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", - source_file_path, - shared_lib_path); -#endif - auto return_code = system(build_command.c_str()); + std::string command = + JitCompiler::build_command(source_file_path, shared_lib_path); + auto return_code = system(command.c_str()); if (return_code) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp new file mode 100644 index 000000000..9eeb19f54 --- /dev/null +++ b/mlx/backend/common/jit_compiler.cpp @@ -0,0 +1,124 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/jit_compiler.h" + +#include + +#include + +namespace mlx::core { + +#ifdef _MSC_VER + +namespace { + +// Split string into array. +std::vector str_split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +// Run a command and get its output. +std::string exec(const std::string& cmd) { + std::unique_ptr pipe( + _popen(cmd.c_str(), "r"), _pclose); + if (!pipe) { + throw std::runtime_error("popen() failed."); + } + char buffer[128]; + std::string ret; + while (fgets(buffer, sizeof(buffer), pipe.get())) { + ret += buffer; + } + // Trim trailing spaces. + ret.erase( + std::find_if( + ret.rbegin(), + ret.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + ret.end()); + return ret; +} + +// Get path information about MSVC. +struct VisualStudioInfo { + VisualStudioInfo() { +#ifdef _M_ARM64 + arch = "arm64"; +#else + arch = "x64"; +#endif + // Get path of Visual Studio. + std::string vs_path = exec(fmt::format( + "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + " -property installationPath", + std::getenv("ProgramFiles(x86)"))); + if (vs_path.empty()) { + throw std::runtime_error("Can not find Visual Studio."); + } + // Read the envs from vcvarsall. + std::string envs = exec(fmt::format( + "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", + vs_path, + arch)); + for (const std::string& line : str_split(envs, '\n')) { + // Each line is in the format "ENV_NAME=values". + auto pos = line.find_first_of('='); + if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) + continue; + std::string name = line.substr(0, pos); + std::string value = line.substr(pos + 1); + if (name == "LIB") { + libpaths = str_split(value, ';'); + } else if (name == "VCToolsInstallDir") { + cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); + } + } + } + std::string arch; + std::string cl_exe; + std::vector libpaths; +}; + +const VisualStudioInfo& GetVisualStudioInfo() { + static VisualStudioInfo info; + return info; +} + +} // namespace + +#endif // _MSC_VER + +std::string JitCompiler::build_command( + const std::string& source_file_path, + const std::string& shared_lib_path) { +#ifdef _MSC_VER + const VisualStudioInfo& info = GetVisualStudioInfo(); + std::string libpaths; + for (const std::string& lib : info.libpaths) { + libpaths += fmt::format(" /libpath:\"{0}\"", lib); + } + std::string command = fmt::format( + "\"" + "\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}" + "\"", + info.cl_exe, + source_file_path, + shared_lib_path, + libpaths); +#else + std::string command = fmt::format( + "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", + source_file_path, + shared_lib_path); +#endif + return command; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/common/jit_compiler.h new file mode 100644 index 000000000..be93bf0e3 --- /dev/null +++ b/mlx/backend/common/jit_compiler.h @@ -0,0 +1,17 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include + +namespace mlx::core { + +class JitCompiler { + public: + // Build a shell command that compiles |source_file_path| to a shared library + // at |shared_lib_path|. + static std::string build_command( + const std::string& source_file_path, + const std::string& shared_lib_path); +}; + +} // namespace mlx::core