Skip to content

Commit

Permalink
add env to control MegascalePjrtClient
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Oct 29, 2024
1 parent e731688 commit 4ceaf18
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ def trace_pallas(kernel: Callable,
return trace_pallas_arg_to_payload[hash_key], tensor_args

# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
os.environ['USE_SINGLE_SLICE'] = 'true'
ir = jax.jit(
kernel, static_argnums=static_argnums,
static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
payload = _extract_backend_config(ir)
os.environ.pop('USE_SINGLE_SLICE', None)

if use_cache:
# if we reach here it means we have a cache miss.
Expand Down

1 comment on commit 4ceaf18

@miladm
Copy link
Collaborator

@miladm miladm commented on 4ceaf18 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zpcore can we have a PR for this change and explain the problem + how this change is a temp patch to unblock the problem.

Please sign in to comment.