diff --git a/uvdat/core/models/networks.py b/uvdat/core/models/networks.py index 235156d2..5bafd566 100644 --- a/uvdat/core/models/networks.py +++ b/uvdat/core/models/networks.py @@ -1,27 +1,82 @@ from django.contrib.gis.db import models as geo_models -from django.db import models -import networkx as nx +from django.db import connection, models from .dataset import Dataset +GCC_QUERY = """ +WITH RECURSIVE n as ( + -- starting node + SELECT id FROM ( + SELECT cnn.id + FROM core_networknode cnn + WHERE + cnn.network_id = %(network_id)s AND + NOT (cnn.id = ANY(%(excluded_nodes)s)) + ORDER BY random() + LIMIT 1 + ) nn + UNION + -- Select the *other* node in the edge + SELECT CASE + WHEN e.to_node_id = n.id + THEN e.from_node_id + ELSE e.to_node_id + END + FROM n + JOIN ( + SELECT * + FROM core_networkedge ne + WHERE + ne.network_id = %(network_id)s AND + NOT ( + ne.from_node_id = ANY(%(excluded_nodes)s) OR + ne.to_node_id = ANY(%(excluded_nodes)s) + ) + ) e + ON + e.from_node_id = n.id OR + e.to_node_id = n.id +) +SELECT id FROM n ORDER BY id +; +""" + class Network(models.Model): dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='networks') category = models.CharField(max_length=25) metadata = models.JSONField(blank=True, null=True) - def get_graph(self): - from uvdat.core.tasks.networks import get_network_graph + def get_gcc(self, excluded_nodes: list[int]): + total_nodes = NetworkNode.objects.filter(network=self).count() + + # This is used to store all the nodes we've already visited, + # starting with the explicitly excluded nodes + cur_excluded_nodes = excluded_nodes.copy() + + # Store largest network found so far + gcc = [] + + with connection.cursor() as cursor: + # If the GCC size is greater than half the network, we know that there's no way to + # find a larger one. If we've exhausted all nodes, also stop searching. + while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): + cursor.execute( + GCC_QUERY, + { + 'excluded_nodes': cur_excluded_nodes, + 'network_id': self.pk, + }, + ) + nodes = [x[0] for x in cursor.fetchall()] + if not nodes: + raise Exception('Expected to find nodes but found none') - return get_network_graph(self) + cur_excluded_nodes.extend(nodes) + if len(nodes) > len(gcc): + gcc = nodes - def get_gcc(self, exclude_nodes): - graph = self.get_graph() - graph.remove_nodes_from(exclude_nodes) - if graph.number_of_nodes == 0 or nx.number_connected_components(graph) == 0: - return [] - gcc = max(nx.connected_components(graph), key=len) - return list(gcc) + return gcc class NetworkNode(models.Model): diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 1ef7eaf9..dbfed9f7 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -5,7 +5,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from uvdat.core.models import Dataset, NetworkEdge, NetworkNode +from uvdat.core.models import Dataset, Network, NetworkEdge, NetworkNode from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission from uvdat.core.rest.serializers import ( DatasetSerializer, @@ -82,16 +82,15 @@ def gcc(self, request, **kwargs): exclude_nodes = exclude_nodes.split(',') exclude_nodes = [int(n) for n in exclude_nodes if len(n)] - # TODO: improve this for datasets with multiple networks; - # this currently returns the gcc for the network with the most excluded nodes - results = [] + # Find the GCC for each network in the dataset + network_gccs: list[list[int]] = [] for network in dataset.networks.all(): - excluded_node_names = [n.name for n in network.nodes.all() if n.id in exclude_nodes] - gcc = network.get_gcc(exclude_nodes) - results.append(dict(excluded=excluded_node_names, gcc=gcc)) - if len(results): - results.sort(key=lambda r: len(r.get('excluded')), reverse=True) - gcc = results[0].get('gcc') - excluded = results[0].get('excluded') - add_gcc_chart_datum(dataset, project_id, excluded, len(gcc)) - return HttpResponse(json.dumps(gcc), status=200) + network: Network + network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes)) + + # TODO: improve this for datasets with multiple networks. + # This currently returns the gcc for the network with the most excluded nodes + gcc = max(network_gccs, key=len) + + add_gcc_chart_datum(dataset, project_id, exclude_nodes, len(gcc)) + return Response(gcc, status=200)