diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..7d363f7d 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: patch + changes: + changed: + - updated the user_profile endpoints to use blueprints instead of endpoints. + fixed: + - updated the user profile endpoint to resist injection attacks on update. diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 684bb3f9..4b2f5809 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -21,6 +21,7 @@ ) from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp from policyengine_api.routes.metadata_routes import metadata_bp +from policyengine_api.routes.user_profile_routes import user_profile_bp from .endpoints import ( get_home, @@ -32,9 +33,6 @@ set_user_policy, get_user_policy, update_user_policy, - set_user_profile, - get_user_profile, - update_user_profile, get_simulations, ) @@ -100,11 +98,7 @@ get_user_policy ) -app.route("//user-profile", methods=["POST"])(set_user_profile) - -app.route("//user-profile", methods=["GET"])(get_user_profile) - -app.route("//user-profile", methods=["PUT"])(update_user_profile) +app.register_blueprint(user_profile_bp) app.route("/simulations", methods=["GET"])(get_simulations) diff --git a/policyengine_api/endpoints/__init__.py b/policyengine_api/endpoints/__init__.py index 0b7bd51e..7f2d4bf6 100644 --- a/policyengine_api/endpoints/__init__.py +++ b/policyengine_api/endpoints/__init__.py @@ -12,9 +12,4 @@ update_user_policy, ) -from .user_profile import ( - set_user_profile, - get_user_profile, - update_user_profile, -) from .simulation import get_simulations diff --git a/policyengine_api/endpoints/user_profile.py b/policyengine_api/endpoints/user_profile.py deleted file mode 100644 index e600bafa..00000000 --- a/policyengine_api/endpoints/user_profile.py +++ /dev/null @@ -1,231 +0,0 @@ -from flask import Response, request -from policyengine_api.utils.payload_validators import validate_country -from policyengine_api.data import database -import json - - -@validate_country -def set_user_profile(country_id: str) -> dict: - """ - Creates a new user_profile - """ - - payload = request.json - primary_country = country_id - auth0_id = payload.pop("auth0_id") - username = payload.pop("username", None) - user_since = payload.pop("user_since") - - try: - row = database.query( - f"SELECT * FROM user_profiles WHERE auth0_id = ?", - (auth0_id,), - ).fetchone() - if row is not None: - response = dict( - status="error", - message=f"User with auth0_id {auth0_id} already exists", - ) - return Response( - json.dumps(response), - status=403, - mimetype="application/json", - ) - except Exception as e: - return Response( - json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } - ), - status=500, - mimetype="application/json", - ) - - try: - # Unfortunately, it's not possible to use RETURNING - # with SQLite3 without rewriting the PolicyEngineDatabase - # object or implementing a true ORM, thus the double query - database.query( - f"INSERT INTO user_profiles (primary_country, auth0_id, username, user_since) VALUES (?, ?, ?, ?)", - (primary_country, auth0_id, username, user_since), - ) - - row = database.query( - f"SELECT * FROM user_profiles WHERE auth0_id = ?", (auth0_id,) - ).fetchone() - - except Exception as e: - return Response( - json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } - ), - status=500, - mimetype="application/json", - ) - - response_body = dict( - status="ok", - message="Record created successfully", - result=dict( - user_id=row["user_id"], - primary_country=row["primary_country"], - username=row["username"], - user_since=row["user_since"], - ), - ) - - return Response( - json.dumps(response_body), - status=201, - mimetype="application/json", - ) - - -@validate_country -def get_user_profile(country_id: str) -> dict: - """ - Get a user profile in one of two ways: by auth0_id, - which returns all data, and by user_id, which returns - all data except auth0_id - """ - - if len(request.args) != 1: - return Response( - json.dumps( - { - "message": f"Improperly formed request: {len(request.args)} args passed, when 1 is required" - } - ), - status=400, - mimetype="application/json", - ) - - label = "" - value = None - if request.args.get("auth0_id"): - label = "auth0_id" - value = request.args.get("auth0_id") - elif request.args.get("user_id"): - label = "user_id" - value = request.args.get("user_id") - else: - return Response( - json.dumps( - { - "message": "Improperly formed request: auth0_id or user_id must be provided" - } - ), - status=400, - mimetype="application/json", - ) - - try: - row = database.query( - f"SELECT * FROM user_profiles WHERE {label} = ?", (value,) - ).fetchone() - - if row is None: - return Response( - json.dumps( - { - "status": "ok", - "message": "No user found", - "result": None, - } - ), - status=404, - mimetype="application/json", - ) - - readable_row = dict(row) - # Delete auth0_id value if querying from user_id, as that value - # is a more private attribute than all others - if label == "user_id": - del readable_row["auth0_id"] - - except Exception as e: - return Response( - json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } - ), - status=500, - mimetype="application/json", - ) - - response_body = dict( - status="ok", - message=f"User #{readable_row['user_id']} found successfully", - result=readable_row, - ) - - return Response( - json.dumps(response_body), - status=200, - mimetype="application/json", - ) - - -@validate_country -def update_user_profile(country_id: str) -> dict: - """ - Update any part of a user_profile, given a user_id, - except the auth0_id value; any attempt to edit this - will assume malicious intent and 403 - """ - - # Construct the relevant UPDATE request - setter_array = [] - args = [] - payload = request.json - - # This must be popped before all others to ensure - # it is not added as an item to modify - user_id = payload.pop("user_id") - - for key in payload: - if key == "auth0_id": - return Response( - json.dumps( - { - "message": "Unauthorized attempt to modify auth0_id parameter; request denied" - } - ), - status=403, - mimetype="application/json", - ) - setter_array.append(f"{key} = ?") - args.append(payload[key]) - setter_phrase = ", ".join(setter_array) - - args.append(user_id) - sql_request = f"UPDATE user_profiles SET {setter_phrase} WHERE user_id = ?" - - try: - database.query(sql_request, (tuple(args))) - except Exception as e: - return Response( - json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } - ), - status=500, - mimetype="application/json", - ) - - response_body = dict( - status="ok", - message=f"User profile #{user_id} updated successfully", - result=dict(user_id=user_id), - ) - - return Response( - json.dumps(response_body), - status=200, - mimetype="application/json", - ) diff --git a/policyengine_api/routes/user_profile_routes.py b/policyengine_api/routes/user_profile_routes.py new file mode 100644 index 00000000..d859629c --- /dev/null +++ b/policyengine_api/routes/user_profile_routes.py @@ -0,0 +1,135 @@ +from flask import Blueprint, Response, request +from policyengine_api.utils.payload_validators import validate_country +from policyengine_api.data import database +import json +from policyengine_api.services.user_service import UserService +from werkzeug.exceptions import BadRequest, NotFound + +user_profile_bp = Blueprint("user_profile", __name__) +user_service = UserService() + + +@user_profile_bp.route("//user-profile", methods=["POST"]) +@validate_country +def set_user_profile(country_id: str) -> Response: + """ + Creates a new user_profile + """ + + payload = request.json + if payload is None: + raise BadRequest("Payload missing from request") + + auth0_id = payload.pop("auth0_id") + username = payload.pop("username", None) + user_since = payload.pop("user_since") + + created, row = user_service.create_profile( + primary_country=country_id, + auth0_id=auth0_id, + username=username, + user_since=user_since, + ) + + response = dict( + status="ok", + message="Record created successfully" if created else "Record exists", + result=dict( + user_id=row["user_id"], + primary_country=row["primary_country"], + username=row["username"], + user_since=row["user_since"], + ), + ) + return Response( + json.dumps(response), + status=201 if created else 200, + mimetype="application/json", + ) + + +@user_profile_bp.route("//user-profile", methods=["GET"]) +@validate_country +def get_user_profile(country_id: str) -> Response: + auth0_id = request.args.get("auth0_id") + user_id = request.args.get("user_id") + + if (auth0_id is None) and (user_id is None): + raise BadRequest("auth0_id or user_id must be provided") + + row = ( + user_service.get_profile(user_id=user_id) + if auth0_id is None + else user_service.get_profile(auth0_id=auth0_id) + ) + + if row is None: + raise NotFound("No such user") + + readable_row = dict(row) + # Delete auth0_id value if querying from user_id, as that value + # is a more private attribute than all others + if auth0_id is None: + del readable_row["auth0_id"] + + response_body = dict( + status="ok", + message=f"User #{readable_row['user_id']} found successfully", + result=readable_row, + ) + + return Response( + json.dumps(response_body), + status=200, + mimetype="application/json", + ) + + +@user_profile_bp.route("//user-profile", methods=["PUT"]) +@validate_country +def update_user_profile(country_id: str) -> Response: + """ + Update any part of a user_profile, given a user_id, + except the auth0_id value; any attempt to edit this + will assume malicious intent and 403 + """ + + # Construct the relevant UPDATE request + setter_array = [] + args = [] + payload = request.json + + if payload is None: + raise BadRequest("No user data provided in request") + + # TODO: we should validate the payload + # to ensure type safety https://github.com/PolicyEngine/policyengine-api/issues/2054 + user_id = payload.pop("user_id") + username = payload.pop("username", None) + primary_country = payload.pop("primary_country", None) + user_since = payload.pop("user_since", None) + + if user_id is None: + raise BadRequest("Payload must include user_id") + + updated = user_service.update_profile( + user_id=user_id, + primary_country=primary_country, + username=username, + user_since=user_since, + ) + + if not updated: + raise NotFound("No such user id") + + response_body = dict( + status="ok", + message=f"User profile #{user_id} updated successfully", + result=dict(user_id=user_id), + ) + + return Response( + json.dumps(response_body), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/services/user_service.py b/policyengine_api/services/user_service.py new file mode 100644 index 00000000..f7c26d4f --- /dev/null +++ b/policyengine_api/services/user_service.py @@ -0,0 +1,71 @@ +import json +from typing import Any +from policyengine_api.data import database + + +class UserService: + def create_profile( + self, + primary_country: str, + auth0_id: str, + username: str | None, + user_since: str, + ) -> tuple[bool, Any]: + """ + returns true if a new record was created and false otherwise. + """ + # TODO: this is not written as an atomic operation. This will cause intermittent errors + # in some cases + # https://github.com/PolicyEngine/policyengine-api/issues/2058 to resolve after + # this refactor. + row = self.get_profile(auth0_id=auth0_id) + if row is not None: + return False, row + # Unfortunately, it's not possible to use RETURNING + # with SQLite3 without rewriting the PolicyEngineDatabase + # object or implementing a true ORM, thus the double query + database.query( + f"INSERT INTO user_profiles (primary_country, auth0_id, username, user_since) VALUES (?, ?, ?, ?)", + (primary_country, auth0_id, username, user_since), + ) + + row = self.get_profile(auth0_id=auth0_id) + + return (True, row) + + def get_profile( + self, auth0_id: str | None = None, user_id: str | None = None + ) -> Any | None: + key = "user_id" if auth0_id is None else "auth0_id" + value = user_id if auth0_id is None else auth0_id + if value is None: + raise ValueError("you must specify either auth0_id or user_id") + row = database.query( + f"SELECT * FROM user_profiles WHERE {key} = ?", + (value,), + ).fetchone() + + return row + + def update_profile( + self, + user_id: str, + primary_country: str | None, + username: str | None, + user_since: str, + ) -> bool: + fields = dict( + primary_country=primary_country, + username=username, + user_since=user_since, + ) + if self.get_profile(user_id=user_id) is None: + return False + + with_values = [key for key in fields if fields[key] is not None] + fields_update = ",".join([f'"{key}" = ?' for key in with_values]) + query = f"UPDATE user_profiles SET {fields_update} WHERE user_id = ?" + values = [fields[key] for key in with_values] + [user_id] + + database.query(query, (tuple(values))) + return True diff --git a/tests/python/test_user_profile.py b/tests/python/test_user_profile.py index 52c29f09..197fdb69 100644 --- a/tests/python/test_user_profile.py +++ b/tests/python/test_user_profile.py @@ -77,7 +77,7 @@ def test_set_and_get_record(self, rest_client): malicious_updated_profile = { **updated_profile, - "auth0_id": self.auth0_id, + "auth0_id": "BOGUS" } res = rest_client.put( @@ -85,7 +85,14 @@ def test_set_and_get_record(self, rest_client): ) return_object = json.loads(res.text) - assert res.status_code == 403 + assert res.status_code == 200 + + row = database.query( + f"SELECT * FROM user_profiles WHERE username = ?", + (test_username, ), + ).fetchone() + + assert row["auth0_id"] == self.auth0_id database.query( f"DELETE FROM user_profiles WHERE user_id = ? AND auth0_id = ? AND primary_country = ?",