Skip to content

Commit

Permalink
fix wandb logging
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Nov 12, 2024
1 parent 4e2d509 commit 56dd759
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
8 changes: 8 additions & 0 deletions spotiflow/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def get_args() -> argparse.Namespace:
default="tensorboard",
help="Logger to use for monitoring training. Defaults to 'tensorboard'.",
)
train_args.add_argument(
"--smart-crop",
type=str2bool,
required=False,
default=False,
help="Use smart cropping for training. Defaults to False.",
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -306,6 +313,7 @@ def main():
"pos_weight": args.pos_weight,
"num_train_samples":args.train_samples,
"finetuned_from": args.finetune_from,
"smart_crop": args.smart_crop,
},
)
log.info("Done!")
Expand Down
8 changes: 6 additions & 2 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import SimpleNamespace
from typing import Callable, Literal, Optional, Sequence, Tuple, Union

import datetime
import dask.array as da
import lightning.pytorch as pl
import numpy as np
Expand Down Expand Up @@ -467,9 +468,10 @@ def fit(
]

if logger == "tensorboard":
logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
logger = pl.loggers.TensorBoardLogger(save_dir=save_dir, name=f"spotiflow-{datetime.datetime.now().strftime('%Y%m%d_%H%M')}")
elif logger == "wandb":
logger = pl.loggers.WandbLogger(save_dir=save_dir)
Path(save_dir/"wandb").mkdir(parents=True, exist_ok=True)
logger = pl.loggers.WandbLogger(save_dir=save_dir, project="spotiflow", name=f"{datetime.datetime.now().strftime('%Y%m%d_%H%M')}")
else:
if logger != "none":
log.warning(f"Logger {logger} not implemented. Using no logger.")
Expand Down Expand Up @@ -1067,6 +1069,8 @@ def predict(
s_src_corr[:actual_n_dims]
]
points.append(p)
del out, img_t, tile, y_tile, p
torch.cuda.empty_cache()

if scale is not None and scale != 1:
y = zoom(y, (1.0 / scale, 1.0 / scale), order=1)
Expand Down
2 changes: 1 addition & 1 deletion spotiflow/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def log_images(self):
self.logger.log_image(
key="flow",
images=[
0.5 * (1 + np.squeeze(v, axis=0).transpose(1, 2, 0))
0.5 * (1 + v.transpose(1, 2, 0))
for v in self._valid_flows[:n_images_to_log]
],
step=self.global_step,
Expand Down
10 changes: 6 additions & 4 deletions spotiflow/utils/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tqdm.auto import tqdm
from dataclasses import dataclass, fields
from scipy.ndimage import map_coordinates

FWHM_CONSTANT = 2 * np.sqrt(2 * np.log(2))

Expand Down Expand Up @@ -100,11 +101,12 @@ def _estimate_params_single2(
y_range = np.arange(-window, window + 1)
y, x = np.meshgrid(y_range, x_range, indexing="ij")

# Crop around the spot
region = image[
# Crop around the spot with interpolation
y_indices, x_indices = np.mgrid[
center[0] - window : center[0] + window + 1,
center[1] - window : center[1] + window + 1,
center[1] - window : center[1] + window + 1
]
region = map_coordinates(image, [y_indices, x_indices], order=3, mode='reflect')

try:
mi, ma = np.min(region), np.max(region)
Expand Down Expand Up @@ -226,7 +228,7 @@ def estimate_params(
peak_range (np.ndarray): peak range of the spots
"""
img = np.pad(img, window, mode="reflect")
centers = np.asarray(centers).astype(int) + window
centers = np.asarray(centers) + window
if max_workers == 1:
params = tuple(
_estimate_params_single(
Expand Down

0 comments on commit 56dd759

Please sign in to comment.