diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 6d0f12aa4ae..0eddefc39f3 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1285,8 +1285,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); static const size_t parameter_wrapping_threadshold = runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200); - static const bool using_pjrt = - runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0; static const bool use_autosharding = ShardingUtil::GetAutoSharding(); LoweringContext lowering_ctx("SyncTensorsGraph", coll.device, po_data->post_order, @@ -1346,7 +1344,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // TODO(yeounoh) enable wrapping with auto-sharding. bool should_wrap_parameter = (program_shape.parameters_size() >= parameter_wrapping_threadshold) && - using_pjrt && !use_autosharding; + !use_autosharding; if (should_wrap_parameter) { TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size() << " parameters. Threadshold = "