Skip to content

Commit

Permalink
remove using_pjrt in xla_graph_executor (pytorch#6768)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Mar 27, 2024
1 parent f083e10 commit d6fb539
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool using_pjrt =
runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0;
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,
po_data->post_order,
Expand Down Expand Up @@ -1346,7 +1344,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// TODO(yeounoh) enable wrapping with auto-sharding.
bool should_wrap_parameter =
(program_shape.parameters_size() >= parameter_wrapping_threadshold) &&
using_pjrt && !use_autosharding;
!use_autosharding;
if (should_wrap_parameter) {
TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size()
<< " parameters. Threadshold = "
Expand Down

0 comments on commit d6fb539

Please sign in to comment.