Skip to content

Commit

Permalink
PR openxla#7849: [XLA:CPU] Add support for cross-process collectives …
Browse files Browse the repository at this point in the history
…using mpi.

Imported from GitHub PR openxla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb9 by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840 by Clemens Giuliani <[email protected]>:

add comment

--
38d1562 by Clemens Giuliani <[email protected]>:

fix windows build

--
201f723 by Clemens Giuliani <[email protected]>:

fmt

--
2784869 by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes openxla#7849

COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869
PiperOrigin-RevId: 620302264
  • Loading branch information
inailuig authored and steeve committed Jul 19, 2024
1 parent 529c00f commit 6d4b0d7
Show file tree
Hide file tree
Showing 11 changed files with 751 additions and 1 deletion.
1 change: 1 addition & 0 deletions third_party/mpitrampoline/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
149 changes: 149 additions & 0 deletions third_party/mpitrampoline/gen.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
diff --git a/gen/gen_decl.py b/gen/gen_decl.py
index 1005b95..696b4e0 100755
--- a/gen/gen_decl.py
+++ b/gen/gen_decl.py
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
@@ -24,7 +24,7 @@ def wrap(line):
lines.append(line)
return "\n".join(lines)

-with open("include/mpi_decl_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Declare C MPI constants\n")
file.write("\n")
for (tp, nm) in constants:
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs))

-with open("include/mpi_decl_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Declare C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")

-with open("include/mpi_decl_constants_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file:
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl)))
file.write("\n")

-with open("include/mpi_decl_functions_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_defn.py b/gen/gen_defn.py
index bf31f35..318222e 100755
--- a/gen/gen_defn.py
+++ b/gen/gen_defn.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
replace_sentinels = False

-with open("src/mpi_defn_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Define C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs))

-with open("src/mpi_defn_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Define C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")

-with open("src/mpi_defn_constants_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file:
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs))

-with open("src/mpi_defn_functions_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_init.py b/gen/gen_init.py
index 4939261..0e52822 100755
--- a/gen/gen_init.py
+++ b/gen/gen_init.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
replace_sentinels = False

-with open("src/mpi_init_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Initialize C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)}
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Initialize C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file:
subs['anm{0}'.format(i)] = anm
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_constants_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"}
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_functions_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
135 changes: 135 additions & 0 deletions third_party/mpitrampoline/mpitrampoline.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Description:
# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI

load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@xla//xla:strict.default.bzl", "py_strict_binary")

package(
default_visibility = ["//visibility:public"],
)

licenses(["notice"])

exports_files(["LICENSE.md"])

genrule(
name = "mpi_version",
srcs = [
"CMakeLists.txt",
"include/mpi_version.h.in",
],
outs = ["include/mpi_version.h"],
cmd = """
PROJECT_VERSION=`cat $(location CMakeLists.txt) \
| grep "MPItrampoline VERSION" | awk '{print $$NF}'`
PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1`
PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2`
PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3`
sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \
-e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \
-e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \
-e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \
$(location include/mpi_version.h.in) > $(location include/mpi_version.h)
""",
)

expand_template(
name = "mpi_defaults",
out = "src/mpi_defaults.h",
substitutions = {
"@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "",
"@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "",
"@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "",
"@MPITRAMPOLINE_DEFAULT_LIB@": "",
"@MPITRAMPOLINE_DEFAULT_PRELOAD@": "",
"@MPITRAMPOLINE_DEFAULT_VERBOSE@": "",
},
template = "src/mpi_defaults.h.in",
)

py_strict_binary(
name = "gen_decl",
srcs = [
"gen/gen_decl.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "decl",
outs = [
"include/mpi_decl_constants_c.h",
"include/mpi_decl_functions_c.h",
],
cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \
$(location include/mpi_decl_functions_c.h)",
tools = [":gen_decl"],
)

py_strict_binary(
name = "gen_defn",
srcs = [
"gen/gen_defn.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "defn",
outs = [
"include/mpi_defn_constants_c.h",
"include/mpi_defn_functions_c.h",
],
cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \
$(location include/mpi_defn_functions_c.h)",
tools = [":gen_defn"],
)

py_strict_binary(
name = "gen_init",
srcs = [
"gen/gen_init.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "init",
outs = [
"include/mpi_init_constants_c.h",
"include/mpi_init_functions_c.h",
],
cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \
$(location include/mpi_init_functions_c.h)",
tools = [":gen_init"],
)

cc_library(
name = "mpitrampoline",
srcs = [
"src/mpi.c",
],
hdrs = [
"include/mpi.h",
"include/mpi_decl_constants_c.h",
"include/mpi_decl_functions_c.h",
"include/mpi_defn_constants_c.h",
"include/mpi_defn_functions_c.h",
"include/mpi_init_constants_c.h",
"include/mpi_init_functions_c.h",
"include/mpi_version.h",
"mpiabi/mpiabi.h",
"src/mpi_defaults.h",
],
copts = [
"-fexceptions",
],
includes = [
"include",
"mpiabi",
"src",
],
)
18 changes: 18 additions & 0 deletions third_party/mpitrampoline/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Provides the repository macro to import mpitrampoline."""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
"""Imports mpitrampoline."""

MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed"
MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84"

tf_http_archive(
name = "mpitrampoline",
sha256 = MPITRAMPOLINE_SHA256,
strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT),
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)),
patch_file = ["//third_party/mpitrampoline:gen.patch"],
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD",
)
2 changes: 2 additions & 0 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# Import third party repository rules. See go/tfbr-thirdparty.
load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
load("//third_party/gloo:workspace.bzl", gloo = "repo")
load("//third_party/mpitrampoline:workspace.bzl", mpitrampoline = "repo")
load("//third_party/nanobind:workspace.bzl", nanobind = "repo")
load("//third_party/robin_map:workspace.bzl", robin_map = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
Expand All @@ -19,6 +20,7 @@ def _initialize_third_party():
""" Load third party repositories. See above load() statements. """
dlpack()
gloo()
mpitrampoline()
nanobind()
robin_map()
stablehlo()
Expand Down
32 changes: 32 additions & 0 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@tsl//tsl:tsl.bzl", "if_oss")
load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla:xla.bzl", "xla_cc_test")
Expand Down Expand Up @@ -286,3 +287,34 @@ cc_library(
"@tsl//tsl/platform:logging",
],
)

cc_library(
name = "mpi_collectives",
srcs = if_oss(["mpi_collectives.cc"]),
hdrs = if_oss(["mpi_collectives.h"]),
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = if_oss([
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@mpitrampoline",
]),
)
Loading

0 comments on commit 6d4b0d7

Please sign in to comment.