Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mfosterw committed Feb 20, 2024
1 parent 5a5861f commit 6d6efa7
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 40 deletions.
2 changes: 1 addition & 1 deletion democrasite/webiscite/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class PullRequestFactory(DjangoModelFactory):
number = Sequence(lambda n: -n) # Use negative numbers to represent fake PRs
title = Faker("text", max_nb_chars=50)
author_name = Faker("user_name")
state = Faker("random_element", elements=["open", "closed"])
state = "open"
additions = Faker("random_int")
deletions = Faker("random_int")
sha = Faker("pystr", min_chars=40, max_chars=40)
Expand Down
34 changes: 20 additions & 14 deletions democrasite/webiscite/tests/test_drf_views.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@
import pytest
from django.utils.timezone import get_current_timezone
from rest_framework.test import APIRequestFactory

from democrasite.webiscite.api.serializers import PullRequestSerializer
from democrasite.webiscite.tests.factories import PullRequestFactory

from ..api.views import BillViewSet, PullRequestViewSet
from ..models import Bill, PullRequest
from ..models import Bill


class TestPullRequestSerializer:
def test_serializer_read_only(self):
serializer = PullRequestSerializer(PullRequestFactory())
with pytest.raises(NotImplementedError):
serializer.create({})
with pytest.raises(NotImplementedError):
assert serializer.update({}, {})


class TestPullRequestViewSet:
def test_viewset_fields(self, api_rf: APIRequestFactory, user):
pr = PullRequest.objects.create(
number=-1,
title="Test PR",
additions=0,
deletions=0,
sha="123",
author_name=user.username,
)
pull_request = PullRequestFactory()

view = PullRequestViewSet.as_view(actions={"get": "retrieve"})
request = api_rf.get("/fake-url/")
request.user = user

response = view(request, pk=pr.pk)
response = view(request, pk=pull_request.number)

assert (
response.data.items()
>= { # Subset of the response data
"title": pr.title,
"sha": pr.sha,
"time_created": pr.time_created.astimezone(get_current_timezone()).isoformat(),
"url": f"http://testserver/api/pull-requests/{pr.pk}/",
"title": pull_request.title,
"sha": pull_request.sha,
"time_created": pull_request.time_created.astimezone(get_current_timezone()).isoformat(),
"url": f"http://testserver/api/pull-requests/{pull_request.number}/",
}.items()
)

Expand Down
22 changes: 20 additions & 2 deletions democrasite/webiscite/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@
from django_celery_beat.models import PeriodicTask
from factory import Faker

from ..models import Bill, PullRequest
from .factories import BillFactory, GithubPullRequestFactory
from ..models import Bill, PullRequest, Vote
from .factories import BillFactory, GithubPullRequestFactory, PullRequestFactory


class TestVote:
"""Test class for all tests related to the Vote model"""

def test_vote_str(self, bill: Bill, user: Any):
bill.vote(user, True)
assert str(Vote.objects.get(user=user, support=True)) == f"{user} for {bill}"


class TestPullRequest:
Expand All @@ -36,6 +44,16 @@ def test_close(self, bill: Bill):
assert pull_request.state == "closed"
assert not pull_request.bill_set.filter(state=Bill.States.OPEN).exists()

def test_close_no_bill(self):
pull_request = PullRequestFactory()
assert pull_request.state == "open"

pull_request.close()

pull_request.refresh_from_db()
assert pull_request.state == "closed"
assert not pull_request.bill_set.exists()


class TestBill:
"""Test class for all tests related to the Bill model"""
Expand Down
19 changes: 11 additions & 8 deletions democrasite/webiscite/tests/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from django.test import RequestFactory
from django.utils.encoding import force_bytes

from democrasite.webiscite.tests.factories import PullRequestFactory

from ..models import Bill, PullRequest
from ..webhooks import GithubWebhookView, PullRequestHandler, github_webhook_view

Expand Down Expand Up @@ -108,14 +106,19 @@ def test_pr_handler_get_response(self, pr_handler: PullRequestHandler, bill: Bil

@patch.object(Bill, "create_from_pr")
def test_opened(self, mock_create, pr_handler: PullRequestHandler):
pull_request = PullRequestFactory()
pr = {"number": pull_request.number, "title": "Test PR", "diff_url": "http://test.com/diff"}
mock_create.return_value = (pull_request, None)
# This just calls the create_from_pr method
response = pr_handler.opened({})

mock_create.assert_called_once_with({})
assert response == mock_create.return_value

response = pr_handler.opened(pr)
@patch.object(Bill, "create_from_pr")
def test_reopened(self, mock_create, pr_handler: PullRequestHandler):
# This does the exact same thing
response = pr_handler.reopened({})

mock_create.assert_called_once_with(pr)
assert response == (pull_request, None)
mock_create.assert_called_once_with({})
assert response == mock_create.return_value

def test_closed_no_pr(self, pr_handler: PullRequestHandler):
response = pr_handler.closed({"number": 1})
Expand Down
31 changes: 16 additions & 15 deletions democrasite/webiscite/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from logging import getLogger
from typing import Any

import requests
# import requests
from django.conf import settings
from django.contrib.auth import get_user_model
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseForbidden, JsonResponse, request
Expand Down Expand Up @@ -130,23 +130,24 @@ def _validate_signature(header_signature: str, request_body: bytes) -> HttpRespo

return None

@staticmethod
def _validate_remote_addr(remote_addr: str) -> str:
"""Validate the remote address of a request from a webhook
# Unused because I'm worried it will take too long
# @staticmethod
# def _validate_remote_addr(remote_addr: str) -> str:
# """Validate the remote address of a request from a webhook

Args:
request: The request from the webhook
# Args:
# request: The request from the webhook

Returns:
str: Error message if the remote address is invalid, otherwise an empty string
"""
# Get the list of IP addresses that GitHub uses to send webhooks
# This will slow the response but since it's not a frequent request, it's acceptable
webhook_allowed_hosts = requests.get("https://api.github.com/meta", timeout=5).json()["hooks"]
if remote_addr not in webhook_allowed_hosts:
return "Invalid remote address for GitHub webhook request"
# Returns:
# str: Error message if the remote address is invalid, otherwise an empty string
# """
# # Get the list of IP addresses that GitHub uses to send webhooks
# # This will slow the response but since it's not a frequent request, it's acceptable
# webhook_allowed_hosts = requests.get("https://api.github.com/meta", timeout=5).json()["hooks"]
# if remote_addr not in webhook_allowed_hosts:
# return "Invalid remote address for GitHub webhook request"

return ""
# return ""

def validate_request(self, headers: request.HttpHeaders, body: bytes) -> HttpResponse | None:
return self._validate_header(headers) or self._validate_signature(headers["x-hub-signature-256"], body)
Expand Down

0 comments on commit 6d6efa7

Please sign in to comment.