Skip to content

Commit

Permalink
Merge pull request #70 from OpenGeoscience/postgres-gcc-query
Browse files Browse the repository at this point in the history
Implement GCC algorithm in native Postgres
  • Loading branch information
jjnesbitt authored Oct 1, 2024
2 parents 9c13113 + 1eadf03 commit 3bf0d01
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 25 deletions.
79 changes: 67 additions & 12 deletions uvdat/core/models/networks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,82 @@
from django.contrib.gis.db import models as geo_models
from django.db import models
import networkx as nx
from django.db import connection, models

from .dataset import Dataset

GCC_QUERY = """
WITH RECURSIVE n as (
-- starting node
SELECT id FROM (
SELECT cnn.id
FROM core_networknode cnn
WHERE
cnn.network_id = %(network_id)s AND
NOT (cnn.id = ANY(%(excluded_nodes)s))
ORDER BY random()
LIMIT 1
) nn
UNION
-- Select the *other* node in the edge
SELECT CASE
WHEN e.to_node_id = n.id
THEN e.from_node_id
ELSE e.to_node_id
END
FROM n
JOIN (
SELECT *
FROM core_networkedge ne
WHERE
ne.network_id = %(network_id)s AND
NOT (
ne.from_node_id = ANY(%(excluded_nodes)s) OR
ne.to_node_id = ANY(%(excluded_nodes)s)
)
) e
ON
e.from_node_id = n.id OR
e.to_node_id = n.id
)
SELECT id FROM n ORDER BY id
;
"""


class Network(models.Model):
dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='networks')
category = models.CharField(max_length=25)
metadata = models.JSONField(blank=True, null=True)

def get_graph(self):
from uvdat.core.tasks.networks import get_network_graph
def get_gcc(self, excluded_nodes: list[int]):
total_nodes = NetworkNode.objects.filter(network=self).count()

# This is used to store all the nodes we've already visited,
# starting with the explicitly excluded nodes
cur_excluded_nodes = excluded_nodes.copy()

# Store largest network found so far
gcc = []

with connection.cursor() as cursor:
# If the GCC size is greater than half the network, we know that there's no way to
# find a larger one. If we've exhausted all nodes, also stop searching.
while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes):
cursor.execute(
GCC_QUERY,
{
'excluded_nodes': cur_excluded_nodes,
'network_id': self.pk,
},
)
nodes = [x[0] for x in cursor.fetchall()]
if not nodes:
raise Exception('Expected to find nodes but found none')

return get_network_graph(self)
cur_excluded_nodes.extend(nodes)
if len(nodes) > len(gcc):
gcc = nodes

def get_gcc(self, exclude_nodes):
graph = self.get_graph()
graph.remove_nodes_from(exclude_nodes)
if graph.number_of_nodes == 0 or nx.number_connected_components(graph) == 0:
return []
gcc = max(nx.connected_components(graph), key=len)
return list(gcc)
return gcc


class NetworkNode(models.Model):
Expand Down
25 changes: 12 additions & 13 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Dataset, NetworkEdge, NetworkNode
from uvdat.core.models import Dataset, Network, NetworkEdge, NetworkNode
from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission
from uvdat.core.rest.serializers import (
DatasetSerializer,
Expand Down Expand Up @@ -82,16 +82,15 @@ def gcc(self, request, **kwargs):
exclude_nodes = exclude_nodes.split(',')
exclude_nodes = [int(n) for n in exclude_nodes if len(n)]

# TODO: improve this for datasets with multiple networks;
# this currently returns the gcc for the network with the most excluded nodes
results = []
# Find the GCC for each network in the dataset
network_gccs: list[list[int]] = []
for network in dataset.networks.all():
excluded_node_names = [n.name for n in network.nodes.all() if n.id in exclude_nodes]
gcc = network.get_gcc(exclude_nodes)
results.append(dict(excluded=excluded_node_names, gcc=gcc))
if len(results):
results.sort(key=lambda r: len(r.get('excluded')), reverse=True)
gcc = results[0].get('gcc')
excluded = results[0].get('excluded')
add_gcc_chart_datum(dataset, project_id, excluded, len(gcc))
return HttpResponse(json.dumps(gcc), status=200)
network: Network
network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes))

# TODO: improve this for datasets with multiple networks.
# This currently returns the gcc for the network with the most excluded nodes
gcc = max(network_gccs, key=len)

add_gcc_chart_datum(dataset, project_id, exclude_nodes, len(gcc))
return Response(gcc, status=200)

0 comments on commit 3bf0d01

Please sign in to comment.