Skip to content

Commit

Permalink
Refactor db results
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderWatzinger committed Nov 29, 2024
1 parent 1ad1839 commit 3e998b2
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 45 deletions.
2 changes: 1 addition & 1 deletion openatlas/database/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def check_single_type_duplicates(ids: list[int]) -> list[int]:
HAVING COUNT(*) > 1;
""",
{'ids': tuple(ids)})
return [row['domain_id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_orphaned_subunits() -> list[dict[str, Any]]:
Expand Down
20 changes: 10 additions & 10 deletions openatlas/database/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_overview_counts(classes: list[str]) -> dict[str, int]:
GROUP BY openatlas_class_name;
""",
{'classes': tuple(classes)})
return {row['name']: row['count'] for row in g.cursor.fetchall()}
return {row['name']: row['count'] for row in list(g.cursor)}


def get_overview_counts_by_type(
Expand All @@ -98,7 +98,7 @@ def get_overview_counts_by_type(
GROUP BY openatlas_class_name;
""",
{'ids': tuple(ids), 'classes': tuple(classes)})
return {row['name']: row['count'] for row in g.cursor.fetchall()}
return {row['name']: row['count'] for row in list(g.cursor)}


def get_latest(classes: list[str], limit: int) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -317,7 +317,7 @@ def get_file_info() -> dict[int, dict[str, Any]]:
row['entity_id']: {
'public': row['public'],
'license_holder': row['license_holder'],
'creator': row['creator']} for row in g.cursor.fetchall()}
'creator': row['creator']} for row in list(g.cursor)}


def get_subunits_without_super(classes: list[str]) -> list[int]:
Expand All @@ -329,7 +329,7 @@ def get_subunits_without_super(classes: list[str]) -> list[int]:
WHERE e.openatlas_class_name IN %(classes)s;
""",
{'classes': tuple(classes)})
return [row['id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_roots(
Expand Down Expand Up @@ -380,7 +380,7 @@ def get_roots(
return {
row['start_node']: {
'id': row['top_level'],
'name': row['name']} for row in g.cursor.fetchall()}
'name': row['name']} for row in list(g.cursor)}


def get_linked_entities_recursive(
Expand All @@ -404,7 +404,7 @@ def get_linked_entities_recursive(
) SELECT {first} FROM items;
""",
{'id_': id_, 'code': tuple(codes) if codes else ''})
return [row[first] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_links_of_entities(
Expand Down Expand Up @@ -461,23 +461,23 @@ def delete_reference_system_links(entity_id: int) -> None:
def get_linked_entities(id_: int, codes: list[str]) -> list[int]:
g.cursor.execute(
"""
SELECT range_id AS result_id
SELECT range_id
FROM model.link
WHERE domain_id = %(id_)s AND property_code IN %(codes)s;
""",
{'id_': id_, 'codes': tuple(codes)})
return [row['result_id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_linked_entities_inverse(id_: int, codes: list[str]) -> list[int]:
g.cursor.execute(
"""
SELECT domain_id AS result_id
SELECT domain_id
FROM model.link
WHERE range_id = %(id_)s AND property_code IN %(codes)s;
""",
{'id_': id_, 'codes': tuple(codes)})
return [row['result_id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def delete_links_by_codes(
Expand Down
6 changes: 3 additions & 3 deletions openatlas/database/gis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
from collections import defaultdict
from typing import Any, Optional
from typing import Any

from flask import g

Expand Down Expand Up @@ -63,7 +63,7 @@ def get_geometry_dict(row: dict[str, Any]) -> dict[str, Any]:
return geometry


def get_centroids_by_id(id_: int) -> Optional[list[dict[str, Any]]]:
def get_centroids_by_id(id_: int) -> list[dict[str, Any]]:
g.cursor.execute(
"""
SELECT
Expand All @@ -86,7 +86,7 @@ def get_centroids_by_id(id_: int) -> Optional[list[dict[str, Any]]]:
for row in list(g.cursor):
if data := get_centroid_dict(row):
geometries.append(data)
return geometries or None
return geometries


def get_centroids_by_ids(ids: list[int]) -> defaultdict[int, list[Any]]:
Expand Down
9 changes: 5 additions & 4 deletions openatlas/database/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_project_by_id(id_: int) -> dict[str, Any]:
return g.cursor.fetchone()


def get_project_by_name(name: str) -> Optional[dict[str, Any]]:
def get_project_by_name(name: str) -> dict[str, Any]:
g.cursor.execute(
f'{SQL} WHERE p.name = %(name)s GROUP BY p.id;',
{'name': name})
Expand All @@ -51,11 +51,12 @@ def delete_project(id_: int) -> None:
def check_origin_ids(project_id: int, origin_ids: list[str]) -> list[str]:
g.cursor.execute(
"""
SELECT origin_id FROM import.entity
SELECT origin_id
FROM import.entity
WHERE project_id = %(project_id)s AND origin_id IN %(ids)s;
""",
{'project_id': project_id, 'ids': tuple(set(origin_ids))})
return [row['origin_id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_id_from_origin_id(project_id: int, origin_id: str) -> list[str]:
Expand All @@ -75,7 +76,7 @@ def check_duplicates(class_: str, names: list[str]) -> list[str]:
WHERE openatlas_class_name = %(class_)s AND LOWER(name) IN %(names)s;
""",
{'class_': class_, 'names': tuple(names)})
return [row['name'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def update_project(id_: int, name: str, description: Optional[str]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion openatlas/database/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_entity_ids_by_type_ids(type_ids: list[int]) -> list[int]:
ORDER BY id;
""",
{'type_ids': tuple(type_ids)})
return [row[0] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def delete_(id_: int) -> None:
Expand Down
2 changes: 1 addition & 1 deletion openatlas/database/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def get_object_mapping() -> dict[int, int]:
JOIN model.entity e2 ON l.range_id = e2.id
AND e.openatlas_class_name = 'place';
""")
return {row['range_id']: row['id'] for row in g.cursor.fetchall()}
return {row['range_id']: row['id'] for row in list(g.cursor)}
2 changes: 1 addition & 1 deletion openatlas/database/openatlas_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_class_count() -> dict[str, int]:
LEFT JOIN model.entity e ON oc.name = e.openatlas_class_name
GROUP BY oc.name;
""")
return {row['name']: row['count'] for row in g.cursor.fetchall()}
return {row['name']: row['count'] for row in list(g.cursor)}


def get_classes() -> list[dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion openatlas/database/overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_by_id(id_: int) -> dict[str, Any]:
WHERE id = %(id)s;
""",
{'id': id_})
return dict(g.cursor.fetchone())
return g.cursor.fetchone()


def remove(id_: int) -> None:
Expand Down
2 changes: 1 addition & 1 deletion openatlas/database/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def get_settings(cursor: Optional[DictCursor] = None) -> dict[str, str]:
cursor = cursor or g.cursor
cursor.execute('SELECT name, value FROM web.settings;')
return {row['name']: row['value'] for row in cursor.fetchall()}
return {row['name']: row['value'] for row in list(cursor)}


def update(field_name: str, value: Any) -> None:
Expand Down
4 changes: 2 additions & 2 deletions openatlas/database/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_types(with_count: bool) -> list[dict[str, Any]]:
ORDER BY e.name;
"""
g.cursor.execute(sql)
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def get_hierarchies() -> list[dict[str, Any]]:
Expand All @@ -61,7 +61,7 @@ def get_hierarchies() -> list[dict[str, Any]]:
SELECT id, name, category, multiple, directional, required
FROM web.hierarchy;
""")
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def set_required(id_: int) -> None:
Expand Down
40 changes: 20 additions & 20 deletions openatlas/database/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any

from flask import g

Expand Down Expand Up @@ -71,7 +71,7 @@ def update_language(user_id: int, value: str) -> None:

def get_all() -> list[dict[str, Any]]:
g.cursor.execute(f'{SQL} ORDER BY username;')
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def get_bookmarks(user_id: int) -> list[int]:
Expand All @@ -82,40 +82,40 @@ def get_bookmarks(user_id: int) -> list[int]:
WHERE user_id = %(user_id)s;
""",
{'user_id': user_id})
return [row['entity_id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]


def get_by_id(user_id: int) -> Optional[dict[str, Any]]:
def get_by_id(user_id: int) -> dict[str, Any]:
g.cursor.execute(f'{SQL} WHERE u.id = %(id)s;', {'id': user_id})
return dict(g.cursor.fetchone()) if g.cursor.rowcount else None
return g.cursor.fetchone()


def get_by_reset_code(code: str) -> Optional[dict[str, Any]]:
def get_by_reset_code(code: str) -> dict[str, Any]:
g.cursor.execute(
f'{SQL} WHERE u.password_reset_code = %(code)s;',
{'code': code})
return dict(g.cursor.fetchone()) if g.cursor.rowcount else None
return g.cursor.fetchone()


def get_by_email(email: str) -> Optional[dict[str, Any]]:
def get_by_email(email: str) -> dict[str, Any]:
g.cursor.execute(
f'{SQL} WHERE LOWER(u.email) = LOWER(%(email)s);',
{'email': email})
return dict(g.cursor.fetchone()) if g.cursor.rowcount else None
return g.cursor.fetchone()


def get_by_username(username: str) -> Optional[dict[str, Any]]:
def get_by_username(username: str) -> dict[str, Any]:
g.cursor.execute(
f'{SQL} WHERE LOWER(u.username) = LOWER(%(username)s);',
{'username': username})
return dict(g.cursor.fetchone()) if g.cursor.rowcount else None
return g.cursor.fetchone()


def get_by_unsubscribe_code(code: str) -> Optional[dict[str, Any]]:
def get_by_unsubscribe_code(code: str) -> dict[str, Any]:
g.cursor.execute(
f'{SQL} WHERE u.unsubscribe_code = %(code)s;',
{'code': code})
return dict(g.cursor.fetchone()) if g.cursor.rowcount else None
return g.cursor.fetchone()


def get_activities(
Expand Down Expand Up @@ -143,7 +143,7 @@ def get_activities(
'id': user_id,
'action': action,
'entity_id': entity_id})
return g.cursor.fetchall()
return list(g.cursor)


def get_created_entities_count(user_id: int) -> int:
Expand Down Expand Up @@ -192,7 +192,7 @@ def delete(id_: int) -> None:

def get_users_for_form() -> list[tuple[int, str]]:
g.cursor.execute('SELECT id, username FROM web.user ORDER BY username;')
return [(row['id'], row['username']) for row in g.cursor.fetchall()]
return [(row['id'], row['username']) for row in list(g.cursor)]


def insert_bookmark(user_id: int, entity_id: int) -> None:
Expand Down Expand Up @@ -221,7 +221,7 @@ def get_settings(user_id: int) -> list[dict[str, Any]]:
WHERE user_id = %(user_id)s;
""",
{'user_id': user_id})
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def get_notes_by_entity_id(
Expand All @@ -235,7 +235,7 @@ def get_notes_by_entity_id(
AND (public IS TRUE or user_id = %(user_id)s);
""",
{'entity_id': entity_id, 'user_id': user_id})
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def get_notes_by_user_id(user_id: int) -> list[dict[str, Any]]:
Expand All @@ -246,7 +246,7 @@ def get_notes_by_user_id(user_id: int) -> list[dict[str, Any]]:
WHERE user_id = %(user_id)s;
""",
{'user_id': user_id})
return [dict(row) for row in g.cursor.fetchall()]
return list(g.cursor)


def get_note_by_id(id_: int) -> dict[str, Any]:
Expand All @@ -257,7 +257,7 @@ def get_note_by_id(id_: int) -> dict[str, Any]:
WHERE id = %(id)s;
""",
{'id': id_})
return dict(g.cursor.fetchone())
return g.cursor.fetchone()


def insert_note(
Expand Down Expand Up @@ -302,4 +302,4 @@ def get_user_entities(id_: int) -> list[int]:
AND l.action = 'insert';
""",
{'user_id': id_})
return [row['id'] for row in g.cursor.fetchall()]
return [row[0] for row in list(g.cursor)]

0 comments on commit 3e998b2

Please sign in to comment.