Skip to content

Commit

Permalink
Don't explcitly include virtual device in hash (#6148)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Dec 14, 2023
1 parent f9dc824 commit d2e0676
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
26 changes: 26 additions & 0 deletions test/test_persistent_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def _spmd_replicated_test(tmpdir, metrics):
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_explicitly_replicated_test(tmpdir, metrics):
xr.initialize_cache(tmpdir)
xr.use_spmd()
t = torch.randn(16)
xt = t.to(xm.xla_device())

n_dev = xr.global_runtime_device_count()
mesh = xs.Mesh(range(n_dev), (n_dev,))
xs.mark_sharding(xt, mesh, (None,))
_assert_correctness_and_metrics(t, xt, metrics)


def _spmd_sharded_test(tmpdir, metrics):
xr.initialize_cache(tmpdir)
xr.use_spmd()
Expand Down Expand Up @@ -124,6 +136,20 @@ def test_persistent_cache_mp(self):
def test_persistent_cache(self, test_fn):
self._run_test(_test_spawn, test_fn)

@absltest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD')
@run_with_tmpdir
def test_replicated_spmd_hash(self, tmpdir):
# The hash should differ between replicated SPMD and the single-device test.
_test_spawn(_spmd_explicitly_replicated_test, (tmpdir, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
}))

_test_spawn(_single_device_test, (tmpdir, {
'PersistentCacheMiss': 1,
'PersistentCacheHit': None
}))


if __name__ == '__main__':
test = absltest.main()
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ torch::lazy::hash_t hash_comp_env(
std::shared_ptr<xla::PjRtClient> client,
std::vector<xla::PjRtDevice*>& ordered_devices) {
torch::lazy::hash_t hash = hash::HashXlaEnvVars();
// Whether or not SPMD mode is active should influence the hash.
hash = torch::lazy::HashCombine(hash, UseVirtualDevice());
auto topology_desc = client->GetTopologyDescription();
if (topology_desc.ok()) {
// Some backends support a topology description which provides a better
Expand Down Expand Up @@ -244,6 +242,7 @@ PjRtComputationClient::PjRtComputationClient() {
std::string device_str = PjRtDeviceToString(device);
string_to_device_.emplace(device_str, device);
}
comp_env_hash_ = hash_comp_env(client_, ordered_devices);

auto tracked_devices = GetLocalDevices();
tracked_devices.emplace_back(spmd_device_str);
Expand Down

0 comments on commit d2e0676

Please sign in to comment.