Skip to content

Commit

Permalink
Merge branch 'basic_tutorial' into actions/black
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Oct 18, 2024
2 parents 60a2364 + 39ed9eb commit 7888713
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
print(
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def iterate(self, num_iterations, model, optimizer, device):
"""
t_start_fetch = time.time()

print("Starting iteration!")
logger.debug("Starting iteration!")

for iteration in range(self.iteration, self.iteration + num_iterations):
raw, gt, target, weight, mask = self.next()
Expand Down
2 changes: 1 addition & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
create_weights_store,
)
from dacapo.experiments import Run
from dacapo.validate import validate_run
from dacapo.validate import validate_run, validate

import torch
from tqdm import tqdm
Expand Down
107 changes: 51 additions & 56 deletions docs/source/notebooks/minimal_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
# # Minimal Tutorial
#

# %% [markdown]
# ## Needed Libraries for this Tutorial
# For the tutorial we will use data from the `skimage` library, and we will use `matplotlib` to visualize the data. You can install these libraries using the following commands:
#
# ```bash
# pip install 'scikit-image[data]'
# pip install matplotlib
# ```

# %% [markdown]
# ## Introduction and overview
#
Expand Down Expand Up @@ -69,7 +78,7 @@

# %% Create some data

import random
# import random

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -141,6 +150,7 @@
)
* 255
)
print("Data saved to cells3d.zarr")


# %% [markdown]
Expand All @@ -159,59 +169,39 @@
# experiments, but is useful for this tutorial.

# %%
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
IntensitiesArrayConfig,
)
from funlib.geometry import Coordinate

datasplit_config = TrainValidateDataSplitConfig(
name="example_datasplit",
train_configs=[
RawGTDatasetConfig(
name="example_dataset",
raw_config=IntensitiesArrayConfig(
name="example_raw_normalized",
source_array_config=ZarrArrayConfig(
name="example_raw",
file_name="cells3d.zarr",
dataset="raw",
),
min=0,
max=255,
),
gt_config=ZarrArrayConfig(
name="example_gt",
file_name="cells3d.zarr",
dataset="mask",
),
)
],
validate_configs=[
RawGTDatasetConfig(
name="example_dataset",
raw_config=IntensitiesArrayConfig(
name="example_raw_normalized",
source_array_config=ZarrArrayConfig(
name="example_raw",
file_name="cells3d.zarr",
dataset="raw",
),
min=0,
max=255,
),
gt_config=ZarrArrayConfig(
name="example_gt",
file_name="cells3d.zarr",
dataset="mask",
),
)
],
)
from dacapo.experiments.datasplits import DataSplitGenerator, DatasetSpec

dataspecs = [
DatasetSpec(
dataset_type="train",
raw_container="cells3d.zarr",
raw_dataset="raw",
gt_container="cells3d.zarr",
gt_dataset="mask",
),
DatasetSpec(
dataset_type="val",
raw_container="cells3d.zarr",
raw_dataset="raw",
gt_container="cells3d.zarr",
gt_dataset="mask",
),
]

datasplit_config = DataSplitGenerator(
name="skimage_tutorial_data",
datasets=dataspecs,
input_resolution=voxel_size,
output_resolution=voxel_size,
targets=["cell"],
).compute()


# %%
datasplit = datasplit_config.datasplit_type(datasplit_config)
viewer = datasplit._neuroglancer()

# %%
config_store.store_datasplit_config(datasplit_config)

# %% [markdown]
Expand All @@ -228,18 +218,20 @@
# note that the clip_distance, tol_distance, and scale_factor are in nm
dist_task_config = DistanceTaskConfig(
name="example_dist",
channels=["mito"],
channels=["cell"],
clip_distance=260 * 10.0,
tol_distance=260 * 10.0,
scale_factor=260 * 20.0,
)
# config_store.delete_task_config(dist_task_config.name)
config_store.store_task_config(dist_task_config)

# an example affinities task configuration
affs_task_config = AffinitiesTaskConfig(
name="example_affs",
neighborhood=[(0, 1, 0), (0, 0, 1)],
)
# config_store.delete_task_config(dist_task_config.name)
config_store.store_task_config(affs_task_config)

# %% [markdown]
Expand Down Expand Up @@ -287,7 +279,7 @@
batch_size=10,
learning_rate=0.0001,
num_data_fetchers=8,
snapshot_interval=100,
snapshot_interval=1000,
min_masked=0.05,
clip_raw=False,
)
Expand All @@ -303,7 +295,7 @@
from dacapo.experiments import RunConfig
from dacapo.experiments.run import Run

iterations = 10000
iterations = 2000
validation_interval = iterations // 4
run_config = RunConfig(
name="example_run",
Expand All @@ -327,13 +319,16 @@

# %%
from dacapo.train import train_run
from dacapo.validate import validate

# from dacapo.validate import validate
from dacapo.experiments.run import Run

from dacapo.store.create_store import create_config_store

config_store = create_config_store()

run = Run(config_store.retrieve_run_config("example_run"))

if __name__ == "__main__":
train_run(run)

Expand Down

0 comments on commit 7888713

Please sign in to comment.