Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFRT prototype #5677

Merged
merged 33 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
fdd13af
Start IFRT prototype
will-cromar Oct 5, 2023
c7e48b4
remove comment
will-cromar Oct 5, 2023
1312465
formatting
will-cromar Oct 5, 2023
a7ae062
basic sharding
will-cromar Oct 11, 2023
835487f
Add `xla::OpSharding` back as source of truth
will-cromar Oct 12, 2023
a821a4d
wrapping and unwrapping sharded data
will-cromar Oct 12, 2023
2e8842b
ExecuteReplicated
will-cromar Oct 12, 2023
d75c040
[revert later] try resharding
will-cromar Oct 18, 2023
5544fc0
Revert "[revert later] try resharding"
will-cromar Oct 19, 2023
03481e0
reassemble sharded outputs
will-cromar Oct 19, 2023
f2f8334
fix output devices
will-cromar Oct 19, 2023
c16ea62
cleanup
will-cromar Oct 19, 2023
4d8418a
fix rebase issues
will-cromar Nov 28, 2023
e5a13a2
fix const plumbing
will-cromar Nov 28, 2023
155a811
shared_ptr to unique_ptr
will-cromar Nov 28, 2023
b1238da
parallelize input/output handling
will-cromar Nov 29, 2023
472b9d4
fix concurrency issues
will-cromar Nov 29, 2023
a6544b7
remove some commented out code
will-cromar Nov 29, 2023
5f6b4af
formatting
will-cromar Dec 1, 2023
c0e1ce8
for coordinator init
will-cromar Dec 1, 2023
1812afa
remove extra `std::move`s
will-cromar Dec 1, 2023
135e68d
unit test
will-cromar Dec 1, 2023
8fce005
tune parallelfors
will-cromar Dec 1, 2023
86883ad
pjrt -> ifrt
will-cromar Dec 13, 2023
07a9efb
fix rebasing issues
will-cromar Dec 13, 2023
3bfaa3a
formatting
will-cromar Dec 13, 2023
a71a9d5
fix timer and comp env hash
will-cromar Dec 13, 2023
0d2c842
formatting
will-cromar Dec 13, 2023
5b26aa4
remove dead code
will-cromar Dec 13, 2023
0df165a
Move SPMD string constant
will-cromar Dec 13, 2023
723191f
fix compile error
will-cromar Dec 13, 2023
9ebc616
format
will-cromar Dec 13, 2023
cc4c93d
remove SPMD from hash
will-cromar Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 100 additions & 29 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
":computation_client",
":env_vars",
":pjrt_computation_client",
":ifrt_computation_client",
"@tsl//tsl/platform:stacktrace",
],
)
Expand Down Expand Up @@ -70,6 +71,36 @@ cc_library(
],
)

cc_library(
name = "ifrt_computation_client",
srcs = [
"ifrt_computation_client.cc",
],
hdrs = [
"ifrt_computation_client.h",
],
deps = [
":computation_client",
":debug_macros",
":env_vars",
":initialize_pjrt",
":operation_manager",
":stablehlo_helper",
":tf_logging",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/python/ifrt",
"@xla//xla/python/pjrt_ifrt",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "pjrt_computation_client",
srcs = [
Expand All @@ -83,6 +114,7 @@ cc_library(
":debug_macros",
":env_hash",
":env_vars",
":initialize_pjrt",
":operation_manager",
":profiler",
":stablehlo_helper",
Expand All @@ -94,11 +126,7 @@ cc_library(
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
Expand Down Expand Up @@ -165,6 +193,25 @@ cc_test(
],
)

cc_library(
name = "initialize_pjrt",
srcs = ["initialize_pjrt.cc"],
hdrs = ["initialize_pjrt.h"],
deps = [
":debug_macros",
":env_hash",
":env_vars",
":profiler",
":sys_util",
":tf_logging",
":xla_coordinator",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
],
)

cc_library(
name = "metrics_analysis",
srcs = ["metrics_analysis.cc"],
Expand Down Expand Up @@ -410,28 +457,52 @@ ptxla_cc_test(
],
)

# disable for now since it is flaky on the upstream test.
# ptxla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# ":pjrt_computation_client",
# ":tensor_source",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:errors",
# "@tsl//tsl/platform:logging",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

ptxla_cc_test(
name = "ifrt_computation_client_test",
srcs = ["ifrt_computation_client_test.cc"],
deps = [
":computation_client",
":ifrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ class ComputationClient {
static int64_t GetDeviceOrdinal(const std::string& device);

protected:
static constexpr auto spmd_device_str = "SPMD:0";

// Metrics common to all client interfaces.
static metrics::Metric* TransferToServerMetric();
static metrics::Metric* TransferToServerTransformMetric();
Expand Down
Loading