Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mlflow #77

Open
wants to merge 305 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
305 commits
Select commit Hold shift + click to select a range
6685e94
bugfixes
sadamov May 28, 2024
6423fdf
pre_commits
sadamov May 28, 2024
59c4947
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov May 31, 2024
4e457ed
config.py is ready for danra
sadamov May 31, 2024
adc592f
streamlined multi-zarr workflow
sadamov Jun 1, 2024
a7bea6b
xarray zarr based data normalization
sadamov Jun 1, 2024
1f7cbe8
adjusted pre-processing scripts to new data config workflow
sadamov Jun 2, 2024
e328152
plotting update with latest get_xy() function
sadamov Jun 2, 2024
cb85cda
making data config more modular
sadamov Jun 2, 2024
eb8c6fb
removing boundaries for now
sadamov Jun 2, 2024
0cfbb33
small updates
sadamov Jun 2, 2024
59d0c8a
improved stats and units retrieval
sadamov Jun 2, 2024
2f6a87a
add GPU-based runner on cirun.io
leifdenby Jun 3, 2024
668dd81
improved zarr-based normalization
sadamov Jun 3, 2024
143cf2a
pdm install with cpu torch
leifdenby Jun 3, 2024
b760915
ensure exec in pdm venv
leifdenby Jun 3, 2024
7797cef
ensure exec in pdm venv
leifdenby Jun 3, 2024
e689650
check version #2
leifdenby Jun 3, 2024
fb8ef23
check version no 3
leifdenby Jun 3, 2024
51b0a0b
check versions
leifdenby Jun 3, 2024
374d032
merge main
sadamov Jun 3, 2024
8fa3ca7
Introduced datetime forcing calculation as seperate script
sadamov Jun 3, 2024
a748903
Fixed order of y and x dims to adhere to #52
sadamov Jun 3, 2024
70425ee
fix for pip install
leifdenby Jun 3, 2024
60110f6
switch cirun instance type
leifdenby Jun 3, 2024
6fff3fc
install py39 on cirun runner
leifdenby Jun 3, 2024
74b4a10
cleanup: boundary_mask, zarr-opening, utils
sadamov Jun 4, 2024
0a041d1
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov Jun 4, 2024
8054e9e
change ami image to gpu
leifdenby Jun 4, 2024
39fbf3a
Merge remote-tracking branch 'upstream/main' into maint/deps-in-pypro…
leifdenby Jun 4, 2024
97aeb2e
use cheaper gpu instance
leifdenby Jun 4, 2024
425123c
adapted tests for zarr-analysis data
sadamov Jun 4, 2024
4dcf671
Readme adapted for yaml zarr analysis workflow
sadamov Jun 4, 2024
6d384f0
samller bugfixes and improvements
sadamov Jun 4, 2024
12ff4f2
Added fixed data config file for testing on Danra
sadamov Jun 4, 2024
03f7769
reducing runtime of tests with smaller sample
sadamov Jun 4, 2024
26f069c
download danra data for test and example (streaming not possible)
sadamov Jun 6, 2024
1f1cbcc
bugfixes after real-life testcase
sadamov Jun 6, 2024
b369306
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov Jun 6, 2024
0cdc361
organize .zarr in /data
sadamov Jun 6, 2024
23ca7b3
cleanup
sadamov Jun 6, 2024
81422f1
linter
sadamov Jun 6, 2024
124541b
static dataset doesn't have time dim
sadamov Jun 7, 2024
6140fdb
making two complex functions more modular
sadamov Jun 7, 2024
db6a912
chunk dataset by time
sadamov Jun 8, 2024
1aaa8dc
create list first for performance
sadamov Jun 8, 2024
81856b2
converting to_array is very slow
sadamov Jun 8, 2024
b3da818
allow for forcings to not be normalized
sadamov Jun 8, 2024
7ee5398
allow non_normalized_vars to be null
sadamov Jun 8, 2024
4782103
fixed coastlines using new xy_extent function
sadamov Jun 8, 2024
e0ffc5b
Some projections return inverted axes (rotatedPole)
sadamov Jun 9, 2024
c1f43b7
Docstrings added
sadamov Jun 13, 2024
21fd929
wip
leifdenby Jun 26, 2024
c52f98e
npy mllam nearly done
leifdenby Jul 6, 2024
80f3639
minor adjustment
leifdenby Jul 7, 2024
048f8c6
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 11, 2024
5aaa239
add pooch and tweak pip cicd testing
leifdenby Jul 11, 2024
66c3b03
combine cicd tests with caching
leifdenby Jul 11, 2024
8566b8f
linting
leifdenby Jul 11, 2024
29bd9e5
add pyg dep
leifdenby Jul 11, 2024
bc7f028
set cirun aws region to frankfurt
leifdenby Jul 11, 2024
2070166
adapt image
leifdenby Jul 11, 2024
e4e86e5
set image
leifdenby Jul 11, 2024
1fba8fe
try different image
leifdenby Jul 11, 2024
02b77cf
add pooch to cicd
leifdenby Jul 11, 2024
b481929
add pdm gpu test
leifdenby Jul 16, 2024
bcec472
start work on readme
leifdenby Jul 16, 2024
c5beec9
Merge branch 'maint/deps-in-pyproject-toml' into datastore
leifdenby Jul 16, 2024
e89facc
Merge branch 'main' into maint/refactor-as-package
leifdenby Jul 16, 2024
0b5687a
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 16, 2024
095fdbc
turn meps testdata download into pytest fixture
leifdenby Jul 16, 2024
49e9bfe
adapt README for package
leifdenby Jul 16, 2024
12cc02b
remove pdm cicd test (will be in separate PR)
leifdenby Jul 16, 2024
b47f50b
remove pdm in gitignore
leifdenby Jul 16, 2024
90d99ca
remove pdm and pyproject files (will be sep PR)
leifdenby Jul 16, 2024
a91eaaa
add pyproject.toml from main
leifdenby Jul 16, 2024
5508cea
clean out tests
leifdenby Jul 16, 2024
5c623c3
fix linting
leifdenby Jul 16, 2024
08ec168
add cli entrypoints import test
leifdenby Jul 16, 2024
d9cf7ba
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
3954f04
tweak cicd pytest execution
leifdenby Jul 16, 2024
f99fdce
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
db9d96f
Update tests/test_mllam_dataset.py
leifdenby Jul 17, 2024
3c864b2
grid-shape ok
leifdenby Jul 17, 2024
1f54b0e
get_vars_names and units
leifdenby Jul 17, 2024
9b88160
get_vars_names and units 2
leifdenby Jul 17, 2024
a9fdad5
test for stats
leifdenby Jul 23, 2024
555154f
get_dataarray test
leifdenby Jul 24, 2024
8b8a77e
get_dataarray test
leifdenby Jul 24, 2024
41f11cd
boundary_mask
leifdenby Jul 24, 2024
a17de0f
get_xy
leifdenby Jul 24, 2024
0a38a7d
remove TrainingSample dataclass
leifdenby Jul 24, 2024
f65f6b5
test for WeatherDataset.__getitem__
leifdenby Jul 24, 2024
a35100e
test for graph creation
leifdenby Jul 24, 2024
cfb0618
more graph creation tests
leifdenby Jul 24, 2024
8698719
check for consistency of num features across splits
leifdenby Jul 24, 2024
3381404
test for single batch from mllam through model
leifdenby Jul 24, 2024
2a6796c
Add init files to expose classes in editable package
joeloskarsson Jul 24, 2024
8f4e0e0
Linting
joeloskarsson Jul 24, 2024
e657abb
working training_step with datastores!
Jul 25, 2024
effc99b
remove superfluous tests
Jul 25, 2024
a047026
fix for dataset length
Jul 25, 2024
d2c62ed
step length should be int
Jul 25, 2024
58f5d99
step length should be int
Jul 25, 2024
64d43a6
training working with mllam datastore!
Jul 25, 2024
07444f8
adapt neural_lam.train_model for datastores
Jul 25, 2024
d1b6fc1
fixes for npy
Jul 25, 2024
6fe19ac
npyfiles datastore complete
leifdenby Jul 26, 2024
fe65a4d
cleanup for datastore examples
leifdenby Jul 26, 2024
e533794
training on ohm with danra!
Jul 26, 2024
640ac05
use mllam-data-prep v0.2.0
Aug 5, 2024
0f16f13
remove py3.12 from pre-commit
Aug 5, 2024
724548e
cleanup
Aug 8, 2024
a1b2037
all tests passing!
Aug 12, 2024
e35958f
use mllam-data-prep v0.3.0
Aug 12, 2024
8b92318
delete requirements.txt
Aug 13, 2024
658836a
remove .DS_Store
Aug 13, 2024
421efed
use tmate in gpu pdm cicd
Aug 13, 2024
05f1e9f
remove requirements
Aug 13, 2024
3afe0e4
update pdm gpu cicd setup to pdm venv on nvme drive
Aug 13, 2024
f3d028b
don't try to use pdm venv in-project
Aug 13, 2024
2c35662
remove tmate
Aug 13, 2024
5f30255
update README with install instructions
Aug 14, 2024
b2b5631
changelog
Aug 14, 2024
c8ae829
update ci/cd badges to include gpu + gpu
Aug 14, 2024
e7cf2c0
Merge pull request #1 from mllam/package_inits
leifdenby Aug 14, 2024
0b72e9d
add pyproject-flake8 to precommit config
Aug 14, 2024
190d1de
use Flake8-pyproject instead
Aug 14, 2024
791af0a
update README
Aug 14, 2024
58fab84
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
dbe2e6d
Merge branch 'maint/refactor-as-package' into maint/deps-in-pyproject…
Aug 14, 2024
eac6e35
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
799d55e
linting fixes
Aug 14, 2024
57bbb81
train only 1 epoch in cicd and print to stdout
Aug 14, 2024
a955cee
log datastore config
Aug 14, 2024
0a79c74
cleanup doctrings
Aug 15, 2024
9f3c014
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Aug 19, 2024
41364a8
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Aug 19, 2024
3422298
update changelog
leifdenby Aug 19, 2024
689ef69
move dev deps optional dependencies group
leifdenby Aug 20, 2024
9a0d538
update cicd tests to install dev deps
leifdenby Aug 20, 2024
bddfcaf
update readme with new dev deps group
leifdenby Aug 20, 2024
b96cfdc
quote the skip step the install readme
leifdenby Aug 20, 2024
2600dee
remove unused files
leifdenby Aug 20, 2024
65a8074
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Aug 20, 2024
6adf6cc
revert to line length of 80
leifdenby Aug 20, 2024
46b37f8
revert docstring formatting changes
leifdenby Aug 20, 2024
3cd0f8b
pin numpy to <2.0.0
leifdenby Aug 20, 2024
826270a
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
leifdenby Aug 20, 2024
4ba22ea
Merge branch 'main' into feat/datastores
leifdenby Aug 20, 2024
1f661c6
fix flake8 linting errors
leifdenby Aug 20, 2024
4838872
Update neural_lam/weather_dataset.py
leifdenby Sep 8, 2024
b59e7e5
Update neural_lam/datastore/multizarr/create_normalization_stats.py
leifdenby Sep 8, 2024
75b1fe7
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
7e736cb
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
613a7e2
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
65e199b
Update tests/test_training.py
leifdenby Sep 8, 2024
4435e26
Update tests/test_datasets.py
leifdenby Sep 8, 2024
4693408
Update README.md
leifdenby Sep 8, 2024
2dfed2c
update README
leifdenby Sep 10, 2024
c3d033d
Merge branch 'main' of https://github.com/mllam/neural-lam into feat/…
leifdenby Sep 10, 2024
4a70268
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
66c663f
column_water -> open_water_fraction
leifdenby Sep 10, 2024
11a7978
fix linting
leifdenby Sep 10, 2024
a41c314
static data same for all splits
leifdenby Sep 10, 2024
6f1efd6
forcing_window_size from args
leifdenby Sep 10, 2024
bacb9ec
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
4a9db4e
only use first ensemble member in datastores
leifdenby Sep 10, 2024
4fc2448
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
bcaa919
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
90bc594
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
5bda935
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
8e7931d
remove all multizarr functionality
leifdenby Sep 10, 2024
6998683
cleanup and test fixes for recent changes
leifdenby Sep 10, 2024
c415008
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
735d324
fix linting
leifdenby Sep 10, 2024
5f2d919
remove multizar example files
leifdenby Sep 10, 2024
5263d2c
normalization -> standardization
leifdenby Sep 10, 2024
ba1bec3
fix import for tests
leifdenby Sep 10, 2024
d04d15e
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
743d7a1
fix coord issues and add datastore example plotting cli
leifdenby Sep 12, 2024
ac10d7d
add lru_cache to get_xy_extent
leifdenby Sep 12, 2024
bf8172a
MLLAMDatastore -> MDPDatastore
leifdenby Sep 12, 2024
90ca400
missed renames for MDPDatastore
leifdenby Sep 12, 2024
154139d
update graph plot for datastores
leifdenby Sep 12, 2024
50ee0b0
use relative import
leifdenby Sep 12, 2024
7dfd570
add long_names and refactor npyfiles create weights
leifdenby Sep 12, 2024
2b45b5a
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
aee0b1c
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
8453c2b
Update neural_lam/models/ar_model.py
leifdenby Sep 27, 2024
7f32557
Update neural_lam/weather_dataset.py
leifdenby Sep 27, 2024
67998b8
read projection from datastore config extra section
leifdenby Sep 27, 2024
ac7e46a
NpyFilesDatastore -> NpyFilesDatastoreMEPS
leifdenby Sep 27, 2024
b7bf506
revert tp training with 1 AR step by default
leifdenby Sep 27, 2024
5df2ecf
add missing kwarg to BaseHiGraphModel.__init__
leifdenby Sep 27, 2024
d4d438f
add missing kwarg to HiLAM.__init__
leifdenby Sep 27, 2024
1889771
add missing kwarg to HiLAMParallel
leifdenby Sep 27, 2024
2c3bbde
check that for enough forecast steps given ar_steps
leifdenby Sep 27, 2024
f0a151b
remove numpy<2.0.0 version cap
leifdenby Sep 27, 2024
f3566b0
tweak print statement working in mdp
Oct 1, 2024
dba94b3
fix missed removed argument from cli
Oct 1, 2024
bca1482
remove wandb config log comment, we log now
Oct 1, 2024
fc973c4
ensure loading from checkpoint during train possible
Oct 1, 2024
9fcf06e
get step_length from datastore in plot_error_map
leifdenby Oct 1, 2024
2bbe666
remove step_legnth attr in ARModel
leifdenby Oct 1, 2024
b41ed2f
remove unused obs_mask arg for vis.plot_prediction
leifdenby Oct 1, 2024
7e46194
ensure no reference to multizarr "data_config"
leifdenby Oct 1, 2024
b57bc7a
introduce neural-lam config
leifdenby Oct 2, 2024
2b30715
include meps neural-lam config example
leifdenby Oct 2, 2024
8e7b2e6
fix extra space typo in BaseDatastore
leifdenby Oct 2, 2024
e0300fb
add check and print of train/test/val split in MDPDatastore
leifdenby Oct 2, 2024
a921e35
add experimental mlflow server support
leifdenby Oct 2, 2024
0f30259
more fixes for mlflow logging support
leifdenby Oct 3, 2024
3fbe2d0
Make wandb work again with pytorch_lightning.logger
khintz Oct 3, 2024
e0284a8
upload of artifact to mlflow works, but instantiates a new experiment
khintz Oct 4, 2024
7eed79b
make mlflow use same experiment run id as pl.logger.MLFlowLogger
khintz Oct 7, 2024
27408f2
logger artifact working for both wandb and mlflow
khintz Oct 7, 2024
e61a9e7
support mlflow system metrics logging
khintz Oct 7, 2024
b53bab5
support model logging for mlflow
khintz Oct 7, 2024
de27e9a
log model
khintz Oct 7, 2024
89d8cde
test system metrics
khintz Nov 13, 2024
54c7ca7
make mlflow work also for eval mode
khintz Nov 15, 2024
a47de0c
dummy prints to identify workflow
khintz Nov 21, 2024
10a4494
update mlflow on eval mode
khintz Nov 21, 2024
427a4b1
Merge branch 'main' into feat/mlflow
khintz Nov 21, 2024
78e874d
inspect plot routines
khintz Nov 25, 2024
5904cbe
identified issue, cleanup next
leifdenby Nov 25, 2024
efe0302
use xarray plot only
leifdenby Nov 26, 2024
a489c2e
don't reraise
leifdenby Nov 26, 2024
242d08b
remove debug plot
leifdenby Nov 26, 2024
c1f706c
remove extent calc used in diagnosing issue
leifdenby Nov 26, 2024
88ec9dc
Test order of dimension in eval plots
khintz Nov 28, 2024
d367cdb
Merge branch 'fix/eval-vis-plots' into feat/mlflow
khintz Nov 28, 2024
90f8918
fix tensors on cpu and plot time index
khintz Nov 28, 2024
53f0ea4
restore tests/test_datasets.py
khintz Nov 29, 2024
cfc249f
cleaning up with focus on linting
khintz Nov 29, 2024
b218c8b
update tests
khintz Nov 29, 2024
1f1aed8
use correct data module for input example
khintz Dec 4, 2024
f3abd47
Merge branch 'main' into feat/mlflow
khintz Dec 4, 2024
47932b5
clean log model function
khintz Dec 9, 2024
98dc5c4
Merge branch 'main' into feat/mlflow
khintz Dec 9, 2024
64971ae
revert bad merge
khintz Dec 9, 2024
010f716
remove unused init for datastore
khintz Dec 9, 2024
2620bd1
set logger url
khintz Dec 9, 2024
75a39e6
change type of default logger_url in config
khintz Dec 9, 2024
9d27a4c
linting
khintz Dec 9, 2024
8f42cd1
fix log_image issue in tests
khintz Dec 9, 2024
b5ebe6f
add entry to changelog
khintz Dec 9, 2024
ae69f3f
remove artifacts from earlier merging/rebase
khintz Dec 11, 2024
6e16035
catch error when aws credentials not set
khintz Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class TrainingConfig:
ManualStateFeatureWeighting, UniformFeatureWeighting
] = dataclasses.field(default_factory=UniformFeatureWeighting)

logger: str = "wandb"
logger_url: str = ""

Comment on lines +89 to +91
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we actually want the config to contain, vs what should be cmd-line arguments? I would have thought that the choice of logger would be an argparse flag, in a similar way as the plotting choices. My thought process is that logging/plotting does not affect the end product (trained model) whereas all the current options in the config does. But we are not really consistent with this divide either, as there are plenty of argparse options currently that change the model training.


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/datastore/plot_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,5 @@ def _parse_dict(arg_str):
selection=selection,
index_selection=index_selection,
)
plt.show()
# plt.show()
plt.savefig("plot_example.png")
71 changes: 50 additions & 21 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
import xarray as xr

# Local
Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(
# Store constant per-variable std.-dev. weighting
# NOTE that this is the inverse of the multiplicative weighting
# in wMSE/wMAE
# TODO: Do we need param_weights for this?
khintz marked this conversation as resolved.
Show resolved Hide resolved
self.register_buffer(
"per_var_std",
self.diff_std / torch.sqrt(self.feature_weights),
Expand Down Expand Up @@ -262,7 +262,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_std_list, dim=1
) # (B, pred_steps, num_grid_nodes, d_f)
else:
pred_std = self.per_var_std # (d_f,)
pred_std = self.diff_std # (d_f,)
khintz marked this conversation as resolved.
Show resolved Hide resolved

return prediction, pred_std

Expand Down Expand Up @@ -539,14 +539,18 @@ def plot_examples(self, batch, n_examples, split, prediction=None):

example_i = self.plotted_examples

wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
)
}
)
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
):

if isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{var_name}_example_{example_i}"
else:
key = f"{var_name}_example"

if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[fig], step=t_i)

plt.close(
"all"
) # Close all figs for this time step, saves memory
Expand All @@ -555,13 +559,15 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
torch.save(
pred_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_pred_{self.plotted_examples}.pt",
),
)
torch.save(
target_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_target_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_target_{self.plotted_examples}.pt",
),
)

Expand All @@ -582,16 +588,16 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
datastore=self._datastore,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
log_dict[full_log_name] = metric_fig

if prefix == "test":
# Save pdf
metric_fig.savefig(
os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
os.path.join(self.logger.save_dir, f"{full_log_name}.pdf")
)
# Save errors also as csv
np.savetxt(
os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
os.path.join(self.logger.save_dir, f"{full_log_name}.csv"),
metric_tensor.cpu().numpy(),
delimiter=",",
)
Expand Down Expand Up @@ -639,8 +645,25 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
)

# Ensure that log_dict has structure for
# logging as dict(str, plt.Figure)
assert all(
isinstance(key, str) and isinstance(value, plt.Figure)
for key, value in log_dict.items()
)

if self.trainer.is_global_zero and not self.trainer.sanity_checking:
wandb.log(log_dict) # Log all

current_epoch = self.trainer.current_epoch

for key, figure in log_dict.items():
# For other loggers than wandb, add epoch to key
if not isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{key}-{current_epoch}"

if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[figure])

plt.close("all") # Close all figs

def on_test_epoch_end(self):
Expand Down Expand Up @@ -672,9 +695,13 @@ def on_test_epoch_end(self):
)
]

# log all to same wandb key, sequentially
for fig in loss_map_figs:
wandb.log({"test_loss": wandb.Image(fig)})
# log all to same key, sequentially
for i, fig in enumerate(loss_map_figs):
key = "test_loss"
if not isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{key}_{i}"
if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[fig])

# also make without title and save as pdf
pdf_loss_map_figs = [
Expand All @@ -683,14 +710,16 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
pdf_loss_maps_dir = os.path.join(
self.logger.save_dir, "spatial_loss_maps"
)
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
os.path.join(wandb.run.dir, "mean_spatial_loss.pt"),
os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"),
)

self.spatial_loss_maps.clear()
Expand Down
145 changes: 135 additions & 10 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
from argparse import ArgumentParser

# Third-party
import mlflow

# for logging the model:
import mlflow.pytorch
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
from loguru import logger
from mlflow.models import infer_signature

# Local
from . import utils
Expand All @@ -23,6 +28,110 @@
}


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""
Custom MLFlow logger that adds functionality not present in the default
"""

def __init__(self, experiment_name, tracking_uri):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.log_param("run_id", self.run_id)

@property
def save_dir(self):
"""
Returns the directory where the MLFlow artifacts are saved
"""
return "mlruns"

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow

key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
"""
# Third-party
from PIL import Image

if step is not None:
key = f"{key}_{step}"

# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but is buggy
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
mlflow.log_image(img, f"{key}.png")

def log_model(self, data_module, model):
input_example = self.create_input_example(data_module)

with torch.no_grad():
model_output = model.common_step(input_example)[
0
] # common_step returns tuple (prediction, target, pred_std, _)

log_model_input_example = {
name: tensor.cpu().numpy()
for name, tensor in zip(
["init_states", "target_states", "forcing", "target_times"],
input_example,
)
}

signature = infer_signature(
log_model_input_example, model_output.cpu().numpy()
)

mlflow.pytorch.log_model(
model,
"model",
signature=signature,
)

def create_input_example(self, data_module):

if data_module.val_dataset is None:
data_module.setup(stage="fit")

data_loader = data_module.train_dataloader()
batch_sample = next(iter(data_loader))
return batch_sample
Comment on lines +81 to +114
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_model and thereby also create_input_example is not used so they can be removed. However it can be used to log a model if one wishes to do so, but with an input example that is not validated.
I will vote for removing this and revisit in another PR if needed.



def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
elif config.training.logger == "mlflow":
url = config.training.logger_url
if url is None:
raise ValueError(
"MLFlow logger requires a URL to the MLFlow server"
)
logger = CustomMLFlowLogger(
experiment_name=args.wandb_project,
tracking_uri=url,
)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)

return logger


@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
Expand Down Expand Up @@ -163,6 +272,12 @@ def main(input_args=None):
help="Number of example predictions to plot during evaluation "
"(default: 1)",
)
parser.add_argument(
"--save_predictions",
action="store_true",
help="If predictions should be saved to disk as a zarr dataset "
"(default: false)",
)

# Logger Settings
parser.add_argument(
Expand Down Expand Up @@ -261,24 +376,30 @@ def main(input_args=None):
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
)

training_logger = _setup_training_logger(
config=config, datastore=datastore, args=args, run_name=run_name
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=f"saved_models/{run_name}",
filename="min_val_loss",
monitor="val_mean_loss",
mode="min",
save_last=True,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
devices=4,
# devices=[1,2],
# devices=[0, 1, 2],
# strategy="auto",
# devices=1, # For eval mode
# num_nodes=1, # For eval mode
accelerator=device_name,
logger=logger,
logger=training_logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
Expand All @@ -287,11 +408,15 @@ def main(input_args=None):

# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
utils.init_training_logger_metrics(
training_logger, val_steps=args.val_steps_to_log
) # Do after initializing logger
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
trainer.test(
model=model,
datamodule=data_module,
ckpt_path=args.load,
)
else:
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)

Expand Down
Loading
Loading