Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make mx.compile work on Windows #1697

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ target_sources(
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()
33 changes: 18 additions & 15 deletions mlx/backend/common/compiled_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"

Expand Down Expand Up @@ -44,11 +45,8 @@ namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail

std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
}
} // namespace detail

// Return a pointer to a compiled function
void* compile(
Expand Down Expand Up @@ -88,9 +86,10 @@ void* compile(
kernel_file_name = kernel_name;
}

std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
auto output_dir = std::filesystem::temp_directory_path();

std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
Expand All @@ -99,19 +98,18 @@ void* compile(

if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::string source_file_name = kernel_file_name + ".cpp";
auto source_file_path = (output_dir / source_file_name).string();

std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();

std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
std::string command = JitCompiler::build_command(
output_dir,
source_file_name,
shared_lib_name);
auto return_code = system(command.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
Expand Down Expand Up @@ -156,6 +154,11 @@ inline void build_kernel(

NodeNamer namer;

#ifdef _MSC_VER
// Export the symbol
os << "__declspec(dllexport) ";
#endif

// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;

Expand Down
128 changes: 128 additions & 0 deletions mlx/backend/common/jit_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/common/jit_compiler.h"

#include <sstream>
#include <vector>

#include <fmt/format.h>

namespace mlx::core {

#ifdef _MSC_VER

namespace {

// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

// Run a command and get its output.
std::string exec(const std::string& cmd) {
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
_popen(cmd.c_str(), "r"), _pclose);
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe.get())) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
return ret;
}

// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};

const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}

} // namespace

#endif // _MSC_VER

std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
}
return fmt::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
"/link /out:\"{3}\" {4} >nul"
"\"",
dir.string(),
info.cl_exe,
source_file_name,
shared_lib_name,
libpaths);
#else
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
}

} // namespace mlx::core
17 changes: 17 additions & 0 deletions mlx/backend/common/jit_compiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright © 2024 Apple Inc.
#pragma once

#include <filesystem>

namespace mlx::core {

class JitCompiler {
public:
// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name);
};

} // namespace mlx::core
2 changes: 1 addition & 1 deletion mlx/backend/common/make_compiled_preamble.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/comp
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join '`n'
$CONTENT = $CONTENT -join "`n"

# Append extra content.
$CONTENT = @"
Expand Down