Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiSortingComparison and MultiTemplateComparison optimal assignment #2911

Open
florian6973 opened this issue May 27, 2024 · 4 comments
Open
Labels
comparison Related to comparison module

Comments

@florian6973
Copy link
Contributor

florian6973 commented May 27, 2024

Hi!

I have been looking at the literature on Multidimensional Assignment Problems / Entity Matching to understand multiple sorting or template assignments, and I realized that the current method does not seem to always return the optimal matching.

To check my hypothesis, I modified the BaseMultiComparison class to create a minimal working example with the potential issue (example from https://arxiv.org/pdf/2112.03346 page 3), it can be run as a script:

from collections import OrderedDict
from copy import deepcopy
import numpy as np

class BaseMultiComparison():
    """
    Base class for graph-based multi comparison classes.

    It handles graph operations, comparisons, and agreements.
    """

    def __init__(self):
        import networkx as nx

        # BaseComparison.__init__(
        #     self,
        #     object_list=object_list,
        #     name_list=name_list,
        #     match_score=match_score,
        #     chance_score=chance_score,
        #     verbose=verbose,
        # )
        # self.match_score = 0.3
        self.name_list = ['a', 'b', 'c']   
        self.object_list = ['1', '2', '3']   
        self._verbose = True
    
        self.graph = None
        self.subgraphs = None
        self.clean_graph = None

    def _compute_all(self):
        self._do_comparison()
        self._do_graph()
        self._clean_graph()
        self._do_agreement()

    def _populate_nodes(self):
        for name in self.name_list:
            for unit_id in self.object_list:
                self.graph.add_node((name, unit_id))

    @property
    def units(self):
        return deepcopy(self._new_units)

    def compute_subgraphs(self):
        """
        Computes subgraphs of connected components.
        Returns
        -------
        sg_object_names: list
            List of sorter names for each node in the connected component subgraph
        sg_units: list
            List of unit ids for each node in the connected component subgraph
        """
        if self.clean_graph is not None:
            g = self.clean_graph
        else:
            g = self.graph

        import networkx as nx

        subgraphs = (g.subgraph(c).copy() for c in nx.connected_components(g))
        sg_object_names = []
        sg_units = []
        for sg in subgraphs:
            object_names = []
            unit_names = []
            for node in sg.nodes:
                object_names.append(node[0])
                unit_names.append(node[1])
            sg_object_names.append(object_names)
            sg_units.append(unit_names)
        return sg_object_names, sg_units

    def _do_comparison(
        self,
    ):
        # do pairwise matching
        if self._verbose:
            print("Multicomparison step 1: pairwise comparison")

        self.comparisons = {
            ('a', 'b'): {
                '1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
            },
             ('b', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
            },
             ('a', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
            }
        }

    def _do_graph(self):
        if self._verbose:
            print("Multicomparison step 2: make graph")

        import networkx as nx

        self.graph = nx.Graph()
        # nodes
        self._populate_nodes()

        # edges
        for comp_name, comp in self.comparisons.items():
            for u1 in comp.keys():
                u2 = comp[u1][0]
                if u2 != -1:
                    name_1, name_2 = comp_name
                    node1 = name_1, u1
                    node2 = name_2, u2
                    score = comp[u1][1]
                    self.graph.add_edge(node1, node2, weight=score)

        # the graph is symmetrical
        self.graph = self.graph.to_undirected()

    def _clean_graph(self):
        if self._verbose:
            print("Multicomparison step 3: clean graph")
        clean_graph = self.graph.copy()
        import networkx as nx

        subgraphs = (clean_graph.subgraph(c).copy() for c in nx.connected_components(clean_graph))
        removed_nodes = 0
        for sg in subgraphs:
            object_names = []
            for node in sg.nodes:
                object_names.append(node[0])
            sorters, counts = np.unique(object_names, return_counts=True)

            if np.any(counts > 1):
                for sorter in sorters[counts > 1]:
                    nodes_duplicate = [n for n in sg.nodes if sorter in n]
                    # get edges
                    edges_duplicates = []
                    weights_duplicates = []
                    for n in nodes_duplicate:
                        edges = sg.edges(n, data=True)
                        for e in edges:
                            edges_duplicates.append(e)
                            weights_duplicates.append(e[2]["weight"])

                    # remove extra edges
                    n_edges_to_remove = len(nodes_duplicate) - 1
                    remove_idxs = np.argsort(weights_duplicates)[:n_edges_to_remove]
                    edges_to_remove = np.array(edges_duplicates, dtype=object)[remove_idxs]

                    for edge_to_remove in edges_to_remove:
                        clean_graph.remove_edge(edge_to_remove[0], edge_to_remove[1])
                        sg.remove_edge(edge_to_remove[0], edge_to_remove[1])
                        if self._verbose:
                            print(f"Removed edge: {edge_to_remove}")

                    # remove extra nodes (as a second step to not affect edge removal)
                    for edge_to_remove in edges_to_remove:
                        if edge_to_remove[0] in nodes_duplicate:
                            node_to_remove = edge_to_remove[0]
                        else:
                            node_to_remove = edge_to_remove[1]
                        if node_to_remove in sg.nodes:
                            sg.remove_node(node_to_remove)
                            print(f"Removed node: {node_to_remove}")
                            removed_nodes += 1

        if self._verbose:
            print(f"Removed {removed_nodes} duplicate nodes")
        self.clean_graph = clean_graph

    def _do_agreement(self):
        # extract agreement from graph
        if self._verbose:
            print("Multicomparison step 4: extract agreement from graph")

        self._new_units = {}

        # save new units
        import networkx as nx

        self.subgraphs = [self.clean_graph.subgraph(c).copy() for c in nx.connected_components(self.clean_graph)]
        for new_unit, sg in enumerate(self.subgraphs):
            edges = list(sg.edges(data=True))
            if len(edges) > 0:
                avg_agr = np.mean([d["weight"] for u, v, d in edges])
            else:
                avg_agr = 0
            object_unit_ids = {}
            for node in sg.nodes:
                object_name, unit_name = node
                object_unit_ids[object_name] = unit_name
            # sort dict based on name list
            sorted_object_unit_ids = OrderedDict()
            for name in self.name_list:
                if name in object_unit_ids:
                    sorted_object_unit_ids[name] = object_unit_ids[name]
            self._new_units[new_unit] = {
                "avg_agreement": avg_agr,
                "unit_ids": sorted_object_unit_ids,
                "agreement_number": len(sg.nodes),
            }
b = BaseMultiComparison()
b._compute_all()
print(b._new_units)

Therefore, according to you, is my MWE correctly adapted from the literature to the spikeinterface framework? If so, have you envisioned other methods so far or should we think more about it to solve this issue please?

Thanks!

Florent

@florian6973 florian6973 changed the title MultiSortingComparison and MultiTemplateComparison optimal assignment issue MultiSortingComparison and MultiTemplateComparison optimal assignment May 27, 2024
@zm711
Copy link
Collaborator

zm711 commented May 29, 2024

@florian6973,

we will take a look at this soon. We are in the middle of a spikeinterface hackathon, but super curious about this. It is a little hard for me to read the code (without having a nice diff view). Could you also post the same code with comments on the lines you changed to make comparison a bit easier. If we haven't responded by next week please ping us again!

@zm711 zm711 added the comparison Related to comparison module label May 29, 2024
@florian6973
Copy link
Contributor Author

florian6973 commented May 29, 2024

Thanks for your reply!

Sure, here are some more details:

  • my goal is to replicate the example from the paper to see if spikeinterface correctly solves the multiple assignment problem. If we consider A, B, C as three different sessions or sorters, (a1, a2, a3, ...) as the corresponding templates / units, and sim the similarity measure (cosine or any other), we would like to know which is the best matching between the templates/units across sessions/sorters.
    image
    We can rewrite the table as three agreement matrices computed by 2-by-2 comparisons.
    The matches from the current Hungarian method in spikeinterface are shown in bold.

$$\begin{array}{c|ccc} & b_1 & b_2 & b_3 \\ \hline a_1 & 0.4 & \textbf{0.6} & 0.6 \\ a_2 & \textbf{0.6} & 0.6 & 0.6 \\ a_3 & 0.6 & 0.6 & \textbf{1} \end{array}$$

$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \\ \hline b_1 & \textbf{1} & 0.1 & 0.1 \\ b_2 & 0.1 & \textbf{1} & 0.1 \\ b_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$

$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \\ \hline a_1 & \textbf{1} & 0.1 & 0.1 \\ a_2 & 0.1 & \textbf{1} & 0.1 \\ a_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$

  • from there, I adapted the BaseMultiComparison to reflect this particular situation, and check if we obtain in the end the true optimal matching $(a_1, b_1, c_1)$, $(a_2, b_2, c_2)$, $(a_3, b_3, c_3)$. Please find the diff below. Note that I do not need the whole comparison matrix given the way the graph is built, and I assume the match_score is low enough:
    def __init__(self):
        self.name_list = ['a', 'b', 'c']   
        self.object_list = ['1', '2', '3']   

    # def _compare_ij(self, i, j):
    #   raise NotImplementedError

    # def _populate_nodes(self):
    #    raise NotImplementedError

    def _populate_nodes(self):
        for name in self.name_list:
            for unit_id in self.object_list:
                self.graph.add_node((name, unit_id))

#     def _do_comparison(
#         self,
#     ):
#         # do pairwise matching
#         if self._verbose:
#             print("Multicomparison step 1: pairwise comparison")

#         self.comparisons = {}
#         for i in range(len(self.object_list)):
#             for j in range(i + 1, len(self.object_list)):
#                 if self.name_list is not None:
#                     name_i = self.name_list[i]
#                     name_j = self.name_list[j]
#                 else:
#                     name_i = "object i"
#                     name_j = "object j"
#                 if self._verbose:
#                     print(f"  Comparing: {name_i} and {name_j}")
#                 comp = self._compare_ij(i, j)
#                 self.comparisons[(name_i, name_j)] = comp


    def _do_comparison(
        self,
    ):
        # do pairwise matching
        if self._verbose:
            print("Multicomparison step 1: pairwise comparison")

        self.comparisons = {
            ('a', 'b'): {
                '1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
            },
             ('b', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
            },
             ('a', 'c'): {
                '1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
            }
        }

    def _do_graph(self):
# ...
#  for comp_name, comp in self.comparisons.items():
#     for u1 in comp.hungarian_match_12.index.values:
#         u2 = comp.hungarian_match_12[u1]
#         if u2 != -1:
#             name_1, name_2 = comp_name
#             node1 = name_1, u1
#             node2 = name_2, u2
#             score = comp.agreement_scores.loc[u1, u2]
#             self.graph.add_edge(node1, node2, weight=score)

        for comp_name, comp in self.comparisons.items():
            for u1 in comp.keys():
                u2 = comp[u1][0]
                if u2 != -1:
                    name_1, name_2 = comp_name
                    node1 = name_1, u1
                    node2 = name_2, u2
                    score = comp[u1][1]
                    self.graph.add_edge(node1, node2, weight=score)
  • but we obtain $(a_1, b_1, c_1)$, $(a_2)$, $(b_2, c_2)$ and $(a_3, b_3, c_3)$ with the spikeinterface code. It is not even what would be expected $(a_1, b_2, c_2)$, $(a_2, b_1, c_1)$ and $(a_3, b_3, c_3)$.

I hope this is clearer. I am not sure if I am fully correct, but I was trying to properly understand the multiple comparison module, so that's why I am asking.

Have a good hackathon :)

By the way, if you are in Boston at some point we could discuss it in person if needed :)

@zm711
Copy link
Collaborator

zm711 commented Jun 1, 2024

Hey @florian6973,

thanks for the well wishes. We could definitely meet at some point. If you're on the slack just send me a message. But I think @alejoe91 is better for looking over this one. I didn't work on the initial code so he would know it way better.

@JoeZiminski
Copy link
Collaborator

Thank a lot for this @florian6973, super interesting and detailed investigation. Will definitely look into this while working on #2626, please feel free to give any feedback and thoughts on the plan I posted there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comparison Related to comparison module
Projects
None yet
Development

No branches or pull requests

3 participants