diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index 469ceca75dd..1683d15ddbb 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -69,6 +69,18 @@ def _spmd_replicated_test(tmpdir, metrics): _assert_correctness_and_metrics(t, xt, metrics) +def _spmd_explicitly_replicated_test(tmpdir, metrics): + xr.initialize_cache(tmpdir) + xr.use_spmd() + t = torch.randn(16) + xt = t.to(xm.xla_device()) + + n_dev = xr.global_runtime_device_count() + mesh = xs.Mesh(range(n_dev), (n_dev,)) + xs.mark_sharding(xt, mesh, (None,)) + _assert_correctness_and_metrics(t, xt, metrics) + + def _spmd_sharded_test(tmpdir, metrics): xr.initialize_cache(tmpdir) xr.use_spmd() @@ -124,6 +136,20 @@ def test_persistent_cache_mp(self): def test_persistent_cache(self, test_fn): self._run_test(_test_spawn, test_fn) + @absltest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD') + @run_with_tmpdir + def test_replicated_spmd_hash(self, tmpdir): + # The hash should differ between replicated SPMD and the single-device test. + _test_spawn(_spmd_explicitly_replicated_test, (tmpdir, { + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None + })) + + _test_spawn(_single_device_test, (tmpdir, { + 'PersistentCacheMiss': 1, + 'PersistentCacheHit': None + })) + if __name__ == '__main__': test = absltest.main() diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 871760d4802..adece361e82 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -89,8 +89,6 @@ torch::lazy::hash_t hash_comp_env( std::shared_ptr client, std::vector& ordered_devices) { torch::lazy::hash_t hash = hash::HashXlaEnvVars(); - // Whether or not SPMD mode is active should influence the hash. - hash = torch::lazy::HashCombine(hash, UseVirtualDevice()); auto topology_desc = client->GetTopologyDescription(); if (topology_desc.ok()) { // Some backends support a topology description which provides a better @@ -244,6 +242,7 @@ PjRtComputationClient::PjRtComputationClient() { std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); } + comp_env_hash_ = hash_comp_env(client_, ordered_devices); auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str);