diff --git a/.gitignore b/.gitignore index d040431..d955053 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,11 @@ models/*/ run_sbatch.sbatch slurm/ scripts/cooper/evaluation_results/ +analysis_results/ scripts/cooper/training/copy_testset.py scripts/rizzoli/upsample_data.py -scripts/cooper/training/find_rec_testset.py \ No newline at end of file +scripts/cooper/training/find_rec_testset.py +scripts/rizzoli/combine_2D_slices.py +scripts/rizzoli/combine_2D_slices_raw.py +scripts/cooper/remove_h5key.py +scripts/cooper/analysis/calc_AZ_area.py \ No newline at end of file diff --git a/scripts/aggregate_data_information.py b/scripts/aggregate_data_information.py index d90ec8c..7086b23 100644 --- a/scripts/aggregate_data_information.py +++ b/scripts/aggregate_data_information.py @@ -12,30 +12,24 @@ stem = "STEM" -def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions): +def aggregate_vesicle_train_data(roots, conditions, resolutions): tomo_names = [] - tomo_vesicles = [] + tomo_vesicles_all, tomo_vesicles_imod = [], [] tomo_condition = [] tomo_resolution = [] tomo_train = [] - for ds, root in roots.items(): - print("Aggregate data for", ds) - train_root = root["train"] - if train_root == "": - test_root = root["test"] - tomograms = sorted(glob(os.path.join(test_root, "2024**", "*.h5"), recursive=True)) - this_test_tomograms = [os.path.basename(tomo) for tomo in tomograms] + def aggregate_split(ds, split_root, split): + if ds.startswith("04"): + tomograms = sorted(glob(os.path.join(split_root, "2024**", "*.h5"), recursive=True)) else: - # This is only the case for 04, which is also nested - tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) - this_test_tomograms = test_tomograms[ds] + tomograms = sorted(glob(os.path.join(split_root, "*.h5"))) assert len(tomograms) > 0, ds this_condition = conditions[ds] this_resolution = resolutions[ds][0] - for tomo_path in tqdm(tomograms): + for tomo_path in tqdm(tomograms, desc=f"Aggregate {split}"): fname = os.path.basename(tomo_path) with h5py.File(tomo_path, "r") as f: try: @@ -43,24 +37,39 @@ def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) except KeyError: tomo_name = fname - n_label_sets = len(f["labels"]) - if n_label_sets > 2: - print(tomo_path, "contains the following labels:", list(f["labels"].keys())) - seg = f["labels/vesicles"][:] - n_vesicles = len(np.unique(seg)) - 1 + if "labels/vesicles/combined_vesicles" in f: + all_vesicles = f["labels/vesicles/combined_vesicles"][:] + imod_vesicles = f["labels/vesicles/masked_vesicles"][:] + n_vesicles_all = len(np.unique(all_vesicles)) - 1 + n_vesicles_imod = len(np.unique(imod_vesicles)) - 2 + else: + vesicles = f["labels/vesicles"][:] + n_vesicles_all = len(np.unique(vesicles)) - 1 + n_vesicles_imod = n_vesicles_all tomo_names.append(tomo_name) - tomo_vesicles.append(n_vesicles) + tomo_vesicles_all.append(n_vesicles_all) + tomo_vesicles_imod.append(n_vesicles_imod) tomo_condition.append(this_condition) tomo_resolution.append(this_resolution) - tomo_train.append("test" if fname in this_test_tomograms else "train/val") + tomo_train.append(split) + + for ds, root in roots.items(): + print("Aggregate data for", ds) + train_root = root["train"] + if train_root != "": + aggregate_split(ds, train_root, "train/val") + test_root = root["test"] + if test_root != "": + aggregate_split(ds, test_root, "test") df = pd.DataFrame({ "tomogram": tomo_names, "condition": tomo_condition, "resolution": tomo_resolution, "used_for": tomo_train, - "vesicle_count": tomo_vesicles, + "vesicle_count_all": tomo_vesicles_all, + "vesicle_count_imod": tomo_vesicles_imod, }) os.makedirs("data_summary", exist_ok=True) @@ -70,15 +79,15 @@ def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) def vesicle_train_data(): roots = { "01": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/01_hoi_maus_2020_incomplete", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/01_hoi_maus_2020_incomplete", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/01_hoi_maus_2020_incomplete", # noqa }, "02": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/02_hcc_nanogold", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/02_hcc_nanogold", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/02_hcc_nanogold", # noqa }, "03": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/03_hog_cs1sy7", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/03_hog_cs1sy7", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/03_hog_cs1sy7", # noqa }, "04": { @@ -86,44 +95,31 @@ def vesicle_train_data(): "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval/", # noqa }, "05": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/05_stem750_sv_training", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/05_stem750_sv_training", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/05_stem750_sv_training", # noqa }, "07": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/07_hoi_s1sy7_tem250_ihgp", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/07_hoi_s1sy7_tem250_ihgp", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/07_hoi_s1sy7_tem250_ihgp", # noqa }, "09": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/09_stem750_66k", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/09_stem750_66k", # noqa "test": "", }, "10": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/10_tem_single_release", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/10_tem_single_release", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/10_tem_single_release", # noqa }, "11": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/11_tem_multiple_release", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/11_tem_multiple_release", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/11_tem_multiple_release", # noqa }, "12": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/12_chemical_fix_cryopreparation", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/12_chemical_fix_cryopreparation", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/12_chemical_fix_cryopreparation", # noqa }, } - test_tomograms = { - "01": ["tomogram-009.h5", "tomogram-038.h5", "tomogram-049.h5", "tomogram-052.h5", "tomogram-057.h5", "tomogram-060.h5", "tomogram-067.h5", "tomogram-074.h5", "tomogram-076.h5", "tomogram-083.h5", "tomogram-133.h5", "tomogram-136.h5", "tomogram-145.h5", "tomogram-149.h5", "tomogram-150.h5"], # noqa - "02": ["tomogram-004.h5", "tomogram-008.h5"], - "03": ["tomogram-003.h5", "tomogram-004.h5", "tomogram-008.h5",], - "04": [], # all used for test - "05": ["tomogram-003.h5", "tomogram-005.h5",], - "07": ["tomogram-006.h5", "tomogram-017.h5",], - "09": [], # no test data - "10": ["tomogram-001.h5", "tomogram-002.h5", "tomogram-007.h5"], - "11": ["tomogram-001.h5 tomogram-007.h5 tomogram-008.h5"], - "12": ["tomogram-004.h5", "tomogram-021.h5", "tomogram-022.h5",], - } - conditions = { "01": single_ax_tem, "02": dual_ax_tem, @@ -150,7 +146,7 @@ def vesicle_train_data(): "12": (1.554, 1.554, 1.554) } - aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) + aggregate_vesicle_train_data(roots, conditions, resolutions) def aggregate_az_train_data(roots, test_tomograms, conditions, resolutions): @@ -397,6 +393,11 @@ def vesicle_domain_adaptation_data(): "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ], + "frog": [ + "block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5", ] } @@ -439,13 +440,42 @@ def vesicle_domain_adaptation_data(): aggregate_da(roots, train_tomograms, test_tomograms, resolutions) +def get_n_images_frog(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/rizzoli/extracted/upsampled_by2" + tomos = ["block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5"] + + n_images = 0 + for tomo in tomos: + path = os.path.join(root, tomo) + with h5py.File(path, "r") as f: + n_images += f["raw"].shape[0] + print(n_images) + + +def get_image_sizes_tem_2d(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/maus_2020_tem2d_wt_unt_div14_exported_scaled/good_for_DAtraining/maus_2020_tem2d_wt_unt_div14_exported_scaled" # noqa + tomos = [ + "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", + "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", + "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ] + for tomo in tomos: + path = os.path.join(root, tomo) + with h5py.File(path, "r") as f: + print(f["raw"].shape) + + def main(): # active_zone_train_data() # compartment_train_data() # mito_train_data() - # vesicle_train_data() + vesicle_train_data() - vesicle_domain_adaptation_data() + # vesicle_domain_adaptation_data() + # get_n_images_frog() + # get_image_sizes_tem_2d() main() diff --git a/scripts/cooper/AZ_segmentation_h5.py b/scripts/cooper/AZ_segmentation_h5.py new file mode 100644 index 0000000..da694c1 --- /dev/null +++ b/scripts/cooper/AZ_segmentation_h5.py @@ -0,0 +1,173 @@ +import argparse +import h5py +import os +from pathlib import Path + +from tqdm import tqdm +from elf.io import open_file + +from synaptic_reconstruction.inference.AZ import segment_AZ +from synaptic_reconstruction.inference.util import parse_tiling + +def _require_output_folders(output_folder): + #seg_output = os.path.join(output_folder, "segmentations") + seg_output = output_folder + os.makedirs(seg_output, exist_ok=True) + return seg_output + +def get_volume(input_path): + ''' + with h5py.File(input_path) as seg_file: + input_volume = seg_file["raw"][:] + ''' + with open_file(input_path, "r") as f: + + # Try to automatically derive the key with the raw data. + keys = list(f.keys()) + if len(keys) == 1: + key = keys[0] + elif "data" in keys: + key = "data" + elif "raw" in keys: + key = "raw" + + input_volume = f[key][:] + return input_volume + +def run_AZ_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, key_label, compartment_seg): + tiling = parse_tiling(tile_shape, halo) + print(f"using tiling {tiling}") + input = get_volume(input_path) + + #check if we have a restricting mask for the segmentation + if mask_path is not None: + with open_file(mask_path, "r") as f: + mask = f[mask_key][:] + else: + mask = None + + #check if intersection with compartment is necessary + if compartment_seg is None: + foreground = segment_AZ(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, mask = mask) + intersection = None + else: + with open_file(compartment_seg, "r") as f: + compartment = f["/labels/compartment"][:] + foreground, intersection = segment_AZ(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, mask = mask, compartment=compartment) + + seg_output = _require_output_folders(output_path) + file_name = Path(input_path).stem + seg_path = os.path.join(seg_output, f"{file_name}.h5") + + #check + os.makedirs(Path(seg_path).parent, exist_ok=True) + + print(f"Saving results in {seg_path}") + with h5py.File(seg_path, "a") as f: + if "raw" in f: + print("raw image already saved") + else: + f.create_dataset("raw", data=input, compression="gzip") + + key=f"AZ/segment_from_{key_label}" + if key in f: + print("Skipping", input_path, "because", key, "exists") + else: + f.create_dataset(key, data=foreground, compression="gzip") + + if mask is not None: + if mask_key in f: + print("mask image already saved") + else: + f.create_dataset(mask_key, data = mask, compression = "gzip") + + if intersection is not None: + intersection_key = "AZ/compartment_AZ_intersection" + if intersection_key in f: + print("intersection already saved") + else: + f.create_dataset(intersection_key, data = intersection, compression = "gzip") + + + + +def segment_folder(args): + input_files = [] + for root, dirs, files in os.walk(args.input_path): + input_files.extend([ + os.path.join(root, name) for name in files if name.endswith(args.data_ext) + ]) + print(input_files) + pbar = tqdm(input_files, desc="Run segmentation") + for input_path in pbar: + + filename = os.path.basename(input_path) + try: + mask_path = os.path.join(args.mask_path, filename) + except: + print(f"Mask file not found for {input_path}") + mask_path = None + + if args.compartment_seg is not None: + try: + compartment_seg = os.path.join(args.compartment_seg, os.path.splitext(filename)[0] + '.h5') + except: + print(f"compartment file not found for {input_path}") + compartment_seg = None + + run_AZ_segmentation(input_path, args.output_path, args.model_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.key_label, compartment_seg) + +def main(): + parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.") + parser.add_argument( + "--input_path", "-i", required=True, + help="The filepath to the mrc file or the directory containing the tomogram data." + ) + parser.add_argument( + "--output_path", "-o", required=True, + help="The filepath to directory where the segmentations will be saved." + ) + parser.add_argument( + "--model_path", "-m", required=True, help="The filepath to the vesicle model." + ) + parser.add_argument( + "--mask_path", help="The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key." + ) + parser.add_argument( + "--mask_key", help="Key name that holds the mask segmentation" + ) + parser.add_argument( + "--tile_shape", type=int, nargs=3, + help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient." + ) + parser.add_argument( + "--halo", type=int, nargs=3, + help="The halo for prediction. Increase the halo to minimize boundary artifacts." + ) + parser.add_argument( + "--key_label", "-k", default = "combined_vesicles", + help="Give the key name for saving the segmentation in h5." + ) + parser.add_argument( + "--data_ext", "-d", default = ".h5", + help="Format extension of data to be segmented, default is .h5." + ) + parser.add_argument( + "--compartment_seg", "-c", + help="Path to compartment segmentation." + "If the compartment segmentation was executed before, this will add a key to output file that stores the intersection between compartment boundary and AZ." + "Maybe need to adjust the compartment key that the segmentation is stored under" + ) + args = parser.parse_args() + + input_ = args.input_path + + if os.path.isdir(input_): + segment_folder(args) + else: + run_AZ_segmentation(input_, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.key_label, args.compartment_seg) + + print("Finished segmenting!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/cooper/analysis/active_zone_analysis.py b/scripts/cooper/analysis/active_zone_analysis.py index d2234c9..bb13ac5 100644 --- a/scripts/cooper/analysis/active_zone_analysis.py +++ b/scripts/cooper/analysis/active_zone_analysis.py @@ -3,15 +3,22 @@ import h5py import numpy as np +import napari +import pandas as pd from scipy.ndimage import binary_closing from skimage.measure import label from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter +from synaptic_reconstruction.morphology import skeletonize_object +from synaptic_reconstruction.distance_measurements import measure_segmentation_to_object_distances from tqdm import tqdm -ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/final_Imig2014_seg_autoComp" # noqa +from compute_skeleton_area import calculate_surface_area -OUTPUT_AZ = "./boundary_az" +ROOT = "./imig_data" # noqa +OUTPUT_AZ = "./az_segmentation" + +RESOLUTION = (1.554,) * 3 def filter_az(path): @@ -20,6 +27,7 @@ def filter_az(path): ds = os.path.basename(ds) out_path = os.path.join(OUTPUT_AZ, ds, fname) os.makedirs(os.path.join(OUTPUT_AZ, ds), exist_ok=True) + if os.path.exists(out_path): return @@ -56,11 +64,192 @@ def filter_az(path): f.create_dataset("filtered_az", data=az_filtered, compression="gzip") -def main(): +def filter_all_azs(): files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) - for ff in tqdm(files): + for ff in tqdm(files, desc="Filter AZ segmentations."): filter_az(ff) +def process_az(path, view=True): + key = "thin_az" + + with h5py.File(path, "r") as f: + if key in f and not view: + return + az_seg = f["filtered_az"][:] + + az_thin = skeletonize_object(az_seg) + + if view: + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + raw_path = os.path.join(ROOT, ds, fname) + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(az_seg) + v.add_labels(az_thin) + napari.run() + else: + with h5py.File(path, "a") as f: + f.create_dataset(key, data=az_thin, compression="gzip") + + +# Apply thinning to all active zones to obtain 1d surface. +def process_all_azs(): + files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + for ff in tqdm(files, desc="Thin AZ segmentations."): + process_az(ff, view=False) + + +def measure_az_area(path): + from skimage import measure + + with h5py.File(path, "r") as f: + seg = f["thin_az"][:] + + # Try via surface mesh. + verts, faces, normals, values = measure.marching_cubes(seg, spacing=RESOLUTION) + surface_area1 = measure.mesh_surface_area(verts, faces) + + # Try via custom function. + surface_area2 = calculate_surface_area(seg, voxel_size=RESOLUTION) + + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + + return pd.DataFrame({ + "Dataset": [ds], + "Tomogram": [fname], + "surface_mesh [nm^2]": [surface_area1], + "surface_custom [nm^2]": [surface_area2], + }) + + +# Measure the AZ surface areas. +def measure_all_areas(): + save_path = "./results/area_measurements.xlsx" + if os.path.exists(save_path): + return + + files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + area_table = [] + for ff in tqdm(files, desc="Measure AZ areas."): + area = measure_az_area(ff) + area_table.append(area) + area_table = pd.concat(area_table) + area_table.to_excel(save_path, index=False) + + manual_results = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/debug/surface/manualAZ_surface_area.xlsx" # noqa + manual_results = pd.read_excel(manual_results)[["Dataset", "Tomogram", "manual"]] + comparison_table = pd.merge(area_table, manual_results, on=["Dataset", "Tomogram"], how="inner") + comparison_table.to_excel("./results/area_comparison.xlsx", index=False) + + +def analyze_areas(): + import seaborn as sns + import matplotlib.pyplot as plt + + table = pd.read_excel("./results/area_comparison.xlsx") + + fig, axes = plt.subplots(2) + sns.scatterplot(data=table, x="manual", y="surface_mesh [nm^2]", ax=axes[0]) + sns.scatterplot(data=table, x="manual", y="surface_custom [nm^2]", ax=axes[1]) + plt.show() + + +def measure_distances(ves_path, az_path): + with h5py.File(az_path, "r") as f: + az = f["thin_az"][:] + + with h5py.File(ves_path, "r") as f: + vesicles = f["vesicles/segment_from_combined_vesicles"][:] + + distances, _, _, _ = measure_segmentation_to_object_distances(vesicles, az, resolution=RESOLUTION) + + ds, fname = os.path.split(az_path) + ds = os.path.basename(ds) + + return pd.DataFrame({ + "Dataset": [ds] * len(distances), + "Tomogram": [fname] * len(distances), + "Distance": distances, + }) + + +# Measure the AZ vesicle distances for all vesicles. +def measure_all_distances(): + save_path = "./results/vesicle_az_distances.xlsx" + if os.path.exists(save_path): + return + + ves_files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) + az_files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + assert len(ves_files) == len(az_files) + + dist_table = [] + for ves_file, az_file in tqdm(zip(ves_files, az_files), total=len(az_files), desc="Measure distances."): + dist = measure_distances(ves_file, az_file) + dist_table.append(dist) + dist_table = pd.concat(dist_table) + + dist_table.to_excel(save_path, index=False) + + +def reformat_distances(): + tab = pd.read_excel("./results/vesicle_az_distances.xlsx") + + munc_ko = {} + munc_ctrl = {} + + snap_ko = {} + snap_ctrl = {} + + for _, row in tab.iterrows(): + ds = row.Dataset + tomo = row.Tomogram + + if ds == "Munc13DKO": + if "CTRL" in tomo: + group = munc_ctrl + else: + group = munc_ko + else: + assert ds == "SNAP25" + if "CTRL" in tomo: + group = snap_ctrl + else: + group = snap_ko + + name = os.path.splitext(tomo)[0] + val = row["Distance [nm]"] + if name in group: + group[name].append(val) + else: + group[name] = [val] + + def save_tab(group, path): + n_ves_max = max(len(v) for v in group.values()) + group = {k: v + [np.nan] * (n_ves_max - len(v)) for k, v in group.items()} + group_tab = pd.DataFrame(group) + group_tab.to_excel(path, index=False) + + os.makedirs("./results/distances_formatted", exist_ok=True) + save_tab(munc_ko, "./results/distances_formatted/munc_ko.xlsx") + save_tab(munc_ctrl, "./results/distances_formatted/munc_ctrl.xlsx") + save_tab(snap_ko, "./results/distances_formatted/snap_ko.xlsx") + save_tab(snap_ctrl, "./results/distances_formatted/snap_ctrl.xlsx") + + +def main(): + # filter_all_azs() + # process_all_azs() + # measure_all_areas() + # analyze_areas() + # measure_all_distances() + reformat_distances() + + if __name__ == "__main__": main() diff --git a/scripts/cooper/analysis/compute_skeleton_area.py b/scripts/cooper/analysis/compute_skeleton_area.py new file mode 100644 index 0000000..6fb05d0 --- /dev/null +++ b/scripts/cooper/analysis/compute_skeleton_area.py @@ -0,0 +1,44 @@ +import numpy as np + + +def calculate_surface_area(skeleton, voxel_size=(1.0, 1.0, 1.0)): + """ + Calculate the surface area of a 3D skeletonized object. + + Parameters: + skeleton (3D array): Binary 3D skeletonized array. + voxel_size (tuple): Physical size of voxels (z, y, x). + + Returns: + float: Approximate surface area of the skeleton. + """ + # Define the voxel dimensions + voxel_area = ( + voxel_size[1] * voxel_size[2], # yz-face area + voxel_size[0] * voxel_size[2], # xz-face area + voxel_size[0] * voxel_size[1], # xy-face area + ) + + # Compute the number of exposed faces for each voxel + exposed_faces = 0 + directions = [ + (1, 0, 0), (-1, 0, 0), # x-axis neighbors + (0, 1, 0), (0, -1, 0), # y-axis neighbors + (0, 0, 1), (0, 0, -1), # z-axis neighbors + ] + + # Iterate over all voxels in the skeleton + for z, y, x in np.argwhere(skeleton): + for i, (dz, dy, dx) in enumerate(directions): + neighbor = (z + dz, y + dy, x + dx) + # Check if the neighbor is outside the volume or not part of the skeleton + if ( + 0 <= neighbor[0] < skeleton.shape[0] and + 0 <= neighbor[1] < skeleton.shape[1] and + 0 <= neighbor[2] < skeleton.shape[2] and + skeleton[neighbor] == 1 + ): + continue + exposed_faces += voxel_area[i // 2] + + return exposed_faces diff --git a/scripts/cooper/analysis/run_size_analysis.py b/scripts/cooper/analysis/run_size_analysis.py new file mode 100644 index 0000000..abad440 --- /dev/null +++ b/scripts/cooper/analysis/run_size_analysis.py @@ -0,0 +1,174 @@ +import os +from glob import glob + +import numpy as np +import pandas as pd +import h5py +from tqdm import tqdm +from synaptic_reconstruction.imod.to_imod import convert_segmentation_to_spheres + +DATA_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/segmentation/for_spatial_distribution_analysis/final_Imig2014_seg/" # noqa +PREDICTION_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/segmentation/for_spatial_distribution_analysis/final_Imig2014_seg/" # noqa +RESULT_FOLDER = "./analysis_results/AZ_filtered_autoComp" + +def get_compartment_with_max_overlap(compartments, vesicles): + """ + Given 3D numpy arrays of compartments and vesicles, this function returns a binary mask + of the compartment with the most overlap with vesicles based on the number of overlapping voxels. + + Parameters: + compartments (numpy.ndarray): 3D array of compartment labels. + vesicles (numpy.ndarray): 3D array of vesicle labels or binary mask. + + Returns: + numpy.ndarray: Binary mask of the compartment with the most overlap with vesicles. + """ + + unique_compartments = np.unique(compartments) + if 0 in unique_compartments: + unique_compartments = unique_compartments[unique_compartments != 0] + + max_overlap_count = 0 + best_compartment = None + + # Iterate over each compartment and calculate the overlap with vesicles + for compartment_label in unique_compartments: + compartment_mask = compartments == compartment_label + vesicle_mask = vesicles > 0 + + intersection = np.logical_and(compartment_mask, vesicle_mask) + overlap_count = np.sum(intersection) + + # Track the compartment with the most overlap in terms of voxel count + if overlap_count > max_overlap_count: + max_overlap_count = overlap_count + best_compartment = compartment_label + + final_mask = compartments == best_compartment + + return final_mask + +# We compute the sizes for all vesicles in the MANUALLY ANNOTATED compartment masks. +# We use the same logic in the size computation as for the vesicle extraction to IMOD, +# including the radius correction factor. --> not needed here +# The number of vesicles is automatically computed as the length of the size list. +def compute_sizes_for_all_tomorams_manComp(): + os.makedirs(RESULT_FOLDER, exist_ok=True) + + resolution = (1.554,) * 3 # Change for each dataset #1.554 for Munc and snap #0.8681 for 04 dataset + radius_factor = 1 + estimate_radius_2d = True + dataset_results = {} + + tomograms = sorted(glob(os.path.join(PREDICTION_ROOT, "**/*.h5"), recursive=True)) + for tomo in tqdm(tomograms): + ds_name, fname = os.path.split(tomo) + ds_name = os.path.split(ds_name)[1] + fname = os.path.splitext(fname)[0] + + # Determine if the tomogram is 'CTRL' or 'DKO' + category = "CTRL" if "CTRL" in fname else "DKO" + + if ds_name not in dataset_results: + dataset_results[ds_name] = {'CTRL': {}, 'DKO': {}} + + if fname in dataset_results[ds_name][category]: + continue + + # Load the vesicle segmentation from the predictions. + with h5py.File(tomo, "r") as f: + segmentation = f["/vesicles/segment_from_combined_vesicles"][:] + + input_path = os.path.join(DATA_ROOT, ds_name, f"{fname}.h5") + assert os.path.exists(input_path), input_path + + # Load the compartment mask from the tomogram + with h5py.File(input_path, "r") as f: + mask = f["labels/compartment"][:] + + segmentation[mask == 0] = 0 + _, sizes = convert_segmentation_to_spheres( + segmentation, resolution=resolution, radius_factor=radius_factor, estimate_radius_2d=estimate_radius_2d + ) + + + dataset_results[ds_name][category][fname] = sizes + + # Save each dataset's results into separate CSV files for CTRL and DKO tomograms + for ds_name, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + + output_path = os.path.join(RESULT_FOLDER, f"size_analysis_for_{ds_name}_{category}_rf1.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) + +# We compute the sizes for all vesicles in the AUTOMATIC SEGMENTED compartment masks. +# We use the same logic in the size computation as for the vesicle extraction to IMOD, +# including the radius correction factor. --> not needed here +# The number of vesicles is automatically computed as the length of the size list. +def compute_sizes_for_all_tomorams_autoComp(): + os.makedirs(RESULT_FOLDER, exist_ok=True) + + resolution = (1.554,) * 3 # Change for each dataset #1.554 for Munc and snap #0.8681 for 04 dataset + radius_factor = 1 + estimate_radius_2d = True + dataset_results = {} + + tomograms = sorted(glob(os.path.join(PREDICTION_ROOT, "**/*.h5"), recursive=True)) + for tomo in tqdm(tomograms): + ds_name, fname = os.path.split(tomo) + ds_name = os.path.split(ds_name)[1] + fname = os.path.splitext(fname)[0] + + # Determine if the tomogram is 'CTRL' or 'DKO' + category = "CTRL" if "CTRL" in fname else "DKO" + + if ds_name not in dataset_results: + dataset_results[ds_name] = {'CTRL': {}, 'DKO': {}} + + if fname in dataset_results[ds_name][category]: + continue + + # Load the vesicle segmentation from the predictions. + with h5py.File(tomo, "r") as f: + segmentation = f["/vesicles/segment_from_combined_vesicles"][:] + + input_path = os.path.join(DATA_ROOT, ds_name, f"{fname}.h5") + assert os.path.exists(input_path), input_path + + # Load the compartment mask from the tomogram + with h5py.File(input_path, "r") as f: + compartments = f["/compartments/segment_from_3Dmodel_v2"][:] + mask = get_compartment_with_max_overlap(compartments, segmentation) + + # if more than half of the vesicles (approximation, its checking pixel and not label) would get filtered by mask it means the compartment seg didn't work and thus we won't use the mask + if np.sum(segmentation[mask == 0] > 0) > (0.5 * np.sum(segmentation > 0)): + print(f"using no mask for {tomo}") + else: + segmentation[mask == 0] = 0 + _, sizes = convert_segmentation_to_spheres( + segmentation, resolution=resolution, radius_factor=radius_factor, estimate_radius_2d=estimate_radius_2d + ) + + dataset_results[ds_name][category][fname] = sizes + + # Save each dataset's results into separate CSV files for CTRL and DKO tomograms + for ds_name, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + + output_path = os.path.join(RESULT_FOLDER, f"size_analysis_for_{ds_name}_{category}_rf1.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) + +def main(): + #compute_sizes_for_all_tomorams_manComp() + compute_sizes_for_all_tomorams_autoComp() + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/run_spatial_distribution_analysis.py b/scripts/cooper/analysis/run_spatial_distribution_analysis.py new file mode 100644 index 0000000..6943484 --- /dev/null +++ b/scripts/cooper/analysis/run_spatial_distribution_analysis.py @@ -0,0 +1,187 @@ +import os +from glob import glob +import pandas as pd +import h5py +from tqdm import tqdm +from synaptic_reconstruction.distance_measurements import measure_segmentation_to_object_distances +import numpy as np + +DATA_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/segmentation/for_spatial_distribution_analysis/final_Imig2014_seg/" # noqa +PREDICTION_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/segmentation/for_spatial_distribution_analysis/final_Imig2014_seg/" # noqa +RESULT_FOLDER = "./analysis_results/AZ_filtered_autoComp" +AZ_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/az_seg_filtered" + + +def get_compartment_with_max_overlap(compartments, vesicles): + """ + Given 3D numpy arrays of compartments and vesicles, this function returns a binary mask + of the compartment with the most overlap with vesicles based on the number of overlapping voxels. + + Parameters: + compartments (numpy.ndarray): 3D array of compartment labels. + vesicles (numpy.ndarray): 3D array of vesicle labels or binary mask. + + Returns: + numpy.ndarray: Binary mask of the compartment with the most overlap with vesicles. + """ + + unique_compartments = np.unique(compartments) + if 0 in unique_compartments: + unique_compartments = unique_compartments[unique_compartments != 0] + + max_overlap_count = 0 + best_compartment = None + + # Iterate over each compartment and calculate the overlap with vesicles + for compartment_label in unique_compartments: + compartment_mask = compartments == compartment_label + vesicle_mask = vesicles > 0 + + intersection = np.logical_and(compartment_mask, vesicle_mask) + overlap_count = np.sum(intersection) + + # Track the compartment with the most overlap in terms of voxel count + if overlap_count > max_overlap_count: + max_overlap_count = overlap_count + best_compartment = compartment_label + + final_mask = compartments == best_compartment + + return final_mask + +# We compute the distances for all vesicles in the AUTOMATIC SEGMENTED compartment masks to the AZ. +# We use different resolution, depending on dataset. +# The closest distance is calculated, i.e., the closest point on the outer membrane of the vesicle to the AZ. +def compute_per_vesicle_distance_to_AZ_autoComp(separate_AZseg=False): + + os.makedirs(RESULT_FOLDER, exist_ok=True) + resolution = (1.554,) * 3 # Change for each dataset #1.554 for Munc and snap #0.8681 for 04 dataset + dataset_results = {} + tomograms = sorted(glob(os.path.join(PREDICTION_ROOT, "**/*.h5"), recursive=True)) + + for tomo in tqdm(tomograms): + ds_name, fname = os.path.split(tomo) + ds_name = os.path.split(ds_name)[1] + fname = os.path.splitext(fname)[0] + + # Determine if the tomogram is 'CTRL' or 'DKO' + category = "CTRL" if "CTRL" in fname else "DKO" + + if ds_name not in dataset_results: + dataset_results[ds_name] = {'CTRL': {}, 'DKO': {}} + + if fname in dataset_results[ds_name][category]: + continue + + + # Load the vesicle segmentation from the predictions + with h5py.File(tomo, "r") as f: + segmentation = f["/vesicles/segment_from_combined_vesicles"][:] + + #Check if AZ seg is stored in a different tomo or same + if separate_AZseg: + print(f"using AZ segmentation from {AZ_PATH}") + #Load the AZ segmentations + AZ_path = os.path.join(AZ_PATH, ds_name, f"{fname}.h5") + with h5py.File(AZ_path, "r") as f_AZ: + segmented_object = f_AZ["/thin_az"][:] + else: + segmented_object = f["/AZ/compartment_AZ_intersection"][:] + + #if AZ intersect is small, compartment seg didn't align with AZ so we use the normal AZ and not intersect + if (segmented_object == 0).all() or np.sum(segmented_object == 1) < 2000: + segmented_object = f["/AZ/segment_from_AZmodel_v3"][:] + + input_path = os.path.join(DATA_ROOT, ds_name, f"{fname}.h5") + assert os.path.exists(input_path), input_path + + # Load the compartment mask from the tomogram + with h5py.File(input_path, "r") as f: + compartments = f["/compartments/segment_from_3Dmodel_v2"][:] + mask = get_compartment_with_max_overlap(compartments, segmentation) + + #if more than half of the vesicles (approximation, its checking pixel and not label) would get filtered by mask it means the compartment seg didn't work and thus we won't use the mask + if np.sum(segmentation[mask == 0] > 0) > (0.5 * np.sum(segmentation > 0)): + print("using no mask") + else: + segmentation[mask == 0] = 0 + + distances, _, _, _ = measure_segmentation_to_object_distances( + segmentation, segmented_object=segmented_object, resolution=resolution + ) + + # Add distances to the dataset dictionary under the appropriate category + dataset_results[ds_name][category][fname] = distances + + # Save each dataset's results into separate CSV files for CTRL and DKO tomograms + for ds_name, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + output_path = os.path.join(RESULT_FOLDER, f"spatial_distribution_analysis_for_{ds_name}_{category}.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) + +# We compute the distances for all vesicles in the MANUALLY ANNOTATED compartment masks to the AZ. +# We use different resolution, depending on dataset. +# The closest distance is calculated, i.e., the closest point on the outer membrane of the vesicle to the AZ. +def compute_per_vesicle_distance_to_AZ_manComp(): + os.makedirs(RESULT_FOLDER, exist_ok=True) + + resolution = (1.554,) * 3 # Change for each dataset #1.554 for Munc and snap #0.8681 for 04 dataset + dataset_results = {} + tomograms = sorted(glob(os.path.join(PREDICTION_ROOT, "**/*.h5"), recursive=True)) + + for tomo in tqdm(tomograms): + ds_name, fname = os.path.split(tomo) + ds_name = os.path.split(ds_name)[1] + fname = os.path.splitext(fname)[0] + + # Determine if the tomogram is 'CTRL' or 'DKO' + category = "CTRL" if "CTRL" in fname else "DKO" + + if ds_name not in dataset_results: + dataset_results[ds_name] = {'CTRL': {}, 'DKO': {}} + + if fname in dataset_results[ds_name][category]: + continue + + # Load the vesicle segmentation from the predictions + with h5py.File(tomo, "r") as f: + segmentation = f["/vesicles/segment_from_combined_vesicles"][:] + segmented_object = f["/AZ/compartment_AZ_intersection_manComp"][:] + + input_path = os.path.join(DATA_ROOT, ds_name, f"{fname}.h5") + assert os.path.exists(input_path), input_path + + # Load the compartment mask from the tomogram + with h5py.File(input_path, "r") as f: + mask = f["/labels/compartment"][:] + + segmentation[mask == 0] = 0 + + distances, _, _, _ = measure_segmentation_to_object_distances( + segmentation, segmented_object=segmented_object, resolution=resolution + ) + + # Add distances to the dataset dictionary under the appropriate category + dataset_results[ds_name][category][fname] = distances + + # Save each dataset's results into separate CSV files for CTRL and DKO tomograms + for ds_name, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + output_path = os.path.join(RESULT_FOLDER, f"spatial_distribution_analysis_for_{ds_name}_{category}.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) +def main(): + compute_per_vesicle_distance_to_AZ_autoComp(separate_AZseg=False) + #compute_per_vesicle_distance_to_AZ_manComp() + + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/compartment_segmentation_h5.py b/scripts/cooper/compartment_segmentation_h5.py new file mode 100644 index 0000000..573ac48 --- /dev/null +++ b/scripts/cooper/compartment_segmentation_h5.py @@ -0,0 +1,116 @@ +import argparse +import h5py +import os +from pathlib import Path + +from tqdm import tqdm +from elf.io import open_file + +from synaptic_reconstruction.inference.compartments import segment_compartments +from synaptic_reconstruction.inference.util import parse_tiling + +def _require_output_folders(output_folder): + #seg_output = os.path.join(output_folder, "segmentations") + seg_output = output_folder + os.makedirs(seg_output, exist_ok=True) + return seg_output + +def get_volume(input_path): + + with open_file(input_path, "r") as f: + + # Try to automatically derive the key with the raw data. + keys = list(f.keys()) + if len(keys) == 1: + key = keys[0] + elif "data" in keys: + key = "data" + elif "raw" in keys: + key = "raw" + + input_volume = f[key][:] + return input_volume + +def run_compartment_segmentation(input_path, output_path, model_path, tile_shape, halo, key_label): + tiling = parse_tiling(tile_shape, halo) + print(f"using tiling {tiling}") + input = get_volume(input_path) + + segmentation, prediction = segment_compartments(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, scale=[0.25, 0.25, 0.25],boundary_threshold=0.2, postprocess_segments=False) + + seg_output = _require_output_folders(output_path) + file_name = Path(input_path).stem + seg_path = os.path.join(seg_output, f"{file_name}.h5") + + #check + os.makedirs(Path(seg_path).parent, exist_ok=True) + + print(f"Saving results in {seg_path}") + with h5py.File(seg_path, "a") as f: + if "raw" in f: + print("raw image already saved") + else: + f.create_dataset("raw", data=input, compression="gzip") + + key=f"compartments/segment_from_{key_label}" + if key in f: + print("Skipping", input_path, "because", key, "exists") + else: + f.create_dataset(key, data=segmentation, compression="gzip") + f.create_dataset(f"compartment_pred_{key_label}/foreground", data = prediction, compression="gzip") + + + + +def segment_folder(args): + input_files = [] + for root, dirs, files in os.walk(args.input_path): + input_files.extend([ + os.path.join(root, name) for name in files if name.endswith(args.data_ext) + ]) + print(input_files) + pbar = tqdm(input_files, desc="Run segmentation") + for input_path in pbar: + run_compartment_segmentation(input_path, args.output_path, args.model_path, args.tile_shape, args.halo, args.key_label) + +def main(): + parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.") + parser.add_argument( + "--input_path", "-i", required=True, + help="The filepath to the mrc file or the directory containing the tomogram data." + ) + parser.add_argument( + "--output_path", "-o", required=True, + help="The filepath to directory where the segmentations will be saved." + ) + parser.add_argument( + "--model_path", "-m", required=True, help="The filepath to the vesicle model." + ) + parser.add_argument( + "--tile_shape", type=int, nargs=3, + help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient." + ) + parser.add_argument( + "--halo", type=int, nargs=3, + help="The halo for prediction. Increase the halo to minimize boundary artifacts." + ) + parser.add_argument( + "--data_ext", "-d", default=".h5", help="The extension of the tomogram data. By default .h5." + ) + parser.add_argument( + "--key_label", "-k", default = "3Dmodel_v1", + help="Give the key name for saving the segmentation in h5." + ) + args = parser.parse_args() + + input_ = args.input_path + + if os.path.isdir(input_): + segment_folder(args) + else: + run_compartment_segmentation(input_, args.output_path, args.model_path, args.tile_shape, args.halo, args.key_label) + + print("Finished segmenting!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/cooper/export_mask_to_imod.py b/scripts/cooper/export_mask_to_imod.py index 98b4b2f..4273707 100644 --- a/scripts/cooper/export_mask_to_imod.py +++ b/scripts/cooper/export_mask_to_imod.py @@ -4,19 +4,11 @@ def export_mask_to_imod(args): - # Test script - # write_segmentation_to_imod( - # "synapse-examples/36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop.mrc", - # "synapse-examples/36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop_mitos.tif", - # "synapse-examples/mito.mod" - # ) write_segmentation_to_imod(args.input_path, args.segmentation_path, args.output_path) def main(): parser = argparse.ArgumentParser() - - args = parser.parse_args() parser.add_argument( "-i", "--input_path", required=True, help="The filepath to the mrc file containing the data." diff --git a/scripts/cooper/full_reconstruction/visualize_results.py b/scripts/cooper/full_reconstruction/visualize_results.py index 5e3f596..839626b 100644 --- a/scripts/cooper/full_reconstruction/visualize_results.py +++ b/scripts/cooper/full_reconstruction/visualize_results.py @@ -6,11 +6,14 @@ import numpy as np import pandas as pd +from skimage.filters import gaussian + ROOT = "./04_full_reconstruction" TABLE = "/home/pape/Desktop/sfb1286/mboc_synapse/draft_figures/full_reconstruction.xlsx" # Skip datasets for which all figures were already done. -SKIP_DS = ["20241019_Tomo-eval_MF_Synapse"] +SKIP_DS = ["20241019_Tomo-eval_MF_Synapse", "20241019_Tomo-eval_PS_Synapse"] +# SKIP_DS = [] def _get_name_and_row(path, table): @@ -46,13 +49,12 @@ def visualize_result(path, table): if ds_name in SKIP_DS: return - # if row["Use for vis"].values[0] == "yes": - if row["Use for vis"].values[0] in ("yes", "no"): + if row["Use for Vis"].values[0] == "no": return compartment_ids = _get_compartment_ids(row) # access = np.s_[:] - access = np.s_[::2, ::2, ::2] + access = np.s_[::3, ::3, ::3] with h5py.File(path, "r") as f: raw = f["raw"][access] @@ -60,6 +62,10 @@ def visualize_result(path, table): active_zone = f["labels/active_zone"][access] mitos = f["labels/mitochondria"][access] compartments = f["labels/compartments"][access] + print("Loading done") + + raw = gaussian(raw) + print("Gaussian done") if any(comp_ids is not None for comp_ids in compartment_ids): mask = np.zeros(raw.shape, dtype="bool") @@ -78,12 +84,14 @@ def visualize_result(path, table): mitos[~mask] = 0 compartments = compartments_new + vesicle_ids = np.unique(vesicles)[1:] + v = napari.Viewer() v.add_image(raw) v.add_labels(mitos) - v.add_labels(vesicles) - v.add_labels(compartments) - v.add_labels(active_zone) + v.add_labels(vesicles, colormap={ves_id: "orange" for ves_id in vesicle_ids}) + v.add_labels(compartments, colormap={1: "red", 2: "green", 3: "orange"}) + v.add_labels(active_zone, colormap={1: "blue"}) v.title = f"{ds_name}/{name}" napari.run() @@ -115,6 +123,7 @@ def main(): paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) table = pd.read_excel(TABLE) for path in paths: + print(path) visualize_result(path, table) # visualize_only_compartment(path, table) diff --git a/scripts/cooper/training/evaluate_AZ.py b/scripts/cooper/training/evaluate_AZ.py new file mode 100644 index 0000000..dbf8d67 --- /dev/null +++ b/scripts/cooper/training/evaluate_AZ.py @@ -0,0 +1,146 @@ +import argparse +import os + +import h5py +import pandas as pd +import numpy as np + +from elf.evaluation.dice import dice_score + +def extract_gt_bounding_box(segmentation, gt, halo=[20, 320, 320]): + # Find the bounding box for the ground truth + bb = np.where(gt > 0) + bb = tuple(slice( + max(int(b.min() - ha), 0), # Ensure indices are not below 0 + min(int(b.max() + ha), sh) # Ensure indices do not exceed shape dimensions + ) for b, sh, ha in zip(bb, gt.shape, halo)) + + # Apply the bounding box to both segmentations + segmentation_cropped = segmentation[bb] + gt_cropped = gt[bb] + + return segmentation_cropped, gt_cropped + +def evaluate(labels, segmentation): + assert labels.shape == segmentation.shape + score = dice_score(segmentation, labels) + return score + +def compute_precision(ground_truth, segmentation): + """ + Computes the Precision score for 3D arrays representing the ground truth and segmentation. + + Parameters: + - ground_truth (np.ndarray): 3D binary array where 1 represents the ground truth region. + - segmentation (np.ndarray): 3D binary array where 1 represents the predicted segmentation region. + + Returns: + - precision (float): The precision score, or 0 if the segmentation is empty. + """ + assert ground_truth.shape == segmentation.shape + # Ensure inputs are binary arrays + ground_truth = (ground_truth > 0).astype(np.int32) + segmentation = (segmentation > 0).astype(np.int32) + + # Compute intersection: overlap between segmentation and ground truth + intersection = np.sum(segmentation * ground_truth) + + # Compute total predicted (segmentation region) + total_predicted = np.sum(segmentation) + + # Handle case where there are no predictions + if total_predicted == 0: + return 0.0 # Precision is undefined; returning 0 + + # Calculate precision + precision = intersection / total_predicted + return precision + +def evaluate_file(labels_path, segmentation_path, model_name, crop= False, precision_score=False): + print(f"Evaluate labels {labels_path} and vesicles {segmentation_path}") + + ds_name = os.path.basename(os.path.dirname(labels_path)) + tomo = os.path.basename(labels_path) + + #get the labels and segmentation + with h5py.File(labels_path) as label_file: + gt = label_file["/labels/thin_az"][:] + + with h5py.File(segmentation_path) as seg_file: + segmentation = seg_file["/AZ/thin_az"][:] + + if crop: + print("cropping the annotation and segmentation") + segmentation, gt = extract_gt_bounding_box(segmentation, gt) + + # Evaluate the match of ground truth and segmentation + if precision_score: + precision = compute_precision(gt, segmentation) + else: + dice_score = evaluate(gt, segmentation) + + # Store results + result_folder = "/user/muth9/u12095/synaptic-reconstruction/scripts/cooper/evaluation_results" + os.makedirs(result_folder, exist_ok=True) + result_path = os.path.join(result_folder, f"evaluation_{model_name}_dice_thinpred_thinanno.csv") + print("Evaluation results are saved to:", result_path) + + # Load existing results if the file exists + if os.path.exists(result_path): + results = pd.read_csv(result_path) + else: + results = None + + # Create a new DataFrame for the current evaluation + if precision_score: + res = pd.DataFrame( + [[ds_name, tomo, precision]], columns=["dataset", "tomogram", "precision"] + ) + else: + res = pd.DataFrame( + [[ds_name, tomo, dice_score]], columns=["dataset", "tomogram", "dice_score"] + ) + + # Combine with existing results or initialize with the new results + if results is None: + results = res + else: + results = pd.concat([results, res]) + + # Save the results to the CSV file + results.to_csv(result_path, index=False) + +def evaluate_folder(labels_path, segmentation_path, model_name, crop = False, precision_score=False): + print(f"Evaluating folder {segmentation_path}") + print(f"Using labels stored in {labels_path}") + + label_files = os.listdir(labels_path) + vesicles_files = os.listdir(segmentation_path) + + for vesicle_file in vesicles_files: + if vesicle_file in label_files: + + evaluate_file(os.path.join(labels_path, vesicle_file), os.path.join(segmentation_path, vesicle_file), model_name, crop, precision_score) + + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("-l", "--labels_path", required=True) + parser.add_argument("-v", "--segmentation_path", required=True) + parser.add_argument("-n", "--model_name", required=True) + parser.add_argument("--crop", action="store_true", help="Crop around the annotation.") + parser.add_argument("--precision", action="store_true", help="Calculate precision score.") + args = parser.parse_args() + + segmentation_path = args.segmentation_path + if os.path.isdir(segmentation_path): + evaluate_folder(args.labels_path, segmentation_path, args.model_name, args.crop, args.precision) + else: + evaluate_file(args.labels_path, segmentation_path, args.model_name, args.crop, args.precision) + + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/training/evaluation.py b/scripts/cooper/training/evaluation.py index d7aaf6e..68fa863 100644 --- a/scripts/cooper/training/evaluation.py +++ b/scripts/cooper/training/evaluation.py @@ -21,7 +21,7 @@ def summarize_eval(results): table = summary.to_markdown(index=False) print(table) -def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key): +def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key, mask_key = None): print(f"Evaluate labels {labels_path} and vesicles {vesicles_path}") ds_name = os.path.basename(os.path.dirname(labels_path)) @@ -33,11 +33,16 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key) #vesicles = labels["vesicles"] gt = labels[anno_key][:] + if mask_key is not None: + mask = labels[mask_key][:] + with h5py.File(vesicles_path) as seg_file: segmentation = seg_file["vesicles"] vesicles = segmentation[segment_key][:] - + if mask_key is not None: + gt[mask == 0] = 0 + vesicles[mask == 0] = 0 #evaluate the match of ground truth and vesicles scores = evaluate(gt, vesicles) @@ -65,7 +70,7 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key) summarize_eval(results) -def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_key): +def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_key, mask_key = None): print(f"Evaluating folder {vesicles_path}") print(f"Using labels stored in {labels_path}") @@ -75,7 +80,7 @@ def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_ke for vesicle_file in vesicles_files: if vesicle_file in label_files: - evaluate_file(os.path.join(labels_path, vesicle_file), os.path.join(vesicles_path, vesicle_file), model_name, segment_key, anno_key) + evaluate_file(os.path.join(labels_path, vesicle_file), os.path.join(vesicles_path, vesicle_file), model_name, segment_key, anno_key, mask_key) @@ -87,13 +92,14 @@ def main(): parser.add_argument("-n", "--model_name", required=True) parser.add_argument("-sk", "--segment_key", required=True) parser.add_argument("-ak", "--anno_key", required=True) + parser.add_argument("-m", "--mask_key") args = parser.parse_args() vesicles_path = args.vesicles_path if os.path.isdir(vesicles_path): - evaluate_folder(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key) + evaluate_folder(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key, args.mask_key) else: - evaluate_file(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key) + evaluate_file(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key, args.mask_key) diff --git a/scripts/cooper/training/filter_AZ.py b/scripts/cooper/training/filter_AZ.py new file mode 100644 index 0000000..78b8ba7 --- /dev/null +++ b/scripts/cooper/training/filter_AZ.py @@ -0,0 +1,67 @@ +import os +import h5py +import numpy as np +from scipy.ndimage import binary_erosion, binary_dilation, label + +def process_labels(label_file_path, erosion_structure=None, dilation_structure=None): + """ + Process the labels: perform erosion, find the largest connected component, + and perform dilation on it. + + Args: + label_file_path (str): Path to the HDF5 file containing the label data. + erosion_structure (ndarray, optional): Structuring element for erosion. + dilation_structure (ndarray, optional): Structuring element for dilation. + + Returns: + None: The processed data is saved back into the HDF5 file under a new key. + """ + with h5py.File(label_file_path, "r+") as label_file: + # Read the ground truth data + gt = label_file["/labels/filtered_az"][:] + + # Perform binary erosion + eroded = binary_erosion(gt, structure=erosion_structure) + + # Label connected components + labeled_array, num_features = label(eroded) + + # Identify the largest connected component + if num_features > 0: + largest_component_label = np.argmax(np.bincount(labeled_array.flat, weights=eroded.flat)[1:]) + 1 + largest_component = (labeled_array == largest_component_label) + else: + largest_component = np.zeros_like(gt, dtype=bool) + + # Perform binary dilation on the largest connected component + dilated = binary_dilation(largest_component, structure=dilation_structure) + + # Save the result back into the HDF5 file + if "labels/erosion_filtered_az" in label_file: + del label_file["labels/erosion_filtered_az"] # Remove if it already exists + label_file.create_dataset("labels/erosion_filtered_az", data=dilated.astype(np.uint8), compression="gzip") + +def process_folder(folder_path, erosion_structure=None, dilation_structure=None): + """ + Process all HDF5 files in a folder. + + Args: + folder_path (str): Path to the folder containing HDF5 files. + erosion_structure (ndarray, optional): Structuring element for erosion. + dilation_structure (ndarray, optional): Structuring element for dilation. + + Returns: + None + """ + for file_name in os.listdir(folder_path): + if file_name.endswith(".h5") or file_name.endswith(".hdf5"): + label_file_path = os.path.join(folder_path, file_name) + print(f"Processing {label_file_path}...") + process_labels(label_file_path, erosion_structure, dilation_structure) + +# Example usage +if __name__ == "__main__": + folder_path = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/training_AZ_v2/postprocessed_AZ/12_chemical_fix_cryopreparation" # Replace with the path to your folder + erosion_structure = np.ones((3, 3, 3)) # Example structuring element + dilation_structure = np.ones((3, 3, 3)) # Example structuring element + process_folder(folder_path, erosion_structure, dilation_structure) diff --git a/scripts/cooper/training/postprocess_AZ.py b/scripts/cooper/training/postprocess_AZ.py new file mode 100644 index 0000000..e2b849e --- /dev/null +++ b/scripts/cooper/training/postprocess_AZ.py @@ -0,0 +1,107 @@ +import os +from glob import glob +import argparse + +import h5py +import numpy as np +from tqdm import tqdm +from scipy.ndimage import binary_closing +from skimage.measure import label +from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter +from synaptic_reconstruction.morphology import skeletonize_object + + + +def filter_az(path, output_path): + """Filter the active zone (AZ) data from the HDF5 file.""" + ds, fname = os.path.split(path) + dataset_name = os.path.basename(ds) + out_file_path = os.path.join(output_path, "postprocessed_AZ", dataset_name, fname) + + os.makedirs(os.path.dirname(out_file_path), exist_ok=True) + + if os.path.exists(out_file_path): + return + + with h5py.File(path, "r") as f: + raw = f["raw"][:] + az = f["AZ/segment_from_AZmodel_v3"][:] + + hmap = edge_filter(raw, sigma=1.0, method="sato", per_slice=True, n_threads=8) + + # Filter the active zone by combining a bunch of things: + # 1. Find a mask with high values in the ridge filter. + threshold_hmap = 0.5 + az_filtered = hmap > threshold_hmap + # 2. Intersect it with the active zone predictions. + az_filtered = np.logical_and(az_filtered, az) + + # Postprocessing of the filtered active zone: + # 1. Apply connected components and only keep the largest component. + az_filtered = label(az_filtered) + ids, sizes = np.unique(az_filtered, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + az_filtered = (az_filtered == ids[np.argmax(sizes)]).astype("uint8") + # 2. Apply binary closing. + az_filtered = np.logical_or(az_filtered, binary_closing(az_filtered, iterations=4)).astype("uint8") + + # Save the result. + with h5py.File(out_file_path, "a") as f: + f.create_dataset("AZ/filtered_az", data=az_filtered, compression="gzip") + + +def process_az(path, view=False): + """Skeletonize the filtered AZ data to obtain a 1D representation.""" + key = "AZ/thin_az" + with h5py.File(path, "r") as f: + if key in f and not view: + return + az_seg = f["AZ/filtered_az"][:] + + az_thin = skeletonize_object(az_seg) + + if view: + import napari + ds, fname = os.path.split(path) + raw_path = os.path.join(ROOT, ds, fname) + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(az_seg) + v.add_labels(az_thin) + napari.run() + else: + with h5py.File(path, "a") as f: + f.create_dataset(key, data=az_thin, compression="gzip") + + +def filter_all_azs(input_path, output_path): + """Apply filtering to all AZ data in the specified directory.""" + files = sorted(glob(os.path.join(input_path, "**/*.h5"), recursive=True)) + for ff in tqdm(files, desc="Filtering AZ segmentations"): + filter_az(ff, output_path) + + +def process_all_azs(output_path): + """Apply skeletonization to all filtered AZ data.""" + files = sorted(glob(os.path.join(output_path, "postprocessed_AZ", "**/*.h5"), recursive=True)) + for ff in tqdm(files, desc="Thinning AZ segmentations"): + process_az(ff, view=False) + + +def main(): + parser = argparse.ArgumentParser(description="Filter and process AZ data.") + parser.add_argument("input_path", type=str, help="Path to the root directory containing datasets.") + parser.add_argument("output_path", type=str, help="Path to the root directory for saving processed data.") + args = parser.parse_args() + + input_path = args.input_path + output_path = args.output_path + + filter_all_azs(input_path, output_path) + process_all_azs(output_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/training/train_AZ.py b/scripts/cooper/training/train_AZ.py index 1468eaf..9d7d283 100644 --- a/scripts/cooper/training/train_AZ.py +++ b/scripts/cooper/training/train_AZ.py @@ -12,7 +12,7 @@ from synaptic_reconstruction.training import semisupervised_training TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects" -OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/training_AZ_v1" +OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/training_AZ_v2" def _require_train_val_test_split(datasets): @@ -80,8 +80,11 @@ def get_paths(split, datasets, testset=True): def train(key, ignore_label = None, training_2D = False, testset = True): + os.makedirs(OUTPUT_ROOT, exist_ok=True) + datasets = [ "01_hoi_maus_2020_incomplete", + "04_hoi_stem_examples", "06_hoi_wt_stem750_fm", "12_chemical_fix_cryopreparation" ] @@ -93,7 +96,7 @@ def train(key, ignore_label = None, training_2D = False, testset = True): print(len(val_paths), "tomograms for validation") patch_shape = [48, 256, 256] - model_name=f"3D-AZ-model-v1" + model_name=f"3D-AZ-model-v3" #checking for 2D training if training_2D: @@ -109,11 +112,11 @@ def train(key, ignore_label = None, training_2D = False, testset = True): val_paths=val_paths, label_key=f"/labels/{key}", patch_shape=patch_shape, batch_size=batch_size, - sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=1), + sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=1, p_reject = 0.95), n_samples_train=None, n_samples_val=25, check=check, save_root="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/AZ_models", - n_iterations=int(5e3), + n_iterations=int(5e4), ignore_label= ignore_label, label_transform=torch_em.transform.label.labels_to_binary, out_channels = 1, diff --git a/scripts/cooper/vesicle_segmentation_h5.py b/scripts/cooper/vesicle_segmentation_h5.py index 9c8b1d1..1136f18 100644 --- a/scripts/cooper/vesicle_segmentation_h5.py +++ b/scripts/cooper/vesicle_segmentation_h5.py @@ -34,7 +34,7 @@ def get_volume(input_path): input_volume = f[key][:] return input_volume -def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, include_boundary, key_label): +def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, include_boundary, key_label, distance_threshold = None): tiling = parse_tiling(tile_shape, halo) print(f"using tiling {tiling}") input = get_volume(input_path) @@ -45,8 +45,17 @@ def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mas mask = f[mask_key][:] else: mask = None + if distance_threshold is not None: + segmentation, prediction = segment_vesicles( + input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, + exclude_boundary=not include_boundary, mask = mask, distance_threshold = distance_threshold + ) + else: + segmentation, prediction = segment_vesicles( + input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, + exclude_boundary=not include_boundary, mask = mask + ) - segmentation, prediction = segment_vesicles(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary, mask = mask) foreground, boundaries = prediction[:2] seg_output = _require_output_folders(output_path) @@ -84,7 +93,7 @@ def segment_folder(args): input_files = [] for root, dirs, files in os.walk(args.input_path): input_files.extend([ - os.path.join(root, name) for name in files if name.endswith(".h5") + os.path.join(root, name) for name in files if name.endswith(args.data_ext) ]) print(input_files) pbar = tqdm(input_files, desc="Run segmentation") @@ -97,7 +106,10 @@ def segment_folder(args): print(f"Mask file not found for {input_path}") mask_path = None - run_vesicle_segmentation(input_path, args.output_path, args.model_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label) + run_vesicle_segmentation( + input_path, args.output_path, args.model_path, mask_path, args.mask_key, + args.tile_shape, args.halo, args.include_boundary, args.key_label, args.distance_threshold + ) def main(): parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.") @@ -134,6 +146,14 @@ def main(): "--key_label", "-k", default = "combined_vesicles", help="Give the key name for saving the segmentation in h5." ) + parser.add_argument( + "--distance_threshold", "-t", type=int, + help="Used for distance based segmentation." + ) + parser.add_argument( + "--data_ext", "-d", default = ".h5", + help="Format extension of data to be segmented, default is .h5." + ) args = parser.parse_args() input_ = args.input_path @@ -141,7 +161,7 @@ def main(): if os.path.isdir(input_): segment_folder(args) else: - run_vesicle_segmentation(input_, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label) + run_vesicle_segmentation(input_, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label, args.distance_threshold) print("Finished segmenting!") diff --git a/scripts/data_summary/active_zone_training_data.xlsx b/scripts/data_summary/active_zone_training_data.xlsx new file mode 100644 index 0000000..b193653 Binary files /dev/null and b/scripts/data_summary/active_zone_training_data.xlsx differ diff --git a/scripts/data_summary/compartment_training_data.xlsx b/scripts/data_summary/compartment_training_data.xlsx new file mode 100644 index 0000000..e141f0b Binary files /dev/null and b/scripts/data_summary/compartment_training_data.xlsx differ diff --git a/scripts/data_summary/vesicle_domain_adaptation_data.xlsx b/scripts/data_summary/vesicle_domain_adaptation_data.xlsx new file mode 100644 index 0000000..8a47219 Binary files /dev/null and b/scripts/data_summary/vesicle_domain_adaptation_data.xlsx differ diff --git a/scripts/data_summary/vesicle_training_data.xlsx b/scripts/data_summary/vesicle_training_data.xlsx new file mode 100644 index 0000000..0f9ee1e Binary files /dev/null and b/scripts/data_summary/vesicle_training_data.xlsx differ diff --git a/scripts/inner_ear/analysis/.gitignore b/scripts/inner_ear/analysis/.gitignore new file mode 100644 index 0000000..cbad005 --- /dev/null +++ b/scripts/inner_ear/analysis/.gitignore @@ -0,0 +1,3 @@ +panels/ +auto_seg_export/ +*.zip diff --git a/scripts/inner_ear/analysis/analyze_distances.py b/scripts/inner_ear/analysis/analyze_distances.py new file mode 100644 index 0000000..c98de9c --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_distances.py @@ -0,0 +1,153 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from common import get_all_measurements, get_measurements_with_annotation + + +def _plot_all(distances): + pools = pd.unique(distances["pool"]) + dist_cols = ["ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + + fig, axes = plt.subplots(3, 3) + + # multiple = "stack" + multiple = "layer" + + structures = ["Ribbon", "PD", "Boundary"] + for i, pool in enumerate(pools): + pool_distances = distances[distances["pool"] == pool] + for j, dist_col in enumerate(dist_cols): + ax = axes[i, j] + ax.set_title(f"{pool} to {structures[j]}") + sns.histplot( + data=pool_distances, x=dist_col, hue="approach", multiple=multiple, kde=False, ax=ax + ) + ax.set_xlabel("distance [nm]") + + fig.tight_layout() + plt.show() + + +# We only care about the following distances: +# - MP-V -> PD, AZ (Boundary) +# - Docked-V -> PD, AZ +# - RA-V -> Ribbon +def _plot_selected(distances, save_path=None): + fig, axes = plt.subplots(2, 2) + multiple = "layer" + + if save_path is not None and os.path.exists(save_path): + os.remove(save_path) + + def _plot(pool_name, distance_col, structure_name, ax): + + this_distances = distances[distances["pool"] == pool_name][["tomogram", "approach", distance_col]] + + ax.set_title(f"{pool_name} to {structure_name}") + sns.histplot( + data=this_distances, x=distance_col, hue="approach", multiple=multiple, kde=False, ax=ax + ) + ax.set_xlabel("distance [nm]") + + if save_path is not None: + approaches = pd.unique(this_distances["approach"]) + tomo_names = pd.unique(this_distances["tomogram"]) + + tomograms = [] + distance_values = {approach: [] for approach in approaches} + + for tomo in tomo_names: + tomo_dists = this_distances[this_distances["tomogram"] == tomo] + max_vesicles = 0 + for approach in approaches: + n_vesicles = len(tomo_dists[tomo_dists["approach"] == approach].values) + if n_vesicles > max_vesicles: + max_vesicles = n_vesicles + + for approach in approaches: + app_dists = tomo_dists[tomo_dists["approach"] == approach][distance_col].values.tolist() + app_dists = app_dists + [np.nan] * (max_vesicles - len(app_dists)) + distance_values[approach].extend(app_dists) + tomograms.extend([tomo] * max_vesicles) + + save_distances = {"tomograms": tomograms} + save_distances.update(distance_values) + save_distances = pd.DataFrame(save_distances) + + sheet_name = f"{pool_name}_{structure_name}" + if os.path.exists(save_path): + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + save_distances.to_excel(writer, sheet_name=sheet_name, index=False) + else: + save_distances.to_excel(save_path, index=False, sheet_name=sheet_name) + + # NOTE: we over-ride a plot here, should not do this in the actual version + _plot("MP-V", "pd_distance [nm]", "PD", axes[0, 0]) + _plot("MP-V", "boundary_distance [nm]", "AZ Membrane", axes[0, 1]) + _plot("Docked-V", "pd_distance [nm]", "PD", axes[1, 0]) + _plot("Docked-V", "boundary_distance [nm]", "AZ Membrane", axes[1, 0]) + _plot("RA-V", "ribbon_distance [nm]", "Ribbon", axes[1, 1]) + + fig.tight_layout() + plt.show() + + +def for_tomos_with_annotation(plot_all=True): + manual_assignments, semi_automatic_assignments, proofread_assignments = get_measurements_with_annotation() + + manual_distances = manual_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + manual_distances["approach"] = ["manual"] * len(manual_distances) + + semi_automatic_distances = semi_automatic_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + semi_automatic_distances["approach"] = ["semi_automatic"] * len(semi_automatic_distances) + + proofread_distances = proofread_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + proofread_distances["approach"] = ["proofread"] * len(proofread_distances) + + distances = pd.concat([manual_distances, semi_automatic_distances, proofread_distances]) + if plot_all: + distances.to_excel("./results/distances_tomos_with_manual_annotations.xlsx", index=False) + _plot_all(distances) + else: + _plot_selected(distances, save_path="./results/selected_distances_tomos_with_manual_annotations.xlsx") + + +def for_all_tomos(plot_all=True): + semi_automatic_assignments, proofread_assignments = get_all_measurements() + + semi_automatic_distances = semi_automatic_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + semi_automatic_distances["approach"] = ["semi_automatic"] * len(semi_automatic_distances) + + proofread_distances = proofread_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + proofread_distances["approach"] = ["proofread"] * len(proofread_distances) + + distances = pd.concat([semi_automatic_distances, proofread_distances]) + if plot_all: + distances.to_excel("./results/distances_all_tomos.xlsx", index=False) + _plot_all(distances) + else: + _plot_selected(distances, save_path="./results/selected_distances_all_tomos.xlsx") + + +def main(): + plot_all = False + for_tomos_with_annotation(plot_all=plot_all) + for_all_tomos(plot_all=plot_all) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/analyze_vesicle_diameters.py b/scripts/inner_ear/analysis/analyze_vesicle_diameters.py new file mode 100644 index 0000000..1f0b3a0 --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_vesicle_diameters.py @@ -0,0 +1,178 @@ +import os +import sys + +from glob import glob + +import mrcfile +import pandas as pd +from tqdm import tqdm + +from synaptic_reconstruction.imod.export import load_points_from_imodinfo +from synaptic_reconstruction.file_utils import get_data_path + +from common import get_finished_tomos + +sys.path.append("../processing") + + +def aggregate_diameters(data_root, table, save_path, get_tab, include_names, sheet_name): + radius_table = [] + for _, row in tqdm(table.iterrows(), total=len(table), desc="Collect tomo information"): + folder = row["Local Path"] + if folder == "": + continue + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + if ( + tomo_name in ("WT strong stim/Mouse 1/modiolar/1", "WT strong stim/Mouse 1/modiolar/2") and + (row["EM alt vs. Neu"] == "neu") + ): + continue + if tomo_name not in include_names: + continue + + tab_path = get_tab(folder) + if tab_path is None: + continue + + tab = pd.read_excel(tab_path) + this_tab = tab[["pool", "radius [nm]"]] + this_tab.insert(0, "tomogram", [tomo_name] * len(this_tab)) + this_tab.insert(3, "diameter [nm]", this_tab["radius [nm]"] * 2) + radius_table.append(this_tab) + + radius_table = pd.concat(radius_table) + + print("Saving table for", len(radius_table), "vesicles to", save_path, sheet_name) + if os.path.exists(save_path): + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + radius_table.to_excel(writer, sheet_name=sheet_name, index=False) + else: + radius_table.to_excel(save_path, sheet_name=sheet_name, index=False) + + +def aggregate_diameters_imod(data_root, table, save_path, include_names, sheet_name): + radius_table = [] + for _, row in tqdm(table.iterrows(), total=len(table), desc="Collect tomo information"): + folder = row["Local Path"] + if folder == "": + continue + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + if ( + tomo_name in ("WT strong stim/Mouse 1/modiolar/1", "WT strong stim/Mouse 1/modiolar/2") and + (row["EM alt vs. Neu"] == "neu") + ): + continue + if tomo_name not in include_names: + continue + + annotation_folder = os.path.join(folder, "manuell") + if not os.path.exists(annotation_folder): + annotation_folder = os.path.join(folder, "Manuell") + if not os.path.exists(annotation_folder): + continue + + annotations = glob(os.path.join(annotation_folder, "*.mod")) + annotation_file = [ann for ann in annotations if ("vesikel" in ann.lower()) or ("vesicle" in ann.lower())] + if len(annotation_file) != 1: + continue + annotation_file = annotation_file[0] + + tomo_file = get_data_path(folder) + with mrcfile.open(tomo_file) as f: + shape = f.data.shape + resolution = list(f.voxel_size.item()) + resolution = [res / 10 for res in resolution][0] + + try: + _, radii, labels, label_names = load_points_from_imodinfo(annotation_file, shape, resolution=resolution) + except AssertionError: + continue + + this_tab = pd.DataFrame({ + "tomogram": [tomo_name] * len(radii), + "pool": [label_names[label_id] for label_id in labels], + "radius [nm]": radii, + "diameter [nm]": 2 * radii, + }) + radius_table.append(this_tab) + + radius_table = pd.concat(radius_table) + print("Saving table for", len(radius_table), "vesicles to", save_path, sheet_name) + radius_table.to_excel(save_path, index=False, sheet_name=sheet_name) + + man_tomos = pd.unique(radius_table.tomogram) + return man_tomos + + +def get_tab_semi_automatic(folder): + tab_name = "measurements_uncorrected_assignments.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def get_tab_proofread(folder): + tab_name = "measurements.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def get_tab_manual(folder): + tab_name = "measurements.xlsx" + res_path = os.path.join(folder, "manuell", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Manuell", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def main(): + from parse_table import parse_table, get_data_root + + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Übersicht.xlsx") + table = parse_table(table_path, data_root) + + all_tomos = get_finished_tomos() + + print("All tomograms") + save_path = "./results/vesicle_diameters_all_tomos.xlsx" + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_semi_automatic, include_names=all_tomos, + sheet_name="Semi-automatic", + ) + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_proofread, include_names=all_tomos, + sheet_name="Proofread", + ) + + print() + print("Tomograms with manual annotations") + # aggregate_diameters(data_root, table, save_path="./results/vesicle_radii_manual.xlsx", get_tab=get_tab_manual) + save_path = "./results/vesicle_diameters_tomos_with_manual_annotations.xlsx" + man_tomos = aggregate_diameters_imod( + data_root, table, save_path=save_path, include_names=all_tomos, sheet_name="Manual", + ) + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_semi_automatic, include_names=man_tomos, + sheet_name="Semi-automatic", + ) + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_proofread, include_names=man_tomos, + sheet_name="Proofread", + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/analyze_vesicle_pools.py b/scripts/inner_ear/analysis/analyze_vesicle_pools.py new file mode 100644 index 0000000..f27a5c2 --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_vesicle_pools.py @@ -0,0 +1,103 @@ +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from common import get_all_measurements, get_measurements_with_annotation + + +def plot_pools(data, errors): + data_for_plot = pd.melt(data, id_vars="Pool", var_name="Method", value_name="Measurement") + + # Plot using seaborn + plt.figure(figsize=(8, 6)) + sns.barplot(data=data_for_plot, x="Pool", y="Measurement", hue="Method") + + # FIXME + # error_for_plot = pd.melt(errors, id_vars="Pool", var_name="Method", value_name="Error") + # # Add error bars manually + # for i, bar in enumerate(plt.gca().patches): + # # Get Standard Deviation for the current bar + # err = error_for_plot.iloc[i % len(error_for_plot)]["Error"] + # bar_x = bar.get_x() + bar.get_width() / 2 + # bar_y = bar.get_height() + # plt.errorbar(bar_x, bar_y, yerr=err, fmt="none", c="black", capsize=4) + + # Customize the chart + plt.title("Different measurements for vesicles per pool") + plt.xlabel("Vesicle Pools") + plt.ylabel("Vesicles per Tomogram") + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.legend(title="Approaches") + + # Show the plot + plt.tight_layout() + plt.show() + + +def for_tomos_with_annotation(): + manual_assignments, semi_automatic_assignments, proofread_assignments = get_measurements_with_annotation() + + manual_counts = manual_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + semi_automatic_counts = semi_automatic_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + proofread_counts = proofread_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + + manual_stats = manual_counts.agg(["mean", "std"]).transpose().reset_index() + semi_automatic_stats = semi_automatic_counts.agg(["mean", "std"]).transpose().reset_index() + proofread_stats = proofread_counts.agg(["mean", "std"]).transpose().reset_index() + + data = pd.DataFrame({ + "Pool": manual_stats["pool"], + "Semi-automatic": semi_automatic_stats["mean"], + "Proofread": proofread_stats["mean"], + "Manual": manual_stats["mean"], + }) + errors = pd.DataFrame({ + "Pool": manual_stats["pool"], + "Semi-automatic": semi_automatic_stats["std"], + "Proofread": proofread_stats["std"], + "Manual": manual_stats["std"], + }) + + plot_pools(data, errors) + + output_path = "./results/vesicle_pools_tomos_with_manual_annotations.xlsx" + data.to_excel(output_path, index=False, sheet_name="Average") + with pd.ExcelWriter(output_path, engine="openpyxl", mode="a") as writer: + errors.to_excel(writer, sheet_name="StandardDeviation", index=False) + + +def for_all_tomos(): + semi_automatic_assignments, proofread_assignments = get_all_measurements() + + proofread_counts = proofread_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + proofread_stats = proofread_counts.agg(["mean", "std"]).transpose().reset_index() + + semi_automatic_counts = semi_automatic_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + semi_automatic_stats = semi_automatic_counts.agg(["mean", "std"]).transpose().reset_index() + + data = pd.DataFrame({ + "Pool": proofread_stats["pool"], + "Semi-automatic": semi_automatic_stats["mean"], + "Proofread": proofread_stats["mean"], + }) + errors = pd.DataFrame({ + "Pool": proofread_stats["pool"], + "Semi-automatic": semi_automatic_stats["std"], + "Proofread": proofread_stats["std"], + }) + + plot_pools(data, errors) + + output_path = "./results/vesicle_pools_all_tomos.xlsx" + data.to_excel(output_path, index=False, sheet_name="Average") + with pd.ExcelWriter(output_path, engine="openpyxl", mode="a") as writer: + errors.to_excel(writer, sheet_name="StandardDeviation", index=False) + + +def main(): + for_tomos_with_annotation() + for_all_tomos() + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/combine_fully_automatic_results.py b/scripts/inner_ear/analysis/combine_fully_automatic_results.py new file mode 100644 index 0000000..54bdbc1 --- /dev/null +++ b/scripts/inner_ear/analysis/combine_fully_automatic_results.py @@ -0,0 +1,69 @@ +import os +import sys + +import pandas as pd + +sys.path.append("..") +sys.path.append("../processing") + + +def combine_fully_auto_results(table, data_root, output_path): + from combine_measurements import combine_results + + val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + val_table = pd.read_excel(val_table_path) + + results = {} + for _, row in table.iterrows(): + folder = row["Local Path"] + if folder == "": + continue + + row_selection = (val_table.Bedingung == row.Bedingung) &\ + (val_table.Maus == row.Maus) &\ + (val_table["Ribbon-Orientierung"] == row["Ribbon-Orientierung"]) &\ + (val_table["OwnCloud-Unterordner"] == row["OwnCloud-Unterordner"]) + complete_vals = val_table[row_selection]["Fertig!"].values + is_complete = (complete_vals == "ja").all() + if not is_complete: + continue + + micro = row["EM alt vs. Neu"] + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + tab_name = "measurements_uncorrected_assignments.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + assert os.path.exists(res_path), res_path + results[tomo_name] = (res_path, "alt" if micro == "beides" else micro) + + if micro == "beides": + micro = "neu" + + new_root = os.path.join(folder, "neues EM") + if not os.path.exists(new_root): + new_root = os.path.join(folder, "Tomo neues EM") + assert os.path.exists(new_root) + + res_path = os.path.join(new_root, "korrektur", "measurements.xlsx") + if not os.path.exists(res_path): + res_path = os.path.join(new_root, "Korrektur", "measurements.xlsx") + assert os.path.exists(res_path), res_path + results[tomo_name] = (res_path, "alt" if micro == "beides" else micro) + + combine_results(results, output_path, sheet_name="vesicles") + + +def main(): + from parse_table import parse_table, get_data_root + + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Übersicht.xlsx") + table = parse_table(table_path, data_root) + + res_path = "../results/fully_automatic_analysis_results.xlsx" + combine_fully_auto_results(table, data_root, output_path=res_path) + + +main() diff --git a/scripts/inner_ear/analysis/common.py b/scripts/inner_ear/analysis/common.py new file mode 100644 index 0000000..772cd31 --- /dev/null +++ b/scripts/inner_ear/analysis/common.py @@ -0,0 +1,88 @@ +# import os +import sys + +import numpy as np +import pandas as pd + +sys.path.append("../processing") + +from parse_table import get_data_root # noqa + + +def get_finished_tomos(): + # data_root = get_data_root() + # val_table = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + + val_table = "/home/pape/Desktop/sfb1286/mboc_synapse/misc/Validierungs-Tabelle-v3-passt.xlsx" + val_table = pd.read_excel(val_table) + + val_table = val_table[val_table["Kommentar 22.11.24"] == "passt"] + n_tomos = len(val_table) + assert n_tomos > 0 + + tomo_names = [] + for _, row in val_table.iterrows(): + name = "/".join([ + row.Bedingung, f"Mouse {int(row.Maus)}", + row["Ribbon-Orientierung"].lower().rstrip("?"), + str(int(row["OwnCloud-Unterordner"]))] + ) + tomo_names.append(name) + + return tomo_names + + +def get_manual_assignments(): + result_path = "../results/20241124_1/fully_manual_analysis_results.xlsx" + results = pd.read_excel(result_path) + return results + + +def get_proofread_assignments(tomograms): + result_path = "../results/20241124_1/automatic_analysis_results.xlsx" + results = pd.read_excel(result_path) + results = results[results["tomogram"].isin(tomograms)] + return results + + +def get_semi_automatic_assignments(tomograms): + result_path = "../results/fully_automatic_analysis_results.xlsx" + results = pd.read_excel(result_path) + results = results[results["tomogram"].isin(tomograms)] + return results + + +def get_measurements_with_annotation(): + manual_assignments = get_manual_assignments() + + # Get the tomos with manual annotations and the ones which are fully done in proofreading. + manual_tomos = pd.unique(manual_assignments["tomogram"]) + finished_tomos = get_finished_tomos() + # Intersect them to get the tomos we are using. + tomos = np.intersect1d(manual_tomos, finished_tomos) + + manual_assignments = manual_assignments[manual_assignments["tomogram"].isin(tomos)] + semi_automatic_assignments = get_semi_automatic_assignments(tomos) + proofread_assignments = get_proofread_assignments(tomos) + + print("Tomograms with manual annotations:", len(tomos)) + return manual_assignments, semi_automatic_assignments, proofread_assignments + + +def get_all_measurements(): + tomos = get_finished_tomos() + print("All tomograms:", len(tomos)) + + semi_automatic_assignments = get_semi_automatic_assignments(tomos) + proofread_assignments = get_proofread_assignments(tomos) + + return semi_automatic_assignments, proofread_assignments + + +def main(): + get_measurements_with_annotation() + get_all_measurements() + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/export_seg_to_imod.py b/scripts/inner_ear/analysis/export_seg_to_imod.py new file mode 100644 index 0000000..eea4b14 --- /dev/null +++ b/scripts/inner_ear/analysis/export_seg_to_imod.py @@ -0,0 +1,128 @@ +import os +from shutil import copyfile +from subprocess import run + +import imageio.v3 as imageio +import mrcfile +import napari +import numpy as np +import pandas as pd +from elf.io import open_file +from skimage.transform import resize +from synaptic_reconstruction.imod.to_imod import write_segmentation_to_imod, write_segmentation_to_imod_as_points + +out_folder = "./auto_seg_export" +os.makedirs(out_folder, exist_ok=True) + + +def _resize(seg, tomo_path): + with open_file(tomo_path, "r") as f: + shape = f["data"].shape + + if shape != seg.shape: + seg = resize(seg, shape, order=0, anti_aliasing=False, preserve_range=True).astype(seg.dtype) + assert seg.shape == shape + return seg + + +def check_imod(tomo_path, mod_path): + run(["imod", tomo_path, mod_path]) + + +def export_pool(pool_name, pool_seg, tomo_path): + seg_path = f"./auto_seg_export/{pool_name}.tif" + pool_seg = _resize(pool_seg, tomo_path) + imageio.imwrite(seg_path, pool_seg, compression="zlib") + + output_path = f"./auto_seg_export/{pool_name}.mod" + write_segmentation_to_imod_as_points(tomo_path, seg_path, output_path, min_radius=5) + + check_imod(tomo_path, output_path) + + +def export_vesicles(folder, tomo_path): + vesicle_pool_path = os.path.join(folder, "Korrektur", "vesicle_pools.tif") + # pool_correction_path = os.path.join(folder, "Korrektur", "pool_correction.tif") + # pool_correction = imageio.imread(pool_correction_path) + + assignment_path = os.path.join(folder, "Korrektur", "measurements.xlsx") + assignments = pd.read_excel(assignment_path) + + vesicles = imageio.imread(vesicle_pool_path) + + pools = {} + for pool_name in pd.unique(assignments.pool): + pool_ids = assignments[assignments.pool == pool_name].id.values + pool_seg = vesicles.copy() + pool_seg[~np.isin(vesicles, pool_ids)] = 0 + pools[pool_name] = pool_seg + + view = False + if view: + v = napari.Viewer() + v.add_labels(vesicles, visible=False) + for pool_name, pool_seg in pools.items(): + v.add_labels(pool_seg, name=pool_name) + napari.run() + else: + for pool_name, pool_seg in pools.items(): + export_pool(pool_name, pool_seg, tomo_path) + + +def export_structure(folder, tomo, name, view=False): + path = os.path.join(folder, "Korrektur", f"{name}.tif") + seg = imageio.imread(path) + seg = _resize(seg, tomo) + + if view: + with open_file(tomo, "r") as f: + raw = f["data"][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg) + napari.run() + + return + + seg_path = f"./auto_seg_export/{name}.tif" + imageio.imwrite(seg_path, seg, compression="zlib") + output_path = f"./auto_seg_export/{name}.mod" + write_segmentation_to_imod(tomo, seg_path, output_path) + check_imod(tomo, output_path) + + +def remove_scale(tomo): + new_path = "./auto_seg_export/Emb71M1aGridA1sec1mod7.rec.rec" + if os.path.exists(new_path): + return new_path + + copyfile(tomo, new_path) + + with mrcfile.open(new_path, "r+") as f: + # Set the origin to (0, 0, 0) + f.header.nxstart = 0 + f.header.nystart = 0 + f.header.nzstart = 0 + f.header.origin = (0.0, 0.0, 0.0) + + # Save changes + f.flush() + + return new_path + + +def main(): + folder = "/home/pape/Work/data/moser/em-synapses/Electron-Microscopy-Susi/Analyse/WT strong stim/Mouse 1/modiolar/1" + tomo = os.path.join(folder, "Emb71M1aGridA1sec1mod7.rec.rec") + + tomo = remove_scale(tomo) + + # export_vesicles(folder, tomo) + # export_structure(folder, tomo, "ribbon", view=False) + # export_structure(folder, tomo, "membrane", view=False) + export_structure(folder, tomo, "PD", view=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/extract_ribbon_stats.py b/scripts/inner_ear/analysis/extract_ribbon_stats.py new file mode 100644 index 0000000..8ee9e12 --- /dev/null +++ b/scripts/inner_ear/analysis/extract_ribbon_stats.py @@ -0,0 +1,36 @@ +import numpy as np +import pandas as pd + + +def main(): + man_path = "../results/20240917_1/fully_manual_analysis_results.xlsx" + auto_path = "../results/20240917_1/automatic_analysis_results.xlsx" + + man_measurements = pd.read_excel(man_path, sheet_name="morphology") + man_measurements = man_measurements[man_measurements.structure == "ribbon"][ + ["tomogram", "surface [nm^2]", "volume [nm^3]"] + ] + + auto_measurements = pd.read_excel(auto_path, sheet_name="morphology") + auto_measurements = auto_measurements[auto_measurements.structure == "ribbon"][ + ["tomogram", "surface [nm^2]", "volume [nm^3]"] + ] + + # save all the automatic measurements + auto_measurements.to_excel("./results/ribbon_morphology_auto.xlsx", index=False) + + man_tomograms = pd.unique(man_measurements["tomogram"]) + auto_tomograms = pd.unique(auto_measurements["tomogram"]) + tomos = np.intersect1d(man_tomograms, auto_tomograms) + + man_measurements = man_measurements[man_measurements.tomogram.isin(tomos)] + auto_measurements = auto_measurements[auto_measurements.tomogram.isin(tomos)] + + save_path = "./results/ribbon_morphology_man-v-auto.xlsx" + man_measurements.to_excel(save_path, sheet_name="manual", index=False) + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + auto_measurements.to_excel(writer, sheet_name="auto", index=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/processing/run_analyis.py b/scripts/inner_ear/processing/run_analyis.py index baeade1..ca5ea0b 100644 --- a/scripts/inner_ear/processing/run_analyis.py +++ b/scripts/inner_ear/processing/run_analyis.py @@ -52,7 +52,7 @@ def _load_segmentation(seg_path, tomo_shape): return seg -def compute_distances(segmentation_paths, save_folder, resolution, force, tomo_shape): +def compute_distances(segmentation_paths, save_folder, resolution, force, tomo_shape, use_corrected_vesicles=True): os.makedirs(save_folder, exist_ok=True) vesicles = None @@ -61,9 +61,10 @@ def _require_vesicles(): vesicle_path = segmentation_paths["vesicles"] if vesicles is None: - vesicle_pool_path = os.path.join(os.path.split(save_folder)[0], "vesicle_pools.tif") - if os.path.exists(vesicle_pool_path): - vesicle_path = vesicle_pool_path + if use_corrected_vesicles: + vesicle_pool_path = os.path.join(os.path.split(save_folder)[0], "vesicle_pools.tif") + if os.path.exists(vesicle_pool_path): + vesicle_path = vesicle_pool_path return _load_segmentation(vesicle_path, tomo_shape) else: @@ -334,8 +335,7 @@ def _insert_missing_vesicles(vesicle_path, original_vesicle_path, pool_correctio imageio.imwrite(vesicle_path, vesicles) -# TODO adapt to segmentation without PD -def analyze_folder(folder, version, n_ribbons, force): +def analyze_folder(folder, version, n_ribbons, force, use_corrected_vesicles): data_path = get_data_path(folder) output_folder = os.path.join(folder, "automatisch", f"v{version}") @@ -352,12 +352,20 @@ def analyze_folder(folder, version, n_ribbons, force): correction_folder = _match_correction_folder(folder) if os.path.exists(correction_folder): output_folder = correction_folder - result_path = os.path.join(output_folder, "measurements.xlsx") + + if use_corrected_vesicles: + result_path = os.path.join(output_folder, "measurements.xlsx") + else: + result_path = os.path.join(output_folder, "measurements_uncorrected_assignments.xlsx") + if os.path.exists(result_path) and not force: return print("Analyse the corrected segmentations from", correction_folder) for seg_name in segmentation_names: + if seg_name == "vesicles" and not use_corrected_vesicles: + continue + seg_path = _match_correction_file(correction_folder, seg_name) if os.path.exists(seg_path): @@ -371,7 +379,10 @@ def analyze_folder(folder, version, n_ribbons, force): segmentation_paths[seg_name] = seg_path - result_path = os.path.join(output_folder, "measurements.xlsx") + if use_corrected_vesicles: + result_path = os.path.join(output_folder, "measurements.xlsx") + else: + result_path = os.path.join(output_folder, "measurements_uncorrected_assignments.xlsx") if os.path.exists(result_path) and not force: return @@ -384,21 +395,29 @@ def analyze_folder(folder, version, n_ribbons, force): with open_file(data_path, "r") as f: tomo_shape = f["data"].shape - out_distance_folder = os.path.join(output_folder, "distances") + if use_corrected_vesicles: + out_distance_folder = os.path.join(output_folder, "distances") + else: + out_distance_folder = os.path.join(output_folder, "distances_uncorrected") distance_paths, skip = compute_distances( segmentation_paths, out_distance_folder, resolution, force=force, tomo_shape=tomo_shape, + use_corrected_vesicles=use_corrected_vesicles ) if skip: return if force or not os.path.exists(result_path): + + if not use_corrected_vesicles: + pool_correction_path = None + analyze_distances( segmentation_paths, distance_paths, resolution, result_path, tomo_shape, pool_correction_path=pool_correction_path ) -def run_analysis(table, version, force=False, val_table=None): +def run_analysis(table, version, force=False, val_table=None, use_corrected_vesicles=True): for i, row in tqdm(table.iterrows(), total=len(table)): folder = row["Local Path"] if folder == "": @@ -426,19 +445,19 @@ def run_analysis(table, version, force=False, val_table=None): micro = row["EM alt vs. Neu"] if micro == "beides": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) folder_new = os.path.join(folder, "Tomo neues EM") if not os.path.exists(folder_new): folder_new = os.path.join(folder, "neues EM") assert os.path.exists(folder_new), folder_new - analyze_folder(folder_new, version, n_ribbons, force=force) + analyze_folder(folder_new, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) elif micro == "alt": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) elif micro == "neu": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) def main(): @@ -447,13 +466,16 @@ def main(): table = parse_table(table_path, data_root) version = 2 - force = True + force = False + use_corrected_vesicles = False - val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") - val_table = pandas.read_excel(val_table_path) - # val_table = None + # val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + # val_table = pandas.read_excel(val_table_path) + val_table = None - run_analysis(table, version, force=force, val_table=val_table) + run_analysis( + table, version, force=force, val_table=val_table, use_corrected_vesicles=use_corrected_vesicles + ) if __name__ == "__main__": diff --git a/scripts/inner_ear/training/postprocessing_and_evaluation.py b/scripts/inner_ear/training/postprocessing_and_evaluation.py index 30c9e42..30c1313 100644 --- a/scripts/inner_ear/training/postprocessing_and_evaluation.py +++ b/scripts/inner_ear/training/postprocessing_and_evaluation.py @@ -13,8 +13,8 @@ from train_structure_segmentation import get_train_val_test_split -ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser" -# ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser" +# ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser" +ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser" MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-inner_ear-v2" OUTPUT_ROOT = "./predictions" @@ -187,8 +187,8 @@ def segment_train_domain(): name = "train_domain" run_vesicle_segmentation(paths, MODEL_PATH, name, is_nested=True) postprocess_structures(paths, name, is_nested=True) - visualize(paths, name, is_nested=True) - results = evaluate(paths, name, is_nested=True, save_path="./results/train_domain_postprocessed.csv") + # visualize(paths, name, is_nested=True) + results = evaluate(paths, name, is_nested=True, save_path="./results/train_domain_postprocessed_v2.csv") print(results) print("Ribbon segmentation:", results["ribbon"].mean(), "+-", results["ribbon"].std()) print("PD segmentation:", results["PD"].mean(), "+-", results["PD"].std()) diff --git a/scripts/inner_ear/training/structure_prediction_and_evaluation.py b/scripts/inner_ear/training/structure_prediction_and_evaluation.py index cb174c7..7ed89a9 100644 --- a/scripts/inner_ear/training/structure_prediction_and_evaluation.py +++ b/scripts/inner_ear/training/structure_prediction_and_evaluation.py @@ -143,10 +143,10 @@ def predict_and_evaluate_train_domain(): print("Run evaluation on", len(paths), "tomos") name = "train_domain" - model_path = "./checkpoints/inner_ear_structure_model" + model_path = "./checkpoints/inner_ear_structure_model_v2" run_prediction(paths, model_path, name, is_nested=True) - evaluate(paths, name, is_nested=True, save_path="./results/train_domain.csv") + evaluate(paths, name, is_nested=True, save_path="./results/train_domain_v2.csv") visualize(paths, name, is_nested=True) @@ -187,9 +187,9 @@ def predict_and_evaluate_rat(): def main(): - # predict_and_evaluate_train_domain() + predict_and_evaluate_train_domain() # predict_and_evaluate_vesicle_pools() - predict_and_evaluate_rat() + # predict_and_evaluate_rat() if __name__ == "__main__": diff --git a/scripts/rizzoli/2D_vesicle_segmentation.py b/scripts/rizzoli/2D_vesicle_segmentation.py index 7974e3b..159be28 100644 --- a/scripts/rizzoli/2D_vesicle_segmentation.py +++ b/scripts/rizzoli/2D_vesicle_segmentation.py @@ -7,6 +7,7 @@ import torch import torch_em import numpy as np +from elf.io import open_file from synaptic_reconstruction.inference.vesicles import segment_vesicles from synaptic_reconstruction.inference.util import parse_tiling @@ -57,7 +58,7 @@ def get_volume(input_path): input_volume = seg_file["raw"][:] return input_volume -def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, halo, include_boundary, key_label): +def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, halo, include_boundary, key_label, scale, mask_path, mask_key): tiling = get_2D_tiling() @@ -69,23 +70,41 @@ def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, ha tiling = parse_tiling(tile_shape, halo) input = get_volume(input_path) + #check if we have a restricting mask for the segmentation + if mask_path is not None: + with open_file(mask_path, "r") as f: + mask = f[mask_key][:] + else: + mask = None + device = "cuda" if torch.cuda.is_available() else "cpu" model = torch_em.util.load_model(checkpoint=model_path, device=device) - def process_slices(input_volume): + def process_slices(input_volume, scale, mask): processed_slices = [] foreground = [] boundaries = [] for z in range(input_volume.shape[0]): slice_ = input_volume[z, :, :] - segmented_slice, prediction_slice = segment_vesicles(input_volume=slice_, model=model, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary) + #check if we have a restricting mask for the segmentation + if mask is not None: + mask_slice = mask[z, :, :] + segmented_slice, prediction_slice = segment_vesicles(input_volume=slice_, model=model, verbose=False, tiling=tiling, return_predictions=True, scale = scale, exclude_boundary=not include_boundary, mask = mask_slice) + else: + segmented_slice, prediction_slice = segment_vesicles(input_volume=slice_, model=model, verbose=False, tiling=tiling, return_predictions=True, scale = scale, exclude_boundary=not include_boundary) + processed_slices.append(segmented_slice) foreground_pred_slice, boundaries_pred_slice = prediction_slice[:2] foreground.append(foreground_pred_slice) boundaries.append(boundaries_pred_slice) return processed_slices, foreground, boundaries - segmentation, foreground, boundaries = process_slices(input) + if input.ndim == 2: + #TODO: check if we have a restricting mask for the segmentation + segmentation, prediction = segment_vesicles(input_volume=input, model=model, verbose=False, tiling=tiling, return_predictions=True, scale = scale, exclude_boundary=not include_boundary) + foreground, boundaries = prediction[:2] + else: + segmentation, foreground, boundaries = process_slices(input, scale, mask) seg_output = _require_output_folders(output_path) file_name = Path(input_path).stem @@ -121,7 +140,11 @@ def segment_folder(args): print(input_files) pbar = tqdm(input_files, desc="Run segmentation") for input_path in pbar: - run_vesicle_segmentation(input_path, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label) + if args.mask_path is not None: + mask_path_for_file = os.path.join(args.mask_path, os.path.basename(input_path)) + else: + mask_path_for_file = None + run_vesicle_segmentation(input_path, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label, args.scale, mask_path_for_file, args.mask_key) def main(): parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.") @@ -152,6 +175,16 @@ def main(): "--key_label", "-k", default = "combined_vesicles", help="Give the key name for saving the segmentation in h5." ) + parser.add_argument( + "--scale", "-s", type=float, nargs=2, + help="Scales the input data." + ) + parser.add_argument( + "--mask_path", help="The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key." + ) + parser.add_argument( + "--mask_key", help="Key name that holds the mask segmentation" + ) args = parser.parse_args() input_ = args.input_path @@ -159,7 +192,7 @@ def main(): if os.path.isdir(input_): segment_folder(args) else: - run_vesicle_segmentation(input_, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label) + run_vesicle_segmentation(input_, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label, args.scale, args.mask_path, args.mask_key) print("Finished segmenting!") diff --git a/scripts/rizzoli/evaluation_2D.py b/scripts/rizzoli/evaluation_2D.py index 1cae666..5b5bbbd 100644 --- a/scripts/rizzoli/evaluation_2D.py +++ b/scripts/rizzoli/evaluation_2D.py @@ -6,8 +6,13 @@ import numpy as np from elf.evaluation import matching +from skimage.transform import rescale - +def transpose_tomo(tomogram): + data0 = np.swapaxes(tomogram, 0, -1) + data1 = np.fliplr(data0) + transposed_data = np.swapaxes(data1, 0, -1) + return transposed_data def evaluate(labels, vesicles): assert labels.shape == vesicles.shape @@ -54,21 +59,34 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key) ds_name = os.path.basename(os.path.dirname(labels_path)) tomo = os.path.basename(labels_path) - + use_mask = False #get the labels and vesicles with h5py.File(labels_path) as label_file: labels = label_file["labels"] #vesicles = labels["vesicles"] gt = labels[anno_key][:] + #gt = rescale(gt, scale=0.5, order=0, anti_aliasing=False, preserve_range=True).astype(gt.dtype) + #gt = transpose_tomo(gt) + + if use_mask: + mask = labels["mask"][:] + mask = rescale(mask, scale=0.5, order=0, anti_aliasing=False, preserve_range=True).astype(mask.dtype) + mask = transpose_tomo(mask) with h5py.File(vesicles_path) as seg_file: segmentation = seg_file["vesicles"] vesicles = segmentation[segment_key][:] + if use_mask: + gt[mask == 0] = 0 + vesicles[mask == 0] = 0 - #evaluate the match of ground truth and vesicles - scores = evaluate_slices(gt, vesicles) + #evaluate the match of ground truth and vesicles + if len(vesicles.shape) == 3: + scores = evaluate_slices(gt, vesicles) + else: + scores = evaluate(gt,vesicles) #store results result_folder ="/user/muth9/u12095/synaptic-reconstruction/scripts/cooper/evaluation_results" os.makedirs(result_folder, exist_ok=True) diff --git a/scripts/rizzoli/train_2D_domain_adaptation.py b/scripts/rizzoli/train_2D_domain_adaptation.py index 86eedd1..ac2a28f 100644 --- a/scripts/rizzoli/train_2D_domain_adaptation.py +++ b/scripts/rizzoli/train_2D_domain_adaptation.py @@ -6,11 +6,13 @@ from sklearn.model_selection import train_test_split from synaptic_reconstruction.training.domain_adaptation import mean_teacher_adaptation -TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/rizzoli/extracted" -OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/2D_DA_training_rizzoli" +TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data" +OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/2D_DA_training_2Dcooper_v1" def _require_train_val_test_split(datasets): train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1 + if len(datasets) < 10: + train_ratio, val_ratio, test_ratio = 0.5, 0.25, 0.25 def _train_val_test_split(names): train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True) @@ -71,8 +73,11 @@ def get_paths(split, datasets, testset=True): return paths def vesicle_domain_adaptation(teacher_model, testset = True): + + os.makedirs(OUTPUT_ROOT, exist_ok=True) + datasets = [ - "upsampled_by2" + "20241021_imig_2014_data_transfer_exported_grouped" ] train_paths = get_paths("train", datasets=datasets, testset=testset) val_paths = get_paths("val", datasets=datasets, testset=testset) @@ -83,11 +88,13 @@ def vesicle_domain_adaptation(teacher_model, testset = True): #adjustable parameters patch_shape = [1, 256, 256] #2D - model_name = "2D-vesicle-DA-rizzoli-v3" + model_name = "2D-vesicle-DA-2Dcooper-imig-v2" model_root = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/models_v2/checkpoints/" checkpoint_path = os.path.join(model_root, teacher_model) + patch_shape = [256, 256] if any("maus" in dataset for dataset in datasets) else [1, 256, 256] + mean_teacher_adaptation( name=model_name, unsupervised_train_paths=train_paths, @@ -97,7 +104,10 @@ def vesicle_domain_adaptation(teacher_model, testset = True): save_root="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/DA_models", source_checkpoint=checkpoint_path, confidence_threshold=0.75, - n_iterations=int(5e4), + batch_size=8, + n_iterations=int(1.5e4), + n_samples_train=8000, + n_samples_val=50, ) diff --git a/scripts/summarize_data.py b/scripts/summarize_data.py new file mode 100644 index 0000000..66fe321 --- /dev/null +++ b/scripts/summarize_data.py @@ -0,0 +1,175 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +# TODO inner ear train data and mito training data are missing +az_train = pd.read_excel("data_summary/active_zone_training_data.xlsx") +compartment_train = pd.read_excel("data_summary/compartment_training_data.xlsx") +vesicle_train = pd.read_excel("data_summary/vesicle_training_data.xlsx") +vesicle_da = pd.read_excel("data_summary/vesicle_domain_adaptation_data.xlsx", sheet_name="cryo") + + +def training_resolutions(): + res_az = np.round(az_train["resolution"].mean(), 2) + res_compartment = np.round(compartment_train["resolution"].mean(), 2) + res_cryo = np.round(vesicle_da["resolution"].mean(), 2) + res_vesicles = np.round(vesicle_train["resolution"].mean(), 2) + + print("Training resolutions for models:") + print("active_zone:", res_az) + print("compartments:", res_compartment) + # TODO + print("mitochondria:", 1.0) + print("vesicles_2d:", res_vesicles) + print("vesicles_3d:", res_vesicles) + print("vesicles_cryo:", res_cryo) + # TODO inner ear + + +def pie_chart(data, count_col, title): + # Plot the pie chart + plt.figure(figsize=(8, 6)) + wedges, texts, autotexts = plt.pie( + data[count_col], + labels=data["Condition"], + autopct="%1.1f%%", # Display percentages + startangle=90, # Start at the top + colors=plt.cm.Paired.colors[:len(data)], # Optional: Custom color palette + textprops={"fontsize": 14} + ) + + for autot in autotexts: + autot.set_fontsize(18) + + plt.title(title, fontsize=18) + plt.tight_layout() + plt.show() + + +def summarize_vesicle_train_data(): + condition_summary = { + "Condition": [], + "Tomograms": [], + "Vesicles": [], + } + + conditions = pd.unique(vesicle_train.condition) + for condition in conditions: + ctab = vesicle_train[vesicle_train.condition == condition] + n_tomos = len(ctab) + n_vesicles_all = ctab["vesicle_count_all"].sum() + n_vesicles_imod = ctab["vesicle_count_imod"].sum() + print(condition) + print("Tomograms:", n_tomos) + print("All-Vesicles:", n_vesicles_all) + print("Vesicles-From-Manual:", n_vesicles_imod) + print() + condition_summary["Condition"].append(condition) + condition_summary["Tomograms"].append(n_tomos) + condition_summary["Vesicles"].append(n_vesicles_all) + condition_summary = pd.DataFrame(condition_summary) + print() + print() + + print("Total:") + print("Tomograms:", len(vesicle_train)) + print("All-Vesicles:", vesicle_train["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", vesicle_train["vesicle_count_imod"].sum()) + print() + + train_tomos = vesicle_train[vesicle_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("All-Vesicles:", train_tomos["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", train_tomos["vesicle_count_imod"].sum()) + print() + + test_tomos = vesicle_train[vesicle_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("All-Vesicles:", test_tomos["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", test_tomos["vesicle_count_imod"].sum()) + + pie_chart(condition_summary, "Tomograms", "Tomograms per Condition") + pie_chart(condition_summary, "Vesicles", "Vesicles per Condition") + + +def summarize_vesicle_da(): + for name in ("inner_ear", "endbulb", "cryo", "frog", "maus_2d"): + tab = pd.read_excel("data_summary/vesicle_domain_adaptation_data.xlsx", sheet_name=name) + print(name) + print("N-tomograms:", len(tab)) + print("N-test:", (tab["used_for"] == "test").sum()) + print("N-vesicles:", tab["vesicle_count"].sum()) + print() + + +def summarize_az_train(): + conditions = pd.unique(az_train.condition) + print(conditions) + + print("Total:") + print("Tomograms:", len(az_train)) + print("Active Zones:", az_train["az_count"].sum()) + print() + + train_tomos = az_train[az_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("Active Zones:", train_tomos["az_count"].sum()) + print() + + test_tomos = az_train[az_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("Active Zones:", test_tomos["az_count"].sum()) + + +def summarize_compartment_train(): + conditions = pd.unique(compartment_train.condition) + print(conditions) + + print("Total:") + print("Tomograms:", len(compartment_train)) + print("Compartments:", compartment_train["compartment_count"].sum()) + print() + + train_tomos = compartment_train[compartment_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("Compartments:", train_tomos["compartment_count"].sum()) + print() + + test_tomos = compartment_train[compartment_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("Compartments:", test_tomos["compartment_count"].sum()) + + +def summarize_inner_ear_data(): + # NOTE: this is not all trainig data, but the data on which we run the analysis + # New tomograms from Sophia. + n_tomos_sophia_tot = 87 + n_tomos_sophia_manual = 33 # noqa + # This is the training data + n_tomos_sohphia_train = "" # TODO # noqa + + # Published tomograms + n_tomos_rat = 19 + n_tomos_tether = 3 + n_tomos_ves_pool = 6 + + # 28 + print("Total published:", n_tomos_rat + n_tomos_tether + n_tomos_ves_pool) + # 115 + print("Total:", n_tomos_rat + n_tomos_tether + n_tomos_ves_pool + n_tomos_sophia_tot) + + +# training_resolutions() +# summarize_vesicle_train_data() +# summarize_vesicle_da() +# summarize_az_train() +# summarize_compartment_train() +# summarize_inner_ear_data() +summarize_inner_ear_data() diff --git a/synaptic_reconstruction/imod/to_imod.py b/synaptic_reconstruction/imod/to_imod.py index 7a98469..f97e4f0 100644 --- a/synaptic_reconstruction/imod/to_imod.py +++ b/synaptic_reconstruction/imod/to_imod.py @@ -16,51 +16,60 @@ from tqdm import tqdm -# FIXME how to bring the data to the IMOD axis convention? -def _to_imod_order(data): - # data = np.swapaxes(data, 0, -1) - # data = np.fliplr(data) - # data = np.swapaxes(data, 0, -1) - return data - - +# TODO: this has still some issues with some tomograms that has an offset info. +# For now, this occurs for the inner ear data tomograms; it works for Fidi's STEM tomograms. +# Ben's theory is that this might be due to data form JEOL vs. ThermoFischer microscopes. +# To test this I can check how it works for data from Maus et al. / Imig et al., which were taken on a JEOL. +# Can also check out the mrc documentation here: https://www.ccpem.ac.uk/mrc_format/mrc2014.php def write_segmentation_to_imod( mrc_path: str, - segmentation_path: str, + segmentation: Union[str, np.ndarray], output_path: str, ) -> None: - """Write a segmentation to a mod file as contours. + """Write a segmentation to a mod file as closed contour objects. Args: - mrc_path: a - segmentation_path: a - output_path: a + mrc_path: The filepath to the mrc file from which the segmentation was derived. + segmentation: The segmentation (either as numpy array or filepath to a .tif file). + output_path: The output path where the mod file will be saved. """ cmd = "imodauto" cmd_path = shutil.which(cmd) assert cmd_path is not None, f"Could not find the {cmd} imod command." + # Load the segmentation from a tif file in case a filepath was passed. + if isinstance(segmentation, str): + assert os.path.exists(segmentation) + segmentation = imageio.imread(segmentation) + + # Binarize the segmentation and flip its axes to match the IMOD axis convention. + segmentation = (segmentation > 0).astype("uint8") + segmentation = np.flip(segmentation, axis=1) + + # Read the voxel size and origin information from the mrc file. assert os.path.exists(mrc_path) - with mrcfile.open(mrc_path, mode="r+") as f: + with mrcfile.open(mrc_path, mode="r") as f: voxel_size = f.voxel_size + nx, ny, nz = f.header.nxstart, f.header.nystart, f.header.nzstart + origin = f.header.origin + # Write the input for imodauto to a temporary mrc file. with tempfile.NamedTemporaryFile(suffix=".mrc") as f: tmp_path = f.name - seg = (imageio.imread(segmentation_path) > 0).astype("uint8") - seg_ = _to_imod_order(seg) - - # import napari - # v = napari.Viewer() - # v.add_image(seg) - # v.add_labels(seg_) - # napari.run() - - mrcfile.new(tmp_path, data=seg_, overwrite=True) + mrcfile.new(tmp_path, data=segmentation, overwrite=True) + # Write the voxel_size and origin infomration. with mrcfile.open(tmp_path, mode="r+") as f: f.voxel_size = voxel_size + + f.header.nxstart = nx + f.header.nystart = ny + f.header.nzstart = nz + f.header.origin = (0.0, 0.0, 0.0) * 3 if origin is None else origin + f.update_header_from_data() + # Run the command. cmd_list = [cmd, "-E", "1", "-u", tmp_path, output_path] run(cmd_list) @@ -121,6 +130,37 @@ def coords_and_rads(prop): rads = [re[1] for re in res] return np.array(coords), np.array(rads) + def coords_and_rads(prop): + seg_id = prop.label + + bbox = prop.bbox + bb = np.s_[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]] + mask = segmentation[bb] == seg_id + + if estimate_radius_2d: + dists = np.array([distance_transform_edt(ma, sampling=resolution[1:]) for ma in mask]) + else: + dists = distance_transform_edt(mask, sampling=resolution) + + max_coord = np.unravel_index(np.argmax(dists), mask.shape) + radius = dists[max_coord] * radius_factor + + offset = np.array(bbox[:3]) + coord = np.array(max_coord) + offset + return coord, radius, seg_id + + with futures.ThreadPoolExecutor(num_workers) as tp: + res = list(tqdm( + tp.map(coords_and_rads, props), disable=not verbose, total=len(props), + desc="Compute coordinates and radii" + )) + + coords = [re[0] for re in res] + rads = [re[1] for re in res] + label_indxes = [re[2] for re in res] + return np.array(coords), np.array(rads), np.array(label_indxes) + + def write_points_to_imod( coordinates: np.ndarray, diff --git a/synaptic_reconstruction/inference/AZ.py b/synaptic_reconstruction/inference/AZ.py new file mode 100644 index 0000000..a1c9da8 --- /dev/null +++ b/synaptic_reconstruction/inference/AZ.py @@ -0,0 +1,89 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import elf.parallel as parallel +import numpy as np +import torch + +from synaptic_reconstruction.inference.util import get_prediction, _Scaler +from synaptic_reconstruction.inference.postprocessing.postprocess_AZ import find_intersection_boundary + +def _run_segmentation( + foreground, verbose, min_size, + # blocking shapes for parallel computation + block_shape=(128, 256, 256), +): + + # get the segmentation via seeded watershed + t0 = time.time() + seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) + if verbose: + print("Compute connected components in", time.time() - t0, "s") + + # size filter + t0 = time.time() + ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) + filter_ids = ids[sizes < min_size] + seg[np.isin(seg, filter_ids)] = 0 + if verbose: + print("Size filter in", time.time() - t0, "s") + seg = np.where(seg > 0, 1, 0) + return seg + +def segment_AZ( + input_volume: np.ndarray, + model_path: Optional[str] = None, + model: Optional[torch.nn.Module] = None, + tiling: Optional[Dict[str, Dict[str, int]]] = None, + min_size: int = 500, + verbose: bool = True, + return_predictions: bool = False, + scale: Optional[List[float]] = None, + mask: Optional[np.ndarray] = None, + compartment: Optional[np.ndarray] = None, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + Segment mitochondria in an input volume. + + Args: + input_volume: The input volume to segment. + model_path: The path to the model checkpoint if `model` is not provided. + model: Pre-loaded model. Either `model_path` or `model` is required. + tiling: The tiling configuration for the prediction. + verbose: Whether to print timing information. + scale: The scale factor to use for rescaling the input volume before prediction. + mask: An optional mask that is used to restrict the segmentation. + + Returns: + The foreground mask as a numpy array. + """ + if verbose: + print("Segmenting AZ in volume of shape", input_volume.shape) + # Create the scaler to handle prediction with a different scaling factor. + scaler = _Scaler(scale, verbose) + input_volume = scaler.scale_input(input_volume) + + # Rescale the mask if it was given and run prediction. + if mask is not None: + mask = scaler.scale_input(mask, is_segmentation=True) + pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) + + # Run segmentation and rescale the result if necessary. + foreground = pred[0] + #print(f"shape {foreground.shape}") + #foreground = pred[0, :, :, :] + print(f"shape {foreground.shape}") + + segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) + + #returning prediciton and intersection not possible atm, but currently do not need prediction anyways + if return_predictions: + pred = scaler.rescale_output(pred, is_segmentation=False) + return segmentation, pred + + if compartment is not None: + intersection = find_intersection_boundary(segmentation, compartment) + return segmentation, intersection + + return segmentation + diff --git a/synaptic_reconstruction/inference/compartments.py b/synaptic_reconstruction/inference/compartments.py index a822d9f..701c222 100644 --- a/synaptic_reconstruction/inference/compartments.py +++ b/synaptic_reconstruction/inference/compartments.py @@ -77,6 +77,12 @@ def _segment_compartments_2d( mask = np.logical_or(binary_closing(mask, iterations=4), mask) segmentation[bb][mask] = prop.label + # import napari + # v = napari.Viewer() + # v.add_image(boundaries) + # v.add_labels(segmentation) + # napari.run() + return segmentation @@ -117,6 +123,7 @@ def _segment_compartments_3d( boundary_threshold=0.4, n_slices_exclude=0, min_z_extent=10, + postprocess_segments=False, ): distances = distance_transform_edt(prediction < boundary_threshold).astype("float32") seg_2d = np.zeros(prediction.shape, dtype="uint32") @@ -132,7 +139,8 @@ def _segment_compartments_3d( seg_2d[z] = seg_z seg = _merge_segmentation_3d(seg_2d, min_z_extent) - seg = _postprocess_seg_3d(seg) + if postprocess_segments: + seg = _postprocess_seg_3d(seg) # import napari # v = napari.Viewer() @@ -155,6 +163,9 @@ def segment_compartments( scale: Optional[List[float]] = None, mask: Optional[np.ndarray] = None, n_slices_exclude: int = 0, + boundary_threshold: float = 0.4, + min_z_extent: int = 10, + postprocess_segments: bool = False, **kwargs, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ @@ -194,9 +205,14 @@ def segment_compartments( # We may want to expose some of the parameters here. t0 = time.time() if input_volume.ndim == 2: - seg = _segment_compartments_2d(pred) + seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold) else: - seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude) + seg = _segment_compartments_3d( + pred, + boundary_threshold=boundary_threshold, + n_slices_exclude=n_slices_exclude, + postprocess_segments=postprocess_segments, + ) if verbose: print("Run segmentation in", time.time() - t0, "s") diff --git a/synaptic_reconstruction/inference/postprocessing/postprocess_AZ.py b/synaptic_reconstruction/inference/postprocessing/postprocess_AZ.py new file mode 100644 index 0000000..8ef2cd3 --- /dev/null +++ b/synaptic_reconstruction/inference/postprocessing/postprocess_AZ.py @@ -0,0 +1,35 @@ +import numpy as np +from skimage.segmentation import find_boundaries + +def find_intersection_boundary(segmented_AZ, segmented_compartment): + """ + Find the cumulative intersection of the boundary of each label in segmented_compartment with segmented_AZ. + + Parameters: + segmented_AZ (numpy.ndarray): 3D array representing the active zone (AZ). + segmented_compartment (numpy.ndarray): 3D array representing the compartment, with multiple labels. + + Returns: + numpy.ndarray: 3D array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ. + """ + # Step 0: Initialize an empty array to accumulate intersections + cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool) + + # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background) + labels = np.unique(segmented_compartment) + labels = labels[labels != 0] # Exclude background label (0) if necessary + + for label in labels: + # Step 2: Create a binary mask for the current label + label_mask = (segmented_compartment == label) + + # Step 3: Find the boundary of the current label's compartment + boundary_compartment = find_boundaries(label_mask, mode='outer') + + # Step 4: Find the intersection with the AZ for this label's boundary + intersection = np.logical_and(boundary_compartment, segmented_AZ) + + # Step 5: Accumulate intersections for each label + cumulative_intersection = np.logical_or(cumulative_intersection, intersection) + + return cumulative_intersection.astype(int) # Convert boolean array to int (1 for intersecting points, 0 elsewhere) diff --git a/synaptic_reconstruction/inference/vesicles.py b/synaptic_reconstruction/inference/vesicles.py index 237d95a..4a56b0f 100644 --- a/synaptic_reconstruction/inference/vesicles.py +++ b/synaptic_reconstruction/inference/vesicles.py @@ -49,6 +49,7 @@ def distance_based_vesicle_segmentation( # Get the segmentation via seeded watershed of components in the boundary distances. t0 = time.time() + print(f"using a distance thresholf of {distance_threshold} for distance based segmentation") seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose) if verbose: print("Compute connected components in", time.time() - t0, "s") @@ -129,6 +130,7 @@ def segment_vesicles( min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = True, + distance_threshold: int = 8, return_predictions: bool = False, scale: Optional[List[float]] = None, exclude_boundary: bool = False, @@ -174,7 +176,7 @@ def segment_vesicles( if distance_based_segmentation: seg = distance_based_vesicle_segmentation( - foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs + foreground, boundaries, verbose=verbose, min_size=min_size, distance_threshold = distance_threshold, **kwargs ) else: seg = simple_vesicle_segmentation( diff --git a/synaptic_reconstruction/morphology.py b/synaptic_reconstruction/morphology.py index 8afea3d..126042f 100644 --- a/synaptic_reconstruction/morphology.py +++ b/synaptic_reconstruction/morphology.py @@ -6,8 +6,11 @@ import numpy as np import pandas as pd -from scipy.ndimage import distance_transform_edt + +from scipy.ndimage import distance_transform_edt, convolve +from skimage.graph import MCP from skimage.measure import regionprops, marching_cubes +from skimage.morphology import skeletonize, medial_axis, label from skimage.segmentation import find_boundaries @@ -87,3 +90,110 @@ def compute_object_morphology(object_, structure_name, resolution=None): "surface [pixel^2]" if resolution is None else "surface [nm^2]": [surface], }) return morphology + + +def _find_endpoints(component): + # Define a 3x3 kernel to count neighbors + kernel = np.ones((3, 3), dtype=int) + neighbor_count = convolve(component.astype(int), kernel, mode="constant", cval=0) + endpoints = np.argwhere((component == 1) & (neighbor_count == 2)) # Degree = 1 + return endpoints + + +def _compute_longest_path(component, endpoints): + # Use the first endpoint as the source + src = tuple(endpoints[0]) + cost = np.where(component, 1, np.inf) # Cost map: 1 for skeleton, inf for background + mcp = MCP(cost) + _, traceback = mcp.find_costs([src]) + + # Use the second endpoint as the destination + dst = tuple(endpoints[-1]) + + # Trace back the path + path = np.zeros_like(component, dtype=bool) + current = dst + + # Extract offsets from the MCP object + offsets = np.array(mcp.offsets) + nrows, ncols = component.shape + + while current != src: + path[current] = True + current_offset_index = traceback[current] + if current_offset_index < 0: + # No valid path found + break + offset = offsets[current_offset_index] + # Move to the predecessor + current = (current[0] - offset[0], current[1] - offset[1]) + # Ensure indices are within bounds + if not (0 <= current[0] < nrows and 0 <= current[1] < ncols): + break + + path[src] = True # Include the source + return path + + +def _prune_skeleton_longest_path(skeleton): + pruned_skeleton = np.zeros_like(skeleton, dtype=bool) + + # Label connected components in the skeleton + labeled_skeleton, num_labels = label(skeleton, return_num=True) + + for label_id in range(1, num_labels + 1): + # Isolate the current connected component + component = (labeled_skeleton == label_id) + + # Find the endpoints of the component + endpoints = _find_endpoints(component) + if len(endpoints) < 2: + continue # Skip if there are no valid endpoints + elif len(endpoints) == 2: # Nothing to prune + pruned_skeleton |= component + continue + + # Compute the longest path using MCP + longest_path = _compute_longest_path(component, endpoints) + + # import napari + # v = napari.Viewer() + # v.add_labels(component) + # v.add_labels(longest_path) + # v.add_points(endpoints) + # napari.run() + + pruned_skeleton |= longest_path + + return pruned_skeleton.astype(skeleton.dtype) + + +def skeletonize_object( + segmentation: np.ndarray, + method: str = "skeletonize", + prune: bool = True, + min_prune_size: int = 10, +): + """Skeletonize a 3D object by inidividually skeletonizing each slice. + + Args: + + Returns: + """ + assert method in ("skeletonize", "medial_axis") + seg_thin = np.zeros_like(segmentation) + skeletor = skeletonize if method == "skeletonize" else medial_axis + # Parallelize? + for z in range(segmentation.shape[0]): + skeleton = skeletor(segmentation[z]) + + if prune: + skeleton = _prune_skeleton_longest_path(skeleton) + if min_prune_size > 0: + skeleton = label(skeleton) + ids, sizes = np.unique(skeleton, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + skeleton = np.isin(skeleton, ids[sizes >= min_prune_size]) + + seg_thin[z] = skeleton + return seg_thin