diff --git a/backend/auth/auth.py b/backend/auth/auth.py index 26f0ff6c0..d8852d79e 100644 --- a/backend/auth/auth.py +++ b/backend/auth/auth.py @@ -1,6 +1,4 @@ -from ..database import db, User -from flask_user import SQLAlchemyAdapter -from flask_user import UserManager +from ..database import User from datetime import timezone, datetime from flask_jwt_extended import ( get_jwt, @@ -8,9 +6,16 @@ get_jwt_identity, set_access_cookies, ) -from flask import current_app +from flask import current_app, jsonify -user_manager = UserManager(SQLAlchemyAdapter(db, User)) + +def login_user(email, password): + user = User.get_by_email(email) + if not user or not user.verify_password(password): + return jsonify({"msg": "Bad email or password"}), 401 + + access_token = create_access_token(identity=user.uid) + return jsonify(access_token=access_token) def refresh_token(response): diff --git a/backend/database/core.py b/backend/database/core.py index 9274a2d01..23cf3fde9 100644 --- a/backend/database/core.py +++ b/backend/database/core.py @@ -5,190 +5,163 @@ from `backend.database`. """ import os -from typing import Any, Optional, TypeVar, Type +import json +from typing import Any, Optional, TypeVar, Type, List +from enum import Enum import click import pandas as pd -import psycopg -import psycopg2.errors from flask import abort, current_app from flask.cli import AppGroup, with_appcontext -from flask_sqlalchemy import SQLAlchemy -from psycopg2 import connect -from psycopg2.extensions import connection -from sqlalchemy.exc import ResourceClosedError from werkzeug.utils import secure_filename -from neomodel import config as neo_config, db as neo_db +from neomodel import ( + db, RelationshipTo, + RelationshipFrom, Relationship +) from neo4j import GraphDatabase +from neomodel.exceptions import DoesNotExist from ..config import TestingConfig from ..utils import dev_only -db = SQLAlchemy() -T = TypeVar("T") +T = TypeVar("T", bound="JsonSerializable") -class CrudMixin: - """Mix me into a database model whose CRUD operations you want to expose in - a convenient manner. - """ +class JsonSerializable: + """Mix me into a database model to make it JSON serializable.""" - def create(self: T, refresh: bool = True) -> T: - db.session.add(self) - db.session.commit() - if refresh: - db.session.refresh(self) - return self + def to_dict(self, include_relationships=True): + """ + Convert the node instance into a dictionary. + Args: + include_relationships (bool): Whether to include + relationships in the output. + + Returns: + dict: A dictionary representation of the node. + """ + # Serialize node properties using deflate to handle conversions + node_props = self.deflate(self.__properties__) - def delete(self) -> None: - db.session.delete(self) - db.session.commit() + # Optionally add related nodes + if include_relationships: + for rel_name, rel_manager in self.__all_relationships__().items(): + related_nodes = rel_manager.all() + node_props[rel_name] = [ + node.to_dict(include_relationships=False) + for node in related_nodes + ] - @classmethod - def get(cls: Type[T], id: Any, abort_if_null: bool = True) -> Optional[T]: - obj = db.session.query(cls).get(id) - if obj is None and abort_if_null: - abort(404) - return obj # type: ignore + return node_props + def to_json(self): + """Convert the node instance into a JSON string.""" + return json.dumps(self.to_dict()) -QUERIES_DIR = os.path.abspath( - os.path.join(os.path.dirname(__file__), "queries") -) - - -def execute_query(filename: str) -> Optional[pd.DataFrame]: - """Run SQL from a file. It will return a Pandas DataFrame if it selected - anything; otherwise it will return None. + @classmethod + def from_dict(cls: Type[T], data: dict) -> T: + """ + Creates or updates an instance of the model from a dictionary. + + Args: + data (dict): A dictionary containing data for the model instance. + + Returns: + Instance of the model. + """ + instance = None + + # Handle unique properties to find existing instances + unique_props = { + prop: data.get(prop) + for prop in cls.__all_properties__() if prop in data and data.get( + prop) + } + + if unique_props: + try: + instance = cls.nodes.get(**unique_props) + # Update existing instance + for key, value in data.items(): + if key in instance.__all_properties__(): + setattr(instance, key, value) + except DoesNotExist: + # No existing instance, create a new one + instance = cls(**unique_props) + else: + instance = cls() + + # Set properties + for key, value in data.items(): + if key in instance.__all_properties__(): + setattr(instance, key, value) + + # Handle relationships if they exist in the dictionary + for rel_name, rel_manager in cls.__all_relationships__().items(): + if rel_name in data: + related_nodes = data[rel_name] + if isinstance(related_nodes, list): + # Assume related_nodes is a list of dictionaries + for rel_data in related_nodes: + related_instance = rel_manager.definition[ + 'node_class'].from_dict(rel_data) + getattr(instance, rel_name).connect(related_instance) + else: + # Assume related_nodes is a single dictionary + related_instance = rel_manager.definition[ + 'node_class'].from_dict(related_nodes) + getattr(instance, rel_name).connect(related_instance) + + instance.save() + return instance - I do not recommend you use this function too often. In general, we should be - using the SQLAlchemy ORM. That said, it's a nice convenience, and there are - times when this function is genuinely something you want to run. - """ - with open(os.path.join(QUERIES_DIR, secure_filename(filename))) as f: - query = f.read() - with db.engine.connect() as conn: - res = conn.execute(query) - try: - df = pd.DataFrame(res.fetchall(), columns=res.keys()) - return df - except ResourceClosedError: - return None + @classmethod + def __all_properties__(cls) -> List[str]: + """Get a list of all properties defined in the class.""" + return [prop_name for prop_name in cls.__dict__ if isinstance( + cls.__dict__[prop_name], property)] + @classmethod + def __all_relationships__(cls) -> dict: + """Get all relationships defined in the class.""" + return { + rel_name: rel_manager for rel_name, rel_manager in cls.__dict__.items() + if isinstance( + rel_manager, (RelationshipTo, RelationshipFrom, Relationship)) + } -@click.group("psql", cls=AppGroup) -@with_appcontext -@click.pass_context -def db_cli(ctx: click.Context): - """Collection of database commands.""" - conn = connect( - user=current_app.config["POSTGRES_USER"], - password=current_app.config["POSTGRES_PASSWORD"], - host=current_app.config["POSTGRES_HOST"], - port=current_app.config["PGPORT"], - dbname="postgres", - ) - conn.autocommit = True - ctx.obj = conn + @classmethod + def get(cls: Type[T], uid: Any, abort_if_null: bool = True) -> Optional[T]: + """ + Get a model instance by its UID, returning None if + not found (or aborting). + + Args: + uid: Unique identifier for the node (could be Neo4j internal ID + or custom UUID). + abort_if_null (bool): Whether to abort if the node is not found. + + Returns: + Optional[T]: An instance of the model or None. + """ + obj = cls.nodes.get_or_none(uid=uid) + if obj is None and abort_if_null: + abort(404) + return obj # type: ignore -pass_psql_admin_connection = click.make_pass_decorator(connection) +# Update Enums to work well with NeoModel +class PropertyEnum(Enum): + """Mix me into an Enum to convert the options to a dictionary.""" + @classmethod + def choices(cls): + return {item.value: item.name for item in cls} -@db_cli.command("create") -@click.option( - "--overwrite/--no-overwrite", - default=False, - is_flag=True, - show_default=True, - help="If true, overwrite the database if it exists.", -) -@pass_psql_admin_connection -@click.pass_context -@dev_only -def create_database( - ctx: click.Context, conn: connection, overwrite: bool = False -): - """Create the database from nothing.""" - database = current_app.config["POSTGRES_DB"] - cursor = conn.cursor() - - if overwrite: - cursor.execute( - f"SELECT bool_or(datname = '{database}') FROM pg_database;" - ) - exists = cursor.fetchall()[0][0] - if exists: - ctx.invoke(delete_database) - - try: - cursor.execute(f"CREATE DATABASE {database};") - except (psycopg2.errors.lookup("42P04"), psycopg.errors.DuplicateDatabase): - click.echo(f"Database {database!r} already exists.") - else: - click.echo(f"Created database {database!r}.") - - -@db_cli.command("init") -def init_database(): - """Initialize the database schemas. - - Run this after the database has been created. - """ - database = current_app.config["POSTGRES_DB"] - db.create_all() - click.echo(f"Initialized the database {database!r}.") - - -@db_cli.command("gen-examples") -def gen_examples_command(): - """Generate 2 incident examples in the database.""" - execute_query("example_incidents.sql") - click.echo("Added 2 example incidents to the database.") - - -@db_cli.command("delete") -@click.option( - "--test-db", - "-t", - default=False, - is_flag=True, - help=f"Deletes the database {TestingConfig.POSTGRES_DB!r}.", +QUERIES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "queries") ) -@pass_psql_admin_connection -@dev_only -def delete_database(conn: connection, test_db: bool): - """Delete the database.""" - if test_db: - database = TestingConfig.POSTGRES_DB - else: - database = current_app.config["POSTGRES_DB"] - - cursor = conn.cursor() - - # Don't validate name for `police_data_test`. - if database != TestingConfig.POSTGRES_DB: - # Make sure we want to do this. - click.echo(f"Are you sure you want to delete database {database!r}?") - click.echo( - "Type in the database name '" - + click.style(database, fg="red") - + "' to confirm" - ) - confirmation = click.prompt("Database name") - if database != confirmation: - click.echo( - "The input does not match. " "The database will not be deleted." - ) - return None - - try: - cursor.execute(f"DROP DATABASE {database};") - except psycopg2.errors.lookup("3D000"): - click.echo(f"Database {database!r} does not exist.") - else: - click.echo(f"Database {database!r} was deleted.") # Neo4j commands @@ -198,10 +171,10 @@ def delete_database(conn: connection, test_db: bool): def neo4j_cli(ctx: click.Context): """Collection of Neo4j database commands.""" neo4j_conn = GraphDatabase.driver( - current_app.config["NEO4J_BOLT_URL"], + current_app.config["GRAPH_NM_URI"], auth=( - current_app.config["NEO4J_USERNAME"], - current_app.config["NEO4J_PASSWORD"], + current_app.config["GRAPH_USER"], + current_app.config["GRAPH_PASSWORD"], ), ) ctx.obj = neo4j_conn @@ -212,7 +185,7 @@ def neo4j_cli(ctx: click.Context): def neo4j_create(): """Create the Neo4j database or ensure it is ready.""" # Example logic to create a constraint or ensure the database is ready - neo_db.cypher_query("CREATE CONSTRAINT ON (n:Node) ASSERT n.uid IS UNIQUE;") + db.cypher_query("CREATE CONSTRAINT ON (n:Node) ASSERT n.uid IS UNIQUE;") click.echo("Neo4j database setup complete.") @@ -220,5 +193,5 @@ def neo4j_create(): @with_appcontext def neo4j_delete(): """Delete all nodes and relationships in the Neo4j database.""" - neo_db.cypher_query("MATCH (n) DETACH DELETE n") + db.cypher_query("MATCH (n) DETACH DELETE n") click.echo("Neo4j database cleared.") diff --git a/backend/database/models/user.py b/backend/database/models/user.py index 0a5176d0e..1a2407876 100644 --- a/backend/database/models/user.py +++ b/backend/database/models/user.py @@ -2,51 +2,15 @@ import bcrypt from backend.database.core import db -from flask_serialize.flask_serialize import FlaskSerialize -# from flask_user import UserMixin -from sqlalchemy.ext.compiler import compiles -# from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.types import String, TypeDecorator -# from ..core import CrudMixin -from backend.database.neo_classes import ExportableNode, PropertyEnum +from backend.database import JsonSerializable, PropertyEnum from neomodel import ( - Relationship, + Relationship, StructuredNode, StringProperty, DateProperty, BooleanProperty, UniqueIdProperty, EmailProperty ) from backend.database.models.partner import PartnerMember -fs_mixin = FlaskSerialize(db) - - -# Creating this class as NOCASE collation is not compatible with ordinary -# SQLAlchemy Strings -class CI_String(TypeDecorator): - """Case-insensitive String subclass definition""" - - impl = String - cache_ok = True - - def __init__(self, length, **kwargs): - if kwargs.get("collate"): - if kwargs["collate"].upper() not in ["BINARY", "NOCASE", "RTRIM"]: - raise TypeError( - "%s is not a valid SQLite collation" % kwargs["collate"] - ) - self.collation = kwargs.pop("collate").upper() - super(CI_String, self).__init__(length=length, **kwargs) - - -@compiles(CI_String, "sqlite") -def compile_ci_string(element, compiler, **kwargs): - base_visit = compiler.visit_string(element, **kwargs) - if element.collation: - return "%s COLLATE %s" % (base_visit, element.collation) - else: - return base_visit - - class UserRole(str, PropertyEnum): PUBLIC = "Public" PASSPORT = "Passport" @@ -65,7 +29,7 @@ def get_value(self): # Define the User data-model. -class User(ExportableNode): +class User(StructuredNode, JsonSerializable): uid = UniqueIdProperty() active = BooleanProperty(default=True) @@ -79,7 +43,8 @@ class User(ExportableNode): first_name = StringProperty(required=True) last_name = StringProperty(required=True) - role = StringProperty(choices=UserRole.choices(), default=UserRole.PUBLIC) + role = StringProperty( + choices=UserRole.choices(), default=UserRole.PUBLIC.value) phone_number = StringProperty() @@ -88,11 +53,29 @@ class User(ExportableNode): 'backend.database.models.partner.Partner', "MEMBER_OF_PARTNER", model=PartnerMember) - def verify_password(self, pw): + def verify_password(self, pw: str) -> bool: + """ + Verify the user's password using bcrypt. + Args: + pw (str): The password to verify. + + Returns: + bool: True if the password is correct, False otherwise. + """ return bcrypt.checkpw(pw.encode("utf8"), self.password.encode("utf8")) - def get_by_email(email): + @classmethod + def get_by_email(cls, email: str) -> "User": + """ + Get a user by their email address. + + Args: + email (str): The user's email. + + Returns: + User: The User instance if found, otherwise None. + """ try: - return User.nodes.get(email=email) - except User.DoesNotExist: + return cls.nodes.get_or_none(email=email) + except cls.DoesNotExist: return None diff --git a/backend/routes/agencies.py b/backend/routes/agencies.py index 382e4734c..b3af02dd1 100644 --- a/backend/routes/agencies.py +++ b/backend/routes/agencies.py @@ -7,7 +7,6 @@ from backend.database.models.user import UserRole from flask import Blueprint, abort, request from flask_jwt_extended.view_decorators import jwt_required -from sqlalchemy.exc import DataError from pydantic import BaseModel from ..database import Agency, db diff --git a/backend/schemas.py b/backend/schemas.py index 2f9fae12d..82216877f 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -1,16 +1,9 @@ from __future__ import annotations import textwrap -from typing import Any, Dict, List, Optional -from pydantic import BaseModel, root_validator -from pydantic.main import ModelMetaclass -from pydantic_sqlalchemy import sqlalchemy_to_pydantic from spectree import SecurityScheme, SpecTree from spectree.models import Server -from sqlalchemy.ext.declarative import DeclarativeMeta -from .database import User -from .database.models.partner import PartnerMember, MemberRole spec = SpecTree( "flask", @@ -74,19 +67,3 @@ ), ], ) - - -def validate(auth=True, **kwargs): - if not auth: - # Disable security for the route - kwargs["security"] = {} - - return spec.validate(**kwargs) - - -def schema_create(model_type: DeclarativeMeta, **kwargs) -> ModelMetaclass: - return sqlalchemy_to_pydantic(model_type, exclude="id", **kwargs) - - -def schema_get(model_type: DeclarativeMeta, **kwargs) -> ModelMetaclass: - return sqlalchemy_to_pydantic(model_type, **kwargs) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 06cdd9b79..c25aebe36 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,6 +1,5 @@ -import psycopg.errors -import psycopg2.errors import pytest +from unittest.mock import patch, MagicMock from backend.api import create_app from backend.auth import user_manager from backend.config import TestingConfig @@ -26,37 +25,35 @@ example_password = "my_password" -@pytest.fixture(scope="session") -def database(): - cfg = TestingConfig() - janitor = DatabaseJanitor( - user=cfg.POSTGRES_USER, - host=cfg.POSTGRES_HOST, - port=cfg.PGPORT, - dbname=cfg.POSTGRES_DB, - version=16.3, - password=cfg.POSTGRES_PASSWORD, - ) +@pytest.fixture +def mock_cypher_query(): + with patch('neomodel.db.cypher_query') as mock_query: + # Set default behavior for the mock query + mock_query.return_value = ([], None) + yield mock_query - try: - janitor.init() - except (psycopg2.errors.lookup("42P04"), psycopg.errors.DuplicateDatabase): - pass - yield +@pytest.fixture +def mock_neomodel_save(): + with patch('neomodel.StructuredNode.save') as mock_save: + yield mock_save - janitor.drop() +@pytest.fixture +def mock_neomodel_delete(): + with patch('neomodel.StructuredNode.delete') as mock_delete: + yield mock_delete -@pytest.fixture(scope="session") -def app(database): - app = create_app(config="testing") - # The app should be ready! Provide the app instance here. - # Use the app context to make testing easier. - # The main time where providing app context can cause false positives is - # when testing CLI commands that don't pass the app context. - with app.app_context(): - yield app + +@pytest.fixture +def mock_neo4j_driver(): + with patch('neo4j.GraphDatabase.driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + # Set default behavior for session.run + mock_session.run.return_value = MagicMock() + yield mock_session.run @pytest.fixture @@ -64,7 +61,7 @@ def client(app): return app.test_client() -@pytest.fixture +""" @pytest.fixture def example_user(db_session): user = User( email=example_email, @@ -382,7 +379,7 @@ def contributor_access_token(client, partner_publisher): }, ) assert res.status_code == 200 - return res.json["access_token"] + return res.json["access_token"] """ @pytest.fixture @@ -390,14 +387,14 @@ def cli_runner(app): return app.test_cli_runner() -@pytest.fixture(scope="session") -def _db(app): - """See this: +# @pytest.fixture(scope="session") +# def _db(app): +# """See this: - https://github.com/jeancochrane/pytest-flask-sqlalchemy +# https://github.com/jeancochrane/pytest-flask-sqlalchemy - Basically, this '_db' fixture is required for the above extension to work. - We use this extension to allow for easy testing of the database. - """ - db.create_all() - yield db +# Basically, this '_db' fixture is required for the above extension to work. +# We use this extension to allow for easy testing of the database. +# """ +# db.create_all() +# yield db