diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index 42040a9cca5..a7b0fdd74d6 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -19,6 +19,9 @@ const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; const char* const kEnvPjRtLocalProcessCount = "PJRT_LOCAL_PROCESS_COUNT"; const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK"; +const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC"; +const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE"; +const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION"; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index e54ba8f72cd..bc8a6fbc667 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -29,6 +29,9 @@ extern const char* const kEnvNeuronLibraryPath; extern const char* const kEnvPjrtDistServiceAddr; extern const char* const kEnvPjRtLocalProcessCount; extern const char* const kEnvPjRtLocalRank; +extern const char* const kEnvPjrtAllocatorCudaAsync; +extern const char* const kEnvPjrtAllocatorPreallocate; +extern const char* const kEnvPjrtAllocatorFraction; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c003f4f9706..94d504e1714 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -62,6 +62,23 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { return xla::ShapeUtil::DeviceShapeToHostShape(shape); } +xla::GpuAllocatorConfig GetGpuAllocatorConfig() { + auto allocator_config = xla::GpuAllocatorConfig{}; + if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { + return allocator_config; + } + if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { + allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; + } + allocator_config.preallocate = + sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); + allocator_config.memory_fraction = + sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); + return allocator_config; +} + } // namespace std::string PjRtComputationClient::PjRtDeviceToString( @@ -141,7 +158,7 @@ PjRtComputationClient::PjRtComputationClient() { << global_process_rank << ", num_nodes=" << global_world_size; client_ = std::move(xla::GetStreamExecutorGpuClient( /*asynchronous=*/async, - /*allocator_config=*/xla::GpuAllocatorConfig{}, + /*allocator_config=*/GetGpuAllocatorConfig(), /*node_id=*/global_process_rank, /*num_nodes=*/global_world_size, /*allowed_devices=*/allowed_devices,