Skip to content

Commit

Permalink
IFRT prototype (#5677)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored and golechwierowicz committed Jan 12, 2024
1 parent bb4f5f6 commit 3571f2d
Show file tree
Hide file tree
Showing 10 changed files with 1,223 additions and 146 deletions.
129 changes: 100 additions & 29 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
":computation_client",
":env_vars",
":pjrt_computation_client",
":ifrt_computation_client",
"@tsl//tsl/platform:stacktrace",
],
)
Expand Down Expand Up @@ -70,6 +71,36 @@ cc_library(
],
)

cc_library(
name = "ifrt_computation_client",
srcs = [
"ifrt_computation_client.cc",
],
hdrs = [
"ifrt_computation_client.h",
],
deps = [
":computation_client",
":debug_macros",
":env_vars",
":initialize_pjrt",
":operation_manager",
":stablehlo_helper",
":tf_logging",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/python/ifrt",
"@xla//xla/python/pjrt_ifrt",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "pjrt_computation_client",
srcs = [
Expand All @@ -83,6 +114,7 @@ cc_library(
":debug_macros",
":env_hash",
":env_vars",
":initialize_pjrt",
":operation_manager",
":profiler",
":stablehlo_helper",
Expand All @@ -94,11 +126,7 @@ cc_library(
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt/distributed",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
Expand Down Expand Up @@ -165,6 +193,25 @@ cc_test(
],
)

cc_library(
name = "initialize_pjrt",
srcs = ["initialize_pjrt.cc"],
hdrs = ["initialize_pjrt.h"],
deps = [
":debug_macros",
":env_hash",
":env_vars",
":profiler",
":sys_util",
":tf_logging",
":xla_coordinator",
"@xla//xla/service:gpu_plugin",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
],
)

cc_library(
name = "metrics_analysis",
srcs = ["metrics_analysis.cc"],
Expand Down Expand Up @@ -410,28 +457,52 @@ ptxla_cc_test(
],
)

# disable for now since it is flaky on the upstream test.
# ptxla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# ":pjrt_computation_client",
# ":tensor_source",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:errors",
# "@tsl//tsl/platform:logging",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

ptxla_cc_test(
name = "ifrt_computation_client_test",
srcs = ["ifrt_computation_client_test.cc"],
deps = [
":computation_client",
":ifrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ class ComputationClient {
static int64_t GetDeviceOrdinal(const std::string& device);

protected:
static constexpr auto spmd_device_str = "SPMD:0";

// Metrics common to all client interfaces.
static metrics::Metric* TransferToServerMetric();
static metrics::Metric* TransferToServerTransformMetric();
Expand Down
Loading

0 comments on commit 3571f2d

Please sign in to comment.