diff --git a/dev/jaxdecomp.py b/dev/jaxdecomp.py deleted file mode 100644 index ddb19e5..0000000 --- a/dev/jaxdecomp.py +++ /dev/null @@ -1,69 +0,0 @@ -import argparse - -import jax -import numpy as np - -# Setting up distributed jax -jax.distributed.initialize() -rank = jax.process_index() -size = jax.process_count() - -import jax.numpy as jnp -import jax_cosmo as jc -from jax.experimental import mesh_utils -from jax.sharding import Mesh - -from jaxpm.painting import cic_paint -from jaxpm.pm import linear_field, lpt - -mesh_shape = [256, 256, 256] -box_size = [256., 256., 256.] -snapshots = jnp.linspace(0.1, 1., 2) - - -@jax.jit -def run_simulation(omega_c, sigma8, seed): - # Create a cosmology - cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) - - # Create a small function to generate the matter power spectrum - k = jnp.logspace(-4, 1, 128) - pk = jc.power.linear_matter_power( - jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) - pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk - ).reshape(x.shape) - - # Create initial conditions - initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed) - - # Initialize particle displacements - dx, p, f = lpt(cosmo, initial_conditions, 1.0) - - field = cic_paint(jnp.zeros_like(initial_conditions), dx) - return field - - -def main(args): - # Setting up distributed random numbers - master_key = jax.random.PRNGKey(42) - key = jax.random.split(master_key, size)[rank] - - # Create computing mesh and sharding information - devices = mesh_utils.create_device_mesh((2, 2)) - mesh = Mesh(devices.T, axis_names=('x', 'y')) - - # Run the simulation on the compute mesh - with mesh: - field = run_simulation(0.32, 0.8, key) - - print('done') - np.save(f'field_{rank}.npy', field.addressable_data(0)) - - # Closing distributed jax - jax.distributed.shutdown() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") - args = parser.parse_args() - main(args) diff --git a/dev/test_pfft.py b/dev/test_pfft.py deleted file mode 100644 index 5a956d8..0000000 --- a/dev/test_pfft.py +++ /dev/null @@ -1,96 +0,0 @@ -# Can be executed with: -# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py -from functools import partial - -import jax -import jax.lax as lax -import jax.numpy as jnp -import numpy as np -from jax.experimental.maps import Mesh, xmap -from jax.experimental.pjit import PartitionSpec, pjit - -jax.distributed.initialize() - -cube_size = 2048 - - -@partial(xmap, - in_axes=[...], - out_axes=['x', 'y', ...], - axis_sizes={ - 'x': cube_size, - 'y': cube_size - }, - axis_resources={ - 'x': 'nx', - 'y': 'ny', - 'key_x': 'nx', - 'key_y': 'ny' - }) -def pnormal(key): - return jax.random.normal(key, shape=[cube_size]) - - -@partial(xmap, - in_axes={ - 0: 'x', - 1: 'y' - }, - out_axes=['x', 'y', ...], - axis_resources={ - 'x': 'nx', - 'y': 'ny' - }) -@jax.jit -def pfft3d(mesh): - # [x, y, z] - mesh = jnp.fft.fft(mesh) # Transform on z - mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x] - mesh = jnp.fft.fft(mesh) # Transform on x - mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y] - mesh = jnp.fft.fft(mesh) # Transform on y - # [z, x, y] - return mesh - - -@partial(xmap, - in_axes={ - 0: 'x', - 1: 'y' - }, - out_axes=['x', 'y', ...], - axis_resources={ - 'x': 'nx', - 'y': 'ny' - }) -@jax.jit -def pifft3d(mesh): - # [z, x, y] - mesh = jnp.fft.ifft(mesh) # Transform on y - mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x] - mesh = jnp.fft.ifft(mesh) # Transform on x - mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z] - mesh = jnp.fft.ifft(mesh) # Transform on z - # [x, y, z] - return mesh - - -key = jax.random.PRNGKey(42) -# keys = jax.random.split(key, 4).reshape((2,2,2)) - -# We reshape all our devices to the mesh shape we want -devices = np.array(jax.devices()).reshape((2, 4)) - -with Mesh(devices, ('nx', 'ny')): - mesh = pnormal(key) - kmesh = pfft3d(mesh) - kmesh.block_until_ready() - -# jax.profiler.start_trace("tensorboard") -# with Mesh(devices, ('nx', 'ny')): -# mesh = pnormal(key) -# kmesh = pfft3d(mesh) -# kmesh.block_until_ready() -# jax.profiler.stop_trace() - -print('Done') diff --git a/dev/test_script.py b/dev/test_script.py deleted file mode 100644 index 4f3ca06..0000000 --- a/dev/test_script.py +++ /dev/null @@ -1,68 +0,0 @@ -# Start this script with: -# mpirun -np 4 python test_script.py -import os - -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' -import jax -import jax.lax as lax -import jax.numpy as jnp -import matplotlib.pylab as plt -import numpy as np -import tensorflow_probability as tfp -from jax.experimental.maps import mesh, xmap -from jax.experimental.pjit import PartitionSpec, pjit - -tfp = tfp.substrates.jax -tfd = tfp.distributions - - -def cic_paint(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], - [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - - dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, - 2)) - mesh = lax.scatter_add( - mesh, - neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - kernel.reshape([-1, 8]), dnums) - return mesh - - -# And let's draw some points from some 3D distribution -dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.], - scale_identity_multiplier=3.) -pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) - -f = pjit(lambda x: cic_paint(x, pos), - in_axis_resources=PartitionSpec('x', 'y', 'z'), - out_axis_resources=None) - -devices = np.array(jax.devices()).reshape((2, 2, 1)) - -# Let's import the mesh -m = jnp.zeros([32, 32, 32]) - -with mesh(devices, ('x', 'y', 'z')): - # Shard the mesh, I'm not sure this is absolutely necessary - m = pjit(lambda x: x, - in_axis_resources=None, - out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) - - # Apply the sharded CiC function - res = f(m) - -plt.imshow(res.sum(axis=2)) -plt.show()