diff --git a/docs/building_decision_trees.rst b/docs/building_decision_trees.rst index acc28d806..cefce7ef6 100644 --- a/docs/building_decision_trees.rst +++ b/docs/building_decision_trees.rst @@ -51,7 +51,7 @@ The file key names are used below the full file names in the General outputs from component selection ======================================== -New columns in ``selector.component_table`` and the "ICA metrics tsv" file: +New columns in ``selector.component_table_`` and the "ICA metrics tsv" file: - classification: While the decision table is running, there may also be intermediate @@ -63,13 +63,13 @@ New columns in ``selector.component_table`` and the "ICA metrics tsv" file: or a comma separated list of tags. These tags may be useful parameters for visualizing and reviewing results -``selector.cross_component_metrics`` and "ICA cross component metrics json": +``selector.cross_component_metrics_`` and "ICA cross component metrics json": A dictionary of metrics that are each a single value calculated across components, for example, kappa and rho elbows. User or pre-defined scaling factors are also stored here. Any constant that is used in the component classification processes that isn't pre-defined in the decision tree file should be saved here. -``selector.component_status_table`` and "ICA status table tsv": +``selector.component_status_table_`` and "ICA status table tsv": A table where each column lists the classification status of each component after each node was run. Columns are only added for runs where component statuses can change. @@ -210,11 +210,9 @@ that is used to check whether results are plausible & can help avoid mistakes. - necessary_metrics A list of the necessary metrics in the component table that will be used - by the tree. If a metric doesn't exist then this will raise an error instead - of executing a tree. (Depending on future code development, this could - potentially be used to run ``tedana`` by specifying a decision tree and - metrics are calculated based on the contents of this field.) If a necessary - metric isn't used, there will be a warning. + by the tree. This field defines what metrics will be calculated on each ICA + component. If a metric doesn't exist then this will raise an error instead + of executing a tree. If a necessary metric isn't used, there will be a warning. - generated_metrics An optional initial field. It lists metrics that are to be calculated as @@ -315,7 +313,7 @@ are good examples for how to meet these expectations. Create a dictionary called "outputs" that includes key fields that should be recorded. The following line should be at the end of each function to retain the output info: -``selector.nodes[selector.current_node_idx]["outputs"] = outputs`` +``selector.nodes[selector.current_node_idx_]["outputs"] = outputs`` Additional fields can be used to log function-specific information, but the following fields are common and may be used by other parts of the code: @@ -339,7 +337,7 @@ Before any data are touched in the function, there should be an call. This will be useful to gather all metrics a tree will use without requiring a specific dataset. -Existing functions define ``function_name_idx = f"Step {selector.current_node_idx}: [text of function_name]``. +Existing functions define ``function_name_idx = f"Step {selector.current_node_idx_}: [text of function_name]``. This is used in logging and is cleaner to initialize near the top of each function. Each function has code that creates a default node label in ``outputs["node_label"]``. @@ -378,7 +376,7 @@ dataframe column that is True or False for the components in ``decide_comps`` ba the function's criteria. That column is an input to :func:`~tedana.selection.selection_utils.change_comptable_classifications`, which will update the component_table classifications, update the classification history -in component_status_table, and update the component classification_tags. Components not +in ``selector.component_status_table_``, and update the component classification_tags. Components not in ``decide_comps`` retain their existing classifications and tags. :func:`~tedana.selection.selection_utils.change_comptable_classifications` also returns and should assign values to @@ -386,7 +384,7 @@ also returns and should assign values to identified as true or false within each function. For calculation functions, the calculated values should be added as a value/key pair to -both ``selector.cross_component_metrics`` and ``outputs``. +both ``selector.cross_component_metrics_`` and ``outputs``. :func:`~tedana.selection.selection_utils.log_decision_tree_step` puts the relevant info from the function call into the program's output log. @@ -395,7 +393,7 @@ Every function should end with: .. code-block:: python - selector.nodes[selector.current_node_idx]["outputs"] = outputs + selector.nodes[selector.current_node_idx_]["outputs"] = outputs return selector functionname.__doc__ = (functionname.__doc__.format(**DECISION_DOCS)) diff --git a/tedana/io.py b/tedana/io.py index 94966793b..b73425554 100644 --- a/tedana/io.py +++ b/tedana/io.py @@ -643,7 +643,7 @@ def writeresults(ts, mask, comptable, mmix, io_generator): ========================================= =========================================== Filename Content ========================================= =========================================== - desc-denoised_bold.nii.gz Denoised time series. + desc-denoised_bold.nii.gz Denoised time series. desc-optcomAccepted_bold.nii.gz High-Kappa time series. (only with verbose) desc-optcomRejected_bold.nii.gz Low-Kappa time series. (only with verbose) diff --git a/tedana/selection/component_selector.py b/tedana/selection/component_selector.py index 4c671dad2..a29183416 100644 --- a/tedana/selection/component_selector.py +++ b/tedana/selection/component_selector.py @@ -224,13 +224,39 @@ def validate_tree(tree): class ComponentSelector: """Load and classify components based on a specified ``tree``.""" - def __init__(self, tree, component_table, cross_component_metrics={}, status_table=None): + def __init__(self, tree): """Initialize the class using the info specified in the json file ``tree``. Parameters ---------- tree : :obj:`str` The named tree or path to a JSON file that defines one. + + Notes + ----- + Initializing the ``ComponentSelector`` confirms tree is valid and + loads all information in the tree json file into ``ComponentSelector``. + """ + self.tree_name = tree + self.tree = load_config(self.tree_name) + + LGR.info("Performing component selection with " + self.tree["tree_id"]) + LGR.info(self.tree.get("info", "")) + RepLGR.info(self.tree.get("report", "")) + + self.necessary_metrics = self.tree["necessary_metrics"] + self.classification_tags = set(self.tree["classification_tags"]) + self.tree["used_metrics"] = set(self.tree.get("used_metrics", [])) + + def select(self, component_table, cross_component_metrics={}, status_table=None): + """Apply the decision tree to data. + + Using the validated tree in ``ComponentSelector`` to run the decision + tree functions to calculate cross_component metrics and classify + each component as accepted or rejected. + + Parameters + ---------- component_table : (C x M) :obj:`pandas.DataFrame` Component metric table. One row for each component, with a column for each metric; the index should be the component number. @@ -244,13 +270,10 @@ def __init__(self, tree, component_table, cross_component_metrics={}, status_tab Notes ----- - Initializing the ``ComponentSelector`` confirms tree is valid and - loads all information in the tree json file into ``ComponentSelector``. - Adds to the ``ComponentSelector``: - - component_status_table: empty dataframe or contents of inputted status_table - - cross_component_metrics: empty dict or contents of inputed values + - ``component_status_table_``: empty dataframe or contents of inputted status_table + - ``cross_component_metrics_``: empty dict or contents of inputed values - used_metrics: empty set Any parameter that is used by a decision tree node function can be passed @@ -270,62 +293,8 @@ def __init__(self, tree, component_table, cross_component_metrics={}, status_tab Required for kundu tree An example initialization with these options would look like - ``selector = ComponentSelector(tree, comptable, n_echos=n_echos, n_vols=n_vols)`` - """ - self.tree_name = tree - - self.__dict__.update(cross_component_metrics) - self.cross_component_metrics = cross_component_metrics - - # Construct an un-executed selector - self.component_table = component_table.copy() - - # To run a decision tree, each component needs to have an initial classification - # If the classification column doesn't exist, create it and label all components - # as unclassified - if "classification" not in self.component_table: - self.component_table["classification"] = "unclassified" - - self.tree = load_config(self.tree_name) - tree_config = self.tree - - LGR.info("Performing component selection with " + tree_config["tree_id"]) - LGR.info(tree_config.get("info", "")) - RepLGR.info(tree_config.get("report", "")) - - self.tree["nodes"] = tree_config["nodes"] - self.necessary_metrics = set(tree_config["necessary_metrics"]) - self.intermediate_classifications = tree_config["intermediate_classifications"] - self.classification_tags = set(tree_config["classification_tags"]) - if "used_metrics" not in self.tree.keys(): - self.tree["used_metrics"] = set() - else: - self.tree["used_metrics"] = set(self.tree["used_metrics"]) - - if status_table is None: - self.component_status_table = self.component_table[ - ["Component", "classification"] - ].copy() - self.component_status_table = self.component_status_table.rename( - columns={"classification": "initialized classification"} - ) - self.start_idx = 0 - else: - # Since a status table exists, we need to skip nodes up to the - # point where the last tree finished - self.start_idx = len(tree_config["nodes"]) - LGR.info(f"Start is {self.start_idx}") - self.component_status_table = status_table + ``selector = selector.select(comptable, n_echos=n_echos, n_vols=n_vols)`` - def select(self): - """Apply the decision tree to data. - - Using the validated tree in ``ComponentSelector`` to run the decision - tree functions to calculate cross_component metrics and classify - each component as accepted or rejected. - - Notes - ----- The selection process uses previously calculated parameters stored in `component_table` for each ICA component such as Kappa (a T2* weighting metric), Rho (an S0 weighting metric), and variance explained. If a necessary metric @@ -338,30 +307,63 @@ def select(self): When this is run, multiple elements in `ComponentSelector` will change including: - - component_table: ``classification`` column with ``accepted`` or ``rejected`` labels + - ``component_table_``: ``classification`` column with ``accepted`` or ``rejected`` labels and ``classification_tags`` column with can hold multiple comma-separated labels explaining why a classification happened - - cross_component_metrics: Any values that were calculated based on the metric + - ``cross_component_metrics_``: Any values that were calculated based on the metric values across components or by direct user input - - component_status_table: Contains the classification statuses at each node in + - ``component_status_table_``: Contains the classification statuses at each node in the decision tree - used_metrics: A list of metrics used in the selection process - nodes: The original tree definition with an added ``outputs`` key listing everything that changed in each node - - current_node_idx: The total number of nodes run in ``ComponentSelector`` + - ``current_node_idx_``: The total number of nodes run in ``ComponentSelector`` """ - if "classification_tags" not in self.component_table.columns: - self.component_table["classification_tags"] = "" + self.cross_component_metrics_ = cross_component_metrics + + # Construct an un-executed selector + self.component_table_ = component_table.copy() # this will crash the program with an error message if not all # necessary_metrics are in the comptable confirm_metrics_exist( - self.component_table, self.necessary_metrics, function_name=self.tree_name + self.component_table_, + self.necessary_metrics, + function_name=self.tree_name, ) + # To run a decision tree, each component needs to have an initial classification + # If the classification column doesn't exist, create it and label all components + # as unclassified + if "classification" not in self.component_table_: + self.component_table_["classification"] = "unclassified" + + if status_table is None: + self.component_status_table_ = self.component_table_[ + ["Component", "classification"] + ].copy() + self.component_status_table_ = self.component_status_table_.rename( + columns={"classification": "initialized classification"} + ) + self.start_idx_ = 0 + else: + # Since a status table exists, we need to skip nodes up to the + # point where the last tree finished. Notes that were executed + # have an output field. Identify the last node with an output field + tmp_idx = len(self.tree["nodes"]) - 1 + while ("outputs" not in self.tree["nodes"][tmp_idx]) and (tmp_idx > 0): + tmp_idx -= 1 + # start at the first node that does not have an output field + self.start_idx_ = tmp_idx + 1 + LGR.info(f"Start is {self.start_idx_}") + self.component_status_table_ = status_table + + if "classification_tags" not in self.component_table_.columns: + self.component_table_["classification_tags"] = "" + # for each node in the decision tree - for self.current_node_idx, node in enumerate( - self.tree["nodes"][self.start_idx :], start=self.start_idx + for self.current_node_idx_, node in enumerate( + self.tree["nodes"][self.start_idx_ :], start=self.start_idx_ ): # parse the variables to use with the function fcn = getattr(selection_nodes, node["functionname"]) @@ -375,32 +377,29 @@ def select(self): kwargs = self.check_null(kwargs, node["functionname"]) all_params = {**params, **kwargs} else: - kwargs = None + kwargs = {} all_params = {**params} LGR.debug( - f"Step {self.current_node_idx}: Running function {node['functionname']} " + f"Step {self.current_node_idx_}: Running function {node['functionname']} " f"with parameters: {all_params}" ) # run the decision node function - if kwargs is not None: - self = fcn(self, **params, **kwargs) - else: - self = fcn(self, **params) + self = fcn(self, **params, **kwargs) self.tree["used_metrics"].update( - self.tree["nodes"][self.current_node_idx]["outputs"]["used_metrics"] + self.tree["nodes"][self.current_node_idx_]["outputs"]["used_metrics"] ) # log the current counts for all classification labels - log_classification_counts(self.current_node_idx, self.component_table) + log_classification_counts(self.current_node_idx_, self.component_table_) LGR.debug( - f"Step {self.current_node_idx} Full outputs: " - f"{self.tree['nodes'][self.current_node_idx]['outputs']}" + f"Step {self.current_node_idx_} Full outputs: " + f"{self.tree['nodes'][self.current_node_idx_]['outputs']}" ) # move decision columns to end - self.component_table = clean_dataframe(self.component_table) + self.component_table_ = clean_dataframe(self.component_table_) # warning anything called a necessary metric wasn't used and if # anything not called a necessary metric was used self.are_only_necessary_metrics_used() @@ -445,7 +444,7 @@ def check_null(self, params, fcn): for key, val in params.items(): if val is None: try: - params[key] = getattr(self, key) + params[key] = self.cross_component_metrics_[key] except AttributeError: raise ValueError( f"Parameter {key} is required in node {fcn}, but not defined. " @@ -466,9 +465,11 @@ def are_only_necessary_metrics_used(self): If either of these happen, a warning is added to the logger. """ - necessary_metrics = self.necessary_metrics - not_declared = self.tree["used_metrics"] - necessary_metrics - not_used = necessary_metrics - self.tree["used_metrics"] + necessary_metrics = set(self.necessary_metrics).union( + set(self.tree.get("generated_metrics", [])) + ) + not_declared = self.tree["used_metrics"] - set(necessary_metrics) + not_used = set(necessary_metrics) - self.tree["used_metrics"] if len(not_declared) > 0: LGR.warning( f"Decision tree {self.tree_name} used the following metrics that were " @@ -488,11 +489,11 @@ def are_all_components_accepted_or_rejected(self): If any other component classifications remain, log a warning. """ - component_classifications = set(self.component_table["classification"].to_list()) + component_classifications = set(self.component_table_["classification"].to_list()) nonfinal_classifications = component_classifications.difference({"accepted", "rejected"}) if nonfinal_classifications: for nonfinal_class in nonfinal_classifications: - numcomp = asarray(self.component_table["classification"] == nonfinal_class).sum() + numcomp = asarray(self.component_table_["classification"] == nonfinal_class).sum() LGR.warning( f"{numcomp} components have a final classification of {nonfinal_class}. " "At the end of the selection process, all components are expected " @@ -500,14 +501,14 @@ def are_all_components_accepted_or_rejected(self): ) @property - def n_comps(self): + def n_comps_(self): """The number of components in the component table.""" - return len(self.component_table) + return len(self.component_table_) @property - def likely_bold_comps(self): + def likely_bold_comps_(self): """A boolean :obj:`pandas.Series` of components that are tagged "Likely BOLD".""" - likely_bold_comps = self.component_table["classification_tags"].copy() + likely_bold_comps = self.component_table_["classification_tags"].copy() for idx in range(len(likely_bold_comps)): if "Likely BOLD" in likely_bold_comps.loc[idx]: likely_bold_comps.loc[idx] = True @@ -516,24 +517,24 @@ def likely_bold_comps(self): return likely_bold_comps @property - def n_likely_bold_comps(self): + def n_likely_bold_comps_(self): """The number of components that are tagged "Likely BOLD".""" - return self.likely_bold_comps.sum() + return self.likely_bold_comps_.sum() @property - def accepted_comps(self): + def accepted_comps_(self): """A boolean :obj:`pandas.Series` of components that are accepted.""" - return self.component_table["classification"] == "accepted" + return self.component_table_["classification"] == "accepted" @property - def n_accepted_comps(self): + def n_accepted_comps_(self): """The number of components that are accepted.""" - return self.accepted_comps.sum() + return self.accepted_comps_.sum() @property - def rejected_comps(self): + def rejected_comps_(self): """A boolean :obj:`pandas.Series` of components that are rejected.""" - return self.component_table["classification"] == "rejected" + return self.component_table_["classification"] == "rejected" def to_files(self, io_generator): """Convert this selector into component files. @@ -543,10 +544,10 @@ def to_files(self, io_generator): io_generator : :obj:`tedana.io.OutputGenerator` The output generator to use for filename generation and saving. """ - io_generator.save_file(self.component_table, "ICA metrics tsv") + io_generator.save_file(self.component_table_, "ICA metrics tsv") io_generator.save_file( - self.cross_component_metrics, + self.cross_component_metrics_, "ICA cross component metrics json", ) - io_generator.save_file(self.component_status_table, "ICA status table tsv") + io_generator.save_file(self.component_status_table_, "ICA status table tsv") io_generator.save_file(self.tree, "ICA decision tree json") diff --git a/tedana/selection/selection_nodes.py b/tedana/selection/selection_nodes.py index 170cff14b..5fed5ddde 100644 --- a/tedana/selection/selection_nodes.py +++ b/tedana/selection/selection_nodes.py @@ -82,7 +82,7 @@ def manual_classify( """ # predefine all outputs that should be logged outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "used_metrics": set(), "node_label": None, "n_true": None, @@ -95,7 +95,7 @@ def manual_classify( if_true = new_classification if_false = "nochange" - function_name_idx = f"Step {selector.current_node_idx}: manual_classify" + function_name_idx = f"Step {selector.current_node_idx_}: manual_classify" if custom_node_label: outputs["node_label"] = custom_node_label else: @@ -105,7 +105,7 @@ def manual_classify( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) if not comps2use: log_decision_tree_step(function_name_idx, comps2use, decide_comps=decide_comps) @@ -132,10 +132,10 @@ def manual_classify( ) if clear_classification_tags: - selector.component_table["classification_tags"] = "" + selector.component_table_["classification_tags"] = "" LGR.info(function_name_idx + " component classification tags are cleared") - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -226,7 +226,7 @@ def dec_left_op_right( """ # predefine all outputs that should be logged outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "used_metrics": set(), "used_cross_component_metrics": set(), "node_label": None, @@ -234,10 +234,10 @@ def dec_left_op_right( "n_false": None, } - function_name_idx = f"Step {selector.current_node_idx}: left_op_right" + function_name_idx = f"Step {selector.current_node_idx_}: left_op_right" # Only select components if the decision tree is being run if not only_used_metrics: - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) def identify_used_metric(val, isnum=False): """ @@ -254,11 +254,11 @@ def identify_used_metric(val, isnum=False): """ orig_val = val if isinstance(val, str): - if val in selector.component_table.columns: + if val in selector.component_table_.columns: outputs["used_metrics"].update([val]) - elif val in selector.cross_component_metrics: + elif val in selector.cross_component_metrics_: outputs["used_cross_component_metrics"].update([val]) - val = selector.cross_component_metrics[val] + val = selector.cross_component_metrics_[val] # If decision tree is being run, then throw errors or messages # if a component doesn't exist. If this is just getting a list # of metrics to be used, then don't bring up warnings @@ -266,14 +266,14 @@ def identify_used_metric(val, isnum=False): if not comps2use: LGR.info( f"{function_name_idx}: {val} is neither a metric in " - "selector.component_table nor selector.cross_component_metrics, " + "selector.component_table_ nor selector.cross_component_metrics_, " f"but no components with {decide_comps} remain by this node " "so nothing happens" ) else: raise ValueError( - f"{val} is neither a metric in selector.component_table " - "nor selector.cross_component_metrics" + f"{val} is neither a metric in selector.component_table_ " + "nor selector.cross_component_metrics_" ) if isnum: if not isinstance(val, (int, float)): @@ -391,13 +391,13 @@ def operator_scale_descript(val_scale, val): LGR.info(f"{function_name_idx} {log_extra_info}") confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) def parse_vals(val): """Get the metric values for the selected components or relevant constant.""" if isinstance(val, str): - return selector.component_table.loc[comps2use, val].copy() + return selector.component_table_.loc[comps2use, val].copy() else: return val # should be a fixed number @@ -457,7 +457,7 @@ def parse_vals(val): if_false=if_false, ) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -512,7 +512,7 @@ def dec_variance_lessthan_thresholds( %(used_metrics)s """ outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "used_metrics": {var_metric}, "node_label": None, "n_true": None, @@ -522,7 +522,7 @@ def dec_variance_lessthan_thresholds( if only_used_metrics: return outputs["used_metrics"] - function_name_idx = f"Step {selector.current_node_idx}: variance_lt_thresholds" + function_name_idx = f"Step {selector.current_node_idx_}: variance_lt_thresholds" if custom_node_label: outputs["node_label"] = custom_node_label else: @@ -534,9 +534,9 @@ def dec_variance_lessthan_thresholds( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -550,7 +550,7 @@ def dec_variance_lessthan_thresholds( if_false=outputs["n_false"], ) else: - variance = selector.component_table.loc[comps2use, var_metric] + variance = selector.component_table_.loc[comps2use, var_metric] decision_boolean = variance < single_comp_threshold # if all the low variance components sum above all_comp_threshold # keep removing the highest remaining variance component until @@ -582,7 +582,7 @@ def dec_variance_lessthan_thresholds( if_false=if_false, ) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -616,7 +616,7 @@ def calc_median( %(selector)s %(used_metrics)s """ - function_name_idx = f"Step {selector.current_node_idx}: calc_median" + function_name_idx = f"Step {selector.current_node_idx_}: calc_median" if not isinstance(median_label, str): raise ValueError( f"{function_name_idx}: median_label must be a string. It is: {median_label}" @@ -630,7 +630,7 @@ def calc_median( ) outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, label_name: None, "used_metrics": {metric_name}, @@ -640,7 +640,7 @@ def calc_median( if only_used_metrics: return outputs["used_metrics"] - if label_name in selector.cross_component_metrics: + if label_name in selector.cross_component_metrics_: LGR.warning( f"{label_name} already calculated. Overwriting previous value in {function_name_idx}" ) @@ -654,9 +654,9 @@ def calc_median( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -666,13 +666,13 @@ def calc_median( decide_comps=decide_comps, ) else: - outputs[label_name] = np.median(selector.component_table.loc[comps2use, metric_name]) + outputs[label_name] = np.median(selector.component_table_.loc[comps2use, metric_name]) - selector.cross_component_metrics[label_name] = outputs[label_name] + selector.cross_component_metrics_[label_name] = outputs[label_name] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -715,9 +715,9 @@ def calc_kappa_elbow( are called """ outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, - "n_echos": selector.n_echos, + "n_echos": selector.cross_component_metrics_["n_echos"], "used_metrics": {"kappa"}, "calc_cross_comp_metrics": [ "kappa_elbow_kundu", @@ -734,9 +734,9 @@ def calc_kappa_elbow( if only_used_metrics: return outputs["used_metrics"] - function_name_idx = f"Step {selector.current_node_idx}: calc_kappa_elbow" + function_name_idx = f"Step {selector.current_node_idx_}: calc_kappa_elbow" - if ("kappa_elbow_kundu" in selector.cross_component_metrics) and ( + if ("kappa_elbow_kundu" in selector.cross_component_metrics_) and ( "kappa_elbow_kundu" in outputs["calc_cross_comp_metrics"] ): LGR.warning( @@ -744,7 +744,7 @@ def calc_kappa_elbow( f"Overwriting previous value in {function_name_idx}" ) - if "varex_upper_p" in selector.cross_component_metrics: + if "varex_upper_p" in selector.cross_component_metrics_: LGR.warning( f"varex_upper_p already calculated. Overwriting previous value in {function_name_idx}" ) @@ -758,9 +758,9 @@ def calc_kappa_elbow( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -775,15 +775,19 @@ def calc_kappa_elbow( outputs["kappa_allcomps_elbow"], outputs["kappa_nonsig_elbow"], outputs["varex_upper_p"], - ) = kappa_elbow_kundu(selector.component_table, selector.n_echos, comps2use=comps2use) - selector.cross_component_metrics["kappa_elbow_kundu"] = outputs["kappa_elbow_kundu"] - selector.cross_component_metrics["kappa_allcomps_elbow"] = outputs["kappa_allcomps_elbow"] - selector.cross_component_metrics["kappa_nonsig_elbow"] = outputs["kappa_nonsig_elbow"] - selector.cross_component_metrics["varex_upper_p"] = outputs["varex_upper_p"] + ) = kappa_elbow_kundu( + selector.component_table_, + selector.cross_component_metrics_["n_echos"], + comps2use=comps2use, + ) + selector.cross_component_metrics_["kappa_elbow_kundu"] = outputs["kappa_elbow_kundu"] + selector.cross_component_metrics_["kappa_allcomps_elbow"] = outputs["kappa_allcomps_elbow"] + selector.cross_component_metrics_["kappa_nonsig_elbow"] = outputs["kappa_nonsig_elbow"] + selector.cross_component_metrics_["varex_upper_p"] = outputs["varex_upper_p"] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -830,7 +834,7 @@ def calc_rho_elbow( for a more detailed explanation of the difference between the kundu and liberal options. """ - function_name_idx = f"Step {selector.current_node_idx}: calc_rho_elbow" + function_name_idx = f"Step {selector.current_node_idx_}: calc_rho_elbow" if rho_elbow_type == "kundu": elbow_name = "rho_elbow_kundu" @@ -843,9 +847,9 @@ def calc_rho_elbow( ) outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, - "n_echos": selector.n_echos, + "n_echos": selector.cross_component_metrics_["n_echos"], "calc_cross_comp_metrics": [ elbow_name, "rho_allcomps_elbow", @@ -862,7 +866,7 @@ def calc_rho_elbow( if only_used_metrics: return outputs["used_metrics"] - if (elbow_name in selector.cross_component_metrics) and ( + if (elbow_name in selector.cross_component_metrics_) and ( elbow_name in outputs["calc_cross_comp_metrics"] ): LGR.warning( @@ -879,12 +883,12 @@ def calc_rho_elbow( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) - subset_comps2use = selectcomps2use(selector, subset_decide_comps) + subset_comps2use = selectcomps2use(selector.component_table_, subset_decide_comps) if not comps2use: log_decision_tree_step( @@ -899,22 +903,22 @@ def calc_rho_elbow( outputs["rho_unclassified_elbow"], outputs["elbow_f05"], ) = rho_elbow_kundu_liberal( - selector.component_table, - selector.n_echos, + selector.component_table_, + selector.cross_component_metrics_["n_echos"], rho_elbow_type=rho_elbow_type, comps2use=comps2use, subset_comps2use=subset_comps2use, ) - selector.cross_component_metrics[elbow_name] = outputs[elbow_name] - selector.cross_component_metrics["rho_allcomps_elbow"] = outputs["rho_allcomps_elbow"] - selector.cross_component_metrics["rho_unclassified_elbow"] = outputs[ + selector.cross_component_metrics_[elbow_name] = outputs[elbow_name] + selector.cross_component_metrics_["rho_allcomps_elbow"] = outputs["rho_allcomps_elbow"] + selector.cross_component_metrics_["rho_unclassified_elbow"] = outputs[ "rho_unclassified_elbow" ] - selector.cross_component_metrics["elbow_f05"] = outputs["elbow_f05"] + selector.cross_component_metrics_["elbow_f05"] = outputs["elbow_f05"] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -978,7 +982,7 @@ def dec_classification_doesnt_exist( """ # predefine all outputs that should be logged outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "used_metrics": set(), "used_cross_comp_metrics": set(), "node_label": None, @@ -989,7 +993,7 @@ def dec_classification_doesnt_exist( if only_used_metrics: return outputs["used_metrics"] - function_name_idx = f"Step {selector.current_node_idx}: classification_doesnt_exist" + function_name_idx = f"Step {selector.current_node_idx_}: classification_doesnt_exist" if custom_node_label: outputs["node_label"] = custom_node_label elif at_least_num_exist == 1: @@ -1009,9 +1013,9 @@ def dec_classification_doesnt_exist( if_true = new_classification if_false = "nochange" - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) - do_comps_exist = selectcomps2use(selector, class_comp_exists) + do_comps_exist = selectcomps2use(selector.component_table_, class_comp_exists) if (not comps2use) or (len(do_comps_exist) >= at_least_num_exist): outputs["n_true"] = 0 @@ -1044,7 +1048,7 @@ def dec_classification_doesnt_exist( if_false=if_false, ) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1093,7 +1097,7 @@ def dec_reclassify_high_var_comps( """ # predefine all outputs that should be logged outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "used_metrics": {"variance explained"}, "used_cross_comp_metrics": {"varex_upper_p"}, "node_label": None, @@ -1104,7 +1108,7 @@ def dec_reclassify_high_var_comps( if only_used_metrics: return outputs["used_metrics"] - function_name_idx = f"Step {selector.current_node_idx}: reclassify_high_var_comps" + function_name_idx = f"Step {selector.current_node_idx_}: reclassify_high_var_comps" if custom_node_label: outputs["node_label"] = custom_node_label else: @@ -1120,13 +1124,13 @@ def dec_reclassify_high_var_comps( if_true = new_classification if_false = "nochange" - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) - if "varex_upper_p" not in selector.cross_component_metrics: + if "varex_upper_p" not in selector.cross_component_metrics_: if not comps2use: LGR.info( f"{function_name_idx}: varex_upper_p is not in " - "selector.cross_component_metrics, but no components with " + "selector.cross_component_metrics_, but no components with " f"{decide_comps} remain by this node so nothing happens" ) else: @@ -1149,13 +1153,13 @@ def dec_reclassify_high_var_comps( else: keep_comps2use = comps2use.copy() for i_loop in range(3): - temp_comptable = selector.component_table.loc[keep_comps2use].sort_values( + temp_comptable = selector.component_table_.loc[keep_comps2use].sort_values( by=["variance explained"], ascending=False ) diff_vals = temp_comptable["variance explained"].diff(-1) diff_vals = diff_vals.fillna(0) keep_comps2use = temp_comptable.loc[ - diff_vals < selector.cross_component_metrics["varex_upper_p"] + diff_vals < selector.cross_component_metrics_["varex_upper_p"] ].index.values # Everything that should be kept as unclassified is False while the few # that are not in keep_comps2use should be True @@ -1179,7 +1183,7 @@ def dec_reclassify_high_var_comps( if_false=if_false, ) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1214,7 +1218,7 @@ def calc_varex_thresh( num_highest_var_comps : :obj:`str` :obj:`int` percentile can be calculated on the num_highest_var_comps components with the lowest variance. Either input an integer directly or input a string that is - a parameter stored in selector.cross_component_metrics ("num_acc_guess" in + a parameter stored in ``selector.cross_component_metrics_`` ("num_acc_guess" in original decision tree). Default=None %(log_extra_info)s %(custom_node_label)s @@ -1225,7 +1229,7 @@ def calc_varex_thresh( %(selector)s %(used_metrics)s """ - function_name_idx = f"Step {selector.current_node_idx}: calc_varex_thresh" + function_name_idx = f"Step {selector.current_node_idx_}: calc_varex_thresh" thresh_label = thresh_label.lower() if thresh_label is None or thresh_label == "": varex_name = "varex_thresh" @@ -1235,7 +1239,7 @@ def calc_varex_thresh( perc_name = f"{thresh_label}_perc" outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, varex_name: None, "num_highest_var_comps": num_highest_var_comps, @@ -1256,25 +1260,25 @@ def calc_varex_thresh( if only_used_metrics: return outputs["used_metrics"] - if varex_name in selector.cross_component_metrics: + if varex_name in selector.cross_component_metrics_: LGR.warning( f"{varex_name} already calculated. Overwriting previous value in {function_name_idx}" ) - if perc_name in selector.cross_component_metrics: + if perc_name in selector.cross_component_metrics_: LGR.warning( f"{perc_name} already calculated. Overwriting previous value in {function_name_idx}" ) - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if num_highest_var_comps is not None: if isinstance(num_highest_var_comps, str): - if num_highest_var_comps in selector.cross_component_metrics: - num_highest_var_comps = selector.cross_component_metrics[num_highest_var_comps] + if num_highest_var_comps in selector.cross_component_metrics_: + num_highest_var_comps = selector.cross_component_metrics_[num_highest_var_comps] elif not comps2use: # Note: It is possible the comps2use requested for this function # is not empty, but the comps2use requested to calculate @@ -1282,13 +1286,13 @@ def calc_varex_thresh( # used, that's unlikely, but worth a comment. LGR.info( f"{function_name_idx}: num_highest_var_comps ( {num_highest_var_comps}) " - "is not in selector.cross_component_metrics, but no components with " + "is not in selector.cross_component_metrics_, but no components with " f"{decide_comps} remain by this node so nothing happens" ) else: raise ValueError( f"{function_name_idx}: num_highest_var_comps ( {num_highest_var_comps}) " - "is not in selector.cross_component_metrics" + "is not in selector.cross_component_metrics_" ) if not isinstance(num_highest_var_comps, int) and comps2use: raise ValueError( @@ -1314,7 +1318,7 @@ def calc_varex_thresh( else: if num_highest_var_comps is None: outputs[varex_name] = scoreatpercentile( - selector.component_table.loc[comps2use, "variance explained"], percentile_thresh + selector.component_table_.loc[comps2use, "variance explained"], percentile_thresh ) else: # Using only the first num_highest_var_comps components sorted to include @@ -1333,17 +1337,19 @@ def calc_varex_thresh( num_highest_var_comps = len(comps2use) sorted_varex = np.flip( - np.sort((selector.component_table.loc[comps2use, "variance explained"]).to_numpy()) + np.sort( + (selector.component_table_.loc[comps2use, "variance explained"]).to_numpy() + ) ) outputs[varex_name] = scoreatpercentile( sorted_varex[:num_highest_var_comps], percentile_thresh ) - selector.cross_component_metrics[varex_name] = outputs[varex_name] + selector.cross_component_metrics_[varex_name] = outputs[varex_name] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1380,7 +1386,7 @@ def calc_extend_factor( """ outputs = { "used_metrics": set(), - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, "extend_factor": None, "calc_cross_comp_metrics": ["extend_factor"], @@ -1389,9 +1395,9 @@ def calc_extend_factor( if only_used_metrics: return outputs["used_metrics"] - function_name_idx = f"Step {selector.current_node_idx}: calc_extend_factor" + function_name_idx = f"Step {selector.current_node_idx_}: calc_extend_factor" - if "extend_factor" in selector.cross_component_metrics: + if "extend_factor" in selector.cross_component_metrics_: LGR.warning( f"extend_factor already calculated. Overwriting previous value in {function_name_idx}" ) @@ -1405,14 +1411,14 @@ def calc_extend_factor( LGR.info(f"{function_name_idx} {log_extra_info}") outputs["extend_factor"] = get_extend_factor( - n_vols=selector.cross_component_metrics["n_vols"], extend_factor=extend_factor + n_vols=selector.cross_component_metrics_["n_vols"], extend_factor=extend_factor ) - selector.cross_component_metrics["extend_factor"] = outputs["extend_factor"] + selector.cross_component_metrics_["extend_factor"] = outputs["extend_factor"] log_decision_tree_step(function_name_idx, -1, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1458,7 +1464,7 @@ def calc_max_good_meanmetricrank( earlier versions of this code. It might be worth consistently using the same term, but this note will hopefully suffice for now. """ - function_name_idx = f"Step {selector.current_node_idx}: calc_max_good_meanmetricrank" + function_name_idx = f"Step {selector.current_node_idx_}: calc_max_good_meanmetricrank" if (metric_suffix is not None) and (metric_suffix != "") and isinstance(metric_suffix, str): metric_name = f"max_good_meanmetricrank_{metric_suffix}" @@ -1466,7 +1472,7 @@ def calc_max_good_meanmetricrank( metric_name = "max_good_meanmetricrank" outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, metric_name: None, "used_metrics": set(), @@ -1476,7 +1482,7 @@ def calc_max_good_meanmetricrank( if only_used_metrics: return outputs["used_metrics"] - if metric_name in selector.cross_component_metrics: + if metric_name in selector.cross_component_metrics_: LGR.warning( "max_good_meanmetricrank already calculated." f"Overwriting previous value in {function_name_idx}" @@ -1490,9 +1496,9 @@ def calc_max_good_meanmetricrank( if log_extra_info: LGR.info(f"{function_name_idx} {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -1503,19 +1509,19 @@ def calc_max_good_meanmetricrank( ) else: num_prov_accept = len(comps2use) - if "extend_factor" in selector.cross_component_metrics: - extend_factor = selector.cross_component_metrics["extend_factor"] + if "extend_factor" in selector.cross_component_metrics_: + extend_factor = selector.cross_component_metrics_["extend_factor"] outputs[metric_name] = extend_factor * num_prov_accept else: raise ValueError( f"extend_factor needs to be in cross_component_metrics for {function_name_idx}" ) - selector.cross_component_metrics[metric_name] = outputs[metric_name] + selector.cross_component_metrics_[metric_name] = outputs[metric_name] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1557,10 +1563,10 @@ def calc_varex_kappa_ratio( This metric sometimes causes issues with high magnitude BOLD responses such as the V1 response to a block-design flashing checkerboard """ - function_name_idx = f"Step {selector.current_node_idx}: calc_varex_kappa_ratio" + function_name_idx = f"Step {selector.current_node_idx_}: calc_varex_kappa_ratio" outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, "kappa_rate": None, "used_metrics": {"kappa", "variance explained"}, @@ -1571,12 +1577,12 @@ def calc_varex_kappa_ratio( if only_used_metrics: return outputs["used_metrics"] - if "kappa_rate" in selector.cross_component_metrics: + if "kappa_rate" in selector.cross_component_metrics_: LGR.warning( f"kappa_rate already calculated. Overwriting previous value in {function_name_idx}" ) - if "varex kappa ratio" in selector.component_table: + if "varex kappa ratio" in selector.component_table_: raise ValueError( "'varex kappa ratio' is already a column in the component_table." f"Recalculating in {function_name_idx} can cause problems since these " @@ -1591,9 +1597,9 @@ def calc_varex_kappa_ratio( if log_extra_info: LGR.info(f"{function_name_idx}: {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -1604,11 +1610,11 @@ def calc_varex_kappa_ratio( ) else: kappa_rate = ( - np.nanmax(selector.component_table.loc[comps2use, "kappa"]) - - np.nanmin(selector.component_table.loc[comps2use, "kappa"]) + np.nanmax(selector.component_table_.loc[comps2use, "kappa"]) + - np.nanmin(selector.component_table_.loc[comps2use, "kappa"]) ) / ( - np.nanmax(selector.component_table.loc[comps2use, "variance explained"]) - - np.nanmin(selector.component_table.loc[comps2use, "variance explained"]) + np.nanmax(selector.component_table_.loc[comps2use, "variance explained"]) + - np.nanmin(selector.component_table_.loc[comps2use, "variance explained"]) ) outputs["kappa_rate"] = kappa_rate LGR.debug( @@ -1617,21 +1623,21 @@ def calc_varex_kappa_ratio( ) # NOTE: kappa_rate is calculated on a subset of components while # "varex kappa ratio" is calculated for all compnents - selector.component_table["varex kappa ratio"] = ( + selector.component_table_["varex kappa ratio"] = ( kappa_rate - * selector.component_table["variance explained"] - / selector.component_table["kappa"] + * selector.component_table_["variance explained"] + / selector.component_table_["kappa"] ) # Unclear if necessary, but this may clean up a weird issue on passing # references in a data frame. # See longer comment in selection_utils.comptable_classification_changer - selector.component_table = selector.component_table.copy() + selector.component_table_ = selector.component_table_.copy() - selector.cross_component_metrics["kappa_rate"] = outputs["kappa_rate"] + selector.cross_component_metrics_["kappa_rate"] = outputs["kappa_rate"] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector @@ -1685,10 +1691,10 @@ def calc_revised_meanmetricrank_guesses( accepted components calculated as the ratio of ``num_acc_guess`` to ``restrict_factor``. """ - function_name_idx = f"Step {selector.current_node_idx}: calc_revised_meanmetricrank_guesses" + function_name_idx = f"Step {selector.current_node_idx_}: calc_revised_meanmetricrank_guesses" outputs = { - "decision_node_idx": selector.current_node_idx, + "decision_node_idx": selector.current_node_idx_, "node_label": None, "num_acc_guess": None, "conservative_guess": None, @@ -1703,24 +1709,24 @@ def calc_revised_meanmetricrank_guesses( }, "used_cross_component_metrics": {"kappa_elbow_kundu", "rho_elbow_kundu"}, "calc_cross_comp_metrics": ["num_acc_guess", "conservative_guess", "restrict_factor"], - "added_component_table_metrics": [f"d_table_score_node{selector.current_node_idx}"], + "added_component_table_metrics": [f"d_table_score_node{selector.current_node_idx_}"], } if only_used_metrics: return outputs["used_metrics"] - if "num_acc_guess" in selector.cross_component_metrics: + if "num_acc_guess" in selector.cross_component_metrics_: LGR.warning( f"num_acc_guess already calculated. Overwriting previous value in {function_name_idx}" ) - if "conservative_guess" in selector.cross_component_metrics: + if "conservative_guess" in selector.cross_component_metrics_: LGR.warning( "conservative_guess already calculated. " f"Overwriting previous value in {function_name_idx}" ) - if "restrict_factor" in selector.cross_component_metrics: + if "restrict_factor" in selector.cross_component_metrics_: LGR.warning( "restrict_factor already calculated. " f"Overwriting previous value in {function_name_idx}" @@ -1728,24 +1734,24 @@ def calc_revised_meanmetricrank_guesses( if not isinstance(restrict_factor, (int, float)): raise ValueError(f"restrict_factor needs to be a number. It is: {restrict_factor}") - if f"d_table_score_node{selector.current_node_idx}" in selector.component_table: + if f"d_table_score_node{selector.current_node_idx_}" in selector.component_table_: raise ValueError( - f"d_table_score_node{selector.current_node_idx} is already a column" + f"d_table_score_node{selector.current_node_idx_} is already a column" f"in the component_table. Recalculating in {function_name_idx} can " "cause problems since these are only calculated on a subset of components" ) - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) for xcompmetric in outputs["used_cross_component_metrics"]: - if xcompmetric not in selector.cross_component_metrics: + if xcompmetric not in selector.cross_component_metrics_: if not comps2use: LGR.info( f"{function_name_idx}: {xcompmetric} is not in " - "selector.cross_component_metrics, but no components with " + "selector.cross_component_metrics_, but no components with " f"{decide_comps} remain by this node so nothing happens" ) else: @@ -1763,9 +1769,9 @@ def calc_revised_meanmetricrank_guesses( if log_extra_info: LGR.info(f"{function_name_idx}: {log_extra_info}") - comps2use = selectcomps2use(selector, decide_comps) + comps2use = selectcomps2use(selector.component_table_, decide_comps) confirm_metrics_exist( - selector.component_table, outputs["used_metrics"], function_name=function_name_idx + selector.component_table_, outputs["used_metrics"], function_name=function_name_idx ) if not comps2use: @@ -1781,46 +1787,48 @@ def calc_revised_meanmetricrank_guesses( [ np.sum( ( - selector.component_table.loc[comps2use, "kappa"] - > selector.cross_component_metrics["kappa_elbow_kundu"] + selector.component_table_.loc[comps2use, "kappa"] + > selector.cross_component_metrics_["kappa_elbow_kundu"] ) & ( - selector.component_table.loc[comps2use, "rho"] - < selector.cross_component_metrics["rho_elbow_kundu"] + selector.component_table_.loc[comps2use, "rho"] + < selector.cross_component_metrics_["rho_elbow_kundu"] ) ), np.sum( - selector.component_table.loc[comps2use, "kappa"] - > selector.cross_component_metrics["kappa_elbow_kundu"] + selector.component_table_.loc[comps2use, "kappa"] + > selector.cross_component_metrics_["kappa_elbow_kundu"] ), ] ) ) outputs["conservative_guess"] = outputs["num_acc_guess"] / outputs["restrict_factor"] - tmp_kappa = selector.component_table.loc[comps2use, "kappa"].to_numpy() - tmp_dice_ft2 = selector.component_table.loc[comps2use, "dice_FT2"].to_numpy() - tmp_signal_m_noise_t = selector.component_table.loc[comps2use, "signal-noise_t"].to_numpy() - tmp_countnoise = selector.component_table.loc[comps2use, "countnoise"].to_numpy() - tmp_countsig_ft2 = selector.component_table.loc[comps2use, "countsigFT2"].to_numpy() + tmp_kappa = selector.component_table_.loc[comps2use, "kappa"].to_numpy() + tmp_dice_ft2 = selector.component_table_.loc[comps2use, "dice_FT2"].to_numpy() + tmp_signal_m_noise_t = selector.component_table_.loc[ + comps2use, "signal-noise_t" + ].to_numpy() + tmp_countnoise = selector.component_table_.loc[comps2use, "countnoise"].to_numpy() + tmp_countsig_ft2 = selector.component_table_.loc[comps2use, "countsigFT2"].to_numpy() tmp_d_table_score = generate_decision_table_score( tmp_kappa, tmp_dice_ft2, tmp_signal_m_noise_t, tmp_countnoise, tmp_countsig_ft2 ) - selector.component_table[f"d_table_score_node{selector.current_node_idx}"] = np.NaN - selector.component_table.loc[ - comps2use, f"d_table_score_node{selector.current_node_idx}" + selector.component_table_[f"d_table_score_node{selector.current_node_idx_}"] = np.NaN + selector.component_table_.loc[ + comps2use, f"d_table_score_node{selector.current_node_idx_}" ] = tmp_d_table_score # Unclear if necessary, but this may clean up a weird issue on passing # references in a data frame. # See longer comment in selection_utils.comptable_classification_changer - selector.component_table = selector.component_table.copy() + selector.component_table_ = selector.component_table_.copy() - selector.cross_component_metrics["conservative_guess"] = outputs["conservative_guess"] - selector.cross_component_metrics["num_acc_guess"] = outputs["num_acc_guess"] - selector.cross_component_metrics["restrict_factor"] = outputs["restrict_factor"] + selector.cross_component_metrics_["conservative_guess"] = outputs["conservative_guess"] + selector.cross_component_metrics_["num_acc_guess"] = outputs["num_acc_guess"] + selector.cross_component_metrics_["restrict_factor"] = outputs["restrict_factor"] log_decision_tree_step(function_name_idx, comps2use, calc_outputs=outputs) - selector.tree["nodes"][selector.current_node_idx]["outputs"] = outputs + selector.tree["nodes"][selector.current_node_idx_]["outputs"] = outputs return selector diff --git a/tedana/selection/selection_utils.py b/tedana/selection/selection_utils.py index 4567926d8..aea51ad05 100644 --- a/tedana/selection/selection_utils.py +++ b/tedana/selection/selection_utils.py @@ -14,13 +14,13 @@ ############################################################## -def selectcomps2use(selector, decide_comps): +def selectcomps2use(component_table, decide_comps): """Get a list of component numbers that fit the classification types in ``decide_comps``. Parameters ---------- - selector : :obj:`~tedana.selection.component_selector.ComponentSelector` - Only uses the component_table in this object + component_table : :obj:`~pandas.DataFrame` + The component_table with metrics and labels for each ICA component decide_comps : :obj:`str` or :obj:`list[str]` or :obj:`list[int]` This is string or a list of strings describing what classifications of components to operate on, using default or intermediate_classification @@ -34,33 +34,31 @@ def selectcomps2use(selector, decide_comps): comps2use : :obj:`list[int]` A list of component indices with classifications included in decide_comps """ - if "classification" not in selector.component_table: - raise ValueError( - "selector.component_table needs a 'classification' column to run selectcomp2suse" - ) + if "classification" not in component_table: + raise ValueError("component_table needs a 'classification' column to run selectcomps2use") if isinstance(decide_comps, (str, int)): decide_comps = [decide_comps] if isinstance(decide_comps, list) and (decide_comps[0] == "all"): - # All components with any string in the classification field - # are set to True - comps2use = list(range(selector.component_table.shape[0])) + # All components with any string in the classification field are set to True + comps2use = list(range(component_table.shape[0])) elif isinstance(decide_comps, list) and all(isinstance(elem, str) for elem in decide_comps): comps2use = [] for didx in range(len(decide_comps)): - newcomps2use = selector.component_table.index[ - selector.component_table["classification"] == decide_comps[didx] + newcomps2use = component_table.index[ + component_table["classification"] == decide_comps[didx] ].tolist() comps2use = list(set(comps2use + newcomps2use)) + elif isinstance(decide_comps, list) and all(isinstance(elem, int) for elem in decide_comps): # decide_comps is already a list of indices - if len(selector.component_table) <= max(decide_comps): + if len(component_table) <= max(decide_comps): raise ValueError( "decide_comps for selectcomps2use is selecting for a component with index" f"{max(decide_comps)} (0 indexing) which is greater than the number " - f"of components: {len(selector.component_table)}" + f"of components: {len(component_table)}" ) elif min(decide_comps) < 0: raise ValueError( @@ -102,8 +100,8 @@ def change_comptable_classifications( Parameters ---------- selector : :obj:`tedana.selection.component_selector.ComponentSelector` - The attributes used are component_table, component_status_table, and - current_node_idx + The attributes used are ``component_table_``, ``component_status_table_``, and + ``current_node_idx_`` if_true, if_false : :obj:`str` If the condition in this step is true or false, give the component the label in this string. Options are 'accepted', 'rejected', @@ -125,12 +123,12 @@ def change_comptable_classifications( Returns ------- selector : :obj:`tedana.selection.component_selector.ComponentSelector` - component_table["classifications"] will reflect any new + ``component_table_["classifications"]`` will reflect any new classifications. - component_status_table will have a new column titled - "Node current_node_idx" that is a copy of the updated classifications + ``component_status_table_`` will have a new column titled + "Node ``current_node_idx_``" that is a copy of the updated classifications column. - component_table["classification_tags"] will be updated to include any + ``component_table_["classification_tags"]`` will be updated to include any new tags. Each tag should appear only once in the string and tags will be separated by commas. n_true, n_false : :obj:`int` @@ -158,8 +156,8 @@ def change_comptable_classifications( dont_warn_reclassify=dont_warn_reclassify, ) - selector.component_status_table[f"Node {selector.current_node_idx}"] = ( - selector.component_table["classification"] + selector.component_status_table_[f"Node {selector.current_node_idx_}"] = ( + selector.component_table_["classification"] ) n_true = decision_boolean.sum() @@ -180,8 +178,8 @@ def comptable_classification_changer( Parameters ---------- selector : :obj:`tedana.selection.component_selector.ComponentSelector` - The attributes used are component_table, component_status_table, and - current_node_idx + The attributes used are ``component_table_``, ``component_status_table_``, and + ``current_node_idx_`` boolstate : :obj:`bool` Change classifications only for True or False components in decision_boolean based on this variable @@ -209,12 +207,12 @@ def comptable_classification_changer( ------- selector : :obj:`tedana.selection.component_selector.ComponentSelector` Operates on the True OR False components depending on boolstate - component_table["classifications"] will reflect any new + ``component_table_["classifications"]`` will reflect any new classifications. - component_status_table will have a new column titled - "Node current_node_idx" that is a copy of the updated classifications + ``component_status_table_`` will have a new column titled + "Node ``current_node_idx_``" that is a copy of the updated classifications column. - component_table["classification_tags"] will be updated to include any + component_table_["classification_tags"] will be updated to include any new tags. Each tag should appear only once in the string and tags will be separated by commas. @@ -235,7 +233,7 @@ def comptable_classification_changer( changeidx = decision_boolean.index[np.asarray(decision_boolean) == boolstate] if not changeidx.empty: current_classifications = set( - selector.component_table.loc[changeidx, "classification"].tolist() + selector.component_table_.loc[changeidx, "classification"].tolist() ) if current_classifications.intersection({"accepted", "rejected"}): if not dont_warn_reclassify: @@ -245,11 +243,11 @@ def comptable_classification_changer( ("accepted" in current_classifications) and (classify_if != "accepted") ) or (("rejected" in current_classifications) and (classify_if != "rejected")): LGR.warning( - f"Step {selector.current_node_idx}: Some classifications are" + f"Step {selector.current_node_idx_}: Some classifications are" " changing away from accepted or rejected. Once a component is " "accepted or rejected, it shouldn't be reclassified" ) - selector.component_table.loc[changeidx, "classification"] = classify_if + selector.component_table_.loc[changeidx, "classification"] = classify_if # NOTE: CAUTION: extremely bizarre pandas behavior violates guarantee # that df['COLUMN'] matches the df as a a whole in this case. # We cannot replicate this consistently, but it seems to happen in some @@ -262,22 +260,22 @@ def comptable_classification_changer( # Comment line below to re-introduce original bug. For the kundu decision # tree it happens on node 6 which is the first time decide_comps is for # a subset of components - selector.component_table = selector.component_table.copy() + selector.component_table_ = selector.component_table_.copy() if tag_if is not None: # only run if a tag is provided for idx in changeidx: - tmpstr = selector.component_table.loc[idx, "classification_tags"] + tmpstr = selector.component_table_.loc[idx, "classification_tags"] if tmpstr == "" or isinstance(tmpstr, float): tmpset = {tag_if} else: tmpset = set(tmpstr.split(",")) tmpset.update([tag_if]) - selector.component_table.loc[idx, "classification_tags"] = ",".join( + selector.component_table_.loc[idx, "classification_tags"] = ",".join( str(s) for s in tmpset ) else: LGR.info( - f"Step {selector.current_node_idx}: No components fit criterion " + f"Step {selector.current_node_idx_}: No components fit criterion " f"{boolstate} to change classification" ) return selector @@ -316,50 +314,37 @@ def clean_dataframe(component_table): def confirm_metrics_exist(component_table, necessary_metrics, function_name=None): - """ - Confirm that all metrics declared in necessary_metrics are included in comptable. + """Confirm that all metrics declared in necessary_metrics are already included in comptable. Parameters ---------- component_table : (C x M) :obj:`pandas.DataFrame` Component metric table. One row for each component, with a column for each metric. The index should be the component number. - necessary_metrics : :obj:`set` - A set of strings of metric names + necessary_metrics : :obj:`list` + A list of strings of metric names. function_name : :obj:`str` - Text identifying the function name that called this function - - Returns - ------- - metrics_exist : :obj:`bool` - True if all metrics in necessary_metrics are in component_table + Text identifying the function name that called this function. Raises ------ ValueError - If metrics_exist is False then raise an error and end the program + If ``metrics_exist`` is False then raise an error and end the program. - Note - ---- - This doesn't check if there are data in each metric's column, just that - the columns exist. Also, the string in `necessary_metrics` and the - column labels in component_table will only be matched if they're identical. + Notes + ----- + This doesn't check if there are data in each metric's column, just that the columns exist. + Also, the string in ``necessary_metrics`` and the column labels in ``component_table`` will + only be matched if they're identical. """ - missing_metrics = necessary_metrics - set(component_table.columns) - metrics_exist = len(missing_metrics) > 0 - if metrics_exist is True: - if function_name is None: - function_name = "unknown function" - - error_msg = ( - f"Necessary metrics for {function_name}: " - f"{necessary_metrics}. " + missing_metrics = set(necessary_metrics) - set(component_table.columns) + if missing_metrics: + function_name = function_name or "unknown function" + raise ValueError( + f"Necessary metrics for {function_name}: {necessary_metrics}. " f"Comptable metrics: {set(component_table.columns)}. " f"MISSING METRICS: {missing_metrics}." ) - raise ValueError(error_msg) - - return metrics_exist def log_decision_tree_step( @@ -378,7 +363,7 @@ def log_decision_tree_step( ---------- function_name_idx : :obj:`str` The name of the function that should be logged. By convention, this - be "Step current_node_idx: function_name" + be "Step ``current_node_idx_``: function_name" comps2use : :obj:`list[int]` or -1 A list of component indices that should be used by a function. Only used to report no components found if empty and report diff --git a/tedana/selection/tedica.py b/tedana/selection/tedica.py index 098099c06..d757a6fc6 100644 --- a/tedana/selection/tedica.py +++ b/tedana/selection/tedica.py @@ -3,13 +3,12 @@ import logging from tedana.metrics import collect -from tedana.selection.component_selector import ComponentSelector LGR = logging.getLogger("GENERAL") RepLGR = logging.getLogger("REPORT") -def automatic_selection(component_table, n_echos, n_vols, tree="kundu"): +def automatic_selection(component_table, selector, **kwargs): """Classify components based on component table and decision tree type. Parameters @@ -20,8 +19,6 @@ def automatic_selection(component_table, n_echos, n_vols, tree="kundu"): The number of echoes in this dataset tree : :obj:`str` The type of tree to use for the ComponentSelector object. Default="kundu" - verbose : :obj:`bool` - More verbose logging output if True. Default=False Returns ------- @@ -62,12 +59,7 @@ def automatic_selection(component_table, n_echos, n_vols, tree="kundu"): ) component_table["classification_tags"] = "" - xcomp = { - "n_echos": n_echos, - "n_vols": n_vols, - } - selector = ComponentSelector(tree, component_table, cross_component_metrics=xcomp) - selector.select() - selector.metadata = collect.get_metadata(selector.component_table) + selector.select(component_table, cross_component_metrics=kwargs) + selector.metadata_ = collect.get_metadata(selector.component_table_) return selector diff --git a/tedana/tests/test_component_selector.py b/tedana/tests/test_component_selector.py index 294fed508..f191df39f 100644 --- a/tedana/tests/test_component_selector.py +++ b/tedana/tests/test_component_selector.py @@ -38,12 +38,12 @@ def dicts_to_test(treechoice): "missing_req_param": A missing required param in a decision node function "missing_function": An undefined decision node function "missing_key": A dict missing one of the required keys (report) + "null_value": A parameter in one node improperly has a null value Returns ------- tree : :ojb:`dict` A dict that can be input into component_selector.validate_tree """ - # valid_dict is a simple valid dictionary to test # It includes a few things that should trigger warnings, but not errors. valid_dict = { @@ -174,21 +174,13 @@ def test_minimal(): xcomp = { "n_echos": 3, } - selector = component_selector.ComponentSelector( - "minimal", - sample_comptable(), - cross_component_metrics=xcomp.copy(), - ) - selector.select() + selector = component_selector.ComponentSelector(tree="minimal") + selector.select(component_table=sample_comptable(), cross_component_metrics=xcomp.copy()) # rerun without classification_tags column initialized - selector = component_selector.ComponentSelector( - "minimal", - sample_comptable(), - cross_component_metrics=xcomp.copy(), - ) - selector.component_table = selector.component_table.drop(columns="classification_tags") - selector.select() + selector = component_selector.ComponentSelector(tree="minimal") + temp_comptable = sample_comptable().drop(columns="classification_tags") + selector.select(component_table=temp_comptable, cross_component_metrics=xcomp.copy()) # validate_tree @@ -262,7 +254,7 @@ def test_validate_tree_fails(): def test_check_null_fails(): """Tests to trigger check_null missing parameter error.""" - selector = component_selector.ComponentSelector("minimal", sample_comptable()) + selector = component_selector.ComponentSelector(tree="minimal") selector.tree = dicts_to_test("null_value") params = selector.tree["nodes"][0]["parameters"] @@ -273,18 +265,15 @@ def test_check_null_fails(): def test_check_null_succeeds(): """Tests check_null finds empty parameter in self.""" + selector = component_selector.ComponentSelector(tree="minimal") + selector.tree = dicts_to_test("null_value") # "left" is missing from the function definition in node # but is found as an initialized cross component metric - xcomp = { + # so this should execute successfully + selector.cross_component_metrics_ = { "left": 3, } - selector = component_selector.ComponentSelector( - "minimal", - sample_comptable(), - cross_component_metrics=xcomp, - ) - selector.tree = dicts_to_test("null_value") params = selector.tree["nodes"][0]["parameters"] functionname = selector.tree["nodes"][0]["functionname"] @@ -293,8 +282,8 @@ def test_check_null_succeeds(): def test_are_only_necessary_metrics_used_warning(): """Tests a warning that wasn't triggered in other test workflows.""" - - selector = component_selector.ComponentSelector("minimal", sample_comptable()) + selector = component_selector.ComponentSelector(tree="minimal") + # selector.select(component_table=sample_comptable()) # warning when an element of necessary_metrics was not in used_metrics selector.tree["used_metrics"] = {"A", "B", "C"} @@ -304,23 +293,27 @@ def test_are_only_necessary_metrics_used_warning(): def test_are_all_components_accepted_or_rejected(): """Tests warnings are triggered in are_all_components_accepted_or_rejected.""" - - selector = component_selector.ComponentSelector("minimal", sample_comptable()) - selector.component_table.loc[7, "classification"] = "intermediate1" - selector.component_table.loc[[1, 3, 5], "classification"] = "intermediate2" + selector = component_selector.ComponentSelector(tree="minimal") + selector.select(component_table=sample_comptable(), cross_component_metrics={"n_echos": 3}) + selector.component_table_.loc[7, "classification"] = "intermediate1" + selector.component_table_.loc[[1, 3, 5], "classification"] = "intermediate2" selector.are_all_components_accepted_or_rejected() def test_selector_properties_smoke(): """Tests to confirm properties match expected results.""" - selector = component_selector.ComponentSelector("minimal", sample_comptable()) + # Runs on un-executed component table to smoke test three class + # functions that are used to count various types of component + # classifications in the component table + selector = component_selector.ComponentSelector(tree="minimal") + selector.component_table_ = sample_comptable() - assert selector.n_comps == 21 + assert selector.n_comps_ == 21 - # Also runs selector.likely_bold_comps and should need to deal with sets in each field - assert selector.n_likely_bold_comps == 17 + # Also runs selector.likely_bold_comps_ and should need to deal with sets in each field + assert selector.n_likely_bold_comps_ == 17 - assert selector.n_accepted_comps == 17 + assert selector.n_accepted_comps_ == 17 - assert selector.rejected_comps.sum() == 4 + assert selector.rejected_comps_.sum() == 4 diff --git a/tedana/tests/test_selection_nodes.py b/tedana/tests/test_selection_nodes.py index b0a7a78c9..e978edc2f 100644 --- a/tedana/tests/test_selection_nodes.py +++ b/tedana/tests/test_selection_nodes.py @@ -36,19 +36,19 @@ def test_manual_classify_smoke(): ) # There should be 4 selected components and component_status_table should # have a new column "Node 0" - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 4 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 0 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 4 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 0 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # No components with "NotALabel" classification so nothing selected and no # Node 1 column not created in component_status_table - selector.current_node_idx = 1 + selector.current_node_idx_ = 1 selector = selection_nodes.manual_classify(selector, "NotAClassification", new_classification) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # Changing components from "rejected" to "accepted" and suppressing warning - selector.current_node_idx = 2 + selector.current_node_idx_ = 2 selector = selection_nodes.manual_classify( selector, "rejected", @@ -58,8 +58,8 @@ def test_manual_classify_smoke(): tag="test tag", dont_warn_reclassify=True, ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 4 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 4 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ def test_dec_left_op_right_succeeds(): @@ -96,13 +96,13 @@ def test_dec_left_op_right_succeeds(): ) # scales are set to make sure 3 components are true and 1 is false using # the sample component table - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 3 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 1 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 3 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 1 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # No components with "NotALabel" classification so nothing selected and no # Node 1 column is created in component_status_table - selector.current_node_idx = 1 + selector.current_node_idx_ = 1 selector = selection_nodes.dec_left_op_right( selector, "accepted", @@ -112,8 +112,8 @@ def test_dec_left_op_right_succeeds(): "kappa", "rho", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # Re-initializing selector so that it has components classificated as # "provisional accept" again @@ -128,14 +128,14 @@ def test_dec_left_op_right_succeeds(): "kappa", "test_elbow", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 3 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 1 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 3 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 1 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # right is a component_table_metric, left is a cross_component_metric # left also has a left_scale that's a cross component metric selector = sample_selector(options="provclass") - selector.cross_component_metrics["new_cc_metric"] = 1.02 + selector.cross_component_metrics_["new_cc_metric"] = 1.02 selector = selection_nodes.dec_left_op_right( selector, "accepted", @@ -146,9 +146,9 @@ def test_dec_left_op_right_succeeds(): "kappa", left_scale="new_cc_metric", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 1 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 3 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 1 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 3 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # left component_table_metric, right is a constant integer value selector = sample_selector(options="provclass") @@ -161,9 +161,9 @@ def test_dec_left_op_right_succeeds(): "kappa", 21, ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 3 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 1 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 3 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 1 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # right component_table_metric, left is a constant float value selector = sample_selector(options="provclass") @@ -176,9 +176,9 @@ def test_dec_left_op_right_succeeds(): 21.0, "kappa", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 1 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 3 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 1 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 3 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # Testing combination of two statements. kappa>21 AND rho<14 selector = sample_selector(options="provclass") @@ -194,9 +194,9 @@ def test_dec_left_op_right_succeeds(): op2="<", right2=14, ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 2 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 2 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 2 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 2 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # Testing combination of three statements. kappa>21 AND rho<14 AND 'variance explained'<5 selector = sample_selector(options="provclass") @@ -215,9 +215,9 @@ def test_dec_left_op_right_succeeds(): op3="<", right3=5, ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 1 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 3 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 1 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 3 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ def test_dec_left_op_right_fails(): @@ -389,27 +389,27 @@ def test_dec_variance_lessthan_thresholds_smoke(): tag_if_true="test true tag", tag_if_false="test false tag", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 1 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 3 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 1 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 3 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # No components with "NotALabel" classification so nothing selected and no # Node 1 column not created in component_status_table - selector.current_node_idx = 1 + selector.current_node_idx_ = 1 selector = selection_nodes.dec_variance_lessthan_thresholds( selector, "accepted", "rejected", "NotAClassification" ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # Running without specifying logging text generates internal text selector = sample_selector(options="provclass") selector = selection_nodes.dec_variance_lessthan_thresholds( selector, "accepted", "rejected", decide_comps ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 4 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 4 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ def test_calc_kappa_elbow(): @@ -436,14 +436,16 @@ def test_calc_kappa_elbow(): "varex_upper_p", } output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_elbow_kundu"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_allcomps_elbow"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_nonsig_elbow"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_p"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_elbow_kundu"] > 0 + assert ( + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_allcomps_elbow"] > 0 + ) + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_nonsig_elbow"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_p"] > 0 # Using a subset of components for decide_comps. selector = selection_nodes.calc_kappa_elbow( @@ -459,14 +461,16 @@ def test_calc_kappa_elbow(): "varex_upper_p", } output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_elbow_kundu"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_allcomps_elbow"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_nonsig_elbow"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_p"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_elbow_kundu"] > 0 + assert ( + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_allcomps_elbow"] > 0 + ) + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_nonsig_elbow"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_p"] > 0 # No components with "NotALabel" classification so nothing selected selector = sample_selector() @@ -475,16 +479,16 @@ def test_calc_kappa_elbow(): # Outputs just the metrics used in this function selector = selection_nodes.calc_kappa_elbow(selector, decide_comps) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_elbow_kundu"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_elbow_kundu"] is None ) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_allcomps_elbow"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_allcomps_elbow"] is None ) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_nonsig_elbow"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_nonsig_elbow"] is None ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_p"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_p"] is None def test_calc_rho_elbow(): @@ -511,16 +515,16 @@ def test_calc_rho_elbow(): "elbow_f05", } output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_elbow_kundu"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_allcomps_elbow"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_elbow_kundu"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_allcomps_elbow"] > 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_unclassified_elbow"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_unclassified_elbow"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["elbow_f05"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["elbow_f05"] > 0 # Standard call to this function using rho_elbow_type="liberal" selector = selection_nodes.calc_rho_elbow( @@ -537,16 +541,16 @@ def test_calc_rho_elbow(): "elbow_f05", } output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_elbow_liberal"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_allcomps_elbow"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_elbow_liberal"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_allcomps_elbow"] > 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_unclassified_elbow"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_unclassified_elbow"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["elbow_f05"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["elbow_f05"] > 0 # Using a subset of components for decide_comps. selector = selection_nodes.calc_rho_elbow( @@ -562,16 +566,16 @@ def test_calc_rho_elbow(): "elbow_f05", } output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_elbow_kundu"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_allcomps_elbow"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_elbow_kundu"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_allcomps_elbow"] > 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_unclassified_elbow"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_unclassified_elbow"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["elbow_f05"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["elbow_f05"] > 0 with pytest.raises(ValueError): selection_nodes.calc_rho_elbow(selector, decide_comps, rho_elbow_type="perfect") @@ -582,15 +586,15 @@ def test_calc_rho_elbow(): # Outputs just the metrics used in this function selector = selection_nodes.calc_rho_elbow(selector, decide_comps) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_elbow_kundu"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_elbow_kundu"] is None assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_allcomps_elbow"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_allcomps_elbow"] is None ) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["rho_unclassified_elbow"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["rho_unclassified_elbow"] is None ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["elbow_f05"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["elbow_f05"] is None def test_calc_median_smoke(): @@ -620,11 +624,11 @@ def test_calc_median_smoke(): ) calc_cross_comp_metrics = {"median_varex"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["median_varex"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["median_varex"] > 0 # repeating standard call and should make a warning because metric_varex already exists selector = selection_nodes.calc_median( @@ -632,7 +636,7 @@ def test_calc_median_smoke(): ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["median_varex"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["median_varex"] > 0 # Log without running if no components of decide_comps are in the component table selector = sample_selector() @@ -642,7 +646,7 @@ def test_calc_median_smoke(): metric_name="variance explained", median_label="varex", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["median_varex"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["median_varex"] is None # Crashes because median_label is not a string with pytest.raises(ValueError): @@ -694,26 +698,26 @@ def test_dec_classification_doesnt_exist_smoke(): custom_node_label="custom label", tag="test true tag", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 # Lists the number of components in decide_comps in n_false - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 17 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 17 # During normal execution, it will find provionally accepted components # and do nothing so another node isn't created - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # No components with "NotALabel" classification so nothing selected and no # Node 1 column not created in component_status_table # Running without specifying logging text generates internal text - selector.current_node_idx = 1 + selector.current_node_idx_ = 1 selector = selection_nodes.dec_classification_doesnt_exist( selector, "accepted", "NotAClassification", class_comp_exists="provisional accept", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # Other normal state is to change classifications when there are # no components with class_comp_exists. Since the component_table @@ -728,9 +732,9 @@ def test_dec_classification_doesnt_exist_smoke(): class_comp_exists="provisional reject", tag="test true tag", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 17 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 0 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 17 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 0 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # Standard execution with at_least_num_exist=5 which should trigger the # components don't exist output @@ -745,10 +749,10 @@ def test_dec_classification_doesnt_exist_smoke(): custom_node_label="custom label", tag="test true tag", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 17 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 17 # Lists the number of components in decide_comps in n_false - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 0 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 0 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ def test_dec_reclassify_high_var_comps(): @@ -782,12 +786,12 @@ def test_dec_reclassify_high_var_comps(): "unclass_highvar", "NotAClassification", ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ # Add varex_upper_p to cross component_metrics to run normal test selector = sample_selector(options="unclass") - selector.cross_component_metrics["varex_upper_p"] = 0.97 + selector.cross_component_metrics_["varex_upper_p"] = 0.97 # Standard execution where with all extra logging code and options changed from defaults selection_nodes.dec_reclassify_high_var_comps( @@ -799,18 +803,18 @@ def test_dec_reclassify_high_var_comps(): tag="test true tag", ) # Lists the number of components in decide_comps in n_true or n_false - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 3 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_false"] == 10 - assert f"Node {selector.current_node_idx}" in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 3 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_false"] == 10 + assert f"Node {selector.current_node_idx_}" in selector.component_status_table_ # No components with "NotALabel" classification so nothing selected and no # Node 1 column is created in component_status_table - selector.current_node_idx = 1 + selector.current_node_idx_ = 1 selector = selection_nodes.dec_reclassify_high_var_comps( selector, "unclass_highvar", "NotAClassification" ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["n_true"] == 0 - assert f"Node {selector.current_node_idx}" not in selector.component_status_table + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["n_true"] == 0 + assert f"Node {selector.current_node_idx_}" not in selector.component_status_table_ def test_calc_varex_thresh_smoke(): @@ -837,12 +841,12 @@ def test_calc_varex_thresh_smoke(): ) calc_cross_comp_metrics = {"varex_upper_thresh", "upper_perc"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_thresh"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["upper_perc"] == 90 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_thresh"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["upper_perc"] == 90 # Standard call , but thresh_label is "" selector = selection_nodes.calc_varex_thresh( @@ -855,12 +859,12 @@ def test_calc_varex_thresh_smoke(): ) calc_cross_comp_metrics = {"varex_thresh", "perc"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_thresh"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["perc"] == 90 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_thresh"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["perc"] == 90 # Standard call using num_highest_var_comps as an integer selector = selection_nodes.calc_varex_thresh( @@ -872,17 +876,17 @@ def test_calc_varex_thresh_smoke(): ) calc_cross_comp_metrics = {"varex_new_lower_thresh", "new_lower_perc"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_new_lower_thresh"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_new_lower_thresh"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["new_lower_perc"] == 25 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["new_lower_perc"] == 25 # Standard call using num_highest_var_comps as a value in cross_component_metrics - selector.cross_component_metrics["num_acc_guess"] = 10 + selector.cross_component_metrics_["num_acc_guess"] = 10 selector = selection_nodes.calc_varex_thresh( selector, decide_comps, @@ -892,14 +896,14 @@ def test_calc_varex_thresh_smoke(): ) calc_cross_comp_metrics = {"varex_new_lower_thresh", "new_lower_perc"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_new_lower_thresh"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_new_lower_thresh"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["new_lower_perc"] == 25 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["new_lower_perc"] == 25 # Raise error if num_highest_var_comps is a string, but not in cross_component_metrics with pytest.raises(ValueError): @@ -921,11 +925,11 @@ def test_calc_varex_thresh_smoke(): num_highest_var_comps="NotACrossCompMetric", ) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_new_lower_thresh"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_new_lower_thresh"] is None ) # percentile_thresh doesn't depend on components and is assigned - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["new_lower_perc"] == 25 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["new_lower_perc"] == 25 # Raise error if num_highest_var_comps is not an integer with pytest.raises(ValueError): @@ -950,20 +954,20 @@ def test_calc_varex_thresh_smoke(): ) calc_cross_comp_metrics = {"varex_new_lower_thresh", "new_lower_perc"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_new_lower_thresh"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_new_lower_thresh"] > 0 ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["new_lower_perc"] == 25 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["new_lower_perc"] == 25 # Run warning logging code to see if any of the cross_component_metrics # already exists and would be over-written selector = sample_selector(options="provclass") - selector.cross_component_metrics["varex_upper_thresh"] = 1 - selector.cross_component_metrics["upper_perc"] = 1 + selector.cross_component_metrics_["varex_upper_thresh"] = 1 + selector.cross_component_metrics_["upper_perc"] = 1 decide_comps = "provisional accept" selector = selection_nodes.calc_varex_thresh( selector, @@ -974,8 +978,8 @@ def test_calc_varex_thresh_smoke(): custom_node_label="custom label", ) assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_thresh"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["upper_perc"] == 90 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_thresh"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["upper_perc"] == 90 # Raise error if percentile_thresh isn't a number selector = sample_selector(options="provclass") @@ -997,10 +1001,10 @@ def test_calc_varex_thresh_smoke(): selector, decide_comps="NotAClassification", thresh_label="upper", percentile_thresh=90 ) assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["varex_upper_thresh"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["varex_upper_thresh"] is None ) # percentile_thresh doesn't depend on components and is assigned - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["upper_perc"] == 90 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["upper_perc"] == 90 def test_calc_extend_factor_smoke(): @@ -1020,27 +1024,27 @@ def test_calc_extend_factor_smoke(): ) calc_cross_comp_metrics = {"extend_factor"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["extend_factor"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["extend_factor"] > 0 # Run warning logging code for if any of the cross_component_metrics # already existed and would be over-written selector = sample_selector() - selector.cross_component_metrics["extend_factor"] = 1.0 + selector.cross_component_metrics_["extend_factor"] = 1.0 selector = selection_nodes.calc_extend_factor(selector) assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["extend_factor"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["extend_factor"] > 0 # Run with extend_factor defined as an input selector = sample_selector() selector = selection_nodes.calc_extend_factor(selector, extend_factor=1.2) assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["extend_factor"] == 1.2 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["extend_factor"] == 1.2 def test_calc_max_good_meanmetricrank_smoke(): @@ -1049,7 +1053,7 @@ def test_calc_max_good_meanmetricrank_smoke(): # Standard use of this function requires some components to be "provisional accept" selector = sample_selector("provclass") # This function requires "extend_factor" to already be defined - selector.cross_component_metrics["extend_factor"] = 2.0 + selector.cross_component_metrics_["extend_factor"] = 2.0 # Outputs just the metrics used in this function {""} used_metrics = selection_nodes.calc_max_good_meanmetricrank( @@ -1066,28 +1070,29 @@ def test_calc_max_good_meanmetricrank_smoke(): ) calc_cross_comp_metrics = {"max_good_meanmetricrank"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["max_good_meanmetricrank"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["max_good_meanmetricrank"] + > 0 ) # Standard call to this function with a user defined metric_suffix selector = sample_selector("provclass") - selector.cross_component_metrics["extend_factor"] = 2.0 + selector.cross_component_metrics_["extend_factor"] = 2.0 selector = selection_nodes.calc_max_good_meanmetricrank( selector, "provisional accept", metric_suffix="testsfx" ) calc_cross_comp_metrics = {"max_good_meanmetricrank_testsfx"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"][ + selector.tree["nodes"][selector.current_node_idx_]["outputs"][ "max_good_meanmetricrank_testsfx" ] > 0 @@ -1096,17 +1101,18 @@ def test_calc_max_good_meanmetricrank_smoke(): # Run warning logging code for if any of the cross_component_metrics # already existed and would be over-written selector = sample_selector("provclass") - selector.cross_component_metrics["max_good_meanmetricrank"] = 10 - selector.cross_component_metrics["extend_factor"] = 2.0 + selector.cross_component_metrics_["max_good_meanmetricrank"] = 10 + selector.cross_component_metrics_["extend_factor"] = 2.0 selector = selection_nodes.calc_max_good_meanmetricrank(selector, "provisional accept") calc_cross_comp_metrics = {"max_good_meanmetricrank"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["max_good_meanmetricrank"] > 0 + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["max_good_meanmetricrank"] + > 0 ) # Raise an error if "extend_factor" isn't pre-defined @@ -1116,11 +1122,11 @@ def test_calc_max_good_meanmetricrank_smoke(): # Log without running if no components of decide_comps are in the component table selector = sample_selector() - selector.cross_component_metrics["extend_factor"] = 2.0 + selector.cross_component_metrics_["extend_factor"] = 2.0 selector = selection_nodes.calc_max_good_meanmetricrank(selector, "NotAClassification") assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["max_good_meanmetricrank"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["max_good_meanmetricrank"] is None ) @@ -1146,29 +1152,29 @@ def test_calc_varex_kappa_ratio_smoke(): ) calc_cross_comp_metrics = {"kappa_rate"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_rate"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_rate"] > 0 # Run warning logging code for if any of the cross_component_metrics # already existed and would be over-written selector = sample_selector("provclass") - selector.cross_component_metrics["kappa_rate"] = 10 + selector.cross_component_metrics_["kappa_rate"] = 10 selector = selection_nodes.calc_varex_kappa_ratio(selector, "provisional accept") assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_rate"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_rate"] > 0 # Log without running if no components of decide_comps are in the component table selector = sample_selector() selector = selection_nodes.calc_varex_kappa_ratio(selector, "NotAClassification") - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["kappa_rate"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["kappa_rate"] is None # Raise error if "varex kappa ratio" is already in component_table selector = sample_selector("provclass") - selector.component_table["varex kappa ratio"] = 42 + selector.component_table_["varex kappa ratio"] = 42 with pytest.raises(ValueError): selector = selection_nodes.calc_varex_kappa_ratio(selector, "provisional accept") @@ -1178,8 +1184,8 @@ def test_calc_revised_meanmetricrank_guesses_smoke(): # Standard use of this function requires some components to be "provisional accept" selector = sample_selector("provclass") - selector.cross_component_metrics["kappa_elbow_kundu"] = 19.1 - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["kappa_elbow_kundu"] = 19.1 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 # Outputs just the metrics used in this function {""} used_metrics = selection_nodes.calc_revised_meanmetricrank_guesses( @@ -1205,46 +1211,46 @@ def test_calc_revised_meanmetricrank_guesses_smoke(): ) calc_cross_comp_metrics = {"num_acc_guess", "conservative_guess", "restrict_factor"} output_calc_cross_comp_metrics = set( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["calc_cross_comp_metrics"] + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["calc_cross_comp_metrics"] ) # Confirming the intended metrics are added to outputs and they have non-zero values assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["num_acc_guess"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["conservative_guess"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["restrict_factor"] == 2 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["num_acc_guess"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["conservative_guess"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["restrict_factor"] == 2 # Run warning logging code for if any of the cross_component_metrics # already existed and would be over-written selector = sample_selector("provclass") - selector.cross_component_metrics["kappa_elbow_kundu"] = 19.1 - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 - selector.cross_component_metrics["num_acc_guess"] = 10 - selector.cross_component_metrics["conservative_guess"] = 10 - selector.cross_component_metrics["restrict_factor"] = 5 + selector.cross_component_metrics_["kappa_elbow_kundu"] = 19.1 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["num_acc_guess"] = 10 + selector.cross_component_metrics_["conservative_guess"] = 10 + selector.cross_component_metrics_["restrict_factor"] = 5 selector = selection_nodes.calc_revised_meanmetricrank_guesses( selector, ["provisional accept", "provisional reject", "unclassified"] ) assert len(output_calc_cross_comp_metrics - calc_cross_comp_metrics) == 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["num_acc_guess"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["conservative_guess"] > 0 - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["restrict_factor"] == 2 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["num_acc_guess"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["conservative_guess"] > 0 + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["restrict_factor"] == 2 # Log without running if no components of decide_comps are in the component table selector = sample_selector() - selector.cross_component_metrics["kappa_elbow_kundu"] = 19.1 - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["kappa_elbow_kundu"] = 19.1 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 selector = selection_nodes.calc_revised_meanmetricrank_guesses(selector, "NotAClassification") - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["num_acc_guess"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["num_acc_guess"] is None assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["conservative_guess"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["conservative_guess"] is None ) # Raise error if "d_table_score_node0" is already in component_table selector = sample_selector("provclass") - selector.cross_component_metrics["kappa_elbow_kundu"] = 19.1 - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 - selector.component_table["d_table_score_node0"] = 42 + selector.cross_component_metrics_["kappa_elbow_kundu"] = 19.1 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 + selector.component_table_["d_table_score_node0"] = 42 with pytest.raises(ValueError): selector = selection_nodes.calc_revised_meanmetricrank_guesses( selector, ["provisional accept", "provisional reject", "unclassified"] @@ -1252,8 +1258,8 @@ def test_calc_revised_meanmetricrank_guesses_smoke(): # Raise error if restrict_factor isn't a number selector = sample_selector("provclass") - selector.cross_component_metrics["kappa_elbow_kundu"] = 19.1 - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["kappa_elbow_kundu"] = 19.1 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 with pytest.raises(ValueError): selector = selection_nodes.calc_revised_meanmetricrank_guesses( selector, @@ -1263,7 +1269,7 @@ def test_calc_revised_meanmetricrank_guesses_smoke(): # Raise error if kappa_elbow_kundu isn't in cross_component_metrics selector = sample_selector("provclass") - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 with pytest.raises(ValueError): selector = selection_nodes.calc_revised_meanmetricrank_guesses( selector, ["provisional accept", "provisional reject", "unclassified"] @@ -1272,12 +1278,12 @@ def test_calc_revised_meanmetricrank_guesses_smoke(): # Do not raise error if kappa_elbow_kundu isn't in cross_component_metrics # and there are no components in decide_comps selector = sample_selector("provclass") - selector.cross_component_metrics["rho_elbow_kundu"] = 15.2 + selector.cross_component_metrics_["rho_elbow_kundu"] = 15.2 selector = selection_nodes.calc_revised_meanmetricrank_guesses( selector, decide_comps="NoComponents" ) - assert selector.tree["nodes"][selector.current_node_idx]["outputs"]["num_acc_guess"] is None + assert selector.tree["nodes"][selector.current_node_idx_]["outputs"]["num_acc_guess"] is None assert ( - selector.tree["nodes"][selector.current_node_idx]["outputs"]["conservative_guess"] is None + selector.tree["nodes"][selector.current_node_idx_]["outputs"]["conservative_guess"] is None ) diff --git a/tedana/tests/test_selection_utils.py b/tedana/tests/test_selection_utils.py index 106387221..465fb8bcc 100644 --- a/tedana/tests/test_selection_utils.py +++ b/tedana/tests/test_selection_utils.py @@ -54,8 +54,19 @@ def sample_selector(options=None): "n_vols": 201, "test_elbow": 21, } - selector = ComponentSelector(tree, component_table, cross_component_metrics=xcomp) - selector.current_node_idx = 0 + selector = ComponentSelector(tree=tree) + + # Add an un-executed component table,cross component metrics, and status table + selector.component_table_ = component_table.copy() + selector.cross_component_metrics_ = xcomp + selector.component_status_table_ = selector.component_table_[ + ["Component", "classification"] + ].copy() + selector.component_status_table_ = selector.component_status_table_.rename( + columns={"classification": "initialized classification"} + ) + + selector.current_node_idx_ = 0 return selector @@ -89,7 +100,7 @@ def test_selectcomps2use_succeeds(): decide_comps_lengths = [4, 17, 21, 21, 1, 3, 0] for idx, decide_comps in enumerate(decide_comps_options): - comps2use = selection_utils.selectcomps2use(selector, decide_comps) + comps2use = selection_utils.selectcomps2use(selector.component_table_, decide_comps) assert len(comps2use) == decide_comps_lengths[idx], ( f"selectcomps2use test should select {decide_comps_lengths[idx]} with " f"decide_comps={decide_comps}, but it selected {len(comps2use)}" @@ -110,11 +121,11 @@ def test_selectcomps2use_fails(): ] for decide_comps in decide_comps_options: with pytest.raises(ValueError): - selection_utils.selectcomps2use(selector, decide_comps) + selection_utils.selectcomps2use(selector.component_table_, decide_comps) - selector.component_table = selector.component_table.drop(columns="classification") + selector.component_table_ = selector.component_table_.drop(columns="classification") with pytest.raises(ValueError): - selection_utils.selectcomps2use(selector, "all") + selection_utils.selectcomps2use(selector.component_table_, "all") def test_comptable_classification_changer_succeeds(): @@ -129,13 +140,13 @@ def test_comptable_classification_changer_succeeds(): def validate_changes(expected_classification): # check every element that was supposed to change, did change changeidx = decision_boolean.index[np.asarray(decision_boolean) == boolstate] - new_vals = selector.component_table.loc[changeidx, "classification"] + new_vals = selector.component_table_.loc[changeidx, "classification"] for val in new_vals: assert val == expected_classification # Change if true selector = sample_selector(options="provclass") - decision_boolean = selector.component_table["classification"] == "provisional accept" + decision_boolean = selector.component_table_["classification"] == "provisional accept" boolstate = True selector = selection_utils.comptable_classification_changer( selector, boolstate, "accepted", decision_boolean, tag_if="testing_tag" @@ -144,7 +155,7 @@ def validate_changes(expected_classification): # Run nochange condition selector = sample_selector(options="provclass") - decision_boolean = selector.component_table["classification"] == "provisional accept" + decision_boolean = selector.component_table_["classification"] == "provisional accept" selector = selection_utils.comptable_classification_changer( selector, boolstate, "nochange", decision_boolean, tag_if="testing_tag" ) @@ -152,7 +163,7 @@ def validate_changes(expected_classification): # Change if false selector = sample_selector(options="provclass") - decision_boolean = selector.component_table["classification"] != "provisional accept" + decision_boolean = selector.component_table_["classification"] != "provisional accept" boolstate = False selector = selection_utils.comptable_classification_changer( selector, boolstate, "rejected", decision_boolean, tag_if="testing_tag1, testing_tag2" @@ -162,7 +173,7 @@ def validate_changes(expected_classification): # Change from accepted to rejected, which should output a warning # (test if the warning appears?) selector = sample_selector(options="provclass") - decision_boolean = selector.component_table["classification"] == "accepted" + decision_boolean = selector.component_table_["classification"] == "accepted" boolstate = True selector = selection_utils.comptable_classification_changer( selector, boolstate, "rejected", decision_boolean, tag_if="testing_tag" @@ -171,7 +182,7 @@ def validate_changes(expected_classification): # Change from rejected to accepted and suppress warning selector = sample_selector(options="provclass") - decision_boolean = selector.component_table["classification"] == "rejected" + decision_boolean = selector.component_table_["classification"] == "rejected" boolstate = True selector = selection_utils.comptable_classification_changer( selector, @@ -191,8 +202,8 @@ def test_change_comptable_classifications_succeeds(): # Given the rho values in the sample table, decision_boolean should have # 2 True and 2 False values - comps2use = selection_utils.selectcomps2use(selector, "provisional accept") - rho = selector.component_table.loc[comps2use, "rho"] + comps2use = selection_utils.selectcomps2use(selector.component_table_, "provisional accept") + rho = selector.component_table_.loc[comps2use, "rho"] decision_boolean = rho < 13.5 selector, n_true, n_false = selection_utils.change_comptable_classifications( @@ -208,7 +219,7 @@ def test_change_comptable_classifications_succeeds(): assert n_false == 2 # check every element that was supposed to change, did change changeidx = decision_boolean.index[np.asarray(decision_boolean) == True] # noqa: E712 - new_vals = selector.component_table.loc[changeidx, "classification"] + new_vals = selector.component_table_.loc[changeidx, "classification"] for val in new_vals: assert val == "accepted" @@ -254,7 +265,7 @@ def test_log_decision_tree_step_smoke(): selector = sample_selector() # Standard run for logging classification changes - comps2use = selection_utils.selectcomps2use(selector, "reject") + comps2use = selection_utils.selectcomps2use(selector.component_table_, "reject") selection_utils.log_decision_tree_step( "Step 0: test_function_name", comps2use, @@ -288,7 +299,7 @@ def test_log_decision_tree_step_smoke(): ) # Logging no components found with a specified classification - comps2use = selection_utils.selectcomps2use(selector, "NotALabel") + comps2use = selection_utils.selectcomps2use(selector.component_table_, "NotALabel") selection_utils.log_decision_tree_step( "Step 0: test_function_name", comps2use, diff --git a/tedana/workflows/ica_reclassify.py b/tedana/workflows/ica_reclassify.py index 8aabde66b..f29d975d5 100644 --- a/tedana/workflows/ica_reclassify.py +++ b/tedana/workflows/ica_reclassify.py @@ -351,7 +351,7 @@ def ica_reclassify_workflow( in_both.append(a) if len(in_both) != 0: - raise ValueError("The following components were both accepted and rejected: " f"{in_both}") + raise ValueError(f"The following components were both accepted and rejected: {in_both}") # Save command into sh file, if the command-line interface was used # TODO: use io_generator to save command @@ -402,12 +402,7 @@ def ica_reclassify_workflow( ) # Make a new selector with the added files - selector = selection.component_selector.ComponentSelector( - previous_tree_fname, - comptable, - cross_component_metrics=xcomp, - status_table=status_table, - ) + selector = selection.component_selector.ComponentSelector(previous_tree_fname) if accept: selector.add_manual(accept, "accepted") @@ -415,8 +410,12 @@ def ica_reclassify_workflow( if reject: selector.add_manual(reject, "rejected") - selector.select() - comptable = selector.component_table + selector.select( + comptable, + cross_component_metrics=xcomp, + status_table=status_table, + ) + comptable = selector.component_table_ # NOTE: most of these will be identical to previous, but this makes # things easier for programs which will view the data after running. @@ -440,7 +439,7 @@ def ica_reclassify_workflow( # Save component selector and tree selector.to_files(io_generator) - if selector.n_accepted_comps == 0: + if selector.n_accepted_comps_ == 0: LGR.warning( "No accepted components remaining after manual classification! " "Please check data and results!" @@ -449,8 +448,8 @@ def ica_reclassify_workflow( mmix_orig = mmix.copy() # TODO: make this a function if tedort: - comps_accepted = selector.accepted_comps - comps_rejected = selector.rejected_comps + comps_accepted = selector.accepted_comps_ + comps_rejected = selector.rejected_comps_ acc_ts = mmix[:, comps_accepted] rej_ts = mmix[:, comps_rejected] betas = np.linalg.lstsq(acc_ts, rej_ts, rcond=None)[0] @@ -459,7 +458,7 @@ def ica_reclassify_workflow( mmix[:, comps_rejected] = resid comp_names = [ io.add_decomp_prefix(comp, prefix="ica", max_value=comptable.index.max()) - for comp in range(selector.n_comps) + for comp in range(selector.n_comps_) ] mixing_df = pd.DataFrame(data=mmix, columns=comp_names) io_generator.save_file(mixing_df, "ICA orthogonalized mixing tsv") diff --git a/tedana/workflows/tedana.py b/tedana/workflows/tedana.py index 975a83fe5..dab1352bd 100644 --- a/tedana/workflows/tedana.py +++ b/tedana/workflows/tedana.py @@ -29,6 +29,7 @@ utils, ) from tedana.bibtex import get_description_references +from tedana.selection.component_selector import ComponentSelector from tedana.stats import computefeats2 from tedana.workflows.parser_utils import check_tedpca_value, is_valid_file @@ -498,6 +499,9 @@ def tedana_workflow( if isinstance(data, str): data = [data] + LGR.info("Initializing and validating component selection tree") + selector = ComponentSelector(tree) + LGR.info(f"Loading input data: {[f for f in data]}") catd, ref_img = io.load_data(data, n_echos=n_echos) @@ -629,8 +633,8 @@ def tedana_workflow( # optimally combine data data_oc = combine.make_optcom(catd, tes, masksum_denoise, t2s=t2s_full, combmode=combmode) - # regress out global signal unless explicitly not desired if "gsr" in gscontrol: + # regress out global signal catd, data_oc = gsc.gscontrol_raw(catd, data_oc, n_echos, io_generator) fout = io_generator.save_file(data_oc, "combined img") @@ -668,20 +672,11 @@ def tedana_workflow( # Estimate betas and compute selection metrics for mixing matrix # generated from dimensionally reduced data using full data (i.e., data # with thermal noise) - LGR.info("Making second component selection guess from ICA results") - required_metrics = [ - "kappa", - "rho", - "countnoise", - "countsigFT2", - "countsigFS0", - "dice_FT2", - "dice_FS0", - "signal-noise_t", - "variance explained", - "normalized variance explained", - "d_table_score", - ] + necessary_metrics = selector.necessary_metrics + # The figures require some metrics that might not be used by the decision tree. + extra_metrics = ["variance explained", "normalized variance explained", "kappa", "rho"] + necessary_metrics = sorted(list(set(necessary_metrics + extra_metrics))) + comptable = metrics.collect.generate_metrics( catd, data_oc, @@ -690,10 +685,16 @@ def tedana_workflow( tes, io_generator, "ICA", - metrics=required_metrics, + metrics=necessary_metrics, + ) + LGR.info("Selecting components from ICA results") + selector = selection.automatic_selection( + comptable, + selector, + n_echos=n_echos, + n_vols=n_vols, ) - ica_selector = selection.automatic_selection(comptable, n_echos, n_vols, tree=tree) - n_likely_bold_comps = ica_selector.n_likely_bold_comps + n_likely_bold_comps = selector.n_likely_bold_comps_ if (n_restarts < maxrestart) and (n_likely_bold_comps == 0): LGR.warning("No BOLD components found. Re-attempting ICA.") elif n_likely_bold_comps == 0: @@ -705,6 +706,9 @@ def tedana_workflow( # If we're going to restart, temporarily allow force overwrite if keep_restarting: io_generator.overwrite = True + # Create a re-initialized selector object if rerunning + selector = ComponentSelector(tree) + RepLGR.disabled = True # Disable the report to avoid duplicate text RepLGR.disabled = False # Re-enable the report after the while loop is escaped io_generator.overwrite = overwrite # Re-enable original overwrite behavior @@ -713,19 +717,12 @@ def tedana_workflow( mixing_file = io_generator.get_name("ICA mixing tsv") mmix = pd.read_table(mixing_file).values - required_metrics = [ - "kappa", - "rho", - "countnoise", - "countsigFT2", - "countsigFS0", - "dice_FT2", - "dice_FS0", - "signal-noise_t", - "variance explained", - "normalized variance explained", - "d_table_score", - ] + selector = ComponentSelector(tree) + necessary_metrics = selector.necessary_metrics + # The figures require some metrics that might not be used by the decision tree. + extra_metrics = ["variance explained", "normalized variance explained", "kappa", "rho"] + necessary_metrics = sorted(list(set(necessary_metrics + extra_metrics))) + comptable = metrics.collect.generate_metrics( catd, data_oc, @@ -734,13 +731,13 @@ def tedana_workflow( tes, io_generator, "ICA", - metrics=required_metrics, + metrics=necessary_metrics, ) - ica_selector = selection.automatic_selection( + selector = selection.automatic_selection( comptable, - n_echos, - n_vols, - tree=tree, + selector, + n_echos=n_echos, + n_vols=n_vols, ) # TODO The ICA mixing matrix should be written out after it is created @@ -758,7 +755,7 @@ def tedana_workflow( io_generator.save_file(betas_oc, "z-scored ICA components img") # Save component selector and tree - ica_selector.to_files(io_generator) + selector.to_files(io_generator) # Save metrics and metadata metric_metadata = metrics.collect.get_metadata(comptable) io_generator.save_file(metric_metadata, "ICA metrics json") @@ -775,16 +772,16 @@ def tedana_workflow( } io_generator.save_file(decomp_metadata, "ICA decomposition json") - if ica_selector.n_likely_bold_comps == 0: + if selector.n_likely_bold_comps_ == 0: LGR.warning("No BOLD components detected! Please check data and results!") # TODO: un-hack separate comptable - comptable = ica_selector.component_table + comptable = selector.component_table_ mmix_orig = mmix.copy() if tedort: - comps_accepted = ica_selector.accepted_comps - comps_rejected = ica_selector.rejected_comps + comps_accepted = selector.accepted_comps_ + comps_rejected = selector.rejected_comps_ acc_ts = mmix[:, comps_accepted] rej_ts = mmix[:, comps_rejected] betas = np.linalg.lstsq(acc_ts, rej_ts, rcond=None)[0] @@ -793,7 +790,7 @@ def tedana_workflow( mmix[:, comps_rejected] = resid comp_names = [ io.add_decomp_prefix(comp, prefix="ICA", max_value=comptable.index.max()) - for comp in range(ica_selector.n_comps) + for comp in range(selector.n_comps_) ] mixing_df = pd.DataFrame(data=mmix, columns=comp_names)