From fbede659a8a45f5a097891a24f57fa42b6e3a194 Mon Sep 17 00:00:00 2001
From: jeffhataws <jthuynh@amazon.com>
Date: Mon, 7 Oct 2024 10:09:32 -0700
Subject: [PATCH] Part 1: Introduce multi-node SPMD support for Neuron (#8204)
 (#8224)

Co-authored-by: Rui <179625410+rpsilva-aws@users.noreply.github.com>
---
 torch_xla/runtime.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py
index e4560df6c70..1946ae05a52 100644
--- a/torch_xla/runtime.py
+++ b/torch_xla/runtime.py
@@ -253,6 +253,15 @@ def use_spmd(auto: Optional[bool] = False):
     torch_xla._XLAC._xla_set_auto_sharding()
     os.environ["XLA_AUTO_SPMD"] = "1"
 
+  if device_type() == 'NEURON':
+    # In case of Neuron, reset the initialization environment to accommodate SPMD.
+    try:
+      from torch_neuronx.initialization import initialize
+
+      initialize()
+    except ImportError:
+      pass
+
 
 def is_spmd():
   """Returns if SPMD is set for execution."""