Skip to content

Commit

Permalink
Merge pull request #47 from JamesPerlman/patch-1
Browse files Browse the repository at this point in the history
Save result from jax.local_device_count
  • Loading branch information
keunhong authored Feb 15, 2022
2 parents ec5b476 + 0097ea3 commit 1a38512
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nerfies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def general_loss_with_squared_residual(squared_x, alpha, scale):
def shard(xs, device_count=None):
"""Split data into shards for multiple devices along the first dimension."""
if device_count is None:
jax.local_device_count()
device_count = jax.local_device_count()
return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)


Expand Down

0 comments on commit 1a38512

Please sign in to comment.