forked from openxla/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR openxla#7849: [XLA:CPU] Add support for cross-process collectives …
…using mpi. Imported from GitHub PR openxla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb9 by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840 by Clemens Giuliani <[email protected]>: add comment -- 38d1562 by Clemens Giuliani <[email protected]>: fix windows build -- 201f723 by Clemens Giuliani <[email protected]>: fmt -- 2784869 by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes openxla#7849 COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869 PiperOrigin-RevId: 620302264
- Loading branch information
Showing
11 changed files
with
751 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
diff --git a/gen/gen_decl.py b/gen/gen_decl.py | ||
index 1005b95..696b4e0 100755 | ||
--- a/gen/gen_decl.py | ||
+++ b/gen/gen_decl.py | ||
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) | ||
|
||
from mpi_constants import constants | ||
from mpi_functions import functions | ||
-from mpi_constants_fortran import constants_fortran | ||
-from mpi_functions_fortran import functions_fortran | ||
+# from mpi_constants_fortran import constants_fortran | ||
+# from mpi_functions_fortran import functions_fortran | ||
|
||
support_profiling = True | ||
have_weak_symbols = False | ||
@@ -24,7 +24,7 @@ def wrap(line): | ||
lines.append(line) | ||
return "\n".join(lines) | ||
|
||
-with open("include/mpi_decl_constants_c.h", "w") as file: | ||
+with open(sys.argv[1], "w") as file: | ||
file.write("// Declare C MPI constants\n") | ||
file.write("\n") | ||
for (tp, nm) in constants: | ||
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file: | ||
'mpi_nm': nm} | ||
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs)) | ||
|
||
-with open("include/mpi_decl_functions_c.h", "w") as file: | ||
+with open(sys.argv[2], "w") as file: | ||
file.write("// Declare C MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args, flags) in functions: | ||
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file: | ||
file.write(Template("\n".join(tmpl)).substitute(subs)) | ||
file.write("\n") | ||
|
||
-with open("include/mpi_decl_constants_fortran.h", "w") as file: | ||
+if False: | ||
file.write("! Declare Fortran MPI constants\n") | ||
file.write("\n") | ||
for (tp, nm) in constants_fortran: | ||
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file: | ||
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl))) | ||
file.write("\n") | ||
|
||
-with open("include/mpi_decl_functions_fortran.h", "w") as file: | ||
+if False: | ||
file.write("! Declare Fortran MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args) in functions_fortran: | ||
diff --git a/gen/gen_defn.py b/gen/gen_defn.py | ||
index bf31f35..318222e 100755 | ||
--- a/gen/gen_defn.py | ||
+++ b/gen/gen_defn.py | ||
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) | ||
|
||
from mpi_constants import constants | ||
from mpi_functions import functions | ||
-from mpi_constants_fortran import constants_fortran | ||
-from mpi_functions_fortran import functions_fortran | ||
+# from mpi_constants_fortran import constants_fortran | ||
+# from mpi_functions_fortran import functions_fortran | ||
|
||
support_profiling = True | ||
have_weak_symbols = False | ||
replace_sentinels = False | ||
|
||
-with open("src/mpi_defn_constants_c.h", "w") as file: | ||
+with open(sys.argv[1], "w") as file: | ||
file.write("// Define C MPI constants") | ||
file.write("\n") | ||
for (tp, nm) in constants: | ||
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file: | ||
'mpi_nm': nm} | ||
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs)) | ||
|
||
-with open("src/mpi_defn_functions_c.h", "w") as file: | ||
+with open(sys.argv[2], "w") as file: | ||
file.write("// Define C MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args, flags) in functions: | ||
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file: | ||
file.write(Template("\n".join(tmpl)).substitute(subs)) | ||
file.write("\n") | ||
|
||
-with open("src/mpi_defn_constants_fortran.h", "w") as file: | ||
+if False: | ||
file.write("// Define Fortran MPI constants\n") | ||
file.write("\n") | ||
for (tp, nm) in constants_fortran: | ||
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file: | ||
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes | ||
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs)) | ||
|
||
-with open("src/mpi_defn_functions_fortran.h", "w") as file: | ||
+if False: | ||
file.write("// Define Fortran MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args) in functions_fortran: | ||
diff --git a/gen/gen_init.py b/gen/gen_init.py | ||
index 4939261..0e52822 100755 | ||
--- a/gen/gen_init.py | ||
+++ b/gen/gen_init.py | ||
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) | ||
|
||
from mpi_constants import constants | ||
from mpi_functions import functions | ||
-from mpi_constants_fortran import constants_fortran | ||
-from mpi_functions_fortran import functions_fortran | ||
+# from mpi_constants_fortran import constants_fortran | ||
+# from mpi_functions_fortran import functions_fortran | ||
|
||
support_profiling = True | ||
have_weak_symbols = False | ||
replace_sentinels = False | ||
|
||
-with open("src/mpi_init_constants_c.h", "w") as file: | ||
+with open(sys.argv[1], "w") as file: | ||
file.write("// Initialize C MPI constants") | ||
file.write("\n") | ||
for (tp, nm) in constants: | ||
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file: | ||
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)} | ||
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) | ||
|
||
-with open("src/mpi_init_functions_c.h", "w") as file: | ||
+with open(sys.argv[2], "w") as file: | ||
file.write("// Initialize C MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args, flags) in functions: | ||
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file: | ||
subs['anm{0}'.format(i)] = anm | ||
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) | ||
|
||
-with open("src/mpi_init_constants_fortran.h", "w") as file: | ||
+if False: | ||
file.write("// Initialize Fortran MPI constants\n") | ||
file.write("\n") | ||
for (tp, nm) in constants_fortran: | ||
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file: | ||
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"} | ||
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) | ||
|
||
-with open("src/mpi_init_functions_fortran.h", "w") as file: | ||
+if False: | ||
file.write("// Initialize Fortran MPI functions\n") | ||
file.write("\n") | ||
for (tp, nm, args) in functions_fortran: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Description: | ||
# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI | ||
|
||
load("@bazel_skylib//rules:expand_template.bzl", "expand_template") | ||
load("@xla//xla:strict.default.bzl", "py_strict_binary") | ||
|
||
package( | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
licenses(["notice"]) | ||
|
||
exports_files(["LICENSE.md"]) | ||
|
||
genrule( | ||
name = "mpi_version", | ||
srcs = [ | ||
"CMakeLists.txt", | ||
"include/mpi_version.h.in", | ||
], | ||
outs = ["include/mpi_version.h"], | ||
cmd = """ | ||
PROJECT_VERSION=`cat $(location CMakeLists.txt) \ | ||
| grep "MPItrampoline VERSION" | awk '{print $$NF}'` | ||
PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1` | ||
PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2` | ||
PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3` | ||
sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \ | ||
-e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \ | ||
-e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \ | ||
-e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \ | ||
$(location include/mpi_version.h.in) > $(location include/mpi_version.h) | ||
""", | ||
) | ||
|
||
expand_template( | ||
name = "mpi_defaults", | ||
out = "src/mpi_defaults.h", | ||
substitutions = { | ||
"@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "", | ||
"@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "", | ||
"@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "", | ||
"@MPITRAMPOLINE_DEFAULT_LIB@": "", | ||
"@MPITRAMPOLINE_DEFAULT_PRELOAD@": "", | ||
"@MPITRAMPOLINE_DEFAULT_VERBOSE@": "", | ||
}, | ||
template = "src/mpi_defaults.h.in", | ||
) | ||
|
||
py_strict_binary( | ||
name = "gen_decl", | ||
srcs = [ | ||
"gen/gen_decl.py", | ||
"mpiabi/mpi_constants.py", | ||
"mpiabi/mpi_functions.py", | ||
], | ||
) | ||
|
||
genrule( | ||
name = "decl", | ||
outs = [ | ||
"include/mpi_decl_constants_c.h", | ||
"include/mpi_decl_functions_c.h", | ||
], | ||
cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \ | ||
$(location include/mpi_decl_functions_c.h)", | ||
tools = [":gen_decl"], | ||
) | ||
|
||
py_strict_binary( | ||
name = "gen_defn", | ||
srcs = [ | ||
"gen/gen_defn.py", | ||
"mpiabi/mpi_constants.py", | ||
"mpiabi/mpi_functions.py", | ||
], | ||
) | ||
|
||
genrule( | ||
name = "defn", | ||
outs = [ | ||
"include/mpi_defn_constants_c.h", | ||
"include/mpi_defn_functions_c.h", | ||
], | ||
cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \ | ||
$(location include/mpi_defn_functions_c.h)", | ||
tools = [":gen_defn"], | ||
) | ||
|
||
py_strict_binary( | ||
name = "gen_init", | ||
srcs = [ | ||
"gen/gen_init.py", | ||
"mpiabi/mpi_constants.py", | ||
"mpiabi/mpi_functions.py", | ||
], | ||
) | ||
|
||
genrule( | ||
name = "init", | ||
outs = [ | ||
"include/mpi_init_constants_c.h", | ||
"include/mpi_init_functions_c.h", | ||
], | ||
cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \ | ||
$(location include/mpi_init_functions_c.h)", | ||
tools = [":gen_init"], | ||
) | ||
|
||
cc_library( | ||
name = "mpitrampoline", | ||
srcs = [ | ||
"src/mpi.c", | ||
], | ||
hdrs = [ | ||
"include/mpi.h", | ||
"include/mpi_decl_constants_c.h", | ||
"include/mpi_decl_functions_c.h", | ||
"include/mpi_defn_constants_c.h", | ||
"include/mpi_defn_functions_c.h", | ||
"include/mpi_init_constants_c.h", | ||
"include/mpi_init_functions_c.h", | ||
"include/mpi_version.h", | ||
"mpiabi/mpiabi.h", | ||
"src/mpi_defaults.h", | ||
], | ||
copts = [ | ||
"-fexceptions", | ||
], | ||
includes = [ | ||
"include", | ||
"mpiabi", | ||
"src", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Provides the repository macro to import mpitrampoline.""" | ||
|
||
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") | ||
|
||
def repo(): | ||
"""Imports mpitrampoline.""" | ||
|
||
MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed" | ||
MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84" | ||
|
||
tf_http_archive( | ||
name = "mpitrampoline", | ||
sha256 = MPITRAMPOLINE_SHA256, | ||
strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT), | ||
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)), | ||
patch_file = ["//third_party/mpitrampoline:gen.patch"], | ||
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.