Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New OI experiments #74

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b83b791
start SPDE branch
maxbeauchamp Mar 14, 2022
43e4406
pull 20220524
Mar 14, 2022
27f99a6
start srv5
Mar 14, 2022
4379b53
merge with spde branch
maxbeauchamp Mar 14, 2022
4001c9c
start srv5
maxbeauchamp Mar 14, 2022
fd14aee
20220524
maxbeauchamp Apr 29, 2022
171daea
New OI directory + related Notebooks
maxbeauchamp May 3, 2022
8804a55
New OI directory + related Notebooks
maxbeauchamp May 3, 2022
36702fd
New OI dir + related Notebooks
maxbeauchamp May 3, 2022
822d0a3
New OI dir + related Notebooks
maxbeauchamp May 3, 2022
9d391fd
generate first profiles
quentinf00 May 16, 2022
47ac4a5
New SPDE improvements
maxbeauchamp May 24, 2022
47cf281
new ose xp
maxbeauchamp Jun 9, 2022
c56bf4a
new ose xp
maxbeauchamp Jun 9, 2022
53b66e7
new xp OSE
maxbeauchamp Jun 17, 2022
1aa6093
new SOE notebooks
maxbeauchamp Jun 18, 2022
98f0b28
new OSE notebooks
maxbeauchamp Jun 18, 2022
073295b
last spde june 2022
maxbeauchamp Jun 29, 2022
33cf7a2
before introducing tau
maxbeauchamp Sep 4, 2022
c0f6bd9
push zay fp osse
maxbeauchamp Sep 21, 2022
817f26a
new xp oi
maxbeauchamp Sep 21, 2022
26678ef
new xp oi
maxbeauchamp Sep 21, 2022
018e106
No bugs on the line
maxbeauchamp Oct 9, 2022
4ab5a10
Finalizing SPDE xp setup
maxbeauchamp Oct 25, 2022
b2fd553
Last before 2023
maxbeauchamp Dec 9, 2022
c54667d
homogeneize qg and ssh xp
maxbeauchamp Jan 3, 2023
539403d
march 2023
maxbeauchamp Mar 7, 2023
6fae20b
new spde version
maxbeauchamp Apr 26, 2023
d30f5be
first push from mee-a100
maxbeauchamp May 30, 2023
3842c5e
first push from mee-a100
maxbeauchamp May 30, 2023
20d0084
running simu cond ssh ok
maxbeauchamp Jul 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
57 changes: 57 additions & 0 deletions .bak/launch_xp_spde_gp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash

concat_simu () { cat << HEREDOC
import xarray as xr
import numpy as np
N=${N}
data = xr.open_dataset('dashboard/oi_gp_spde_wlp_ml_loss/lightning_logs/version_'+str(N)+'/maps.nc')
for i in range(1,2):
new_simu = xr.open_dataset('dashboard/oi_gp_spde_wlp_ml_loss/lightning_logs/version_'+str(N+i)+'/maps.nc')
data = data.update({'simu':(('time','lat', 'lon', 'daw', 'nsimu'),
np.concatenate((data.simu.values,
new_simu.simu.values),axis=4)),
'nsimu': np.arange(2*(i+1))})
nc = data
nc.to_netcdf('tmp_${start}.nc')
HEREDOC
}

ldate=($(seq 0 4 21 | xargs -I {} date -d "2013-01-03 {} days" +%Y-%m-%d) '2013-01-28')
len=${#ldate[@]}
# run all the dates
echo $len
for (( i=1; i<$len; i++ )); do
start="${ldate[$((i-1))]}"
end="${ldate[$i]}"
end=$(date -d "$end + 3 days" '+%Y-%m-%d')
echo "### RUN DATE ${start} to ${end} ###"
# run all the simulations
for (( j=0; j<2; j++ )); do
echo "### RUN SIMU $((j*5)) to $(((j+1)*5)) ###"
CUDA_VISIBLE_DEVICES=0 HYDRA_FULL_ERROR=1 python hydra_main.py xp=mbeaucha/xp_spde/gp_diff_pow_2/oi_gp_spde_wlp_mlloss file_paths=mee_a100_gp entrypoint=test datamodule.test_slices.0._args_.0='"'$start'"' datamodule.test_slices.0._args_.1='"'$end'"' +entrypoint.ckpt_path=../../model_oi_gp_spde_v2.ckpt
done
# merge NetCDF
N=`ls dashboard/oi_gp_spde_wlp_ml_loss/lightning_logs/ | tail -n 1 | cut -f2 -d'_'`
N=$((N-1))
concat_simu | python -
done
# merge NetCDF
python <<HEREDOC
import xarray as xr
nc = xr.open_mfdataset('tmp*.nc',combine='nested',concat_dim='time')
nc.to_netcdf('/homes/m19beauc/maps_gp_4DvarNet_SPDE.nc')
HEREDOC

# cleaning
rm -rf tmp*.nc
cd dashboard
# clean before
for dir in `ls .` ; do
if [ -d ${dir}/lightning_logs ] ; then
echo ${dir}/lightning_logs
cd ${dir}/lightning_logs
ls -lt | tail -n +3 | awk '{print $9}' | xargs rm -r
cd ../..
fi
done
cd ..
Binary file added .bak/model_oi_osse_spde_v2.ckpt
Binary file not shown.
24 changes: 7 additions & 17 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@ tb_profile
*err
*out
log
spde.py
scipy_sparse_tools.py
models_spde.py
models_wSPDE.py
merge_ose_osse*
main_wSPDE.slurm
lit_model_4DEnVar.py
lit_model_4DEnVar.py.old
lit_model_En4DVar.py
lit_model_Un4DVar.py
solver_4DEnVar.py
hydra_config/xp/stoch*
hydra_config/xp/4DEnVar_osse_wuc_gf.yaml
hydra_config/training/stoch*
hydra_config/training/4DEnVar_osse_wuc.yaml
hydra_config/params/fourdvarnet_osse_wuc.yaml
hydra_config/params/fourdvarnet_stoch_osse.yaml
checkpoints
__pycache__
tmp
Expand All @@ -50,3 +33,10 @@ instab
lightning_logs_archives
report
trained_cfgs
trials
oi/eval_notebooks/.ipynb_checkpoints
oi/eval_notebooks/__pycache__
oi/logs
postpro/*/*png
out.file
log*
45 changes: 37 additions & 8 deletions dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
_ds.time.attrs["units"] = "seconds since 2012-10-01"
_ds = xr.decode_cf(_ds)
else:
_ds['time'] = pd.to_datetime(_ds.time)
_ds['time'] = pd.to_datetime(_ds.time.values)

# rename latitute/longitude to lat/lon for consistency
rename_coords = {}
Expand Down Expand Up @@ -169,7 +169,6 @@ def __init__(
self.ds = self.ds_reflected.assign_coords(
lon=self.padded_coords['lon'], lat=self.padded_coords['lat']
)

# III) get lon-lat for the final reconstruction
dX = ((slice_win['lon']-strides['lon'])/2)*self.resolution
dY = ((slice_win['lat']-strides['lat'])/2)*self.resolution
Expand All @@ -180,13 +179,15 @@ def __init__(


self.ds = self.ds.transpose("time", "lat", "lon")

if self.interp_na:
self.ds = interpolate_na_2D(self.ds)

if compute:
self.ds = self.ds.compute()

self.ds = self.ds.transpose("time", "lat", "lon")


self.slice_win = slice_win
self.strides = strides or {}
self.ds_size = {
Expand Down Expand Up @@ -257,6 +258,7 @@ def __init__(
aug_train_data=False,
compute=False,
pp='std',
rmv_patches=False
):
super().__init__()
self.use_auto_padding=use_auto_padding
Expand All @@ -265,6 +267,10 @@ def __init__(
self.return_coords = False
self.pp=pp

self.dim_range = dim_range
self.slice_win = slice_win
self.strides = strides

self.gt_ds = XrDataset(
gt_path, gt_var,
slice_win=slice_win,
Expand All @@ -275,8 +281,9 @@ def __init__(
resize_factor=resize_factor,
compute=compute,
auto_padding=use_auto_padding,
interp_na=True,
interp_na=False,
)

self.obs_mask_ds = XrDataset(
obs_mask_path, obs_mask_var,
slice_win=slice_win,
Expand All @@ -287,6 +294,7 @@ def __init__(
resize_factor=resize_factor,
compute=compute,
auto_padding=use_auto_padding,
interp_na=False,
)

self.oi_ds = XrDataset(
Expand All @@ -313,14 +321,16 @@ def __init__(
resize_factor=resize_factor,
compute=compute,
auto_padding=use_auto_padding,
interp_na=True,
interp_na=False,
)
else:
self.sst_ds = None

if self.aug_train_data:
self.perm = np.random.permutation(len(self.obs_mask_ds))

self.rmv_patches = rmv_patches

self.norm_stats = (0, 1)
self.norm_stats_sst = (0, 1)

Expand Down Expand Up @@ -387,13 +397,29 @@ def __getitem__(self, item):
obs_mask_item = ~np.isnan(_obs_item)
obs_item = np.where(~np.isnan(_obs_item), _obs_item, np.zeros_like(_obs_item))

# remove patches from data
'''
if self.rmv_patches==True:
n_patch = 10
s_patch = 10
for i in range(len(obs_item)):
posx = np.random.randint(s_patch,self.slice_win['lon']-s_patch,n_patch)
posy = np.random.randint(s_patch,self.slice_win['lat']-s_patch,n_patch)
ix = np.stack([np.arange(posx[ipatch]-s_patch,posx[ipatch]+s_patch+1) for ipatch in range(n_patch)])
iy = np.stack([np.arange(posy[ipatch]-s_patch,posy[ipatch]+s_patch+1) for ipatch in range(n_patch)])
ix, iy = np.transpose(np.stack([np.meshgrid(ix[ipatch],iy[ipatch]) for ipatch in range(n_patch)]),
(1,0,2,3))
gt_item[i,ix,iy] = 0.
obs_item[i,~ix,~iy] = 0.
obs_mask_item[i,~ix,~iy] = 0.
'''

if self.sst_ds == None:
return oi_item, obs_mask_item, obs_item, gt_item
else:
pp_sst = self.get_pp(self.norm_stats_sst)
_sst_item = pp_sst(self.sst_ds[item % length])
sst_item = np.where(~np.isnan(_sst_item), _sst_item, 0.)

return oi_item, obs_mask_item, obs_item, gt_item, sst_item

class FourDVarNetDataModule(pl.LightningDataModule):
Expand Down Expand Up @@ -423,7 +449,8 @@ def __init__(
dl_kwargs=None,
compute=False,
use_auto_padding=False,
pp='std'
pp='std',
rmv_patches=False
):
super().__init__()
self.resize_factor = resize_factor
Expand Down Expand Up @@ -459,6 +486,7 @@ def __init__(
self.norm_stats = (0, 1)
self.norm_stats_sst = None

self.rmv_patches = rmv_patches

def mean_stds(self, ds):
sum = 0
Expand Down Expand Up @@ -548,6 +576,7 @@ def setup(self, stage=None):
aug_train_data=self.aug_train_data,
compute=self.compute,
pp=self.pp,
rmv_patches=self.rmv_patches
) for sl in self.train_slices])


Expand All @@ -574,6 +603,7 @@ def setup(self, stage=None):
compute=self.compute,
use_auto_padding=self.use_auto_padding,
pp=self.pp,
rmv_patches=False
) for sl in slices]
)
for slices in (self.val_slices, self.test_slices)
Expand Down Expand Up @@ -602,7 +632,6 @@ def val_dataloader(self):
def test_dataloader(self):
return DataLoader(self.test_ds, **{**dict(shuffle=False), **self.dl_kwargs})


if __name__ == '__main__':
"""
Test run for single batch loading and trainer.fit
Expand Down
29 changes: 23 additions & 6 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,44 @@ channels:
- pytorch
- conda-forge
- defaults
- anaconda
#- anaconda
- pyviz
dependencies:
- jupyterlab
- python=3.9
- python=3.10
- pip=21.1
- scikit-learn=0.24.2
- cartopy=0.19.0
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- scikit-learn=1.2.1
- scikit-sparse
- numpy=1.24.2
- netcdf4=1.6.3
- matplotlib=3.7.1
- cartopy=0.21.1
- cartopy_offlinedata=0.2.4
- tabulate=0.8.9
- xrft=0.4.1
- numpy_groupies=0.9.13
- pytorch-lightning=1.6.2
- pytorch-lightning=1.9.4
- pip:
- xarray[complete]==0.21
- holoviews[recommended]
- sparse
- GitPython
- setuptools==57.0.0
- torchmetrics==0.6.0
- --find-links https://data.pyg.org/whl/torch-1.12.0+cu113.html
- eccodes==1.2.0
- omegaconf
- einops
- hydra-submitit-launcher
- opencv-python==4.5.2.52
- opencv-python==4.7.0.72
- hydra-core
- kornia
- pyepsg
- torch-scatter
- torch-sparse
- cupy-cuda113
- torch_tb_profiler
# mamba install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
8 changes: 8 additions & 0 deletions hydra_config/domain/cnatl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
lat:
_target_: builtins.slice
_args_: [33., 53.]
lon:
_target_: builtins.slice
_args_: [-50., -10.]


7 changes: 7 additions & 0 deletions hydra_config/domain/gf_mod.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
lat:
_target_: builtins.slice
_args_: [33., 42.95]
lon:
_target_: builtins.slice
_args_: [-65., -55.05]

7 changes: 7 additions & 0 deletions hydra_config/domain/gp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
lat:
_target_: builtins.slice
_args_: [0, 100]
lon:
_target_: builtins.slice
_args_: [0, 100]

8 changes: 8 additions & 0 deletions hydra_config/domain/osmosis_x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
lat:
_target_: builtins.slice
_args_: [44., 56.]
lon:
_target_: builtins.slice
_args_: [-20.5, -10.5]


7 changes: 7 additions & 0 deletions hydra_config/domain/qg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
lat:
_target_: builtins.slice
_args_: [-2.5, 2.5]
lon:
_target_: builtins.slice
_args_: [-2.5, 2.5]

3 changes: 3 additions & 0 deletions hydra_config/entrypoint/predict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: hydra_main.FourDVarNetHydraRunner.predict
ckpt_path: ???

2 changes: 1 addition & 1 deletion hydra_config/entrypoint/run.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: hydra_main.FourDVarNetHydraRunner.run
max_epochs: 200
progress_bar_refresh_rate: 5
#progress_bar_refresh_rate: 5
2 changes: 1 addition & 1 deletion hydra_config/entrypoint/train.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: hydra_main.FourDVarNetHydraRunner.train
ckpt_path: null
max_epochs: 200
progress_bar_refresh_rate: 5
#progress_bar_refresh_rate: 5
limit_train_batches: 1.0
5 changes: 5 additions & 0 deletions hydra_config/file_paths/dc_gp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
obs_path: /gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/data/SPDE_diffusion_dataset.nc
gt_path: /gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/data/SPDE_diffusion_dataset.nc
spde_params_path: /gpfswork/rech/yrf/uba22to/4dvarnet-core/oi/data/SPDE_diffusion_dataset.nc


5 changes: 5 additions & 0 deletions hydra_config/file_paths/dc_osse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
oi_path: /gpfsstore/rech/yrf/commun/NATL60/NATL/oi/ssh_NATL60_swot_4nadir.nc
obs_mask_path: /gpfsstore/rech/yrf/commun/NATL60/NATL/data_new/dataset_nadir_0d_swot.nc
gt_path: /gpfsstore/rech/yrf/commun/NATL60/NATL/ref/NATL60-CJM165_NATL_ssh_y2013.1y.nc
sst_path: /gpfsstore/rech/yrf/commun/NATL60/NATL/ref/NATL60-CJM165_NATL_sst_y2013.1y.nc

3 changes: 3 additions & 0 deletions hydra_config/file_paths/jz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ cal_data: /gpfswork/rech/yrf/commun/CalData/full_cal_obs.nc
noisy_swot: /gpfswork/rech/yrf/commun/CalData/cal_data_karin_noise_only.nc
natl_ssh: /gpfsstore/rech/yrf/commun/NATL60/NATL/ref/NATL60-CJM165_NATL_ssh_y2013.1y.nc
natl_sst: /gpfsdsstore/projects/rech/yrf/commun/NATL60/NATL/ref/NATL60-CJM165_NATL_sst_y2013.1y.nc
obs_path: /gpfswork/rech/ubn/commun/xp_gp/spde_diffusion_dataset.nc
gt_path: /gpfswork/rech/ubn/commun/xp_gp/spde_diffusion_dataset.nc
spde_params_path: /gpfswork/rech/ubn/commun/xp_gp/spde_diffusion_dataset.nc

oi_swot_4nadir: /gpfsstore/rech/yrf/commun/NATL60/NATL/oi/ssh_NATL60_swot_4nadir.nc
pseudo_obs: /gpfsstore/rech/yrf/commun/NATL60/NATL/data_new/dataset_nadir_0d_swot.nc
Expand Down
Loading