From eef7bb4e9db7d58d78137018636f7df3b0a214cb Mon Sep 17 00:00:00 2001 From: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:26:42 -0700 Subject: [PATCH] Using dummy store when user skips barrier to save open file descriptors (#6834) --- configuration.yaml | 8 ++++++++ torch_xla/_internal/rendezvous.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/configuration.yaml b/configuration.yaml index bf7b786fb1a..c2f8c8290bd 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -216,6 +216,14 @@ variables: - Compiler cache size for the op by op executor. type: int default_value: 2048 + XLA_USE_DUMMY_STORE: + description: + - If set to true, and user skips store based barrier by + setting TORCH_DIST_INIT_BARRIER=0, the `pjrt_rendezvous_handler` + will create a DummyStore to replace TCPStore to save open file + descriptors. + type: bool + default_value: false device_variables: TPU_NUM_DEVICES: description: diff --git a/torch_xla/_internal/rendezvous.py b/torch_xla/_internal/rendezvous.py index a4e8dc20fb2..28828dfbc9d 100644 --- a/torch_xla/_internal/rendezvous.py +++ b/torch_xla/_internal/rendezvous.py @@ -13,6 +13,12 @@ _store_lock = threading.Lock() +class DummyStore(dist.Store): + + def __init__(self, *args, **kwargs): + super().__init__() + + def pjrt_rendezvous_handler(url: str, timeout: datetime.timedelta = ..., **kwargs): @@ -34,7 +40,14 @@ def pjrt_rendezvous_handler(url: str, with _store_lock: global _store if not _store: - if xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True': + # Create DummyStore when user skips store based barrier by setting TORCH_DIST_INIT_BARRIER=0 + # and enables XLA_USE_DUMMY_STORE=1. It's safe to do so because store created by _pjrt_rendezvous_handler + # is only used as a barrier in process groups. If store is needed, user can set XLA_USE_DUMMY_STORE=0 to + # use TCPStore. + if xu.getenv_as('TORCH_DIST_INIT_BARRIER', int, 1) == 0 and xu.getenv_as( + 'XLA_USE_DUMMY_STORE', int, 0) == 1: + _store = DummyStore() + elif xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True': attempt = xu.getenv_as('TORCHELASTIC_RESTART_COUNT', int, defval=0) tcp_store = dist.TCPStore( master_ip, master_port, xr.process_count(), is_master=False)