From 79191b32ce7c17d5612f6a3cdf5943e86b6b125e Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sat, 21 Sep 2024 14:00:32 +0100 Subject: [PATCH 01/16] navis.read_ivscc: basic IVSCC feature extraction --- docs/api.md | 1 + navis/morpho/__init__.py | 3 +- navis/morpho/ivscc.py | 401 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 navis/morpho/ivscc.py diff --git a/docs/api.md b/docs/api.md index 0382bf81..97d05f64 100644 --- a/docs/api.md +++ b/docs/api.md @@ -314,6 +314,7 @@ Functions to analyze morphology. | [`navis.persistence_vectors()`][navis.persistence_vectors] | {{ autosummary("navis.persistence_vectors") }} | | [`navis.strahler_index()`][navis.strahler_index] | {{ autosummary("navis.strahler_index") }} | | [`navis.segment_analysis()`][navis.segment_analysis] | {{ autosummary("navis.segment_analysis") }} | +| [`navis.ivscc_features()`][navis.ivscc_features] | {{ autosummary("navis.ivscc_features") }} | | [`navis.sholl_analysis()`][navis.sholl_analysis] | {{ autosummary("navis.sholl_analysis") }} | | [`navis.tortuosity()`][navis.tortuosity] | {{ autosummary("navis.tortuosity") }} | | [`navis.betweeness_centrality()`][navis.betweeness_centrality] | {{ autosummary("navis.betweeness_centrality") }} | diff --git a/navis/morpho/__init__.py b/navis/morpho/__init__.py index f7d4f919..234f9cea 100644 --- a/navis/morpho/__init__.py +++ b/navis/morpho/__init__.py @@ -26,6 +26,7 @@ from .persistence import (persistence_points, persistence_vectors, persistence_distances) from .fq import form_factor +from .ivscc import ivscc_features __all__ = ['strahler_index', 'bending_flow', 'flow_centrality', 'synapse_flow_centrality', @@ -37,4 +38,4 @@ 'subset_neuron', 'smooth_voxels', 'sholl_analysis', 'persistence_points', 'betweeness_centrality', 'persistence_vectors', 'persistence_distances', 'combine_neurons', - 'segment_analysis', 'form_factor'] + 'segment_analysis', 'form_factor', 'ivscc_features'] diff --git a/navis/morpho/ivscc.py b/navis/morpho/ivscc.py new file mode 100644 index 00000000..4e2bdc73 --- /dev/null +++ b/navis/morpho/ivscc.py @@ -0,0 +1,401 @@ +# This script is part of navis (http://www.github.com/navis-org/navis). +# Copyright (C) 2018 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import pandas as pd +import numpy as np + + +from abc import ABC, abstractmethod +from scipy.stats import wasserstein_distance +from typing import Union, Sequence + +from .. import config, graph, core +from . import subset_neuron, tortuosity + +# Set up logging +logger = config.get_logger(__name__) + +__all__ = sorted( + [ + "ivscc", + ] +) + +# A mapping of label IDs to compartment names +# Note: anything above 5 is considered "undefined" or "custom" +label_to_comp = { + -1: "root", + 0: "undefined", + 1: "soma", + 2: "axon", + 3: "basal_dendrite", + 4: "apical_dendrite", +} +comp_to_label = {v: k for k, v in label_to_comp.items()} + + +class CompartmentNotFoundError(Exception): + """An exception raised when a compartment is not found.""" + + pass + + +class Features(ABC): + def __init__(self, neuron: "core.TreeNeuron", label=None, verbose=False): + self.neuron = neuron + self.verbose = verbose + + if label is None: + self.label = "" + elif not label.endswith("_"): + self.label = f"{label}_" + else: + self.label = label + + # Make sure the neuron is rooted to the soma (if present) + self.soma = self.neuron.soma + if self.soma is not None: + self.soma_pos = self.neuron.soma_pos[0] + self.soma_radius = self.neuron.nodes.set_index("node_id").loc[ + self.soma, "radius" + ] + + if self.neuron.soma not in self.neuron.root: + self.neuron = self.neuron.reroot(self.neuron.soma) + + # Calculate geodesic distances from leafs to all other nodes (directed) + self.leaf_dists = graph.geodesic_matrix( + self.neuron, self.neuron.leafs.node_id.values, directed=True + ) + # Replace infinities with -1 + self.leaf_dists[self.leaf_dists == float("inf")] = -1 + + self.features = {} + + def record_feature(self, name, value): + """Record a feature.""" + self.features[f"{self.label}{name}"] = value + + @abstractmethod + def extract_features(self): + """Extract features.""" + pass + + +class BasicFeatures(Features): + """Base class for features.""" + + def extract_features(self): + """Extract basic features.""" + self.record_feature( + "extent_y", self.neuron.nodes.y.max() - self.neuron.nodes.y.min() + ) + self.record_feature( + "extent_x", self.neuron.nodes.x.max() - self.neuron.nodes.x.min() + ) + self.record_feature( + "max_branch_order", (self.neuron.nodes.type == "branch").sum() + 1 + ) + self.record_feature("num_nodes", len(self.neuron.nodes)) + self.record_feature("total_length", self.neuron.cable_length) + + if self.soma is None: + if self.verbose: + logger.warning( + f"{self.neuron.id} has no `.soma` attribute, skipping soma-related features." + ) + return + + # x/y bias from soma + # Note: this is absolute for x and relative for y + self.record_feature( + "bias_x", + abs( + (self.neuron.nodes.x.max() - self.soma_pos[0]) + - (self.soma_pos[0] - self.neuron.nodes.x.min()) + ), + ) + self.record_feature( + "bias_y", + (self.neuron.nodes.y.max() - self.soma_pos[1]) + - (self.soma_pos[1] - self.neuron.nodes.y.min()), + ) + + # Distances from soma + self.record_feature( + "max_euclidean_distance", + ( + (self.neuron.nodes[["x", "y", "z"]] - self.soma_pos) + .pow(2) + .sum(axis=1) + .pow(0.5) + .sum() + .max() + ), + ) + self.record_feature( + "max_path_length", + self.leaf_dists.loc[ + self.leaf_dists.index.isin(self.neuron.nodes.node_id) + ].values.max(), + ) + + # Tortuosity + self.record_feature("mean_contraction", tortuosity(self.neuron)) + + # Branching (number of linear segments between branch) + self.record_feature("num_branches", len(self.neuron.small_segments)) + + return self.features + + +class CompartmentFeatures(BasicFeatures): + """Base class for compartment-specific features.""" + + def __init__(self, neuron: "core.TreeNeuron", compartment, verbose=False): + if "label" not in neuron.nodes.columns: + raise ValueError( + f"No 'label' column found in node table for neuron {neuron.id}" + ) + + if ( + compartment not in neuron.nodes.label.values + and comp_to_label.get(compartment, compartment) + not in neuron.nodes.label.values + ): + raise CompartmentNotFoundError( + f"No {compartment} ({comp_to_label.get(compartment, compartment)}) compartments found in neuron {neuron.id}" + ) + + # Initialize the parent class + super().__init__(neuron, label=compartment, verbose=verbose) + + # Now subset the neuron to this compartment + self.neuron = subset_neuron( + self.neuron, + ( + self.neuron.nodes.label.isin( + (compartment, comp_to_label[compartment]) + ).values + ), + ) + + +class AxonFeatures(CompartmentFeatures): + """Extract features from an axon.""" + + def __init__(self, neuron: "core.TreeNeuron", verbose=False): + super().__init__(neuron, "axon", verbose=verbose) + + def extract_features(self): + # Extract basic features via the parent class + super().extract_features() + + # Now deal witha axon-specific features: + + if self.soma is not None: + # Distance between axon root and soma surface + # Note: we're catering for potentially multiple roots here + axon_root_pos = self.neuron.nodes.loc[ + self.neuron.nodes.type == "root", ["x", "y", "z"] + ].values + + # Closest dist between an axon root and the soma + dist = np.linalg.norm(axon_root_pos - self.soma_pos, axis=1).min() + + # Subtract soma radius from the distance + dist -= self.soma_radius + + self.record_feature("exit_distance", dist) + + # Axon theta: The relative radial position of the point where the neurite from which + # the axon derives exits the soma. + + # Get the node where the axon exits the soma + exit_node = self.neuron.nodes.loc[self.neuron.nodes.type == "root"] + + # Get theta + theta = np.arctan2( + exit_node.y.values - self.soma_pos[1], + exit_node.x.values - self.soma_pos[0], + )[0] + self.record_feature("exit_theta", theta) + + return self.features + + +class BasalDendriteFeatures(CompartmentFeatures): + """Extract features from a basal dendrite.""" + + def __init__(self, neuron: "core.TreeNeuron", verbose=False): + super().__init__(neuron, "basal_dendrite", verbose=verbose) + + def extract_features(self): + # Extract basic features via the parent class + super().extract_features() + + # Now deal with basal dendrite-specific features + if self.soma is not None: + # Number of stems sprouting from the soma + # (i.e. number of nodes with a parent that is the soma) + self.record_feature( + "calculate_number_of_stems", (self.neuron.nodes.parent_id == self.soma).sum() + ) + + return self.features + + +class ApicalDendriteFeatures(CompartmentFeatures): + """Extract features from a apical dendrite.""" + + def __init__(self, neuron: "core.TreeNeuron", verbose=False): + super().__init__(neuron, "apical_dendrite", verbose=verbose) + + def extract_features(self): + # Extract basic features via the parent class + super().extract_features() + + return self.features + + +class OverlapFeatures(Features): + """Features that compare two compartments (e.g. overlap).""" + + # Compartments to compare + compartments = ("axon", "basal_dendrite", "apical_dendrite") + + def extract_features(self): + # Iterate over compartments + for c1 in self.compartments: + if c1 in self.neuron.nodes.label.values: + c1_nodes = self.neuron.nodes[self.neuron.nodes.label == c1] + elif comp_to_label.get(c1, c1) in self.neuron.nodes.label.values: + c1_nodes = self.neuron.nodes[ + self.neuron.nodes.label == comp_to_label[c1] + ] + else: + continue + for c2 in self.compartments: + if c1 == c2: + continue + if c2 in self.neuron.nodes.label.values: + c2_nodes = self.neuron.nodes[self.neuron.nodes.label == c2] + elif comp_to_label.get(c2, c2) in self.neuron.nodes.label.values: + c2_nodes = self.neuron.nodes[ + self.neuron.nodes.label == comp_to_label[c2] + ] + else: + continue + + # Calculate % of nodes of a given compartment type above/overlapping/below the + # full y-extent of another compartment type + self.features[f"{c1}_frac_above_{c2}"] = ( + c1_nodes.y > c2_nodes.y.max() + ).sum() / len(c1_nodes) + self.features[f"{c1}_frac_intersect_{c2}"] = ( + (c1_nodes.y >= c2_nodes.y.min()) & (c1_nodes.y <= c2_nodes.y.max()) + ).sum() / len(c1_nodes) + self.features[f"{c1}_frac_below_{c2}"] = ( + c1_nodes.y < c2_nodes.y.min() + ).sum() / len(c1_nodes) + + # Calculate earth mover's distance (EMD) between the two compartments + if f"{c2}_emd_with_{c1}" not in self.features: + self.features[f"{c1}_emd_with_{c2}"] = wasserstein_distance( + c1_nodes.y, c2_nodes.y + ) + + return self.features + + +def ivscc_features( + x: "core.TreeNeuron", features=None, missing_compartments="ignore", verbose=False +) -> Union[float, pd.DataFrame]: + """Calculate IVSCC features for neuron(s). + + Please see the `IVSCC` tutorial for more details. + + Parameters + ---------- + x : TreeNeuron | NeuronList + Neuron(s) to calculate IVCSS for. + features : Sequence[Features], optional + Provide specific features to calculate. + Must be subclasses of `BasicFeatures`. + If `None`, will use default features. + missing_compartments : "ignore" | "skip" | "raise" + What to do if a neuron is missing a compartment + (e.g. no axon or basal dendrite): + - "ignore" (default): ignore that compartment + - "skip": skip the entire neuron + - "raise": raise an exception + + Returns + ------- + ivcss : pd.DataFrame + IVCSS features for the neuron(s). + + """ + + if isinstance(x, core.TreeNeuron): + x = core.NeuronList([x]) + + if features is None: + features = DEFAULT_FEATURES + + data = {} + for n in x: + data[n.id] = {} + for feat in features: + try: + f = feat(n, verbose=verbose) + except CompartmentNotFoundError as e: + if missing_compartments == "ignore": + continue + elif missing_compartments == "skip": + data.pop(n.id) + break + else: + raise e + + data[n.id].update(f.extract_features()) + + return pd.DataFrame(data) + + +def _check_compartments(n, compartments): + """Check if `compartments` are valid.""" + if compartments == "auto": + if "label" not in n.nodes.columns: + return None + return n.nodes.label.unique() + elif compartments is True: + return n.nodes.label.unique() + elif isinstance(compartments, str): + if "label" not in n.nodes.columns or compartments not in n.nodes.label.unique(): + raise ValueError(f"Compartment not present: {compartments}") + return [compartments] + elif isinstance(compartments, Sequence): + if "label" not in n.nodes.columns: + raise ValueError("No 'label' column found in node table.") + for c in compartments: + if c not in n.nodes.label.unique(): + raise ValueError(f"Compartment not present: {c}") + return compartments + elif compartments in (None, False): + return None + + raise ValueError(f"Invalid `compartments`: {compartments}") + + +DEFAULT_FEATURES = [AxonFeatures, BasalDendriteFeatures, ApicalDendriteFeatures, OverlapFeatures] From f21c106d0d178b2b82c0c9eb3eacc2bc0a029e15 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sat, 21 Sep 2024 14:01:25 +0100 Subject: [PATCH 02/16] tortuosity: by default, use segments as-is --- navis/morpho/mmetrics.py | 72 +++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/navis/morpho/mmetrics.py b/navis/morpho/mmetrics.py index d84b822c..d2475117 100644 --- a/navis/morpho/mmetrics.py +++ b/navis/morpho/mmetrics.py @@ -1271,7 +1271,7 @@ def flow_centrality(x: "core.NeuronObject") -> "core.NeuronObject": def tortuosity( x: "core.NeuronObject", - seg_length: Union[int, float, str, Sequence[Union[int, float, str]]] = 10, + seg_length: Optional[Union[int, float, str, Sequence[Union[int, float, str]]]] = None, ) -> Union[float, Sequence[float], pd.DataFrame]: """Calculate tortuosity of a neuron. @@ -1280,19 +1280,10 @@ def tortuosity( `L` (`seg_length`) to the Euclidean distance `R` between its ends. The way this is implemented in `navis`: + For each linear stretch (i.e. segments between branch points, leafs or roots) + we calculate its geodesic length `L` and the Euclidean distance `R` between + its ends. The final tortuosity is the mean of `L / R` across all segments. - 1. Each linear stretch (i.e. between branch points or branch points to a - leaf node) is divided into segments of exactly `seg_length` - geodesic length. Any remainder is skipped. - 2. For each of these segments we divide its geodesic length `L` - (i.e. `seg_length`) by the Euclidean distance `R` between its start and - its end. - 3. The final tortuosity is the mean of `L / R` across all segments. - - Note - ---- - If you want to make sure that segments are as close to length `L` as - possible, consider resampling the neuron using [`navis.resample_skeleton`][]. Parameters ---------- @@ -1300,18 +1291,27 @@ def tortuosity( Neuron to analyze. If MeshNeuron, will generate and use a skeleton representation. seg_length : int | float | str | list thereof, optional - Target segment length(s) `L`. If neuron(s) have their - `.units` set, you can also pass a string such as - "1 micron". `seg_length` must be larger than the - current sampling resolution of the neuron. + Target segment length(s) `L`. If `seg_length` is + provided, each linear segment is further divided into + segments of exactly `seg_length` (geodesic) length + and the tortuosity is calculated for each of these + sub-segments. If `seg_length` is not provided, the + tortuosity is calculated for each linear segment as is. + + If neuron(s) have their `.units` set, you can also + pass a string such as "1 micron". `seg_length` must + be larger than the current sampling resolution of the + neuron. If you want to make sure that segments are as + close to length `L` as possible, consider resampling the + neuron using [`navis.resample_skeleton`][]. Returns ------- tortuosity : float | np.array | pandas.DataFrame If x is NeuronList, will return DataFrame. If x is single TreeNeuron, will return either a - single float (if single seg_length is queried) or a - DataFrame (if multiple seg_lengths are queried). + single float (if no or a single seg_length is queried) + or a DataFrame (if multiple seg_lengths are queried). See Also -------- @@ -1323,7 +1323,11 @@ def tortuosity( -------- >>> import navis >>> n = navis.example_neurons(1) - >>> # Calculate tortuosity with 1 micron seg lengths + >>> # Calculate tortuosity as-is + >>> T = navis.tortuosity(n) + >>> round(T, 3) + 1.074 + >>> # Calculate tortuosity with 1 micron segment lengths >>> T = navis.tortuosity(n, seg_length='1 micron') >>> round(T, 3) 1.054 @@ -1356,6 +1360,34 @@ def tortuosity( if isinstance(seg_length, (list, np.ndarray)): return [tortuosity(x, l) for l in seg_length] + if seg_length is None: + return _tortuosity_simple(x) + else: + return _tortuosity_segmented(x, seg_length) + + +def _tortuosity_simple(x: "core.TreeNeuron") -> float: + """Calculate tortuosity for neuron as-is.""" + # Iterate over segments + locs = x.nodes.set_index("node_id")[["x", "y", "z"]].astype(float) + T_all = [] + for i, seg in enumerate(x.small_segments): + # Get coordinates + coords = locs.loc[seg].values + + # Calculate geodesic distance for this segment + L = np.linalg.norm(np.diff(coords.T), axis=0).sum() + + # Calculate Euclidean distance for this segment + R = np.linalg.norm(coords[0] - coords[-1]) + T = L / R + T_all = np.append(T_all, T) + + return T_all.mean() + + +def _tortuosity_segmented(x: "core.TreeNeuron", seg_length: Union[int, float, str]) -> float: + """Calculate tortuosity for segmented neuron.""" # From here on out seg length is single value seg_length: float = x.map_units(seg_length, on_error="raise") From fb9becfb62e33c6924725bbb1b6814e18c442129 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:47:20 +0100 Subject: [PATCH 03/16] Neurons: implement __add/sub__ magic methods for offsetting coordinates --- navis/core/base.py | 166 ++++++++++++++------------- navis/core/dotprop.py | 35 ++++++ navis/core/mesh.py | 33 +++++- navis/core/neuronlist.py | 2 + navis/core/skeleton.py | 51 +++++++++ navis/core/volumes.py | 236 ++++++++++++++++++++++----------------- navis/core/voxel.py | 34 +++++- 7 files changed, 370 insertions(+), 187 deletions(-) diff --git a/navis/core/base.py b/navis/core/base.py index 3ee3b415..cb0f67bd 100644 --- a/navis/core/base.py +++ b/navis/core/base.py @@ -34,7 +34,7 @@ except ImportError: xxhash = None -__all__ = ['Neuron'] +__all__ = ["Neuron"] # Set up logging logger = config.get_logger(__name__) @@ -45,8 +45,9 @@ pint.Quantity([]) -def Neuron(x: Union[nx.DiGraph, str, pd.DataFrame, 'TreeNeuron', 'MeshNeuron'], - **metadata): +def Neuron( + x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata +): """Constructor for Neuron objects. Depending on the input, either a `TreeNeuron` or a `MeshNeuron` is returned. @@ -183,10 +184,10 @@ class BaseNeuron(UnitObject): connectors: Optional[pd.DataFrame] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'units'] + SUMMARY_PROPS = ["type", "name", "units"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name'] + EQ_ATTRIBUTES = ["name"] #: Temporary attributes that need clearing when neuron data changes TEMP_ATTR = ["_memory_usage"] @@ -212,8 +213,8 @@ def __init__(self, **kwargs): def __getattr__(self, key): """Get attribute.""" - if key.startswith('has_'): - key = key[key.index('_') + 1:] + if key.startswith("has_"): + key = key[key.index("_") + 1 :] if hasattr(self, key): data = getattr(self, key) if isinstance(data, pd.DataFrame): @@ -223,7 +224,7 @@ def __getattr__(self, key): return False # This is necessary because np.any does not like strings elif isinstance(data, str): - if data == 'NA' or not data: + if data == "NA" or not data: return False return True elif utils.is_iterable(data) and len(data) > 0: @@ -231,16 +232,16 @@ def __getattr__(self, key): elif data: return True return False - elif key.startswith('n_'): - key = key[key.index('_') + 1:] + elif key.startswith("n_"): + key = key[key.index("_") + 1 :] if hasattr(self, key): data = getattr(self, key, None) if isinstance(data, pd.DataFrame): return data.shape[0] elif utils.is_iterable(data): return len(data) - elif isinstance(data, str) and data == 'NA': - return 'NA' + elif isinstance(data, str) and data == "NA": + return "NA" return None raise AttributeError(f'Attribute "{key}" not found') @@ -284,8 +285,7 @@ def __add__(self, other): """Implement addition.""" if isinstance(other, BaseNeuron): return core.NeuronList([self, other]) - else: - return NotImplemented + return NotImplemented def __imul__(self, other): """Multiplication with assignment (*=).""" @@ -295,28 +295,37 @@ def __itruediv__(self, other): """Division with assignment (/=).""" return self.__truediv__(other, copy=False) + def __iadd__(self, other): + """Addition with assignment (+=).""" + return self.__add__(other, copy=False) + + def __isub__(self, other): + """Subtraction with assignment (-=).""" + return self.__sub__(other, copy=False) + def _repr_html_(self): frame = self.summary().to_frame() - frame.columns = [''] + frame.columns = [""] # return self._gen_svg_thumbnail() + frame._repr_html_() return frame._repr_html_() def _gen_svg_thumbnail(self): """Generate 2D plot for thumbnail.""" import matplotlib.pyplot as plt + # Store some previous states prev_level = logger.getEffectiveLevel() prev_pbar = config.pbar_hide prev_int = plt.isinteractive() plt.ioff() # turn off interactive mode - logger.setLevel('WARNING') + logger.setLevel("WARNING") config.pbar_hide = True fig = plt.figure(figsize=(2, 2)) ax = fig.add_subplot(111) fig, ax = self.plot2d(connectors=False, ax=ax) output = StringIO() - fig.savefig(output, format='svg') + fig.savefig(output, format="svg") if prev_int: plt.ion() # turn on interactive mode @@ -339,9 +348,11 @@ def _clear_temp_attr(self, exclude: list = []) -> None: for a in [at for at in self.TEMP_ATTR if at not in exclude]: try: delattr(self, a) - logger.debug(f'Neuron {self.id} {hex(id(self))}: attribute {a} cleared') + logger.debug(f"Neuron {self.id} {hex(id(self))}: attribute {a} cleared") except AttributeError: - logger.debug(f'Neuron {self.id} at {hex(id(self))}: Unable to clear temporary attribute "{a}"') + logger.debug( + f'Neuron {self.id} at {hex(id(self))}: Unable to clear temporary attribute "{a}"' + ) except BaseException: raise @@ -358,8 +369,10 @@ def _register_attr(self, name, value, summary=True, temporary=False): if isinstance(value, (numbers.Number, str, bool, np.bool_, type(None))): self.SUMMARY_PROPS.append(name) else: - logger.error(f'Attribute "{name}" of type "{type(value)}" ' - 'can not be added to summary') + logger.error( + f'Attribute "{name}" of type "{type(value)}" ' + "can not be added to summary" + ) if temporary: self.TEMP_ATTR.append(name) @@ -386,14 +399,14 @@ def core_md5(self) -> str: MD5 checksum of core data. `None` if no core data. """ - hash = '' + hash = "" for prop in self.CORE_DATA: cols = None # See if we need to parse props into property and columns # e.g. "nodes:node_id,parent_id,x,y,z" - if ':' in prop: - prop, cols = prop.split(':') - cols = cols.split(',') + if ":" in prop: + prop, cols = prop.split(":") + cols = cols.split(",") if hasattr(self, prop): data = getattr(self, prop) @@ -419,9 +432,11 @@ def datatables(self) -> List[str]: @property def extents(self) -> np.ndarray: """Extents of neuron in x/y/z direction (includes connectors).""" - if not hasattr(self, 'bbox'): - raise ValueError('Neuron must implement `.bbox` (bounding box) ' - 'property to calculate extents.') + if not hasattr(self, "bbox"): + raise ValueError( + "Neuron must implement `.bbox` (bounding box) " + "property to calculate extents." + ) bbox = self.bbox return bbox[:, 1] - bbox[:, 0] @@ -432,26 +447,26 @@ def id(self) -> Any: Must be hashable. If not set, will assign a random unique identifier. Can be indexed by using the `NeuronList.idx[]` locator. """ - return getattr(self, '_id', None) + return getattr(self, "_id", None) @id.setter def id(self, value): try: hash(value) except BaseException: - raise ValueError('id must be hashable') + raise ValueError("id must be hashable") self._id = value @property def label(self) -> str: """Label (e.g. for legends).""" # If explicitly set return that label - if getattr(self, '_label', None): + if getattr(self, "_label", None): return self._label # If no label set, produce one from name + id (optional) - name = getattr(self, 'name', None) - id = getattr(self, 'id', None) + name = getattr(self, "name", None) + id = getattr(self, "id", None) # If no name, use type if not name: @@ -465,11 +480,11 @@ def label(self) -> str: try: id = str(id) except BaseException: - id = '' + id = "" # Only use ID if it is not the same as name if id and name != id: - label += f' ({id})' + label += f" ({id})" return label @@ -482,7 +497,7 @@ def label(self, value: str): @property def name(self) -> str: """Neuron name.""" - return getattr(self, '_name', None) + return getattr(self, "_name", None) @name.setter def name(self, value: str): @@ -498,10 +513,9 @@ def connectors(self, v): if isinstance(v, type(None)): self._connectors = None else: - self._connectors = utils.validate_table(v, - required=['x', 'y', 'z'], - rename=True, - restrict=False) + self._connectors = utils.validate_table( + v, required=["x", "y", "z"], rename=True, restrict=False + ) @property def presynapses(self): @@ -510,19 +524,19 @@ def presynapses(self): Requires a "type" column in connector table. Will look for type labels that include "pre" or that equal 0 or "0". """ - if not isinstance(getattr(self, 'connectors', None), pd.DataFrame): - raise ValueError('No connector table found.') + if not isinstance(getattr(self, "connectors", None), pd.DataFrame): + raise ValueError("No connector table found.") # Make an educated guess what presynapses are types = self.connectors["type"].unique() pre = [t for t in types if "pre" in str(t).lower() or t in [0, "0"]] if len(pre) == 0: - logger.debug(f'Unable to find presynapses in types: {types}') + logger.debug(f"Unable to find presynapses in types: {types}") return self.connectors.iloc[0:0] # return empty DataFrame elif len(pre) > 1: - raise ValueError(f'Found ambigous presynapse labels: {pre}') + raise ValueError(f"Found ambigous presynapse labels: {pre}") - return self.connectors[self.connectors['type'] == pre[0]] + return self.connectors[self.connectors["type"] == pre[0]] @property def postsynapses(self): @@ -531,27 +545,25 @@ def postsynapses(self): Requires a "type" column in connector table. Will look for type labels that include "post" or that equal 1 or "1". """ - if not isinstance(getattr(self, 'connectors', None), pd.DataFrame): - raise ValueError('No connector table found.') + if not isinstance(getattr(self, "connectors", None), pd.DataFrame): + raise ValueError("No connector table found.") # Make an educated guess what presynapses are types = self.connectors["type"].unique() post = [t for t in types if "post" in str(t).lower() or t in [1, "1"]] if len(post) == 0: - logger.debug(f'Unable to find postsynapses in types: {types}') + logger.debug(f"Unable to find postsynapses in types: {types}") return self.connectors.iloc[0:0] # return empty DataFrame elif len(post) > 1: - raise ValueError(f'Found ambigous postsynapse labels: {post}') - - return self.connectors[self.connectors['type'] == post[0]] - + raise ValueError(f"Found ambigous postsynapse labels: {post}") + return self.connectors[self.connectors["type"] == post[0]] @property def is_stale(self) -> bool: """Test if temporary attributes might be outdated.""" # If we know we are stale, just return True - if getattr(self, '_stale', False): + if getattr(self, "_stale", False): return True else: # Only check if we believe we are not stale @@ -561,7 +573,7 @@ def is_stale(self) -> bool: @property def is_locked(self): """Test if neuron is locked.""" - return getattr(self, '_lock', 0) > 0 + return getattr(self, "_lock", 0) > 0 @property def type(self) -> str: @@ -578,9 +590,9 @@ def bbox(self) -> np.ndarray: """Bounding box of neuron.""" raise NotImplementedError(f"Bounding box not implemented for {type(self)}.") - def convert_units(self, - to: Union[pint.Unit, str], - inplace: bool = False) -> Optional['BaseNeuron']: + def convert_units( + self, to: Union[pint.Unit, str], inplace: bool = False + ) -> Optional["BaseNeuron"]: """Convert coordinates to different unit. Only works if neuron's `.units` is not dimensionless. @@ -622,19 +634,21 @@ def convert_units(self, # Multiply by conversion factor n *= conv - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n - def copy(self, deepcopy=False) -> 'BaseNeuron': + def copy(self, deepcopy=False) -> "BaseNeuron": """Return a copy of the neuron.""" copy_fn = copy.deepcopy if deepcopy else copy.copy # Attributes not to copy - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new empty neuron x = self.__class__() # Override with this neuron's data - x.__dict__.update({k: copy_fn(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy_fn(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x @@ -645,18 +659,16 @@ def summary(self, add_props=None) -> pd.Series: # Add .id to summary if not a generic UUID if not isinstance(self.id, uuid.UUID): - props.insert(2, 'id') + props.insert(2, "id") if add_props: - props, ix = np.unique(np.append(props, add_props), - return_inverse=True) + props, ix = np.unique(np.append(props, add_props), return_inverse=True) props = props[ix] # This is to catch an annoying "UnitStrippedWarning" with pint with warnings.catch_warnings(): warnings.simplefilter("ignore") - s = pd.Series([getattr(self, at, 'NA') for at in props], - index=props) + s = pd.Series([getattr(self, at, "NA") for at in props], index=props) return s @@ -705,10 +717,11 @@ def plot3d(self, **kwargs): return plot3d(core.NeuronList(self, make_copy=False), **kwargs) - def map_units(self, - units: Union[pint.Unit, str], - on_error: Union[Literal['raise'], - Literal['ignore']] = 'raise') -> Union[int, float]: + def map_units( + self, + units: Union[pint.Unit, str], + on_error: Union[Literal["raise"], Literal["ignore"]] = "raise", + ) -> Union[int, float]: """Convert units to match neuron space. Only works if neuron's `.units` is isometric and not dimensionless. @@ -725,7 +738,7 @@ def map_units(self, See Also -------- - [`navis.to_neuron_space`][] + [`navis.core.to_neuron_space`][] The base function for this method. Examples @@ -744,8 +757,7 @@ def map_units(self, [0.125, 0.125, 0.125] """ - return core.core_utils.to_neuron_space(units, neuron=self, - on_error=on_error) + return core.core_utils.to_neuron_space(units, neuron=self, on_error=on_error) def memory_usage(self, deep=False, estimate=False): """Return estimated memory usage of this neuron. @@ -775,8 +787,8 @@ def memory_usage(self, deep=False, estimate=False): # as possible if hasattr(self, "_memory_usage"): mu = self._memory_usage - if mu['deep'] == deep and mu['estimate'] == estimate: - return mu['size'] + if mu["deep"] == deep and mu["estimate"] == estimate: + return mu["size"] size = 0 if not estimate: @@ -803,8 +815,6 @@ def memory_usage(self, deep=False, estimate=False): else: size += v.dtype.itemsize * v.shape[0] - self._memory_usage = {'deep': deep, - 'estimate': estimate, - 'size': size} + self._memory_usage = {"deep": deep, "estimate": estimate, "size": size} return size diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index aa4893b7..f1af9c15 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -182,6 +182,41 @@ def __mul__(self, other, copy=True): return n return NotImplemented + def __add__(self, other, copy=True): + """Implement addition for coordinates.""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + _ = np.add(n.points, other, out=n.points, casting='unsafe') + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] += other + + # Force recomputing of KDTree + if hasattr(n, '_tree'): + delattr(n, '_tree') + + return n + # If another neuron, return a list of neurons + elif isinstance(other, BaseNeuron): + return core.NeuronList([self, other]) + return NotImplemented + + def __sub__(self, other, copy=True): + """Implement subtraction for coordinates.""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + _ = np.subtract(n.points, other, out=n.points, casting='unsafe') + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] -= other + + # Force recomputing of KDTree + if hasattr(n, '_tree'): + delattr(n, '_tree') + + return n + return NotImplemented + def __getstate__(self): """Get state (used e.g. for pickling).""" state = {k: v for k, v in self.__dict__.items() if not callable(v)} diff --git a/navis/core/mesh.py b/navis/core/mesh.py index fc75a560..d073b3d4 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -24,11 +24,11 @@ import skeletor as sk import trimesh as tm -from io import BufferedIOBase from typing import Union, Optional from .. import utils, config, meshes, conversion, graph from .base import BaseNeuron +from .neuronlist import NeuronList from .skeleton import TreeNeuron from .core_utils import temp_property @@ -225,6 +225,35 @@ def __mul__(self, other, copy=True): return n return NotImplemented + def __add__(self, other, copy=True): + """Implement addition for coordinates (vertices, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + n = self.copy() if copy else self + _ = np.add(n.vertices, other, out=n.vertices, casting='unsafe') + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] += other + + self._clear_temp_attr() + + return n + # If another neuron, return a list of neurons + elif isinstance(other, BaseNeuron): + return NeuronList([self, other]) + return NotImplemented + + def __sub__(self, other, copy=True): + """Implement subtraction for coordinates (vertices, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + n = self.copy() if copy else self + _ = np.subtract(n.vertices, other, out=n.vertices, casting='unsafe') + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] -= other + + self._clear_temp_attr() + + return n + return NotImplemented + @property def bbox(self) -> np.ndarray: """Bounding box (includes connectors).""" @@ -308,7 +337,7 @@ def volume(self) -> float: def skeleton(self) -> 'TreeNeuron': """Skeleton representation of this neuron. - Uses [`navis.mesh2skeleton`][]. + Uses [`navis.conversion.mesh2skeleton`][]. """ if not hasattr(self, '_skeleton'): diff --git a/navis/core/neuronlist.py b/navis/core/neuronlist.py index b20f1136..ab5e6a61 100644 --- a/navis/core/neuronlist.py +++ b/navis/core/neuronlist.py @@ -566,9 +566,11 @@ def append(self, v): >>> nl = navis.example_neurons() >>> len(nl) 5 + >>> # Add a single neuron to the list >>> nl.append(nl[0]) >>> len(nl) 6 + >>> # Add a list of neurons to the list >>> nl.append(nl) >>> len(nl) 12 diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index dc64bbea..d34f01d3 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -277,6 +277,57 @@ def __mul__(self, other, copy=True): return n return NotImplemented + def __add__(self, other, copy=True): + """Implement addition for coordinates (nodes, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + if utils.is_iterable(other): + # If offset isotropic use only single value + if len(set(other)) == 1: + other == other[0] + elif len(other) != 3: + raise ValueError('Addition by list/array requires 3' + 'multipliers for x/y/z coordinates ' + f'got {len(other)}') + + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + n.nodes[['x', 'y', 'z']] += other + + # Do the connectors + if n.has_connectors: + n.connectors[['x', 'y', 'z']] += other + + n._clear_temp_attr(exclude=['classify_nodes']) + return n + # If another neuron, return a list of neurons + elif isinstance(other, BaseNeuron): + return core.NeuronList([self, other]) + return NotImplemented + + def __sub__(self, other, copy=True): + """Implement subtraction for coordinates (nodes, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + if utils.is_iterable(other): + # If offset is isotropic use only single value + if len(set(other)) == 1: + other == other[0] + elif len(other) != 3: + raise ValueError('Addition by list/array requires 3' + 'multipliers for x/y/z coordinates ' + f'got {len(other)}') + + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + n.nodes[['x', 'y', 'z']] -= other + + # Do the connectors + if n.has_connectors: + n.connectors[['x', 'y', 'z']] -= other + + n._clear_temp_attr(exclude=['classify_nodes']) + return n + return NotImplemented + def __getstate__(self): """Get state (used e.g. for pickling).""" state = {k: v for k, v in self.__dict__.items() if not callable(v)} diff --git a/navis/core/volumes.py b/navis/core/volumes.py index dcd33f05..d112b0d1 100644 --- a/navis/core/volumes.py +++ b/navis/core/volumes.py @@ -86,7 +86,7 @@ def __init__( # Trimesh return a navis.Volume instead of a trimesh.Trimesh for f in dir(trimesh.Trimesh): # Don't mess with magic/private methods - if f.startswith('_'): + if f.startswith("_"): continue # Skip properties if not callable(getattr(trimesh.Trimesh, f)): @@ -96,38 +96,40 @@ def __init__( @property def name(self): """Name of this volume.""" - return self.metadata.get('name') + return self.metadata.get("name") @name.setter def name(self, value): - self.metadata['name'] = value + self.metadata["name"] = value @property def color(self): """Color used for plotting.""" - return self.metadata.get('color') + return self.metadata.get("color") @color.setter def color(self, value): - self.metadata['color'] = value + self.metadata["color"] = value @property def id(self): """ID of this volume.""" - return self.metadata.get('id') + return self.metadata.get("id") @id.setter def id(self, value): - self.metadata['id'] = value + self.metadata["id"] = value @classmethod - def from_csv(cls, - vertices: str, - faces: str, - name: Optional[str] = None, - color: Union[str, - Sequence[Union[int, float]]] = (.85, .85, .85, .2), - volume_id: Optional[int] = None, **kwargs) -> 'Volume': + def from_csv( + cls, + vertices: str, + faces: str, + name: Optional[str] = None, + color: Union[str, Sequence[Union[int, float]]] = (0.85, 0.85, 0.85, 0.2), + volume_id: Optional[int] = None, + **kwargs, + ) -> "Volume": """Load volume from csv files containing vertices and faces. Parameters @@ -145,18 +147,19 @@ def from_csv(cls, """ if not os.path.isfile(vertices) or not os.path.isfile(faces): - raise ValueError('File(s) not found.') + raise ValueError("File(s) not found.") - with open(vertices, 'r') as f: + with open(vertices, "r") as f: reader = csv.reader(f, **kwargs) vertices = np.array([r for r in reader]).astype(float) - with open(faces, 'r') as f: + with open(faces, "r") as f: reader = csv.reader(f, **kwargs) faces = np.array([r for r in reader]).astype(int) - return cls(faces=faces, vertices=vertices, name=name, color=color, - volume_id=volume_id) + return cls( + faces=faces, vertices=vertices, name=name, color=color, volume_id=volume_id + ) def to_csv(self, filename: str, **kwargs) -> None: """Save volume as two separated csv files containing vertices and faces. @@ -170,17 +173,17 @@ def to_csv(self, filename: str, **kwargs) -> None: Keyword arguments passed to `csv.reader`. """ - for data, suffix in zip([self.faces, self.vertices], - ['_faces.csv', '_vertices.csv']): - with open(filename + suffix, 'w') as csvfile: + for data, suffix in zip( + [self.faces, self.vertices], ["_faces.csv", "_vertices.csv"] + ): + with open(filename + suffix, "w") as csvfile: writer = csv.writer(csvfile) writer.writerows(data) @classmethod - def from_json(cls, - filename: str, - import_kwargs: Dict = {}, - **init_kwargs) -> 'Volume': + def from_json( + cls, filename: str, import_kwargs: Dict = {}, **init_kwargs + ) -> "Volume": """Load volume from json file containing vertices and faces. Parameters @@ -198,13 +201,12 @@ def from_json(cls, """ if not os.path.isfile(filename): - raise ValueError('File not found.') + raise ValueError("File not found.") - with open(filename, 'r') as f: + with open(filename, "r") as f: data = json.load(f, **import_kwargs) - return cls(faces=data['faces'], - vertices=data['vertices'], **init_kwargs) + return cls(faces=data["faces"], vertices=data["vertices"], **init_kwargs) @classmethod def from_object(cls, obj: Any, **init_kwargs) -> "Volume": @@ -223,16 +225,15 @@ def from_object(cls, obj: Any, **init_kwargs) -> "Volume": navis.Volume """ - if not hasattr(obj, 'vertices') or not hasattr(obj, 'faces'): - raise ValueError('Object must have faces and vertices attributes.') + if not hasattr(obj, "vertices") or not hasattr(obj, "faces"): + raise ValueError("Object must have faces and vertices attributes.") return cls(faces=obj.faces, vertices=obj.vertices, **init_kwargs) @classmethod - def from_file(cls, - filename: str, - import_kwargs: Dict = {}, - **init_kwargs) -> 'Volume': + def from_file( + cls, filename: str, import_kwargs: Dict = {}, **init_kwargs + ) -> "Volume": """Load volume from file. Parameters @@ -253,20 +254,22 @@ def from_file(cls, """ if not os.path.isfile(filename): - raise ValueError('File not found.') + raise ValueError("File not found.") f, ext = os.path.splitext(filename) - if ext == '.json': - return cls.from_json(filename=filename, - import_kwargs=import_kwargs, - **init_kwargs) + if ext == ".json": + return cls.from_json( + filename=filename, import_kwargs=import_kwargs, **init_kwargs + ) try: import trimesh except ImportError: - raise ImportError('Unable to import: trimesh missing - please ' - 'install: "pip install trimesh"') + raise ImportError( + "Unable to import: trimesh missing - please " + 'install: "pip install trimesh"' + ) except BaseException: raise @@ -283,18 +286,18 @@ def to_json(self, filename: str) -> None: Filename to use. """ - with open(filename, 'w') as f: - json.dump({'vertices': self.vertices.tolist(), - 'faces': self.faces.tolist()}, - f) + with open(filename, "w") as f: + json.dump( + {"vertices": self.vertices.tolist(), "faces": self.faces.tolist()}, f + ) @classmethod - def combine(cls, - x: Sequence['Volume'], - name: str = 'comb_vol', - color: Union[str, - Sequence[Union[int, float]]] = (.85, .85, .85, .2) - ) -> 'Volume': + def combine( + cls, + x: Sequence["Volume"], + name: str = "comb_vol", + color: Union[str, Sequence[Union[int, float]]] = (0.85, 0.85, 0.85, 0.2), + ) -> "Volume": """Merge multiple volumes into a single object. Parameters @@ -320,7 +323,7 @@ def combine(cls, x = [x] # type: ignore if False in [isinstance(v, Volume) for v in x]: - raise TypeError('Input must be list of volumes') + raise TypeError("Input must be list of volumes") vertices: np.ndarray = np.empty((0, 3)) faces: List[List[int]] = [] @@ -329,8 +332,7 @@ def combine(cls, for vol in x: offs = len(vertices) vertices = np.append(vertices, vol.vertices, axis=0) - faces += [[f[0] + offs, f[1] + offs, f[2] + offs] - for f in vol.faces] + faces += [[f[0] + offs, f[1] + offs, f[2] + offs] for f in vol.faces] return cls(vertices=vertices, faces=faces, name=name, color=color) @@ -371,28 +373,29 @@ def __repr__(self): """ elements = [] - if hasattr(self, 'name'): + if hasattr(self, "name"): # for Trimesh - elements.append(f'name={self.name}') - if hasattr(self, 'id') and not isinstance(self.id, uuid.UUID): + elements.append(f"name={self.name}") + if hasattr(self, "id") and not isinstance(self.id, uuid.UUID): # for Trimesh - elements.append(f'id={self.id}') - if hasattr(self, 'color'): + elements.append(f"id={self.id}") + elements.append(f"units={self.units}") + if hasattr(self, "color"): # for Trimesh - elements.append(f'color={self.color}') - if hasattr(self, 'vertices'): + elements.append(f"color={self.color}") + if hasattr(self, "vertices"): # for Trimesh and PointCloud - elements.append(f'vertices.shape={self.vertices.shape}') - if hasattr(self, 'faces'): + elements.append(f"vertices.shape={self.vertices.shape}") + if hasattr(self, "faces"): # for Trimesh - elements.append(f'faces.shape={self.faces.shape}') + elements.append(f"faces.shape={self.faces.shape}") return f'' def __truediv__(self, other): """Implement division for vertices.""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() - _ = np.divide(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.divide(n.vertices, other, out=n.vertices, casting="unsafe") return n return NotImplemented @@ -400,17 +403,37 @@ def __mul__(self, other): """Implement multiplication for vertices.""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() - _ = np.multiply(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.multiply(n.vertices, other, out=n.vertices, casting="unsafe") return n return NotImplemented - def resize(self, - x: Union[float, int], - method: Union[Literal['center'], - Literal['centroid'], - Literal['normals'], - Literal['origin']] = 'center', - inplace: bool = False) -> Optional['Volume']: + def __add__(self, other): + """Implement addition for vertices.""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + n = self.copy() + _ = np.add(n.vertices, other, out=n.vertices, casting="unsafe") + return n + return NotImplemented + + def __sub__(self, other): + """Implement subtraction for vertices.""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + n = self.copy() + _ = np.subtract(n.vertices, other, out=n.vertices, casting="unsafe") + return n + return NotImplemented + + def resize( + self, + x: Union[float, int], + method: Union[ + Literal["center"], + Literal["centroid"], + Literal["normals"], + Literal["origin"], + ] = "center", + inplace: bool = False, + ) -> Optional["Volume"]: """Resize volume. Parameters @@ -457,25 +480,27 @@ def resize(self, method = method.lower() - perm_methods = ['center', 'origin', 'normals', 'centroid'] + perm_methods = ["center", "origin", "normals", "centroid"] if method not in perm_methods: - raise ValueError(f'Unknown method "{method}". Allowed ' - f'methods: {", ".join(perm_methods)}') + raise ValueError( + f'Unknown method "{method}". Allowed ' + f'methods: {", ".join(perm_methods)}' + ) if not inplace: v = self.copy() else: v = self - if method == 'normals': + if method == "normals": v.vertices = v.vertices + (v.vertex_normals * x) else: # Get the center - if method == 'center': + if method == "center": cn = np.mean(v.vertices, axis=0) - elif method == 'centroid': + elif method == "centroid": cn = v.centroid - elif method == 'origin': + elif method == "origin": cn = np.array([0, 0, 0]) # Get vector from center to each vertex @@ -488,8 +513,8 @@ def resize(self, v.vertices = vec + cn # Make sure to reset any pyoctree data on this volume - if hasattr(v, 'pyoctree'): - delattr(v, 'pyoctree') + if hasattr(v, "pyoctree"): + delattr(v, "pyoctree") if not inplace: return v @@ -517,8 +542,8 @@ def plot3d(self, **kwargs): """ from .. import plotting - if 'color' in kwargs: - self.color = kwargs['color'] + if "color" in kwargs: + self.color = kwargs["color"] return plotting.plot3d(self, **kwargs) @@ -545,19 +570,18 @@ def _outlines_3d(self, view="xy", **kwargs): """ co2d = np.array(self.to_2d(view=view, **kwargs)) - if view in ['xy', 'yx']: + if view in ["xy", "yx"]: third = np.repeat(self.center[2], co2d.shape[0]) - elif view in ['xz', 'zx']: + elif view in ["xz", "zx"]: third = np.repeat(self.center[1], co2d.shape[0]) - elif view in ['yz', 'zy']: + elif view in ["yz", "zy"]: third = np.repeat(self.center[0], co2d.shape[0]) return np.append(co2d, third.reshape(co2d.shape[0], 1), axis=1) - def to_2d(self, - alpha: float = 0.00017, - view: tuple = ('x', 'y'), - invert_y: bool = False) -> Sequence[Union[float, int]]: + def to_2d( + self, alpha: float = 0.00017, view: tuple = ("x", "y"), invert_y: bool = False + ) -> Sequence[Union[float, int]]: """Compute the 2d alpha shape (concave hull) this volume. Uses Scipy Delaunay and shapely. @@ -587,7 +611,7 @@ def add_edge(edges, edge_points, coords, i, j): edges.add((i, j)) edge_points.append(coords[[i, j]]) - accepted_views = ['x', 'z', 'y', '-x', '-z', '-y'] + accepted_views = ["x", "z", "y", "-x", "-z", "-y"] for ax in view: if ax not in accepted_views: @@ -597,14 +621,14 @@ def add_edge(edges, edge_points, coords, i, j): from shapely.ops import unary_union, polygonize # type: ignore import shapely.geometry as geometry # type: ignore except ImportError: - raise ImportError('This function needs the shapely>=1.8.0') + raise ImportError("This function needs the shapely>=1.8.0") coords: np.ndarray - map = {'x': 0, 'y': 1, 'z': 2} + map = {"x": 0, "y": 1, "z": 2} - x_ix = map[view[0].replace('-', '').replace('+', '')] - y_ix = map[view[1].replace('-', '').replace('+', '')] + x_ix = map[view[0].replace("-", "").replace("+", "")] + y_ix = map[view[1].replace("-", "").replace("+", "")] coords = self.vertices[:, [x_ix, y_ix]] @@ -614,14 +638,14 @@ def add_edge(edges, edge_points, coords, i, j): # loop over triangles: # ia, ib, ic = indices of corner points of the triangle # Note that "vertices" property was renamed to "simplices" - for ia, ib, ic in getattr(tri, 'simplices', getattr(tri, 'vertices', [])): + for ia, ib, ic in getattr(tri, "simplices", getattr(tri, "vertices", [])): pa: np.ndarray = coords[ia] # type: ignore pb: np.ndarray = coords[ib] # type: ignore pc: np.ndarray = coords[ic] # type: ignore # Lengths of sides of triangle - a = math.sqrt((pa[0] - pb[0])**2 + (pa[1] - pb[1])**2) # type: ignore - b = math.sqrt((pb[0] - pc[0])**2 + (pb[1] - pc[1])**2) # type: ignore - c = math.sqrt((pc[0] - pa[0])**2 + (pc[1] - pa[1])**2) # type: ignore + a = math.sqrt((pa[0] - pb[0]) ** 2 + (pa[1] - pb[1]) ** 2) # type: ignore + b = math.sqrt((pb[0] - pc[0]) ** 2 + (pb[1] - pc[1]) ** 2) # type: ignore + c = math.sqrt((pc[0] - pa[0]) ** 2 + (pc[1] - pa[1]) ** 2) # type: ignore # Semiperimeter of triangle s = (a + b + c) / 2.0 # Area of triangle by Heron's formula @@ -653,9 +677,11 @@ def validate(self): self.fill_holes() self.fix_normals() if not self.is_volume: - raise utils.VolumeError("Mesh is not a volume " - "(e.g. not watertight, incorrect " - "winding) and could not be fixed.") + raise utils.VolumeError( + "Mesh is not a volume " + "(e.g. not watertight, incorrect " + "winding) and could not be fixed." + ) def _force_volume(f): diff --git a/navis/core/voxel.py b/navis/core/voxel.py index ba28e829..e01f0c54 100644 --- a/navis/core/voxel.py +++ b/navis/core/voxel.py @@ -143,7 +143,7 @@ def __setstate__(self, d): self.__dict__.update(d) def __truediv__(self, other, copy=True): - """Implement division for coordinates (units, connectors).""" + """Implement division for coordinates (units, connectors, offset).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self @@ -165,7 +165,7 @@ def __truediv__(self, other, copy=True): return NotImplemented def __mul__(self, other, copy=True): - """Implement multiplication for coordinates (units, connectors).""" + """Implement multiplication for coordinates (units, connectors, offset).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self @@ -186,6 +186,36 @@ def __mul__(self, other, copy=True): return n return NotImplemented + def __add__(self, other, copy=True): + """Implement addition for coordinates (offset, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + + n.offset = n.offset + other + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] += other + + self._clear_temp_attr() + + return n + return NotImplemented + + def __sub__(self, other, copy=True): + """Implement subtraction for coordinates (offset, connectors).""" + if isinstance(other, numbers.Number) or utils.is_iterable(other): + # If a number, consider this an offset for coordinates + n = self.copy() if copy else self + + n.offset = n.offset - other + if n.has_connectors: + n.connectors.loc[:, ['x', 'y', 'z']] -= other + + self._clear_temp_attr() + + return n + return NotImplemented + @property def _base_data_type(self) -> str: """Type of data (grid or voxels) underlying this neuron.""" From 07f32f7e2838647fa1d074f9569559ab2f5dcd24 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:47:57 +0100 Subject: [PATCH 04/16] MeshNeuron: add soma_pos setter --- navis/core/mesh.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/navis/core/mesh.py b/navis/core/mesh.py index d073b3d4..bd8ea79b 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -358,6 +358,29 @@ def soma(self): """Not implemented for MeshNeurons - use `.soma_pos`.""" raise AttributeError("MeshNeurons have a soma position (`.soma_pos`), not a soma.") + @property + def soma_pos(self): + """X/Y/Z position of the soma. + + Returns `None` if no soma. + """ + return getattr(self, '_soma_pos', None) + + @soma_pos.setter + def soma_pos(self, value): + """Set soma by position.""" + if value is None: + self.soma = None + return + + try: + value = np.asarray(value).astype(np.float64).reshape(3) + except BaseException: + raise ValueError(f'Unable to convert soma position "{value}" ' + f'to numeric (3, ) numpy array.') + + self._soma_pos = value + @property def type(self) -> str: """Neuron type.""" From 99ce0d6d3008f695c5ece3a05796721de39ea683 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:48:48 +0100 Subject: [PATCH 05/16] ivscc_features: add progress bar --- navis/morpho/ivscc.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/navis/morpho/ivscc.py b/navis/morpho/ivscc.py index 4e2bdc73..44984b6e 100644 --- a/navis/morpho/ivscc.py +++ b/navis/morpho/ivscc.py @@ -249,7 +249,8 @@ def extract_features(self): # Number of stems sprouting from the soma # (i.e. number of nodes with a parent that is the soma) self.record_feature( - "calculate_number_of_stems", (self.neuron.nodes.parent_id == self.soma).sum() + "calculate_number_of_stems", + (self.neuron.nodes.parent_id == self.soma).sum(), ) return self.features @@ -319,11 +320,15 @@ def extract_features(self): def ivscc_features( - x: "core.TreeNeuron", features=None, missing_compartments="ignore", verbose=False + x: "core.TreeNeuron", + features=None, + missing_compartments="ignore", + verbose=False, + progress=True, ) -> Union[float, pd.DataFrame]: """Calculate IVSCC features for neuron(s). - Please see the `IVSCC` tutorial for more details. + Please see the `IVSCC` tutorial for details. Parameters ---------- @@ -354,7 +359,9 @@ def ivscc_features( features = DEFAULT_FEATURES data = {} - for n in x: + for n in config.tqdm( + x, desc="Calculating IVSCC features", disable=not progress or config.pbar_hide + ): data[n.id] = {} for feat in features: try: @@ -363,6 +370,8 @@ def ivscc_features( if missing_compartments == "ignore": continue elif missing_compartments == "skip": + if verbose: + print(f"Skipping neuron {n.id}: {e}") data.pop(n.id) break else: @@ -398,4 +407,9 @@ def _check_compartments(n, compartments): raise ValueError(f"Invalid `compartments`: {compartments}") -DEFAULT_FEATURES = [AxonFeatures, BasalDendriteFeatures, ApicalDendriteFeatures, OverlapFeatures] +DEFAULT_FEATURES = [ + AxonFeatures, + BasalDendriteFeatures, + ApicalDendriteFeatures, + OverlapFeatures, +] From 6a598a7551fa5bfb30c27853af89dfffea0a1057 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:49:22 +0100 Subject: [PATCH 06/16] plotting: allow passing categorical palette for numerical data --- navis/plotting/colors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/navis/plotting/colors.py b/navis/plotting/colors.py index 13a80ae0..5e38bd98 100644 --- a/navis/plotting/colors.py +++ b/navis/plotting/colors.py @@ -315,8 +315,8 @@ def vertex_colors(neurons, by, palette, alpha=1, use_alpha=False, vmin=None, vma # First check if data is numerical or categorical is_num = [utils.is_numeric(a, bool_numeric=False, try_convert=False) for a in values] - # If numerical - if all(is_num): + # If numerical and we weren't given a categorical palette + if all(is_num) and not isinstance(palette, dict): # Get min/max values if not vmin: vmin = [np.nanmin(v) for v in values] @@ -365,8 +365,8 @@ def vertex_colors(neurons, by, palette, alpha=1, use_alpha=False, vmin=None, vma colors.append(c) # We don't want to deal with mixed data - elif any(is_num): - raise ValueError('Data appears to be mixed numeric and non-numeric.') + # elif any(is_num): + # raise ValueError('Data appears to be mixed numeric and non-numeric.') else: # Find unique values unique_v = np.unique([v for l in values for v in np.unique(l)]) From 18707dc5a979053104cdc0acf777d5df747c3e78 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:50:03 +0100 Subject: [PATCH 07/16] plot2d: change default method back to 2d --- navis/plotting/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/navis/plotting/settings.py b/navis/plotting/settings.py index 4a476d77..bdb78b1f 100644 --- a/navis/plotting/settings.py +++ b/navis/plotting/settings.py @@ -127,7 +127,7 @@ class Matplotlib2dSettings(BasePlottingSettings): _name = "matplotlib backend" - method: Literal["2d", "3d", "3d_complex"] = "3d" + method: Literal["2d", "3d", "3d_complex"] = "2d" group_neurons: bool = False autoscale: bool = True orthogonal: bool = True From d92aa2e8b36173edee36613fed65546b19d78a24 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:51:16 +0100 Subject: [PATCH 08/16] resample_skeleton: resample arbitrary numeric and non-numerical columns --- navis/sampling/resampling.py | 189 ++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 78 deletions(-) diff --git a/navis/sampling/resampling.py b/navis/sampling/resampling.py index b14bfeaa..d0d07b30 100644 --- a/navis/sampling/resampling.py +++ b/navis/sampling/resampling.py @@ -19,7 +19,7 @@ import scipy.spatial import scipy.interpolate -from typing import Union, Optional, List, overload +from typing import Union, Optional, List from typing_extensions import Literal from .. import config, core, utils, graph @@ -30,29 +30,12 @@ __all__ = ['resample_skeleton', 'resample_along_axis'] -@overload -def resample_skeleton(x: 'core.TreeNeuron', - resample_to: int, - inplace: bool = False, - method: str = 'linear', - skip_errors: bool = True - ) -> 'core.TreeNeuron': ... - - -@overload -def resample_skeleton(x: 'core.NeuronList', - resample_to: int, - inplace: bool = False, - method: str = 'linear', - skip_errors: bool = True - ) -> 'core.NeuronList': ... - - @utils.map_neuronlist(desc='Resampling', allow_parallel=True) def resample_skeleton(x: 'core.NeuronObject', resample_to: Union[int, str], inplace: bool = False, method: str = 'linear', + map_columns: Optional[list] = None, skip_errors: bool = True ) -> Optional['core.NeuronObject']: """Resample skeleton(s) to given resolution. @@ -85,6 +68,11 @@ def resample_skeleton(x: 'core.NeuronObject', method : str, optional See `scipy.interpolate.interp1d` for possible options. By default, we're using linear interpolation. + map_columns : list of str, optional + Names of additional columns to carry over to the resampled + neuron. Numerical columns will be interpolated according to + `method`. Non-numerical columns will be interpolated + using nearest neighbour interpolation. inplace : bool, optional If True, will modify original neuron. If False, a resampled copy is returned. @@ -127,14 +115,43 @@ def resample_skeleton(x: 'core.NeuronObject', raise TypeError(f'Unable to resample data of type "{type(x)}"') # Map units (non-str are just passed through) - resample_to = x.map_units(resample_to, on_error='raise') + resample_to = x.map_units(resample_to, on_error="raise") if not inplace: x = x.copy() - # Collect some information for later - locs = dict(zip(x.nodes.node_id.values, x.nodes[['x', 'y', 'z']].values)) - radii = dict(zip(x.nodes.node_id.values, x.nodes.radius.values)) + num_cols = ["x", "y", "z", "radius"] + non_num_cols = [] + + if map_columns: + if isinstance(map_columns, str): + map_columns = [map_columns] + + for col in map_columns: + if col in num_cols or col in non_num_cols: + continue + if col not in x.nodes.columns: + raise ValueError(f'Column "{col}" not found in node table') + if pd.api.types.is_numeric_dtype(x.nodes[col].dtype): + num_cols.append(col) + else: + non_num_cols.append(col) + + # Collect coordinates + locs = dict(zip(x.nodes.node_id.values, x.nodes[["x", "y", "z"]].values)) + + # Collect values for all columns + values = { + col: dict(zip(x.nodes.node_id.values, x.nodes[col].values)) + for col in num_cols + non_num_cols + } + + # For categorical columns, we need to translate them to numerical values + cat2num = {} + num2cat = {} + for col in non_num_cols: + cat2num[col] = {c: i for i, c in enumerate(x.nodes[col].unique())} + num2cat[col] = {i: c for c, i in cat2num[col].items()} new_nodes: List = [] max_tn_id = x.nodes.node_id.max() + 1 @@ -146,7 +163,7 @@ def resample_skeleton(x: 'core.NeuronObject', # Get coordinates coords = np.vstack([locs[n] for n in seg]) # Get radii - rad = [radii[tn] for tn in seg] + # rad = [radii[tn] for tn in seg] # Vecs between subsequently measured points vecs = np.diff(coords.T) @@ -156,83 +173,99 @@ def resample_skeleton(x: 'core.NeuronObject', dist = np.insert(dist, 0, 0) # If path is too short, just keep the first and last node - if dist[-1] < resample_to or (method == 'cubic' and len(seg) <= 3): - new_nodes += [[seg[0], seg[-1], - coords[0][0], coords[0][1], coords[0][2], - radii[seg[0]]]] + if dist[-1] < resample_to or (method == "cubic" and len(seg) <= 3): + new_nodes += [ + [seg[0], seg[-1]] + [values[c][seg[0]] for c in num_cols + non_num_cols] + ] continue # Distances (i.e. resolution) of interpolation n_nodes = np.round(dist[-1] / resample_to) new_dist = np.linspace(dist[0], dist[-1], int(n_nodes)) - try: - sampleX = scipy.interpolate.interp1d(dist, coords[:, 0], - kind=method) - sampleY = scipy.interpolate.interp1d(dist, coords[:, 1], - kind=method) - sampleZ = scipy.interpolate.interp1d(dist, coords[:, 2], - kind=method) - sampleR = scipy.interpolate.interp1d(dist, rad, - kind=method) - except ValueError as e: - if skip_errors: - errors += 1 - new_nodes += x.nodes.loc[x.nodes.node_id.isin(seg[:-1]), - ['node_id', 'parent_id', - 'x', 'y', 'z', - 'radius']].values.tolist() - continue - else: - raise e - - # Sample each dim - xnew = sampleX(new_dist) - ynew = sampleY(new_dist) - znew = sampleZ(new_dist) - rnew = sampleR(new_dist) - - # Generate new coordinates - new_coords = np.array([xnew, ynew, znew]).T + samples = {} + # Interpolate numerical columns + for col in num_cols: + try: + samples[col] = scipy.interpolate.interp1d( + dist, [values[col][n] for n in seg], kind=method + ) + except ValueError as e: + if skip_errors: + errors += 1 + new_nodes += x.nodes.loc[ + x.nodes.node_id.isin(seg[:-1]), + ["node_id", "parent_id"] + num_cols + non_num_cols, + ].values.tolist() + continue + else: + raise e + # Interpolate non-numerical columns + for col in non_num_cols: + try: + samples[col] = scipy.interpolate.interp1d( + dist, [cat2num[col][values[col][n]] for n in seg], kind="nearest" + ) + except ValueError as e: + if skip_errors: + errors += 1 + new_nodes += x.nodes.loc[ + x.nodes.node_id.isin(seg[:-1]), + ["node_id", "parent_id"] + num_cols + non_num_cols, + ].values.tolist() + continue + else: + raise e + + # Sample each column + new_values = {} + for col in num_cols: + new_values[col] = samples[col](new_dist) + for col in non_num_cols: + new_values[col] = [num2cat[col][int(samples[col](d))] for d in new_dist] # Generate new ids (start and end node IDs of this segment are kept) - new_ids = np.concatenate((seg[:1], [max_tn_id + i for i in range(len(new_coords) - 2)], seg[-1:])) + new_ids = np.concatenate( + (seg[:1], [max_tn_id + i for i in range(len(new_dist) - 2)], seg[-1:]) + ) # Increase max index max_tn_id += len(new_ids) # Keep track of new nodes - new_nodes += [[tn, pn, co[0], co[1], co[2], r] - for tn, pn, co, r in zip(new_ids[:-1], - new_ids[1:], - new_coords, - rnew)] + new_nodes += [ + [tn, pn] + [new_values[c][i] for c in num_cols + non_num_cols] + for i, (tn, pn) in enumerate(zip(new_ids[:-1], new_ids[1:])) + ] if errors: - logger.warning(f'{errors} ({errors/i:.0%}) segments skipped due to ' - 'errors') + logger.warning(f"{errors} ({errors/i:.0%}) segments skipped due to " "errors") # Add root node(s) - root = x.nodes.loc[x.nodes.node_id.isin(utils.make_iterable(x.root)), - ['node_id', 'parent_id', 'x', 'y', 'z', 'radius']] + root = x.nodes.loc[ + x.nodes.node_id.isin(utils.make_iterable(x.root)), + ["node_id", "parent_id"] + num_cols + non_num_cols, + ] new_nodes += [list(r) for r in root.values] # Generate new nodes dataframe - new_nodes = pd.DataFrame(data=new_nodes, - columns=['node_id', 'parent_id', - 'x', 'y', 'z', 'radius']) + new_nodes = pd.DataFrame( + data=new_nodes, columns=["node_id", "parent_id"] + num_cols + non_num_cols + ) # Convert columns to appropriate dtypes - dtypes = {k: x.nodes[k].dtype for k in ['node_id', 'parent_id', 'x', 'y', 'z', 'radius']} + dtypes = { + k: x.nodes[k].dtype for k in ["node_id", "parent_id"] + num_cols + non_num_cols + } for cols in new_nodes.columns: - new_nodes = new_nodes.astype(dtypes, errors='ignore') + new_nodes = new_nodes.astype(dtypes, errors="ignore") # Remove duplicate nodes (branch points) new_nodes = new_nodes[~new_nodes.node_id.duplicated()] # Generate KDTree - tree = scipy.spatial.cKDTree(new_nodes[['x', 'y', 'z']].values) + tree = scipy.spatial.cKDTree(new_nodes[["x", "y", "z"]].values) # Map soma onto new nodes if required # Note that if `._soma` is a soma detection function we can't tell # how to deal with it. Ideally the new soma node will @@ -241,10 +274,10 @@ def resample_skeleton(x: 'core.NeuronObject', # than one soma is detected now. Also a "label" column in the node # table would be lost at this point. # We will go for the easy option which is to pin the soma at this point. - nodes = x.nodes.set_index('node_id', inplace=False) - if np.any(getattr(x, 'soma')): + nodes = x.nodes.set_index("node_id", inplace=False) + if np.any(getattr(x, "soma")): soma_nodes = utils.make_iterable(x.soma) - old_pos = nodes.loc[soma_nodes, ['x', 'y', 'z']].values + old_pos = nodes.loc[soma_nodes, ["x", "y", "z"]].values # Get nearest neighbours dist, ix = tree.query(old_pos) @@ -266,13 +299,13 @@ def resample_skeleton(x: 'core.NeuronObject', # Map connectors back if necessary if x.has_connectors: # Get position of old synapse-bearing nodes - old_tn_position = nodes.loc[x.connectors.node_id, ['x', 'y', 'z']].values + old_tn_position = nodes.loc[x.connectors.node_id, ["x", "y", "z"]].values # Get nearest neighbours dist, ix = tree.query(old_tn_position) # Map back onto neuron - x.connectors['node_id'] = new_nodes.node_id.values[ix] + x.connectors["node_id"] = new_nodes.node_id.values[ix] # Map tags back if necessary # Expects `tags` to be a dictionary {'tag': [node_id1, node_id2, ...]} @@ -281,7 +314,7 @@ def resample_skeleton(x: 'core.NeuronObject', nodes_to_remap = list({n for l in x.tags.values() for n in l}) # Get position of old tag-bearing nodes - old_tn_position = nodes.loc[nodes_to_remap, ['x', 'y', 'z']].values + old_tn_position = nodes.loc[nodes_to_remap, ["x", "y", "z"]].values # Get nearest neighbours dist, ix = tree.query(old_tn_position) From 8989525bee8f92ed75e3bc8319828172e1a84aec Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:52:50 +0100 Subject: [PATCH 09/16] plotting: use linewidth parameter for skeleton to mesh conversion --- navis/plotting/dd.py | 6 +++++- navis/plotting/k3d/k3d_objects.py | 28 ++++++++++++++++++++-------- navis/plotting/plotly/graph_objs.py | 10 ++++++++-- navis/plotting/vispy/visuals.py | 10 ++++++++-- 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/navis/plotting/dd.py b/navis/plotting/dd.py index 8630565b..f28ae14e 100644 --- a/navis/plotting/dd.py +++ b/navis/plotting/dd.py @@ -560,7 +560,11 @@ def plot2d( ) fig._radius_warned = True - _neuron = conversion.tree2meshneuron(neuron, warn_missing_radii=False) + _neuron = conversion.tree2meshneuron( + neuron, + warn_missing_radii=False, + radius_scale_factor=settings.get("linewidth", 1), + ) _neuron.connectors = neuron.connectors neuron = _neuron diff --git a/navis/plotting/k3d/k3d_objects.py b/navis/plotting/k3d/k3d_objects.py index 28c0693c..76a7a63c 100644 --- a/navis/plotting/k3d/k3d_objects.py +++ b/navis/plotting/k3d/k3d_objects.py @@ -109,7 +109,9 @@ def neuron2k3d(x, colormap, settings): if isinstance(neuron, core.TreeNeuron) and settings.radius: # Warn once if more than 5% of nodes have missing radii if not _radius_warned: - if ((neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes) > 0.05: + if ( + (neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes + ) > 0.05: logger.warning( "Some skeleton nodes have radius <= 0. This may lead to " "rendering artifacts. Set `radius=False` to plot skeletons " @@ -117,7 +119,11 @@ def neuron2k3d(x, colormap, settings): ) _radius_warned = True - _neuron = conversion.tree2meshneuron(neuron, warn_missing_radii=False) + _neuron = conversion.tree2meshneuron( + neuron, + warn_missing_radii=False, + radius_scale_factor=settings.get("linewidth", 1), + ) _neuron.connectors = neuron.connectors neuron = _neuron @@ -168,17 +174,21 @@ def neuron2k3d(x, colormap, settings): # Add connectors if (settings.connectors or settings.connectors_only) and neuron.has_connectors: if isinstance(settings.connectors, (list, np.ndarray, tuple)): - connectors = neuron.connectors[neuron.connectors.type.isin(settings.connectors)] - elif settings.connectors == 'pre': + connectors = neuron.connectors[ + neuron.connectors.type.isin(settings.connectors) + ] + elif settings.connectors == "pre": connectors = neuron.presynapses - elif settings.connectors == 'post': + elif settings.connectors == "post": connectors = neuron.postsynapses elif isinstance(settings.connectors, str): - connectors = neuron.connectors[neuron.connectors.type == settings.connectors] + connectors = neuron.connectors[ + neuron.connectors.type == settings.connectors + ] else: connectors = neuron.connectors - for j, this_cn in connectors.groupby('type'): + for j, this_cn in connectors.groupby("type"): if isinstance(settings.cn_colors, dict): c = settings.cn_colors.get( j, cn_lay.get(j, {"color": (10, 10, 10)})["color"] @@ -204,7 +214,9 @@ def neuron2k3d(x, colormap, settings): positions=this_cn[["x", "y", "z"]].values, name=cn_label, shader="flat", - point_size=settings.cn_size if settings.cn_size else cn_lay['size'] * 50, + point_size=settings.cn_size + if settings.cn_size + else cn_lay["size"] * 50, color=c, ) ) diff --git a/navis/plotting/plotly/graph_objs.py b/navis/plotting/plotly/graph_objs.py index 9bbbb689..0a371f9c 100644 --- a/navis/plotting/plotly/graph_objs.py +++ b/navis/plotting/plotly/graph_objs.py @@ -138,7 +138,9 @@ def neuron2plotly(x, colormap, settings): if isinstance(neuron, core.TreeNeuron) and settings.radius: # Warn once if more than 5% of nodes have missing radii if not _radius_warned: - if ((neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes) > 0.05: + if ( + (neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes + ) > 0.05: logger.warning( "Some skeleton nodes have radius <= 0. This may lead to " "rendering artifacts. Set `radius=False` to plot skeletons " @@ -146,7 +148,11 @@ def neuron2plotly(x, colormap, settings): ) _radius_warned = True - _neuron = conversion.tree2meshneuron(neuron, warn_missing_radii=False) + _neuron = conversion.tree2meshneuron( + neuron, + warn_missing_radii=False, + radius_scale_factor=settings.get("linewidth", 1), + ) _neuron.connectors = neuron.connectors neuron = _neuron diff --git a/navis/plotting/vispy/visuals.py b/navis/plotting/vispy/visuals.py index 92bcf152..9b756091 100644 --- a/navis/plotting/vispy/visuals.py +++ b/navis/plotting/vispy/visuals.py @@ -266,7 +266,9 @@ def neuron2vispy(x, settings): if isinstance(neuron, core.TreeNeuron) and settings.radius: # Warn once if more than 5% of nodes have missing radii if not _radius_warned: - if ((neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes) > 0.05: + if ( + (neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes + ) > 0.05: logger.warning( "Some skeleton nodes have radius <= 0. This may lead to " "rendering artifacts. Set `radius=False` to plot skeletons " @@ -274,7 +276,11 @@ def neuron2vispy(x, settings): ) _radius_warned = True - _neuron = conversion.tree2meshneuron(neuron, warn_missing_radii=False) + _neuron = conversion.tree2meshneuron( + neuron, + warn_missing_radii=False, + radius_scale_factor=settings.get("linewidth", 1), + ) _neuron.connectors = neuron.connectors neuron = _neuron From 820cc635fb3449f65c5b9aa256b7ae7acf22ce79 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:53:21 +0100 Subject: [PATCH 10/16] plot2d: fix error when passing a non-matplotlib palette --- navis/plotting/dd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/navis/plotting/dd.py b/navis/plotting/dd.py index f28ae14e..058eb6ff 100644 --- a/navis/plotting/dd.py +++ b/navis/plotting/dd.py @@ -1124,7 +1124,7 @@ def _plot_skeleton(neuron, color, ax, settings): ) ax.add_line(this_line) else: - if settings.palette: + if isinstance(settings.palette, str): cmap = plt.get_cmap(settings.palette) else: cmap = DEPTH_CMAP From dc972774e2415c537df1f187b4dc599f16dd62d0 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:54:00 +0100 Subject: [PATCH 11/16] plot2d: fix issue when color_by is name of both neuron AND node property --- navis/plotting/dd.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/navis/plotting/dd.py b/navis/plotting/dd.py index 058eb6ff..245a4c9f 100644 --- a/navis/plotting/dd.py +++ b/navis/plotting/dd.py @@ -370,8 +370,9 @@ def plot2d( # Parse objects (neurons, volumes, points, _) = utils.parse_objects(x) - # Color_by can be a per-node/vertex color, or a per-neuron color - # such as property of the neuron + # Here we check whether `color_by` is a neuron property which we + # want to translate into a single color per neuron, or a + # per node/vertex property which we will parse late color_neurons_by = None if settings.color_by is not None and neurons: if not settings.palette: @@ -380,9 +381,18 @@ def plot2d( "when using `color_by` argument." ) - # Check if this is a neuron property + # Check if this may be a neuron property if isinstance(settings.color_by, str): - if hasattr(neurons[0], settings.color_by): + # Check if this could be a neuron property + has_prop = hasattr(neurons[0], settings.color_by) + + # For TreeNeurons, we also check if it is a node property + # If so, prioritize this. + if isinstance(neurons[0], core.TreeNeuron): + if settings.color_by in neurons[0].nodes.columns: + has_prop = False + + if has_prop: # If it is, use it to color neurons color_neurons_by = [ getattr(neuron, settings.color_by) for neuron in neurons @@ -393,7 +403,7 @@ def plot2d( color_neurons_by = settings.color_by settings.color_by = None - # Generate the colormaps + # Generate the per-neuron colors (neuron_cmap, volumes_cmap) = prepare_colormap( settings.color, neurons=neurons, From a89667301c8fecc469aa1f0cdb00ea858c2e6f85 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:54:19 +0100 Subject: [PATCH 12/16] plot2d: allow passing `soma` as a dict to customize the soma appearance --- navis/plotting/dd.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/navis/plotting/dd.py b/navis/plotting/dd.py index 245a4c9f..14cd6e36 100644 --- a/navis/plotting/dd.py +++ b/navis/plotting/dd.py @@ -84,11 +84,13 @@ def plot2d( Object parameters ----------------- - soma : bool, default=True + soma : bool | dict, default=True Plot soma if one exists. Size of the soma is determined by the neuron's `.soma_radius` property which defaults - to the "radius" column for `TreeNeurons`. + to the "radius" column for `TreeNeurons`. You can also + pass `soma` as a dictionary to customize the appearance + of the soma - for example `soma={"color": "red", "lw": 2, "ec": 1}`. radius : "auto" (default) | bool @@ -561,8 +563,11 @@ def plot2d( if isinstance(neuron, core.TreeNeuron) and settings.radius: # Warn once if more than 5% of nodes have missing radii - if not getattr(fig, '_radius_warned', False): - if ((neuron.nodes.radius.fillna(0).values <= 0).sum() / neuron.n_nodes) > 0.05: + if not getattr(fig, "_radius_warned", False): + if ( + (neuron.nodes.radius.fillna(0).values <= 0).sum() + / neuron.n_nodes + ) > 0.05: logger.warning( "Some skeleton nodes have radius <= 0. This may lead to " "rendering artifacts. Set `radius=False` to plot skeletons " @@ -1192,9 +1197,7 @@ def _plot_skeleton(neuron, color, ax, settings): d = [n.x, n.y, n.z][_get_depth_axis(settings.view)] soma_color = DEPTH_CMAP(settings.norm(d)) - sx, sy = _parse_view2d(np.array([[n.x, n.y, n.z]]), settings.view) - c = mpatches.Circle( - (sx[0], sy[0]), + soma_defaults = dict( radius=r, fill=True, fc=soma_color, @@ -1202,6 +1205,11 @@ def _plot_skeleton(neuron, color, ax, settings): zorder=4, edgecolor="none", ) + if isinstance(settings.soma, dict): + soma_defaults.update(settings.soma) + + sx, sy = _parse_view2d(np.array([[n.x, n.y, n.z]]), settings.view) + c = mpatches.Circle((sx[0], sy[0]), **soma_defaults) ax.add_patch(c) return None, None @@ -1292,14 +1300,17 @@ def _plot_skeleton(neuron, color, ax, settings): x = r * np.outer(np.cos(u), np.sin(v)) + n.x y = r * np.outer(np.sin(u), np.sin(v)) + n.y z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + n.z - surf = ax.plot_surface( - x, - y, - z, + + soma_defaults = dict( color=soma_color, shade=settings.mesh_shade, rasterized=settings.rasterize, ) + if isinstance(settings.soma, dict): + soma_defaults.update(settings.soma) + + surf = ax.plot_surface(x, y, z, **soma_defaults) + if settings.group_neurons: surf.set_gid(neuron.id) From 644ead895613d4e624253f81ca39061090945c16 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:55:02 +0100 Subject: [PATCH 13/16] cable_length: new "mask" parameter --- navis/morpho/mmetrics.py | 46 ++++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/navis/morpho/mmetrics.py b/navis/morpho/mmetrics.py index d2475117..dcd8c415 100644 --- a/navis/morpho/mmetrics.py +++ b/navis/morpho/mmetrics.py @@ -165,14 +165,14 @@ def strahler_index( raise ValueError(f'`method` must be "standard" or "greedy", got "{method}"') if utils.fastcore: - x.nodes['strahler_index'] = utils.fastcore.strahler_index( + x.nodes["strahler_index"] = utils.fastcore.strahler_index( x.nodes.node_id.values, x.nodes.parent_id.values, method=method, to_ignore=to_ignore, min_twig_size=min_twig_size, ).astype(np.int16) - x.nodes['strahler_index'] = x.nodes.strahler_index.fillna(1) + x.nodes["strahler_index"] = x.nodes.strahler_index.fillna(1) return x # Find branch, root and end nodes @@ -1688,13 +1688,17 @@ def betweeness_centrality( @utils.map_neuronlist(desc="Cable length", allow_parallel=True) @utils.meshneuron_skeleton(method="pass_through") -def cable_length(x) -> Union[int, float]: +def cable_length(x, mask=None) -> Union[int, float]: """Calculate cable length. Parameters ---------- x : TreeNeuron | MeshNeuron | NeuronList Neuron(s) for which to calculate cable length. + mask : None | boolean array | callable + If provided, will only consider nodes where + `mask` is True. Callable must accept a DataFrame of nodes + and return a boolean array of the same length. Returns ------- @@ -1704,25 +1708,49 @@ def cable_length(x) -> Union[int, float]: """ utils.eval_param(x, name="x", allowed_types=(core.TreeNeuron,)) + nodes = x.nodes + if mask is not None: + if callable(mask): + mask = mask(x.nodes) + + if isinstance(mask, np.ndarray): + if len(mask) != len(x.nodes): + raise ValueError( + f"Length of mask ({len(mask)}) must match number of nodes " + f"({len(x.nodes)})." + ) + else: + raise ValueError( + f"Mask must be callable or boolean array, got {type(mask)}" + ) + + nodes = x.nodes.loc[mask, ['node_id','parent_id', 'x', 'y', 'z']].copy() + + # Set the parent IDs to -1 for nodes that are not in the mask + nodes.loc[~nodes.parent_id.isin(nodes.node_id), "parent_id"] = -1 + + if not len(nodes): + return 0 + # See if we can use fastcore if not utils.fastcore: # The by far fastest way to get the cable length is to work on the node table # Using the igraph representation is about the same speed... if it is already calculated! # However, one problem with the graph representation is that with large neuronlists # it adds a lot to the memory footprint. - not_root = (x.nodes.parent_id >= 0).values - xyz = x.nodes[["x", "y", "z"]].values[not_root] + not_root = (nodes.parent_id >= 0).values + xyz = nodes[["x", "y", "z"]].values[not_root] xyz_parent = ( x.nodes.set_index("node_id") - .loc[x.nodes.parent_id.values[not_root], ["x", "y", "z"]] + .loc[nodes.parent_id.values[not_root], ["x", "y", "z"]] .values ) cable_length = np.sum(np.linalg.norm(xyz - xyz_parent, axis=1)) else: cable_length = utils.fastcore.dag.parent_dist( - x.nodes.node_id.values, - x.nodes.parent_id.values, - x.nodes[["x", "y", "z"]].values, + nodes.node_id.values, + nodes.parent_id.values, + nodes[["x", "y", "z"]].values, root_dist=0, ).sum() From 519f8ab1681db61558e18ce3e4c89da8a9cc25ab Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 15:56:50 +0100 Subject: [PATCH 14/16] prune_twigs: new "mask" parameter --- navis/morpho/manipulation.py | 952 ++++++++++++++++++++--------------- 1 file changed, 534 insertions(+), 418 deletions(-) diff --git a/navis/morpho/manipulation.py b/navis/morpho/manipulation.py index 02586b66..0a19cd93 100644 --- a/navis/morpho/manipulation.py +++ b/navis/morpho/manipulation.py @@ -12,8 +12,8 @@ # GNU General Public License for more details. -""" This module contains functions to analyse and manipulate neuron morphology. -""" +"""This module contains functions to analyse and manipulate neuron morphology.""" + import warnings import pandas as pd @@ -24,7 +24,7 @@ from collections import namedtuple from itertools import combinations from scipy.ndimage import gaussian_filter -from typing import Union, Optional, Sequence, List, Set +from typing import Union, Optional, Sequence, List, Set, Callable from typing_extensions import Literal try: @@ -38,26 +38,40 @@ # Set up logging logger = config.get_logger(__name__) -__all__ = sorted(['prune_by_strahler', 'stitch_skeletons', 'split_axon_dendrite', - 'average_skeletons', 'despike_skeleton', 'guess_radius', - 'smooth_skeleton', 'smooth_voxels', - 'heal_skeleton', 'cell_body_fiber', - 'break_fragments', 'prune_twigs', 'prune_at_depth', - 'drop_fluff', 'combine_neurons']) - -NeuronObject = Union['core.NeuronList', 'core.TreeNeuron'] - - -@utils.map_neuronlist(desc='Pruning', allow_parallel=True) -@utils.meshneuron_skeleton(method='subset') -def cell_body_fiber(x: NeuronObject, - method: Union[Literal['longest_neurite'], - Literal['betweenness']] = 'betweenness', - reroot_soma: bool = True, - heal: bool = True, - threshold: float = 0.95, - inverse: bool = False, - inplace: bool = False): +__all__ = sorted( + [ + "prune_by_strahler", + "stitch_skeletons", + "split_axon_dendrite", + "average_skeletons", + "despike_skeleton", + "guess_radius", + "smooth_skeleton", + "smooth_voxels", + "heal_skeleton", + "cell_body_fiber", + "break_fragments", + "prune_twigs", + "prune_at_depth", + "drop_fluff", + "combine_neurons", + ] +) + +NeuronObject = Union["core.NeuronList", "core.TreeNeuron"] + + +@utils.map_neuronlist(desc="Pruning", allow_parallel=True) +@utils.meshneuron_skeleton(method="subset") +def cell_body_fiber( + x: NeuronObject, + method: Union[Literal["longest_neurite"], Literal["betweenness"]] = "betweenness", + reroot_soma: bool = True, + heal: bool = True, + threshold: float = 0.95, + inverse: bool = False, + inplace: bool = False, +): """Prune neuron to its cell body fiber. Here, "cell body fiber" (CBF) refers to the tract connecting the soma to the @@ -113,29 +127,31 @@ def cell_body_fiber(x: NeuronObject, under the hood for `method='betweeness'`. """ - utils.eval_param(method, 'method', - allowed_values=('longest_neurite', 'betweenness')) + utils.eval_param( + method, "method", allowed_values=("longest_neurite", "betweenness") + ) # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Expected TreeNeuron(s), got {type(x)}') + raise TypeError(f"Expected TreeNeuron(s), got {type(x)}") if not inplace: x = x.copy() if x.n_trees > 1 and heal: - _ = heal_skeleton(x, method='LEAFS', inplace=True) + _ = heal_skeleton(x, method="LEAFS", inplace=True) # If no branches, just return the neuron - if 'branch' not in x.nodes.type.values: + if "branch" not in x.nodes.type.values: return x if reroot_soma and not isinstance(x.soma, type(None)): x.reroot(x.soma, inplace=True) # Find main branch point - cut = graph.find_main_branchpoint(x, method=method, threshold=threshold, - reroot_soma=False) + cut = graph.find_main_branchpoint( + x, method=method, threshold=threshold, reroot_soma=False + ) # Find the path to root (and account for multiple roots) for r in x.root: @@ -157,14 +173,16 @@ def cell_body_fiber(x: NeuronObject, return x -@utils.map_neuronlist(desc='Pruning', allow_parallel=True) -@utils.meshneuron_skeleton(method='subset') -def prune_by_strahler(x: NeuronObject, - to_prune: Union[int, List[int], range, slice], - inplace: bool = False, - reroot_soma: bool = True, - force_strahler_update: bool = False, - relocate_connectors: bool = False) -> NeuronObject: +@utils.map_neuronlist(desc="Pruning", allow_parallel=True) +@utils.meshneuron_skeleton(method="subset") +def prune_by_strahler( + x: NeuronObject, + to_prune: Union[int, List[int], range, slice], + inplace: bool = False, + reroot_soma: bool = True, + force_strahler_update: bool = False, + relocate_connectors: bool = False, +) -> NeuronObject: """Prune neuron based on [Strahler order](https://en.wikipedia.org/wiki/Strahler_number). Parameters @@ -209,7 +227,7 @@ def prune_by_strahler(x: NeuronObject, """ # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Expected TreeNeuron(s), got {type(x)}') + raise TypeError(f"Expected TreeNeuron(s), got {type(x)}") # Make a copy if necessary before making any changes neuron = x @@ -219,7 +237,7 @@ def prune_by_strahler(x: NeuronObject, if reroot_soma and not isinstance(neuron.soma, type(None)): neuron.reroot(neuron.soma, inplace=True) - if 'strahler_index' not in neuron.nodes or force_strahler_update: + if "strahler_index" not in neuron.nodes or force_strahler_update: mmetrics.strahler_index(neuron) # Prepare indices @@ -228,8 +246,10 @@ def prune_by_strahler(x: NeuronObject, if isinstance(to_prune, int): if to_prune < 1: - raise ValueError('SI to prune must be positive. Please see docs' - 'for additional options.') + raise ValueError( + "SI to prune must be positive. Please see docs" + "for additional options." + ) to_prune = [to_prune] elif isinstance(to_prune, range): to_prune = list(to_prune) @@ -239,26 +259,31 @@ def prune_by_strahler(x: NeuronObject, # Prepare parent dict if needed later if relocate_connectors: - parent_dict = { - tn.node_id: tn.parent_id for tn in neuron.nodes.itertuples()} + parent_dict = {tn.node_id: tn.parent_id for tn in neuron.nodes.itertuples()} # Avoid setting the nodes as this potentiall triggers a regeneration # of the graph which in turn will raise an error because some nodes might # still have parents that don't exist anymore - neuron._nodes = neuron._nodes[~neuron._nodes.strahler_index.isin(to_prune)].reset_index(drop=True, inplace=False) + neuron._nodes = neuron._nodes[ + ~neuron._nodes.strahler_index.isin(to_prune) + ].reset_index(drop=True, inplace=False) if neuron.has_connectors: if not relocate_connectors: - neuron._connectors = neuron._connectors[neuron._connectors.node_id.isin(neuron._nodes.node_id.values)].reset_index(drop=True, inplace=False) + neuron._connectors = neuron._connectors[ + neuron._connectors.node_id.isin(neuron._nodes.node_id.values) + ].reset_index(drop=True, inplace=False) else: remaining_tns = set(neuron._nodes.node_id.values) - for cn in neuron._connectors[~neuron.connectors.node_id.isin(neuron._nodes.node_id.values)].itertuples(): + for cn in neuron._connectors[ + ~neuron.connectors.node_id.isin(neuron._nodes.node_id.values) + ].itertuples(): this_tn = parent_dict[cn.node_id] while True: if this_tn in remaining_tns: break this_tn = parent_dict[this_tn] - neuron._connectors.loc[cn.Index, 'node_id'] = this_tn + neuron._connectors.loc[cn.Index, "node_id"] = this_tn # Reset indices of node and connector tables (important for igraph!) neuron._nodes.reset_index(inplace=True, drop=True) @@ -268,7 +293,9 @@ def prune_by_strahler(x: NeuronObject, # Theoretically we can end up with disconnected pieces, i.e. with more # than 1 root node -> we have to fix the nodes that lost their parents - neuron._nodes.loc[~neuron._nodes.parent_id.isin(neuron._nodes.node_id.values), 'parent_id'] = -1 + neuron._nodes.loc[ + ~neuron._nodes.parent_id.isin(neuron._nodes.node_id.values), "parent_id" + ] = -1 # Remove temporary attributes neuron._clear_temp_attr() @@ -276,14 +303,16 @@ def prune_by_strahler(x: NeuronObject, return neuron -@utils.map_neuronlist(desc='Pruning', allow_parallel=True) -@utils.meshneuron_skeleton(method='subset') -def prune_twigs(x: NeuronObject, - size: Union[float, str], - exact: bool = False, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> NeuronObject: +@utils.map_neuronlist(desc="Pruning", allow_parallel=True) +@utils.meshneuron_skeleton(method="subset") +def prune_twigs( + x: NeuronObject, + size: Union[float, str], + exact: bool = False, + mask: Optional[Union[Sequence[int], Callable]] = None, + inplace: bool = False, + recursive: Union[int, bool, float] = False, +) -> NeuronObject: """Prune terminal twigs under a given size. By default this function will simply drop all terminal twigs shorter than @@ -301,6 +330,10 @@ def prune_twigs(x: NeuronObject, units, e.g. '5 microns'. exact: bool See notes above. + mask : iterable | callable, optional + Either a boolean mask, a list of node IDs or a callable taking + a neuron as input and returning one of the former. If provided, + only nodes that are in the mask will be considered for pruning. inplace : bool, optional If False, pruning is performed on copy of original neuron which is then returned. @@ -359,39 +392,57 @@ def prune_twigs(x: NeuronObject, """ # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Expected TreeNeuron(s), got {type(x)}') + raise TypeError(f"Expected TreeNeuron(s), got {type(x)}") # Convert to neuron units - numbers will be passed through - size = x.map_units(size, on_error='raise') + size = x.map_units(size, on_error="raise") if not exact: - return _prune_twigs_simple(x, - size=size, - inplace=inplace, - recursive=recursive) + return _prune_twigs_simple( + x, size=size, inplace=inplace, recursive=recursive, mask=mask + ) else: - return _prune_twigs_precise(x, - size=size, - inplace=inplace) + return _prune_twigs_precise(x, size=size, mask=mask, inplace=inplace) -def _prune_twigs_simple(neuron: 'core.TreeNeuron', - size: float, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> Optional[NeuronObject]: +def _prune_twigs_simple( + neuron: "core.TreeNeuron", + size: float, + inplace: bool = False, + mask: Optional[Union[Sequence[int], Callable]] = None, + recursive: Union[int, bool, float] = False, +) -> Optional[NeuronObject]: """Prune twigs using simple method.""" if not isinstance(neuron, core.TreeNeuron): - raise TypeError(f'Expected Neuron/List, got {type(neuron)}') + raise TypeError(f"Expected Neuron/List, got {type(neuron)}") # If people set recursive=True, assume that they mean float("inf") if isinstance(recursive, bool) and recursive: - recursive = float('inf') + recursive = float("inf") # Make a copy if necessary before making any changes if not inplace: neuron = neuron.copy() + if callable(mask): + mask = mask(neuron) + + if mask is not None: + mask = np.asarray(mask) + + if mask.dtype == bool: + if len(mask) != neuron.n_nodes: + raise ValueError("Mask length must match number of nodes") + mask_nodes = neuron.nodes.node_id.values[mask] + elif mask.dtype in (int, np.int32, np.int64): + mask_nodes = mask + else: + raise TypeError( + f"Mask must be boolean or list of node IDs, got {mask.dtype}" + ) + else: + mask_nodes = None + if utils.fastcore: nodes_to_keep = utils.fastcore.prune_twigs( neuron.nodes.node_id.values, @@ -400,21 +451,30 @@ def _prune_twigs_simple(neuron: 'core.TreeNeuron', weights=utils.fastcore.dag.parent_dist( neuron.nodes.node_id.values, neuron.nodes.parent_id.values, - neuron.nodes[['x', 'y', 'z']].values, - ) + neuron.nodes[["x", "y", "z"]].values, + ), ) + # If mask is given, check if we have to re-add any nodes + # This is a bit cumbersome at the moment - we should add a + # mask feature to the fastcore function + if mask_nodes is not None: + for seg in graph._break_segments(neuron): + # If this segment would be dropped and the first node is not in the mask + # we have to keep the whole segment + if seg[0] not in nodes_to_keep and seg[0] not in mask_nodes: + nodes_to_keep = np.append(nodes_to_keep, seg[1:]) + if len(nodes_to_keep) < neuron.n_nodes: - subset.subset_neuron(neuron, - nodes_to_keep, - inplace=True) + subset.subset_neuron(neuron, nodes_to_keep, inplace=True) if recursive: - recursive -= 1 - prune_twigs(neuron, size=size, inplace=True, recursive=recursive) + prune_twigs( + neuron, size=size, inplace=True, recursive=recursive - 1, mask=mask_nodes + ) else: # Find terminal nodes - leafs = neuron.nodes[neuron.nodes.type == 'end'].node_id.values + leafs = neuron.nodes[neuron.nodes.type == "end"].node_id.values # Find terminal segments segs = graph._break_segments(neuron) @@ -425,35 +485,42 @@ def _prune_twigs_simple(neuron: 'core.TreeNeuron', # Find out which to delete segs_to_delete = segs[seg_lengths <= size] + + # If mask is given, only consider nodes in mask + if mask_nodes is not None: + segs_to_delete = [s for s in segs_to_delete if s[0] in mask_nodes] + if len(segs_to_delete): # Unravel the into list of node IDs -> skip the last parent nodes_to_delete = [n for s in segs_to_delete for n in s[:-1]] # Subset neuron - nodes_to_keep = neuron.nodes[~neuron.nodes.node_id.isin(nodes_to_delete)].node_id.values - subset.subset_neuron(neuron, - nodes_to_keep, - inplace=True) + nodes_to_keep = neuron.nodes[ + ~neuron.nodes.node_id.isin(nodes_to_delete) + ].node_id.values + subset.subset_neuron(neuron, nodes_to_keep, inplace=True) # Go recursive if recursive: - recursive -= 1 - prune_twigs(neuron, size=size, inplace=True, recursive=recursive) + prune_twigs( + neuron, size=size, inplace=True, recursive=recursive - 1, mask=mask_nodes + ) return neuron -def _prune_twigs_precise(neuron: 'core.TreeNeuron', - size: float, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> Optional[NeuronObject]: +def _prune_twigs_precise( + neuron: "core.TreeNeuron", + size: float, + inplace: bool = False, + recursive: Union[int, bool, float] = False, +) -> Optional[NeuronObject]: """Prune twigs using precise method.""" if not isinstance(neuron, core.TreeNeuron): - raise TypeError(f'Expected Neuron/List, got {type(neuron)}') + raise TypeError(f"Expected Neuron/List, got {type(neuron)}") if size <= 0: - raise ValueError('`length` must be > 0') + raise ValueError("`length` must be > 0") # Make a copy if necessary before making any changes if not inplace: @@ -464,26 +531,28 @@ def _prune_twigs_precise(neuron: 'core.TreeNeuron', # Find all nodes that could possibly be within distance to a leaf tree = graph.neuron2KDTree(neuron) - res = tree.query_ball_point(neuron.leafs[['x', 'y', 'z']].values, - r=size) + res = tree.query_ball_point(neuron.leafs[["x", "y", "z"]].values, r=size) candidates = neuron.nodes.node_id.values[np.unique(np.concatenate(res))] # For each node in neuron find out which leafs are directly distal to it # `distal` is a matrix with all nodes in columns and leafs in rows distal = graph.distal_to(neuron, a=leafs, b=candidates) # Turn matrix into dictionary {'node': [leafs, distal, to, it]} - melted = distal.reset_index(drop=False).melt(id_vars='index') + melted = distal.reset_index(drop=False).melt(id_vars="index") melted = melted[melted.value] - melted.groupby('variable')['index'].apply(list) + melted.groupby("variable")["index"].apply(list) # `distal` is now a dictionary for {'node_id': [leaf1, leaf2, ..], ..} - distal = melted.groupby('variable')['index'].apply(list).to_dict() + distal = melted.groupby("variable")["index"].apply(list).to_dict() # For each node find the distance to any leaf - note we are using `length` # as cutoff here # `path_len` is a dict mapping {nodeA: {nodeB: length, ...}, ...} # if nodeB is not in dictionary, it's not within reach - path_len = dict(nx.all_pairs_dijkstra_path_length(neuron.graph.reverse(), - cutoff=size, weight='weight')) + path_len = dict( + nx.all_pairs_dijkstra_path_length( + neuron.graph.reverse(), cutoff=size, weight="weight" + ) + ) # For each leaf in `distal` check if it's within length not_in_length = {k: set(v) - set(path_len[k]) for k, v in distal.items()} @@ -491,18 +560,17 @@ def _prune_twigs_precise(neuron: 'core.TreeNeuron', # For a node to be deleted its PARENT has to be within # `length` to ALL edges that are distal do it in_range = {k for k, v in not_in_length.items() if not any(v)} - nodes_to_keep = neuron.nodes.loc[~neuron.nodes.parent_id.isin(in_range), - 'node_id'].values + nodes_to_keep = neuron.nodes.loc[ + ~neuron.nodes.parent_id.isin(in_range), "node_id" + ].values if len(nodes_to_keep) < neuron.n_nodes: # Subset neuron - subset.subset_neuron(neuron, - nodes_to_keep, - inplace=True) + subset.subset_neuron(neuron, nodes_to_keep, inplace=True) # For each of the new leafs check their shortest distance to the # original leafs to get the remainder - is_new_leaf = (neuron.nodes.type == 'end').values + is_new_leaf = (neuron.nodes.type == "end").values new_leafs = neuron.nodes[is_new_leaf].node_id.values max_len = [max([path_len[l1][l2] for l2 in distal[l1]]) for l1 in new_leafs] @@ -511,10 +579,10 @@ def _prune_twigs_precise(neuron: 'core.TreeNeuron', len_to_prune = size - np.array(max_len) # Get vectors from leafs to their parents - nodes = neuron.nodes.set_index('node_id') - parents = nodes.loc[new_leafs, 'parent_id'].values - loc1 = neuron.leafs[['x', 'y', 'z']].values - loc2 = nodes.loc[parents, ['x', 'y', 'z']].values + nodes = neuron.nodes.set_index("node_id") + parents = nodes.loc[new_leafs, "parent_id"].values + loc1 = neuron.leafs[["x", "y", "z"]].values + loc2 = nodes.loc[parents, ["x", "y", "z"]].values vec = loc1 - loc2 vec_len = np.linalg.norm(vec, axis=1) vec_norm = vec / vec_len.reshape(-1, 1) @@ -527,42 +595,43 @@ def _prune_twigs_precise(neuron: 'core.TreeNeuron', # will be deleted anyway if not all(to_remove): new_loc = loc1 - vec_norm * len_to_prune.reshape(-1, 1) - neuron.nodes.loc[is_new_leaf, ['x', 'y', 'z']] = new_loc.astype( + neuron.nodes.loc[is_new_leaf, ["x", "y", "z"]] = new_loc.astype( neuron.nodes.x.dtype, copy=False ) if any(to_remove): leafs_to_remove = new_leafs[to_remove] - nodes_to_keep = neuron.nodes.loc[~neuron.nodes.node_id.isin(leafs_to_remove), - 'node_id'].values + nodes_to_keep = neuron.nodes.loc[ + ~neuron.nodes.node_id.isin(leafs_to_remove), "node_id" + ].values # Subset neuron - subset.subset_neuron(neuron, - nodes_to_keep, - inplace=True) + subset.subset_neuron(neuron, nodes_to_keep, inplace=True) return neuron -@utils.map_neuronlist(desc='Splitting', allow_parallel=True) -@utils.meshneuron_skeleton(method='split', - include_connectors=True, - copy_properties=['color', 'compartment'], - disallowed_kwargs={'label_only': True}, - heal=True) -def split_axon_dendrite(x: NeuronObject, - metric: Union[Literal['synapse_flow_centrality'], - Literal['flow_centrality'], - Literal['bending_flow'], - Literal['segregation_index']] = 'synapse_flow_centrality', - flow_thresh: float = .9, - split: Union[Literal['prepost'], - Literal['distance']] = 'prepost', - cellbodyfiber: Union[Literal['soma'], - Literal['root'], - bool] = False, - reroot_soma: bool = True, - label_only: bool = False - ) -> 'core.NeuronList': +@utils.map_neuronlist(desc="Splitting", allow_parallel=True) +@utils.meshneuron_skeleton( + method="split", + include_connectors=True, + copy_properties=["color", "compartment"], + disallowed_kwargs={"label_only": True}, + heal=True, +) +def split_axon_dendrite( + x: NeuronObject, + metric: Union[ + Literal["synapse_flow_centrality"], + Literal["flow_centrality"], + Literal["bending_flow"], + Literal["segregation_index"], + ] = "synapse_flow_centrality", + flow_thresh: float = 0.9, + split: Union[Literal["prepost"], Literal["distance"]] = "prepost", + cellbodyfiber: Union[Literal["soma"], Literal["root"], bool] = False, + reroot_soma: bool = True, + label_only: bool = False, +) -> "core.NeuronList": """Split a neuron into axon and dendrite. The result is highly dependent on the method and on your neuron's @@ -665,42 +734,55 @@ def split_axon_dendrite(x: NeuronObject, the axon/dendrite split. """ - COLORS = {'axon': (178, 34, 34), - 'dendrite': (0, 0, 255), - 'cellbodyfiber': (50, 50, 50), - 'linker': (150, 150, 150)} + COLORS = { + "axon": (178, 34, 34), + "dendrite": (0, 0, 255), + "cellbodyfiber": (50, 50, 50), + "linker": (150, 150, 150), + } # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): raise TypeError(f'Can only process TreeNeurons, got "{type(x)}"') if not x.has_connectors: - if metric != 'flow_centrality': - raise ValueError('Neuron must have connectors.') - elif split == 'prepost': - raise ValueError('Set `split="distance"` when trying to split neurons ' - 'without connectors.') - - _METRIC = ('synapse_flow_centrality', 'bending_flow', 'segregation_index', - 'flow_centrality') - utils.eval_param(metric, 'metric', allowed_values=_METRIC) - utils.eval_param(split, 'split', allowed_values=('prepost', 'distance')) - utils.eval_param(cellbodyfiber, 'cellbodyfiber', - allowed_values=('soma', 'root', False)) - - if metric == 'flow_centrality': - msg = ("As of navis version 1.4.0, `method='flow_centrality'` " - "uses synapse-independent, morphology-only flow to generate splits." - "Please use `method='synapse_flow_centrality' for " - "synapse-based axon-dendrite splits. " - "This warning will be removed in a future version of navis.") + if metric != "flow_centrality": + raise ValueError("Neuron must have connectors.") + elif split == "prepost": + raise ValueError( + 'Set `split="distance"` when trying to split neurons ' + "without connectors." + ) + + _METRIC = ( + "synapse_flow_centrality", + "bending_flow", + "segregation_index", + "flow_centrality", + ) + utils.eval_param(metric, "metric", allowed_values=_METRIC) + utils.eval_param(split, "split", allowed_values=("prepost", "distance")) + utils.eval_param( + cellbodyfiber, "cellbodyfiber", allowed_values=("soma", "root", False) + ) + + if metric == "flow_centrality": + msg = ( + "As of navis version 1.4.0, `method='flow_centrality'` " + "uses synapse-independent, morphology-only flow to generate splits." + "Please use `method='synapse_flow_centrality' for " + "synapse-based axon-dendrite splits. " + "This warning will be removed in a future version of navis." + ) warnings.warn(msg, DeprecationWarning) logger.warning(msg) if len(x.root) > 1: - raise ValueError(f'Unable to split neuron {x.id}: multiple roots. ' - 'Try `navis.heal_skeleton(x)` to merged ' - 'disconnected fragments.') + raise ValueError( + f"Unable to split neuron {x.id}: multiple roots. " + "Try `navis.heal_skeleton(x)` to merged " + "disconnected fragments." + ) # Make copy, so that we don't screw things up original = x @@ -710,11 +792,11 @@ def split_axon_dendrite(x: NeuronObject, x.reroot(x.soma, inplace=True) FUNCS = { - 'bending_flow': mmetrics.bending_flow, - 'synapse_flow_centrality': mmetrics.synapse_flow_centrality, - 'flow_centrality': mmetrics.flow_centrality, - 'segregation_index': mmetrics.arbor_segregation_index - } + "bending_flow": mmetrics.bending_flow, + "synapse_flow_centrality": mmetrics.synapse_flow_centrality, + "flow_centrality": mmetrics.flow_centrality, + "segregation_index": mmetrics.arbor_segregation_index, + } if metric not in FUNCS: raise ValueError(f'Unknown `metric`: "{metric}"') @@ -733,7 +815,7 @@ def split_axon_dendrite(x: NeuronObject, # The first step is to remove the linker -> that's the bit that connects # the axon and dendrite is_linker = x.nodes[metric] >= x.nodes[metric].max() * flow_thresh - linker = set(x.nodes.loc[is_linker, 'node_id'].values) + linker = set(x.nodes.loc[is_linker, "node_id"].values) # We try to perform processing on the graph to avoid overhead from # (re-)generating neurons @@ -747,17 +829,17 @@ def split_axon_dendrite(x: NeuronObject, # Figure out which one is which axon = set() - if split == 'prepost': + if split == "prepost": # Collect # of pre- and postsynapses on each of the connected components sm = pd.DataFrame() - sm['n_nodes'] = [len(c) for c in cc] + sm["n_nodes"] = [len(c) for c in cc] pre = x.presynapses post = x.postsynapses - sm['n_pre'] = [pre[pre.node_id.isin(c)].shape[0] for c in cc] - sm['n_post'] = [post[post.node_id.isin(c)].shape[0] for c in cc] - sm['prepost_ratio'] = (sm.n_pre / sm.n_post) - sm['frac_post'] = sm.n_post / sm.n_post.sum() - sm['frac_pre'] = sm.n_pre / sm.n_pre.sum() + sm["n_pre"] = [pre[pre.node_id.isin(c)].shape[0] for c in cc] + sm["n_post"] = [post[post.node_id.isin(c)].shape[0] for c in cc] + sm["prepost_ratio"] = sm.n_pre / sm.n_post + sm["frac_post"] = sm.n_post / sm.n_post.sum() + sm["frac_pre"] = sm.n_pre / sm.n_pre.sum() # In theory, we can encounter neurons with either no pre- or no # postsynapses (e.g. sensory neurons). @@ -765,19 +847,21 @@ def split_axon_dendrite(x: NeuronObject, # causes frac_pre/post to be NaN. By filling, we make sure that the # split doesn't fail further down but they might end up missing either # an axon or a dendrite (which may actually be OK?). - sm['frac_post'] = sm['frac_post'].fillna(0) - sm['frac_pre'] = sm['frac_pre'].fillna(0) + sm["frac_post"] = sm["frac_post"].fillna(0) + sm["frac_pre"] = sm["frac_pre"].fillna(0) # Produce the ratio of pre- to postsynapses - sm['frac_prepost'] = (sm.frac_pre / sm.frac_post) + sm["frac_prepost"] = sm.frac_pre / sm.frac_post # Some small side branches might have either no pre- or no postsynapses. # Even if they have synapses: if the total count is low they might be # incorrectly assigned to a compartment. Here, we will make sure that # they are disregarded for now to avoid introducing noise. Instead we # will connect them onto their parent compartment later. - sm.loc[sm[['frac_pre', 'frac_post']].max(axis=1) < 0.01, - ['prepost_ratio', 'frac_prepost']] = np.nan + sm.loc[ + sm[["frac_pre", "frac_post"]].max(axis=1) < 0.01, + ["prepost_ratio", "frac_prepost"], + ] = np.nan logger.debug(sm) # Each fragment is considered separately as either giver or recipient @@ -818,7 +902,7 @@ def split_axon_dendrite(x: NeuronObject, # The CBF is defined as the part of the neuron between the soma (or root) # and the first branch point with sizeable synapse flow cbf = set() - if cellbodyfiber and (np.any(x.soma) or cellbodyfiber == 'root'): + if cellbodyfiber and (np.any(x.soma) or cellbodyfiber == "root"): # To excise the CBF, we subset the neuron to those parts with # no/hardly any flow and find the part that contains the soma no_flow = x.nodes[x.nodes[metric] <= x.nodes[metric].max() * 0.05] @@ -846,60 +930,63 @@ def split_axon_dendrite(x: NeuronObject, # If we have, assign these nodes to the closest node with a compartment if any(miss): # Find the closest nodes with a compartment - m = graph.geodesic_matrix(original, - directed=False, - weight=None, - from_=miss) + m = graph.geodesic_matrix(original, directed=False, weight=None, from_=miss) # Subset geodesic matrix to nodes that have a compartment - nodes_w_comp = original.nodes.node_id.values[~np.isin(original.nodes.node_id.values, miss)] + nodes_w_comp = original.nodes.node_id.values[ + ~np.isin(original.nodes.node_id.values, miss) + ] closest = np.argmin(m.loc[:, nodes_w_comp].values, axis=1) closest_id = nodes_w_comp[closest] linker += m.index.values[np.isin(closest_id, linker)].tolist() - axon += m.index.values[np.isin(closest_id, axon)].tolist() - dendrite += m.index.values[np.isin(closest_id, dendrite)].tolist() - cbf += m.index.values[np.isin(closest_id, cbf)].tolist() + axon += m.index.values[np.isin(closest_id, axon)].tolist() + dendrite += m.index.values[np.isin(closest_id, dendrite)].tolist() + cbf += m.index.values[np.isin(closest_id, cbf)].tolist() # Add labels if label_only: nodes = original.nodes - nodes['compartment'] = None + nodes["compartment"] = None is_linker = nodes.node_id.isin(linker) is_axon = nodes.node_id.isin(axon) is_dend = nodes.node_id.isin(dendrite) is_cbf = nodes.node_id.isin(cbf) - nodes.loc[is_linker, 'compartment'] = 'linker' - nodes.loc[is_dend, 'compartment'] = 'dendrite' - nodes.loc[is_axon, 'compartment'] = 'axon' - nodes.loc[is_cbf, 'compartment'] = 'cellbodyfiber' + nodes.loc[is_linker, "compartment"] = "linker" + nodes.loc[is_dend, "compartment"] = "dendrite" + nodes.loc[is_axon, "compartment"] = "axon" + nodes.loc[is_cbf, "compartment"] = "cellbodyfiber" # Set connector compartments - cmp_map = original.nodes.set_index('node_id').compartment.to_dict() - original.connectors['compartment'] = original.connectors.node_id.map(cmp_map) + cmp_map = original.nodes.set_index("node_id").compartment.to_dict() + original.connectors["compartment"] = original.connectors.node_id.map(cmp_map) # Turn into categorical data - original.nodes['compartment'] = original.nodes.compartment.astype('category') - original.connectors['compartment'] = original.connectors.compartment.astype('category') + original.nodes["compartment"] = original.nodes.compartment.astype("category") + original.connectors["compartment"] = original.connectors.compartment.astype( + "category" + ) return original # Generate the actual splits nl = [] - for label, nodes in zip(['cellbodyfiber', 'dendrite', 'linker', 'axon'], - [cbf, dendrite, linker, axon]): + for label, nodes in zip( + ["cellbodyfiber", "dendrite", "linker", "axon"], [cbf, dendrite, linker, axon] + ): if not len(nodes): continue n = subset.subset_neuron(original, nodes) n.color = COLORS.get(label, (100, 100, 100)) - n._register_attr('compartment', label) + n._register_attr("compartment", label) nl.append(n) return core.NeuronList(nl) -def combine_neurons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'] - ) -> 'core.NeuronObject': +def combine_neurons( + *x: Union[Sequence[NeuronObject], "core.NeuronList"], +) -> "core.NeuronObject": """Combine multiple neurons into one. Parameters @@ -949,10 +1036,10 @@ def combine_neurons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'] # Check that neurons are all of the same type if len(nl.types) > 1: - raise TypeError('Unable to combine neurons of different types') + raise TypeError("Unable to combine neurons of different types") if isinstance(nl[0], core.TreeNeuron): - x = stitch_skeletons(*nl, method='NONE', master='FIRST') + x = stitch_skeletons(*nl, method="NONE", master="FIRST") elif isinstance(nl[0], core.MeshNeuron): x = nl[0].copy() comb = tm.util.concatenate([n.trimesh for n in nl]) @@ -960,8 +1047,10 @@ def combine_neurons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'] x._faces = comb.faces if any(nl.has_connectors): - x._connectors = pd.concat([n.connectors for n in nl], # type: ignore # no stubs for concat - ignore_index=True) + x._connectors = pd.concat( + [n.connectors for n in nl], # type: ignore # no stubs for concat + ignore_index=True, + ) elif isinstance(nl[0], core.Dotprops): x = nl[0].copy() x._points = np.vstack(nl._points) @@ -972,26 +1061,26 @@ def combine_neurons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'] x._alpha = np.hstack(nl.alpha) if any(nl.has_connectors): - x._connectors = pd.concat([n.connectors for n in nl], # type: ignore # no stubs for concat - ignore_index=True) + x._connectors = pd.concat( + [n.connectors for n in nl], # type: ignore # no stubs for concat + ignore_index=True, + ) elif isinstance(nl[0], core.VoxelNeuron): - raise TypeError('Combining VoxelNeuron not (yet) supported') + raise TypeError("Combining VoxelNeuron not (yet) supported") else: - raise TypeError(f'Unable to combine {type(nl[0])}') + raise TypeError(f"Unable to combine {type(nl[0])}") return x -def stitch_skeletons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'], - method: Union[Literal['LEAFS'], - Literal['ALL'], - Literal['NONE'], - Sequence[int]] = 'ALL', - master: Union[Literal['SOMA'], - Literal['LARGEST'], - Literal['FIRST']] = 'SOMA', - max_dist: Optional[float] = None, - ) -> 'core.TreeNeuron': +def stitch_skeletons( + *x: Union[Sequence[NeuronObject], "core.NeuronList"], + method: Union[ + Literal["LEAFS"], Literal["ALL"], Literal["NONE"], Sequence[int] + ] = "ALL", + master: Union[Literal["SOMA"], Literal["LARGEST"], Literal["FIRST"]] = "SOMA", + max_dist: Optional[float] = None, +) -> "core.TreeNeuron": """Stitch multiple skeletons together. Uses minimum spanning tree to determine a way to connect all fragments @@ -1061,8 +1150,8 @@ def stitch_skeletons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'], """ master = str(master).upper() - ALLOWED_MASTER = ('SOMA', 'LARGEST', 'FIRST') - utils.eval_param(master, 'master', allowed_values=ALLOWED_MASTER) + ALLOWED_MASTER = ("SOMA", "LARGEST", "FIRST") + utils.eval_param(master, "master", allowed_values=ALLOWED_MASTER) # Compile list of individual neurons neurons = utils.unpack_neurons(x) @@ -1071,29 +1160,29 @@ def stitch_skeletons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'], nl = core.NeuronList(neurons).copy() if len(nl) < 2: - logger.warning(f'Need at least 2 neurons to stitch, found {len(nl)}') + logger.warning(f"Need at least 2 neurons to stitch, found {len(nl)}") return nl[0] # If no soma, switch to largest - if master == 'SOMA' and not any(nl.has_soma): - master = 'LARGEST' + if master == "SOMA" and not any(nl.has_soma): + master = "LARGEST" # First find master - if master == 'SOMA': + if master == "SOMA": # Pick the first neuron with a soma m_ix = [i for i, n in enumerate(nl) if n.has_soma][0] - elif master == 'LARGEST': + elif master == "LARGEST": # Pick the largest neuron - m_ix = sorted(list(range(len(nl))), - key=lambda x: nl[x].n_nodes, - reverse=True)[0] + m_ix = sorted(list(range(len(nl))), key=lambda x: nl[x].n_nodes, reverse=True)[ + 0 + ] else: # Pick the first neuron m_ix = 0 m = nl[m_ix] # Check if we need to make any node IDs unique - if nl.nodes.duplicated(subset='node_id').sum() > 0: + if nl.nodes.duplicated(subset="node_id").sum() > 0: # Master neuron will not be changed seen_tn: Set[int] = set(m.nodes.node_id) for i, n in enumerate(nl): @@ -1120,17 +1209,21 @@ def stitch_skeletons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'], new_map = dict(zip(non_unique, new_tn)) # Remap node IDs - if no new value, keep the old - n.nodes['node_id'] = n.nodes.node_id.map(lambda x: new_map.get(x, x)) + n.nodes["node_id"] = n.nodes.node_id.map(lambda x: new_map.get(x, x)) if n.has_connectors: - n.connectors['node_id'] = n.connectors.node_id.map(lambda x: new_map.get(x, x)) + n.connectors["node_id"] = n.connectors.node_id.map( + lambda x: new_map.get(x, x) + ) - if getattr(n, 'tags', None) is not None: + if getattr(n, "tags", None) is not None: n.tags = {new_map.get(k, k): v for k, v in n.tags.items()} # type: ignore # Remap parent IDs new_map[None] = -1 # type: ignore - n.nodes['parent_id'] = n.nodes.parent_id.map(lambda x: new_map.get(x, x)).astype(int) + n.nodes["parent_id"] = n.nodes.parent_id.map( + lambda x: new_map.get(x, x) + ).astype(int) # Add new nodes to seen seen_tn = seen_tn | set(new_tn) @@ -1139,96 +1232,100 @@ def stitch_skeletons(*x: Union[Sequence[NeuronObject], 'core.NeuronList'], n._clear_temp_attr() # We will start by simply merging all neurons into one - m._nodes = pd.concat([n.nodes for n in nl], # type: ignore # no stubs for concat - ignore_index=True) + m._nodes = pd.concat( + [n.nodes for n in nl], # type: ignore # no stubs for concat + ignore_index=True, + ) if any(nl.has_connectors): - m._connectors = pd.concat([n.connectors for n in nl], # type: ignore # no stubs for concat - ignore_index=True) + m._connectors = pd.concat( + [n.connectors for n in nl], # type: ignore # no stubs for concat + ignore_index=True, + ) if not m.has_tags or not isinstance(m.tags, dict): m.tags = {} # type: ignore # TreeNeuron has no tags for n in nl: - for k, v in (getattr(n, 'tags', None) or {}).items(): + for k, v in (getattr(n, "tags", None) or {}).items(): m.tags[k] = m.tags.get(k, []) + list(utils.make_iterable(v)) # Reset temporary attributes of our final neuron m._clear_temp_attr() # If this is all we meant to do, return this neuron - if not utils.is_iterable(method) and (method == 'NONE' or method is None): + if not utils.is_iterable(method) and (method == "NONE" or method is None): return m return _stitch_mst(m, nodes=method, inplace=False, max_dist=max_dist) -def _mst_igraph(nl: 'core.NeuronList', - new_edges: pd.DataFrame) -> List[List[int]]: +def _mst_igraph(nl: "core.NeuronList", new_edges: pd.DataFrame) -> List[List[int]]: """Compute edges necessary to connect a fragmented neuron using igraph.""" # Generate a union of all graphs g = nl[0].igraph.disjoint_union(nl[1:].igraph) # We have to manually set the node IDs again - nids = np.concatenate([n.igraph.vs['node_id'] for n in nl]) - g.vs['node_id'] = nids + nids = np.concatenate([n.igraph.vs["node_id"] for n in nl]) + g.vs["node_id"] = nids # Set existing edges to zero weight to make sure they have priority when # calculating the minimum spanning tree - g.es['weight'] = 0 + g.es["weight"] = 0 # If two nodes occupy the same position (e.g. after if fragments are the # result of cutting), they will have a distance of 0. Hence, we won't be # able to simply filter by distance - g.es['new'] = False + g.es["new"] = False # Convert node IDs in new_edges to vertex IDs and add to graph - name2ix = dict(zip(g.vs['node_id'], range(len(g.vs)))) - new_edges['source_ix'] = new_edges.source.map(name2ix) - new_edges['target_ix'] = new_edges.target.map(name2ix) + name2ix = dict(zip(g.vs["node_id"], range(len(g.vs)))) + new_edges["source_ix"] = new_edges.source.map(name2ix) + new_edges["target_ix"] = new_edges.target.map(name2ix) # Add new edges - g.add_edges(new_edges[['source_ix', 'target_ix']].values.tolist()) + g.add_edges(new_edges[["source_ix", "target_ix"]].values.tolist()) # Add edge weight to new edges - g.es[-new_edges.shape[0]:]['weight'] = new_edges.weight.values + g.es[-new_edges.shape[0] :]["weight"] = new_edges.weight.values # Keep track of new edges - g.es[-new_edges.shape[0]:]['new'] = True + g.es[-new_edges.shape[0] :]["new"] = True # Compute the minimum spanning tree - mst = g.spanning_tree(weights='weight') + mst = g.spanning_tree(weights="weight") # Extract the new edges to_add = mst.es.select(new=True) # Convert to node IDs - to_add = [(g.vs[e.source]['node_id'], - g.vs[e.target]['node_id'], - {'weight': e['weight']}) - for e in to_add] + to_add = [ + (g.vs[e.source]["node_id"], g.vs[e.target]["node_id"], {"weight": e["weight"]}) + for e in to_add + ] return to_add -def _mst_nx(nl: 'core.NeuronList', - new_edges: pd.DataFrame) -> List[List[int]]: +def _mst_nx(nl: "core.NeuronList", new_edges: pd.DataFrame) -> List[List[int]]: """Compute edges necessary to connect a fragmented neuron using networkX.""" # Generate a union of all graphs g = nx.union_all([n.graph for n in nl]).to_undirected() # Set existing edges to zero weight to make sure they have priority when # calculating the minimum spanning tree - nx.set_edge_attributes(g, 0, 'weight') + nx.set_edge_attributes(g, 0, "weight") # If two nodes occupy the same position (e.g. after if fragments are the # result of cutting), they will have a distance of 0. Hence, we won't be # able to simply filter by distance - nx.set_edge_attributes(g, False, 'new') + nx.set_edge_attributes(g, False, "new") # Convert new edges in the right format - edges_nx = [(r.source, r.target, {'weight': r.weight, 'new': True}) - for r in new_edges.itertuples()] + edges_nx = [ + (r.source, r.target, {"weight": r.weight, "new": True}) + for r in new_edges.itertuples() + ] # Add edges to union graph g.add_edges_from(edges_nx) @@ -1237,15 +1334,16 @@ def _mst_nx(nl: 'core.NeuronList', edges = nx.minimum_spanning_edges(g) # Edges that need adding are those that were newly added - to_add = [e for e in edges if e[2]['new']] + to_add = [e for e in edges if e[2]["new"]] return to_add -def average_skeletons(x: 'core.NeuronList', - limit: Union[int, str] = 10, - base_neuron: Optional[Union[int, 'core.TreeNeuron']] = None - ) -> 'core.TreeNeuron': +def average_skeletons( + x: "core.NeuronList", + limit: Union[int, str] = 10, + base_neuron: Optional[Union[int, "core.TreeNeuron"]] = None, +) -> "core.TreeNeuron": """Compute an average from a list of skeletons. This is a very simple implementation which may give odd results if used @@ -1288,14 +1386,14 @@ def average_skeletons(x: 'core.NeuronList', raise TypeError(f'Need NeuronList, got "{type(x)}"') if len(x) < 2: - raise ValueError('Need at least 2 neurons to average!') + raise ValueError("Need at least 2 neurons to average!") # Map limit into unit space, if applicable - limit = x[0].map_units(limit, on_error='raise') + limit = x[0].map_units(limit, on_error="raise") # Generate KDTrees for each neuron for n in x: - n.tree = graph.neuron2KDTree(n, tree_type='c', data='nodes') # type: ignore # TreeNeuron has no tree + n.tree = graph.neuron2KDTree(n, tree_type="c", data="nodes") # type: ignore # TreeNeuron has no tree # Set base for average: we will use this neurons nodes to query # the KDTrees @@ -1306,9 +1404,11 @@ def average_skeletons(x: 'core.NeuronList', elif isinstance(base_neuron, type(None)): bn = x[0].copy() else: - raise ValueError(f'Unable to interpret base_neuron of type "{type(base_neuron)}"') + raise ValueError( + f'Unable to interpret base_neuron of type "{type(base_neuron)}"' + ) - base_nodes = bn.nodes[['x', 'y', 'z']].values + base_nodes = bn.nodes[["x", "y", "z"]].values other_neurons = x[[n != bn for n in x]] # Make sure these stay 2-dimensional arrays -> will add a colum for each @@ -1319,18 +1419,17 @@ def average_skeletons(x: 'core.NeuronList', # For each "other" neuron, collect nearest neighbour coordinates for n in other_neurons: - nn_dist, nn_ix = n.tree.query(base_nodes, - k=1, - distance_upper_bound=limit) + nn_dist, nn_ix = n.tree.query(base_nodes, k=1, distance_upper_bound=limit) # Translate indices into coordinates # First, make empty array this_coords = np.zeros((len(nn_dist), 3)) # Set coords without a nearest neighbour within distances to "None" - this_coords[nn_dist == float('inf')] = None + this_coords[nn_dist == float("inf")] = None # Fill in coords of nearest neighbours - this_coords[nn_dist != float( - 'inf')] = n.tree.data[nn_ix[nn_dist != float('inf')]] + this_coords[nn_dist != float("inf")] = n.tree.data[ + nn_ix[nn_dist != float("inf")] + ] # Add coords to base coords base_x = np.append(base_x, this_coords[:, 0:1], axis=1) base_y = np.append(base_y, this_coords[:, 1:2], axis=1) @@ -1349,19 +1448,21 @@ def average_skeletons(x: 'core.NeuronList', mean_z[np.isnan(mean_z)] = base_nodes[np.isnan(mean_z), 2] # Change coordinates accordingly - bn.nodes['x'] = mean_x - bn.nodes['y'] = mean_y - bn.nodes['z'] = mean_z + bn.nodes["x"] = mean_x + bn.nodes["y"] = mean_y + bn.nodes["z"] = mean_z return bn -@utils.map_neuronlist(desc='Despiking', allow_parallel=True) -def despike_skeleton(x: NeuronObject, - sigma: int = 5, - max_spike_length: int = 1, - inplace: bool = False, - reverse: bool = False) -> Optional[NeuronObject]: +@utils.map_neuronlist(desc="Despiking", allow_parallel=True) +def despike_skeleton( + x: NeuronObject, + sigma: int = 5, + max_spike_length: int = 1, + inplace: bool = False, + reverse: bool = False, +) -> Optional[NeuronObject]: r"""Remove spikes in skeleton (e.g. from jumps in image data). For each node A, the Euclidean distance to its next successor (parent) @@ -1404,13 +1505,13 @@ def despike_skeleton(x: NeuronObject, # The decorator makes sure that we have single neurons at this point if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Can only process TreeNeurons, not {type(x)}') + raise TypeError(f"Can only process TreeNeurons, not {type(x)}") if not inplace: x = x.copy() # Index nodes table by node ID - this_nodes = x.nodes.set_index('node_id', inplace=False) + this_nodes = x.nodes.set_index("node_id", inplace=False) segs_to_walk = x.segments @@ -1423,45 +1524,48 @@ def despike_skeleton(x: NeuronObject, # Go over all segments for seg in segs_to_walk: # Get nodes A, B and C of this segment - this_A = this_nodes.loc[seg[:-l - 1]] + this_A = this_nodes.loc[seg[: -l - 1]] this_B = this_nodes.loc[seg[l:-1]] - this_C = this_nodes.loc[seg[l + 1:]] + this_C = this_nodes.loc[seg[l + 1 :]] # Get coordinates - A = this_A[['x', 'y', 'z']].values - B = this_B[['x', 'y', 'z']].values - C = this_C[['x', 'y', 'z']].values + A = this_A[["x", "y", "z"]].values + B = this_B[["x", "y", "z"]].values + C = this_C[["x", "y", "z"]].values # Calculate euclidean distances A->B and A->C dist_AB = np.linalg.norm(A - B, axis=1) dist_AC = np.linalg.norm(A - C, axis=1) # Get the spikes - spikes_ix = np.where(np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma)[0] + spikes_ix = np.where( + np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma + )[0] spikes = this_B.iloc[spikes_ix] if not spikes.empty: # Interpolate new position(s) between A and C new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2 - this_nodes.loc[spikes.index, ['x', 'y', 'z']] = new_positions + this_nodes.loc[spikes.index, ["x", "y", "z"]] = new_positions # Reassign node table x.nodes = this_nodes.reset_index(drop=False, inplace=False) # The weights in the graph have changed, we need to update that - x._clear_temp_attr(exclude=['segments', 'small_segments', - 'classify_nodes']) + x._clear_temp_attr(exclude=["segments", "small_segments", "classify_nodes"]) return x -@utils.map_neuronlist(desc='Guessing', allow_parallel=True) -def guess_radius(x: NeuronObject, - method: str = 'linear', - limit: Optional[int] = None, - smooth: bool = True, - inplace: bool = False) -> Optional[NeuronObject]: +@utils.map_neuronlist(desc="Guessing", allow_parallel=True) +def guess_radius( + x: NeuronObject, + method: str = "linear", + limit: Optional[int] = None, + smooth: bool = True, + inplace: bool = False, +) -> Optional[NeuronObject]: """Guess radii for skeleton nodes. Uses distance between connectors and nodes to guess radii. Interpolate for @@ -1497,10 +1601,10 @@ def guess_radius(x: NeuronObject, """ # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Can only process TreeNeurons, not {type(x)}') + raise TypeError(f"Can only process TreeNeurons, not {type(x)}") - if not hasattr(x, 'connectors') or x.connectors.empty: - raise ValueError('Neuron must have connectors!') + if not hasattr(x, "connectors") or x.connectors.empty: + raise ValueError("Neuron must have connectors!") if not inplace: x = x.copy() @@ -1511,59 +1615,58 @@ def guess_radius(x: NeuronObject, # We will be using the index as distance to interpolate. For this we have # to change method 'linear' to 'index' - method = 'index' if method == 'linear' else method + method = "index" if method == "linear" else method # Collect connectors and calc distances cn = x.connectors.copy() # Prepare nodes (add parent_dist for later, set index) - x.nodes['parent_dist'] = mmetrics.parent_dist(x, root_dist=0) - nodes = x.nodes.set_index('node_id', inplace=False) + x.nodes["parent_dist"] = mmetrics.parent_dist(x, root_dist=0) + nodes = x.nodes.set_index("node_id", inplace=False) # For each connector (pre and post), get the X/Y distance to its node - cn_locs = cn[['x', 'y']].values - tn_locs = nodes.loc[cn.node_id.values, - ['x', 'y']].values + cn_locs = cn[["x", "y"]].values + tn_locs = nodes.loc[cn.node_id.values, ["x", "y"]].values dist = np.sqrt(np.sum((tn_locs - cn_locs) ** 2, axis=1).astype(int)) - cn['dist'] = dist + cn["dist"] = dist # Get max distance per node (in case of multiple connectors per # node) - cn_grouped = cn.groupby('node_id').dist.max() + cn_grouped = cn.groupby("node_id").dist.max() # Set undefined radii to None so that they are ignored for interpolation - nodes.loc[nodes.radius <= 0, 'radius'] = None + nodes.loc[nodes.radius <= 0, "radius"] = None # Assign radii to nodes - nodes.loc[cn_grouped.index, 'radius'] = cn_grouped.values.astype( + nodes.loc[cn_grouped.index, "radius"] = cn_grouped.values.astype( nodes.radius.dtype, copy=False ) # Go over each segment and interpolate radii - for s in config.tqdm(x.segments, desc='Interp.', disable=config.pbar_hide, - leave=config.pbar_leave): - + for s in config.tqdm( + x.segments, desc="Interp.", disable=config.pbar_hide, leave=config.pbar_leave + ): # Get this segments radii and parent dist - this_radii = nodes.loc[s, ['radius', 'parent_dist']] - this_radii['parent_dist_cum'] = this_radii.parent_dist.cumsum() + this_radii = nodes.loc[s, ["radius", "parent_dist"]] + this_radii["parent_dist_cum"] = this_radii.parent_dist.cumsum() # Set cumulative distance as index and drop parent_dist - this_radii = this_radii.set_index('parent_dist_cum', - drop=True).drop('parent_dist', - axis=1) + this_radii = this_radii.set_index("parent_dist_cum", drop=True).drop( + "parent_dist", axis=1 + ) # Interpolate missing radii - interp = this_radii.interpolate(method=method, limit_direction='both', - limit=limit) + interp = this_radii.interpolate( + method=method, limit_direction="both", limit=limit + ) if smooth: - interp = interp.rolling(smooth, - min_periods=1).max() + interp = interp.rolling(smooth, min_periods=1).max() - nodes.loc[s, 'radius'] = interp.values + nodes.loc[s, "radius"] = interp.values # Set non-interpolated radii back to -1 - nodes.loc[nodes.radius.isnull(), 'radius'] = -1 + nodes.loc[nodes.radius.isnull(), "radius"] = -1 # Reassign nodes x.nodes = nodes.reset_index(drop=False, inplace=False) @@ -1571,11 +1674,13 @@ def guess_radius(x: NeuronObject, return x -@utils.map_neuronlist(desc='Smoothing', allow_parallel=True) -def smooth_skeleton(x: NeuronObject, - window: int = 5, - to_smooth: list = ['x', 'y', 'z'], - inplace: bool = False) -> NeuronObject: +@utils.map_neuronlist(desc="Smoothing", allow_parallel=True) +def smooth_skeleton( + x: NeuronObject, + window: int = 5, + to_smooth: list = ["x", "y", "z"], + inplace: bool = False, +) -> NeuronObject: """Smooth skeleton(s) using rolling windows. Parameters @@ -1618,26 +1723,28 @@ def smooth_skeleton(x: NeuronObject, """ # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Can only process TreeNeurons, not {type(x)}') + raise TypeError(f"Can only process TreeNeurons, not {type(x)}") if not inplace: x = x.copy() # Prepare nodes (add parent_dist for later, set index) # mmetrics.parent_dist(x, root_dist=0) - nodes = x.nodes.set_index('node_id', inplace=False).copy() + nodes = x.nodes.set_index("node_id", inplace=False).copy() to_smooth = utils.make_iterable(to_smooth) miss = to_smooth[~np.isin(to_smooth, nodes.columns)] if len(miss): - raise ValueError(f'Column(s) not found in node table: {miss}') + raise ValueError(f"Column(s) not found in node table: {miss}") # Go over each segment and smooth - for s in config.tqdm(x.segments[::-1], desc='Smoothing', - disable=config.pbar_hide, - leave=config.pbar_leave): - + for s in config.tqdm( + x.segments[::-1], + desc="Smoothing", + disable=config.pbar_hide, + leave=config.pbar_leave, + ): # Get this segment's parent distances and get cumsum this_co = nodes.loc[s, to_smooth] @@ -1656,10 +1763,10 @@ def smooth_skeleton(x: NeuronObject, return x -@utils.map_neuronlist(desc='Smoothing', allow_parallel=True) -def smooth_voxels(x: NeuronObject, - sigma: int = 1, - inplace: bool = False) -> NeuronObject: +@utils.map_neuronlist(desc="Smoothing", allow_parallel=True) +def smooth_voxels( + x: NeuronObject, sigma: int = 1, inplace: bool = False +) -> NeuronObject: """Smooth voxel(s) using a Gaussian filter. Parameters @@ -1696,7 +1803,7 @@ def smooth_voxels(x: NeuronObject, """ # The decorator makes sure that at this point we have single neurons if not isinstance(x, core.VoxelNeuron): - raise TypeError(f'Can only process VoxelNeurons, not {type(x)}') + raise TypeError(f"Can only process VoxelNeurons, not {type(x)}") if not inplace: x = x.copy() @@ -1708,9 +1815,11 @@ def smooth_voxels(x: NeuronObject, return x -def break_fragments(x: Union['core.TreeNeuron', 'core.MeshNeuron'], - labels_only: bool = False, - min_size: Optional[int] = None) -> 'core.NeuronList': +def break_fragments( + x: Union["core.TreeNeuron", "core.MeshNeuron"], + labels_only: bool = False, + min_size: Optional[int] = None, +) -> "core.NeuronList": """Break neuron into its connected components. Neurons can consists of several disconnected fragments. This function @@ -1765,7 +1874,7 @@ def break_fragments(x: Union['core.TreeNeuron', 'core.MeshNeuron'], if labels_only: cc_id = {n: i for i, cc in enumerate(comp) for n in cc} if isinstance(x, core.TreeNeuron): - x.nodes['fragment'] = x.nodes.node_id.map(cc_id).astype(str) + x.nodes["fragment"] = x.nodes.node_id.map(cc_id).astype(str) elif isinstance(x, core.MeshNeuron): x.fragments = np.array([cc_id[i] for i in range(x.n_vertices)]).astype(str) return x @@ -1773,23 +1882,26 @@ def break_fragments(x: Union['core.TreeNeuron', 'core.MeshNeuron'], if min_size: comp = [cc for cc in comp if len(cc) >= min_size] - return core.NeuronList([subset.subset_neuron(x, - list(ss), - inplace=False) for ss in config.tqdm(comp, - desc='Breaking', - disable=config.pbar_hide, - leave=config.pbar_leave)]) - - -@utils.map_neuronlist(desc='Healing', allow_parallel=True) -def heal_skeleton(x: 'core.NeuronList', - method: Union[Literal['LEAFS'], - Literal['ALL']] = 'ALL', - max_dist: Optional[float] = None, - min_size: Optional[float] = None, - drop_disc: float = False, - mask: Optional[Sequence] = None, - inplace: bool = False) -> Optional[NeuronObject]: + return core.NeuronList( + [ + subset.subset_neuron(x, list(ss), inplace=False) + for ss in config.tqdm( + comp, desc="Breaking", disable=config.pbar_hide, leave=config.pbar_leave + ) + ] + ) + + +@utils.map_neuronlist(desc="Healing", allow_parallel=True) +def heal_skeleton( + x: "core.NeuronList", + method: Union[Literal["LEAFS"], Literal["ALL"]] = "ALL", + max_dist: Optional[float] = None, + min_size: Optional[float] = None, + drop_disc: float = False, + mask: Optional[Sequence] = None, + inplace: bool = False, +) -> Optional[NeuronObject]: """Heal fragmented skeleton(s). Tries to heal a fragmented skeleton (i.e. a neuron with multiple roots) @@ -1857,7 +1969,7 @@ def heal_skeleton(x: 'core.NeuronList', """ method = str(method).upper() - if method not in ('LEAFS', 'ALL'): + if method not in ("LEAFS", "ALL"): raise ValueError(f'Unknown method "{method}"') # The decorator makes sure that at this point we have single neurons @@ -1865,17 +1977,14 @@ def heal_skeleton(x: 'core.NeuronList', raise TypeError(f'Expected TreeNeuron(s), got "{type(x)}"') if not isinstance(max_dist, type(None)): - max_dist = x.map_units(max_dist, on_error='raise') + max_dist = x.map_units(max_dist, on_error="raise") if not inplace: x = x.copy() - _ = _stitch_mst(x, - nodes=method, - max_dist=max_dist, - min_size=min_size, - mask=mask, - inplace=True) + _ = _stitch_mst( + x, nodes=method, max_dist=max_dist, min_size=min_size, mask=mask, inplace=True + ) # See if we need to drop remaining disconnected fragments if drop_disc: @@ -1888,14 +1997,14 @@ def heal_skeleton(x: 'core.NeuronList', return x -def _stitch_mst(x: 'core.TreeNeuron', - nodes: Union[Literal['LEAFS'], - Literal['ALL'], - list] = 'ALL', - max_dist: Optional[float] = np.inf, - min_size: Optional[float] = None, - mask: Optional[Sequence] = None, - inplace: bool = False) -> Optional['core.TreeNeuron']: +def _stitch_mst( + x: "core.TreeNeuron", + nodes: Union[Literal["LEAFS"], Literal["ALL"], list] = "ALL", + max_dist: Optional[float] = np.inf, + min_size: Optional[float] = None, + mask: Optional[Sequence] = None, + inplace: bool = False, +) -> Optional["core.TreeNeuron"]: """Stitch disconnected neuron using a minimum spanning tree. Parameters @@ -1935,8 +2044,9 @@ def _stitch_mst(x: 'core.TreeNeuron', mask = np.asarray(mask) if mask.dtype == bool: if len(mask) != len(x.nodes): - raise ValueError("Length of boolean mask must match number of " - "nodes in the neuron") + raise ValueError( + "Length of boolean mask must match number of " "nodes in the neuron" + ) mask = x.nodes.node_id.values[mask] # Get connected components @@ -1960,8 +2070,8 @@ def _stitch_mst(x: 'core.TreeNeuron', cc = cc[cc.isin(above)] # Filter to leaf nodes if applicable - if nodes == 'LEAFS': - keep = to_use['type'].isin(['end', 'root']) + if nodes == "LEAFS": + keep = to_use["type"].isin(["end", "root"]) to_use = to_use[keep] cc = cc[keep] @@ -1972,10 +2082,10 @@ def _stitch_mst(x: 'core.TreeNeuron', cc = cc[keep] # Collect fragments - Fragment = namedtuple('Fragment', ['frag_id', 'node_ids', 'kd']) + Fragment = namedtuple("Fragment", ["frag_id", "node_ids", "kd"]) fragments = [] for frag_id, df in to_use.groupby(cc): - kd = KDTree(df[[*'xyz']].values) + kd = KDTree(df[[*"xyz"]].values) fragments.append(Fragment(frag_id, df.node_id.values, kd)) # Sort from big-to-small, so the calculations below use a @@ -2014,30 +2124,36 @@ def _stitch_mst(x: 'core.TreeNeuron', # Add edge from one fragment to another, # but keep track of which fine-grained skeleton # nodes were used to calculate distance. - frag_graph.add_edge(frag_a.frag_id, frag_b.frag_id, - node_a=node_a, node_b=node_b, - distance=dist_ab) + frag_graph.add_edge( + frag_a.frag_id, + frag_b.frag_id, + node_a=node_a, + node_b=node_b, + distance=dist_ab, + ) # Compute inter-fragment MST edges - frag_edges = nx.minimum_spanning_edges(frag_graph, weight='distance', data=True) + frag_edges = nx.minimum_spanning_edges(frag_graph, weight="distance", data=True) # For each inter-fragment edge, add the corresponding # fine-grained edge between skeleton nodes in the original graph. g = x.graph.to_undirected() - to_add = [[e[2]['node_a'], e[2]['node_b']] for e in frag_edges] + to_add = [[e[2]["node_a"], e[2]["node_b"]] for e in frag_edges] g.add_edges_from(to_add) # Rewire based on graph return graph.rewire_skeleton(x, g, inplace=inplace) -@utils.map_neuronlist(desc='Pruning', must_zip=['source'], allow_parallel=True) -@utils.meshneuron_skeleton(method='subset') -def prune_at_depth(x: NeuronObject, - depth: Union[float, int], *, - source: Optional[int] = None, - inplace: bool = False - ) -> Optional[NeuronObject]: +@utils.map_neuronlist(desc="Pruning", must_zip=["source"], allow_parallel=True) +@utils.meshneuron_skeleton(method="subset") +def prune_at_depth( + x: NeuronObject, + depth: Union[float, int], + *, + source: Optional[int] = None, + inplace: bool = False, +) -> Optional[NeuronObject]: """Prune all neurites past a given distance from a source. Parameters @@ -2077,9 +2193,9 @@ def prune_at_depth(x: NeuronObject, """ # The decorator makes sure that at this point we only have single neurons if not isinstance(x, core.TreeNeuron): - raise TypeError(f'Expected TreeNeuron, got {type(x)}') + raise TypeError(f"Expected TreeNeuron, got {type(x)}") - depth = x.map_units(depth, on_error='raise') + depth = x.map_units(depth, on_error="raise") if depth < 0: raise ValueError(f'`depth` must be > 0, got "{depth}"') @@ -2100,12 +2216,12 @@ def prune_at_depth(x: NeuronObject, return x -@utils.map_neuronlist(desc='Pruning', allow_parallel=True) -def drop_fluff(x: Union['core.TreeNeuron', - 'core.MeshNeuron', - 'core.NeuronList'], - keep_size: Optional[float] = None, - inplace: bool = False): +@utils.map_neuronlist(desc="Pruning", allow_parallel=True) +def drop_fluff( + x: Union["core.TreeNeuron", "core.MeshNeuron", "core.NeuronList"], + keep_size: Optional[float] = None, + inplace: bool = False, +): """Remove small disconnected pieces of "fluff". By default, this function will remove all but the largest connected @@ -2138,7 +2254,7 @@ def drop_fluff(x: Union['core.TreeNeuron', (6309, 6037) """ - utils.eval_param(x, name='x', allowed_types=(core.TreeNeuron, core.MeshNeuron)) + utils.eval_param(x, name="x", allowed_types=(core.TreeNeuron, core.MeshNeuron)) G = x.graph # Skeleton graphs are directed @@ -2159,11 +2275,11 @@ def drop_fluff(x: Union['core.TreeNeuron', x = subset.subset_neuron(x, subset=keep, inplace=inplace, keep_disc_cn=True) # See if we need to re-attach any connectors - id_col = 'node_id' if isinstance(x, core.TreeNeuron) else 'vertex_id' + id_col = "node_id" if isinstance(x, core.TreeNeuron) else "vertex_id" if x.has_connectors and id_col in x.connectors: disc = ~x.connectors[id_col].isin(x.graph.nodes).values if any(disc): - xyz = x.connectors.loc[disc, ['x', 'y', 'z']].values + xyz = x.connectors.loc[disc, ["x", "y", "z"]].values x.connectors.loc[disc, id_col] = x.snap(xyz)[0] return x From 06836a04a4d9720dc230bb5d3208ec32a5fd865b Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 16:13:15 +0100 Subject: [PATCH 15/16] MeshNeuron: fix .soma_pos setter --- navis/core/mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/navis/core/mesh.py b/navis/core/mesh.py index bd8ea79b..2dff9ab7 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -370,7 +370,7 @@ def soma_pos(self): def soma_pos(self, value): """Set soma by position.""" if value is None: - self.soma = None + self._soma_pos = None return try: From 031bc6b20825c33bdf45495b5e7d40180581ac13 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 22 Sep 2024 16:45:56 +0100 Subject: [PATCH 16/16] prune_twigs: fix `mask` when `precise=True` --- navis/morpho/manipulation.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/navis/morpho/manipulation.py b/navis/morpho/manipulation.py index 0a19cd93..f294e8be 100644 --- a/navis/morpho/manipulation.py +++ b/navis/morpho/manipulation.py @@ -513,6 +513,7 @@ def _prune_twigs_precise( neuron: "core.TreeNeuron", size: float, inplace: bool = False, + mask: Optional[Union[Sequence[int], Callable]] = None, recursive: Union[int, bool, float] = False, ) -> Optional[NeuronObject]: """Prune twigs using precise method.""" @@ -534,17 +535,38 @@ def _prune_twigs_precise( res = tree.query_ball_point(neuron.leafs[["x", "y", "z"]].values, r=size) candidates = neuron.nodes.node_id.values[np.unique(np.concatenate(res))] + if callable(mask): + mask = mask(neuron) + + if mask is not None: + if mask.dtype == bool: + if len(mask) != neuron.n_nodes: + raise ValueError("Mask length must match number of nodes") + mask_nodes = neuron.nodes.node_id.values[mask] + elif mask.dtype in (int, np.int32, np.int64): + mask_nodes = mask + else: + raise TypeError( + f"Mask must be boolean or list of node IDs, got {mask.dtype}" + ) + + candidates = np.intersect1d(candidates, mask_nodes) + + if not len(candidates): + return neuron + # For each node in neuron find out which leafs are directly distal to it # `distal` is a matrix with all nodes in columns and leafs in rows distal = graph.distal_to(neuron, a=leafs, b=candidates) + # Turn matrix into dictionary {'node': [leafs, distal, to, it]} melted = distal.reset_index(drop=False).melt(id_vars="index") melted = melted[melted.value] - melted.groupby("variable")["index"].apply(list) + # `distal` is now a dictionary for {'node_id': [leaf1, leaf2, ..], ..} distal = melted.groupby("variable")["index"].apply(list).to_dict() - # For each node find the distance to any leaf - note we are using `length` + # For each node find the distance to any leaf - note we are using `size` # as cutoff here # `path_len` is a dict mapping {nodeA: {nodeB: length, ...}, ...} # if nodeB is not in dictionary, it's not within reach @@ -571,6 +593,12 @@ def _prune_twigs_precise( # For each of the new leafs check their shortest distance to the # original leafs to get the remainder is_new_leaf = (neuron.nodes.type == "end").values + + # If there is a mask, we have to exclude old leafs which would not have + # been in the mask + if mask is not None: + is_new_leaf = is_new_leaf & np.isin(neuron.nodes.node_id, mask_nodes) + new_leafs = neuron.nodes[is_new_leaf].node_id.values max_len = [max([path_len[l1][l2] for l2 in distal[l1]]) for l1 in new_leafs] @@ -581,7 +609,7 @@ def _prune_twigs_precise( # Get vectors from leafs to their parents nodes = neuron.nodes.set_index("node_id") parents = nodes.loc[new_leafs, "parent_id"].values - loc1 = neuron.leafs[["x", "y", "z"]].values + loc1 = nodes.loc[new_leafs, ["x", "y", "z"]].values loc2 = nodes.loc[parents, ["x", "y", "z"]].values vec = loc1 - loc2 vec_len = np.linalg.norm(vec, axis=1)