Skip to content

Commit

Permalink
Allow users to have specific write access to certain endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
jackleland committed Dec 3, 2024
1 parent fb1a469 commit 9f0c215
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 37 deletions.
13 changes: 10 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ services:
image: local/psat-server-web
# Run the tests, using the root user to avoid permission issues
command: >
bash -c "python manage.py makemigrations --noinput &&
python manage.py test
bash -c "
python manage.py makemigrations --noinput &&
mysql -h db -u root -p${MYSQL_ROOT_PASSWORD} < sql/init.sql &&
cd ../schema &&
mysql -h db -u root -p${MYSQL_ROOT_PASSWORD} ${MYSQL_TEST_DATABASE} < create_schema.sql &&
cd ../atlas &&
python manage.py test --keepdb --noinput
|| exit $?"
volumes:
# Mount the code directories into the image to allow for live code changes
Expand All @@ -95,6 +100,8 @@ services:
- ./psat_server_web/atlas/atlas:/app/psat_server_web/atlas/atlas
- ./psat_server_web/atlas/accounts:/app/psat_server_web/atlas/accounts
- ./psat_server_web/atlas/tests:/app/psat_server_web/atlas/tests
- ./psat_server_web/schema:/app/psat_server_web/schema
- ./docker/init.sql:/app/psat_server_web/atlas/sql/init.sql
ports:
- 8087:8087
depends_on:
Expand Down Expand Up @@ -126,7 +133,7 @@ services:
- DJANGO_NAMESERVER_API_URL=''
- DJANGO_LASAIR_TOKEN=${DJANGO_LASAIR_TOKEN}
- DJANGO_DUSTMAP_LOCATION=/tmp/dustmap
- DJANGO_LOG_LEVEL=DEBUG
- DJANGO_LOG_LEVEL=ERROR
- API_TOKEN_EXPIRY=10
- DJANGO_PANSTARRS_TOKEN=${PANSTARRS_TOKEN}
- DJANGO_PANSTARRS_BASE_URL=${PANSTARRS_BASE_URL}
6 changes: 6 additions & 0 deletions docker/init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- drop and create database for use in testing
DROP DATABASE IF EXISTS `atlas_test`;
CREATE DATABASE `atlas_test`;

-- -- create user and grant rights
-- GRANT ALL ON atlas_test.* TO 'atlas'@'%';
6 changes: 5 additions & 1 deletion psat_server_web/atlas/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ class GroupProfile(models.Model):
on_delete=models.CASCADE,
related_name='profile'
)
api_write_access = models.BooleanField(
default=False,
help_text='Does the group have write access to the API?'
)
token_expiration_time = models.DurationField(
help_text='in days, default 1 day (24*60*60 seconds)',
default=timedelta(days=1)
)
)
description = models.TextField(
blank=True,
help_text='What is the group for?'
Expand Down
9 changes: 7 additions & 2 deletions psat_server_web/atlas/atlasapi/authentication.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from re import match
import logging

from django.conf import settings
from django.utils.timezone import now
from rest_framework.authtoken.models import Token
from rest_framework.authentication import TokenAuthentication
from rest_framework.exceptions import AuthenticationFailed

logger = logging.getLogger(__name__)

class ExpiringTokenAuthentication(TokenAuthentication):
"""
Token authentication using the ExpiringToken model, which has an expiry
Expand All @@ -23,8 +26,9 @@ def authenticate_credentials(self, key):
try:
group_profile = user.groups.first().profile
except AttributeError:
# TODO: Log this error?
raise AuthenticationFailed('Could not authenticate: Group has no profile. Please contact administrator.')
msg = 'Could not authenticate: Group has no profile. Please contact administrator.'
logger.error(msg)
raise AuthenticationFailed(msg)
token_expiration_time = group_profile.token_expiration_time.total_seconds()
else:
# Otherwise use the default expiration time
Expand All @@ -33,6 +37,7 @@ def authenticate_credentials(self, key):
# Calculate the token's age and compare it to the expiration setting
token_age = (now() - token.created).total_seconds()
if token_age > token_expiration_time:
logger.warning(f'User {user} attempted to use an expired token.')
raise AuthenticationFailed('Token has expired.')

return user, token
Expand Down
39 changes: 36 additions & 3 deletions psat_server_web/atlas/atlasapi/permissions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
import logging

from rest_framework.permissions import BasePermission, SAFE_METHODS

class IsApprovedUser(BasePermission):

logger = logging.getLogger(__name__)

class HasReadAccess(BasePermission):
def has_permission(self, request, view):
# Allow all safe methods (GET, OPTIONS, HEAD)
if request.method in SAFE_METHODS:
return True

# Allow POST if the user is authenticated
return (request.user
and request.user.is_authenticated)


class HasWriteAccess(BasePermission):
def has_permission(self, request, view):
# Allow all safe methods (GET, OPTIONS, HEAD)
if request.method in SAFE_METHODS:
return True

# Only allow POST if the user is authenticated, active, and staff
write_fl = False
user = request.user
# Retrieve the user's group and get the api write access flag from the
# group profile
if user.groups.exists():
try:
group_profile = user.groups.first().profile
write_fl = group_profile.api_write_access
except AttributeError:
# If the group has no profile, then there's something wrong with
# the database. This should be fixed by an administrator, but
# we don't need to block the user from accessing the API.
msg = 'Could not authorise based on group: Group has no profile.'
logger.error(msg)
write_fl = False

# Only allow POST to write endpoints if the user is authenticated and is
# either in a writeable group or is a staff member
return (request.user
and request.user.is_authenticated
and request.user.is_staff)
and (write_fl or request.user.is_staff))
31 changes: 16 additions & 15 deletions psat_server_web/atlas/atlasapi/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ObjectDetectionListSerializer,
)
from .authentication import QueryAuthentication, ExpiringTokenAuthentication
from .permissions import IsApprovedUser
from .permissions import HasReadAccess, HasWriteAccess

def retcode(message):
if 'error' in message: return status.HTTP_400_BAD_REQUEST
Expand Down Expand Up @@ -68,7 +68,7 @@ def post(self, request, *args, **kwargs):

class ConeView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
serializer = ConeSerializer(data=request.GET, context={'request': request})
Expand All @@ -87,7 +87,7 @@ def post(self, request, format=None):

class ObjectsView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = ObjectsSerializer(data=request.GET, context={'request': request})
Expand All @@ -106,7 +106,7 @@ def post(self, request, format=None):

class ObjectListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = ObjectListSerializer(data=request.GET, context={'request': request})
Expand All @@ -125,7 +125,7 @@ def post(self, request, format=None):

class VRAScoresView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
return Response({"Error": "GET is not implemented for this service."})
Expand All @@ -142,7 +142,7 @@ def post(self, request, format=None):

class VRAScoresListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = VRAScoresListSerializer(data=request.GET, context={'request': request})
Expand All @@ -162,7 +162,7 @@ def post(self, request, format=None):
# appropriate to the circumstances. E.g. if object is not found generate a 404, etc.
class VRATodoView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
return Response({"Error": "GET is not implemented for this service."})
Expand All @@ -179,7 +179,7 @@ def post(self, request, format=None):
# 2024-05-07 KWS Added VRATodoListView.
class VRATodoListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = VRATodoListSerializer(data=request.GET, context={'request': request})
Expand All @@ -197,7 +197,7 @@ def post(self, request, format=None):

class TcsObjectGroupsView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
return Response({"Error": "GET is not implemented for this service."})
Expand All @@ -213,7 +213,7 @@ def post(self, request, format=None):

class TcsObjectGroupsListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = TcsObjectGroupsListSerializer(data=request.GET, context={'request': request})
Expand All @@ -232,7 +232,8 @@ def post(self, request, format=None):

class TcsObjectGroupsDeleteView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
# TODO: Change this to HasDeleteAccess?
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
return Response({"Error": "GET is not implemented for this service."})
Expand All @@ -255,7 +256,7 @@ def post(self, request, format=None):
# appropriate to the circumstances. E.g. if object is not found generate a 404, etc.
class VRARankView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
return Response({"Error": "GET is not implemented for this service."})
Expand All @@ -272,7 +273,7 @@ def post(self, request, format=None):
# 2024-05-22 KWS Added VRARankListView.
class VRARankListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = VRARankListSerializer(data=request.GET, context={'request': request})
Expand All @@ -292,7 +293,7 @@ def post(self, request, format=None):
# 2024-09-24 KWS Added ExternalCrossmatchesListView.
class ExternalCrossmatchesListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasReadAccess]

def get(self, request):
serializer = ExternalCrossmatchesListSerializer(data=request.GET, context={'request': request})
Expand All @@ -312,7 +313,7 @@ def post(self, request, format=None):
# 2024-09-24 KWS Added ExternalCrossmatchesListView.
class ObjectDetectionListView(APIView):
authentication_classes = [ExpiringTokenAuthentication, QueryAuthentication]
permission_classes = [IsAuthenticated&IsApprovedUser]
permission_classes = [IsAuthenticated&HasWriteAccess]

def get(self, request):
serializer = ObjectDetectionListSerializer(data=request.GET, context={'request': request})
Expand Down
Loading

0 comments on commit 9f0c215

Please sign in to comment.