Skip to content

Commit

Permalink
jaxdecomp proto (#21)
Browse files Browse the repository at this point in the history
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <[email protected]>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <[email protected]>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <[email protected]>
Co-authored-by: Francois Lanusse <[email protected]>
Co-authored-by: Wassim KABALAN <[email protected]>
Co-authored-by: Hugo Simonfroy <[email protected]>
Former-commit-id: 8c2e823
  • Loading branch information
5 people authored Dec 20, 2024
1 parent a0a7927 commit df8602b
Show file tree
Hide file tree
Showing 26 changed files with 1,884 additions and 447 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Code Formatting

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files
45 changes: 45 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Tests

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
run_tests:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10" , "3.11" , "3.12"]

steps:
- name: Checkout Source
uses: actions/[email protected]

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install .[test]
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v -m distributed
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule
celerybeat.pid


out
traces
*.npy
*.out
# SageMath parsed files
*.sage.py

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver

**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**

## Goals

Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:
- Keep implementation simple and readable, in pure NumPy API
- Transparent distribution using builtin `xmap`
- Any order forward and backward automatic differentiation
- Support automated batching using `vmap`
- Compatibility with external optimizer libraries like `optax`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`


## Open development and use

Expand All @@ -23,6 +22,10 @@ Current expectations are:
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers.

## Getting Started

To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial.


## Contributors ✨

Expand Down
52 changes: 0 additions & 52 deletions design.md

This file was deleted.

96 changes: 0 additions & 96 deletions dev/test_pfft.py

This file was deleted.

68 changes: 0 additions & 68 deletions dev/test_script.py

This file was deleted.

Loading

0 comments on commit df8602b

Please sign in to comment.