From af60e1d75f1872e2d7c0c29d7837a82498638e35 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Thu, 9 May 2024 09:04:28 +0200 Subject: [PATCH] Initial commit --- .github/workflows/code_quality.yml | 33 +++ .github/workflows/deploy.yml | 39 ++++ .github/workflows/run_tests.yml | 33 +++ .gitignore | 61 +++++ LICENSE | 29 +++ README.md | 74 +++++++ ariadne_lambda/__init__.py | 4 + ariadne_lambda/base.py | 55 +++++ ariadne_lambda/graphql.py | 65 ++++++ ariadne_lambda/http_handler.py | 318 +++++++++++++++++++++++++++ ariadne_lambda/schema.py | 77 +++++++ pyproject.toml | 109 +++++++++ tests/__init__.py | 0 tests/conftest.py | 27 +++ tests/data/api_gateway_v1_event.json | 42 ++++ tests/data/api_gateway_v2_event.json | 28 +++ tests/test_graphql.py | 61 +++++ tests/test_http_handler.py | 136 ++++++++++++ tests/test_schema.py | 70 ++++++ 19 files changed, 1261 insertions(+) create mode 100644 .github/workflows/code_quality.yml create mode 100644 .github/workflows/deploy.yml create mode 100644 .github/workflows/run_tests.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 ariadne_lambda/__init__.py create mode 100644 ariadne_lambda/base.py create mode 100644 ariadne_lambda/graphql.py create mode 100644 ariadne_lambda/http_handler.py create mode 100644 ariadne_lambda/schema.py create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/data/api_gateway_v1_event.json create mode 100644 tests/data/api_gateway_v2_event.json create mode 100644 tests/test_graphql.py create mode 100644 tests/test_http_handler.py create mode 100644 tests/test_schema.py diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml new file mode 100644 index 0000000..52a80e3 --- /dev/null +++ b/.github/workflows/code_quality.yml @@ -0,0 +1,33 @@ +name: Code Quality + +on: + push: + branches: + - "*" + tags: + - "*" + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + + - name: Build the package + run: | + hatch build + + - name: Run linter + run: | + hatch run ruff check . diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..2abcab0 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,39 @@ +name: Publish on PyPI + +on: + release: + types: + - published + +jobs: + build-and-publish: + name: Build and publish to PyPI + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@master + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install Hatch + run: | + python -m pip install --upgrade pip + pip install hatch + + - name: Build with Hatch + run: | + hatch build + + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@v1.4.2 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + packages_dir: dist + verbose: true + + - name: Clean distribution directory + run: rm -rf dist/* diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 0000000..f220d9b --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,33 @@ +name: Run Tests + +on: + push: + branches: + - "*" + tags: + - "*" + +jobs: + build-and-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + + - name: Build the package + run: | + hatch build + + - name: Run tests + run: | + hatch run test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bc8ee0b --- /dev/null +++ b/.gitignore @@ -0,0 +1,61 @@ +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Serverless directories +.serverless + +.ruff_cache/ +node_modules/ +__pycache__ + +.DS_Store +.env +.pytest_cache +.python-version + +package +*.zip + +# Ignore the .terraform directory +.terraform/ + +# Ignore Terraform lock files +.terraform.lock.hcl + +# Ignore all .tfstate files +*.tfstate +*.tfstate.* + +# Ignore crash log files +crash.log + +# Ignore any .tfvars files that may contain sensitive data +*.tfvars + +# Ignore override files as they are usually used to override resources locally +override.tf +override.tf.json +*_override.tf +*_override.tf.json + +# Ignore CLI configuration files +.terraformrc +terraform.rc + +# pytest-cov +.coverage \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0d51753 --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2024, Mirumee Labs +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..2d567a1 --- /dev/null +++ b/README.md @@ -0,0 +1,74 @@ +# Ariadne AWS Lambda Extension + +This package extends the Ariadne library by adding a GraphQL HTTP handler designed for use in AWS Lambda environments. It enables easy integration of GraphQL services with AWS serverless infrastructure, making it straightforward to deploy GraphQL APIs without worrying about the underlying server management. + +## Introduction + +This project provides an extension to the Ariadne GraphQL library, specifically tailored for deploying GraphQL APIs on AWS Lambda. It simplifies handling GraphQL requests by providing a custom HTTP handler that seamlessly integrates with the AWS Lambda and API Gateway, allowing developers to focus on their GraphQL schema and resolvers instead of server and infrastructure management. + +## Installation + +To install the extension, use pip: + +```bash +pip install ariadne-lambda +``` + +## Quick Start + +Here's a basic example of how to use the extension in your AWS Lambda function: + +```python +from typing import Any + +from ariadne import QueryType, gql, make_executable_schema +from ariadne_lambda.graphql import GraphQLLambda +from asgiref.sync import async_to_sync +from aws_lambda_powertools.utilities.typing import LambdaContext + +type_defs = gql( + """ + type Query { + hello: String! + } +""" +) +query = QueryType() + + +@query.field("hello") +def resolve_hello(_, info): + request = info.context["request"] + user_agent = request.headers.get("user-agent", "guest") + return "Hello, %s!" % user_agent + + +schema = make_executable_schema(type_defs, query) +graphql_app = GraphQLLambda(schema=schema) + + +def graphql_http_handler(event: dict[str, Any], context: LambdaContext): + return async_to_sync(graphql_app)(event, context) +``` + +## Documentation + +For full documentation on Ariadne, visit [Ariadne's Documentation](https://ariadnegraphql.org/docs/). For details on AWS Lambda, refer to the [AWS Lambda Developer Guide](https://docs.aws.amazon.com/lambda/latest/dg/welcome.html). + +## Features + +- Easy integration with AWS Lambda and API Gateway. +- Support for GraphQL queries and mutations. +- Customizable context and error handling. +- Seamless extension of the Ariadne library for serverless applications. + +## Contributing + +We welcome all contributions to Ariadne! If you've found a bug or issue, feel free to use [GitHub issues](https://github.com/mirumee/ariadne-lambda/issues). If you have any questions or feedback, don't hesitate to catch us on [GitHub discussions](https://github.com/mirumee/ariadne/discussions/). + +For guidance and instructions, please see [CONTRIBUTING.md](CONTRIBUTING.md). + +Also make sure you follow [@AriadneGraphQL](https://twitter.com/AriadneGraphQL) on Twitter for latest updates, news and random musings! + +**Crafted with ❤️ by [Mirumee Software](http://mirumee.com)** +hello@mirumee.com diff --git a/ariadne_lambda/__init__.py b/ariadne_lambda/__init__.py new file mode 100644 index 0000000..4dbaeb0 --- /dev/null +++ b/ariadne_lambda/__init__.py @@ -0,0 +1,4 @@ +from ariadne_lambda.graphql import GraphQLLambda +from ariadne_lambda.http_handler import GraphQLAWSAPIHTTPGatewayHandler + +__all__ = ["GraphQLLambda", "GraphQLAWSAPIHTTPGatewayHandler"] diff --git a/ariadne_lambda/base.py b/ariadne_lambda/base.py new file mode 100644 index 0000000..76ebd23 --- /dev/null +++ b/ariadne_lambda/base.py @@ -0,0 +1,55 @@ +from abc import abstractmethod +from inspect import isawaitable +from typing import Any + +from ariadne.asgi.handlers.base import GraphQLHandler +from aws_lambda_powertools.utilities.typing import LambdaContext + + +class GraphQLLambdaHandler(GraphQLHandler): + @abstractmethod + async def handle(self, event: dict, context: LambdaContext): + """An entrypoint for the AWS Lambda connection handler. + + This method is called by Ariadne AWS Lambda GraphQL application. Subclasses + are expected to handle specific event content based on the gateway of invocation + triggeting the lambda function (e.g. API Gateway, ALB, etc.). + + # Required arguments + + `event`: The AWS Lambda event dictionary. + + `context`: The AWS Lambda context object. + """ + raise NotImplementedError( + "Subclasses of GraphQLLambdaHandler must implement the 'handle' method" + ) + + async def get_context_for_request( + self, + request: Any, + data: dict, + ) -> Any: + """Return the context value for the request. + + This method is called by the handler to get the context value for the + request. Subclasses can override it to provide custom context value + based on the request. + + # Required arguments + + `request`: The request object as defined by the 'handle' method. + + `data`: GraphQL data from connection. + """ + if callable(self.context_value): + try: + context = self.context_value(request, data) # type: ignore + except TypeError: # TODO: remove in 0.20 + context = self.context_value(request) # type: ignore + + if isawaitable(context): + context = await context + return context + + return self.context_value or {"request": request} \ No newline at end of file diff --git a/ariadne_lambda/graphql.py b/ariadne_lambda/graphql.py new file mode 100644 index 0000000..ef13358 --- /dev/null +++ b/ariadne_lambda/graphql.py @@ -0,0 +1,65 @@ +from logging import Logger, LoggerAdapter +from typing import Any + +from ariadne.explorer import Explorer, ExplorerGraphiQL +from ariadne.format_error import format_error +from ariadne.types import ( + ContextValue, + ErrorFormatter, + QueryParser, + QueryValidator, + RootValue, + ValidationRules, +) +from graphql import ExecutionContext, GraphQLSchema + +from ariadne_lambda.base import GraphQLLambdaHandler +from ariadne_lambda.http_handler import GraphQLAWSAPIHTTPGatewayHandler + + +class GraphQLLambda: + def __init__( + self, + schema: GraphQLSchema, + *, + context_value: ContextValue | None = None, + root_value: RootValue | None = None, + query_parser: QueryParser | None = None, + query_validator: QueryValidator | None = None, + validation_rules: ValidationRules | None = None, + execute_get_queries: bool = False, + debug: bool = False, + introspection: bool = True, + explorer: Explorer | None = None, + logger: None | str | Logger | LoggerAdapter = None, + error_formatter: ErrorFormatter = format_error, + execution_context_class: type[ExecutionContext] | None = None, + http_handler: GraphQLLambdaHandler | None = None, + ) -> None: + if http_handler: + self.http_handler = http_handler + else: + self.http_handler = GraphQLAWSAPIHTTPGatewayHandler() + + if not explorer: + explorer = ExplorerGraphiQL() + + self.http_handler.configure( + schema, + context_value, + root_value, + query_parser, + query_validator, + validation_rules, + execute_get_queries, + debug, + introspection, + explorer, + logger, + error_formatter, + execution_context_class, + ) + + async def __call__(self, event: dict, context: Any) -> dict: + response = await self.http_handler.handle(event, context) + return response diff --git a/ariadne_lambda/http_handler.py b/ariadne_lambda/http_handler.py new file mode 100644 index 0000000..bb4c453 --- /dev/null +++ b/ariadne_lambda/http_handler.py @@ -0,0 +1,318 @@ +import json +from inspect import isawaitable +from typing import Any + +from ariadne.constants import ( + DATA_TYPE_JSON, + DATA_TYPE_MULTIPART, +) +from ariadne.exceptions import HttpBadRequestError, HttpError +from ariadne.explorer import Explorer +from ariadne.graphql import graphql +from ariadne.types import ( + ContextValue, + ExtensionList, + Extensions, + GraphQLResult, + Middlewares, +) +from aws_lambda_powertools.utilities.typing import LambdaContext +from graphql import DocumentNode, MiddlewareManager + +from ariadne_lambda.base import GraphQLLambdaHandler +from ariadne_lambda.schema import Request, Response + + +class GraphQLAWSAPIHTTPGatewayHandler(GraphQLLambdaHandler): + """Handler for AWS Lambda functions triggered by HTTP requests via API Gateway. + + Designed to process both Query and Mutation operations in a GraphQL schema. + Ideal for serverless architectures, providing a bridge between AWS Lambda and GraphQL. + """ + + def __init__( + self, + extensions: Extensions | None = None, + middleware: Middlewares | None = None, + middleware_manager_class: type[MiddlewareManager] | None = None, + ) -> None: + super().__init__() + + self.extensions = extensions + self.middleware = middleware + self.middleware_manager_class = middleware_manager_class or MiddlewareManager + + async def handle(self, event: dict, context: LambdaContext): + """Processes AWS Lambda event triggered by an API Gateway HTTP request. + + Extracts the HTTP request from the Lambda event, + and delegates to the appropriate handler based on the HTTP method. + """ + request = Request.create_from_event(event) + return (await self.handle_request(request)).render() + + async def handle_request(self, request: Request) -> Response: + """Determines the request type (GET or POST) and routes to the corresponding GraphQL + processor. + Supports executing queries directly from GET requests, or handling + introspection and GraphQL explorers.""" + if request.method == "GET": + if self.execute_get_queries and request.params and request.params.get("query"): + return await self.graphql_http_server(request) + if self.introspection and self.explorer: + # only render explorer when introspection is enabled + return await self.render_explorer(request, self.explorer) + + if request.method == "POST": + return await self.graphql_http_server(request) + + return self.handle_not_allowed_method(request) + + async def render_explorer(self, request: Request, explorer: Explorer) -> Response: + """Return a HTML response with GraphQL explorer. + + # Required arguments: + + `request`: the `Request` instance from Starlette or FastAPI. + + `explorer`: an `Explorer` instance that implements the + `html(request: Request)` method which returns either the `str` with HTML + or `None`. If explorer returns `None`, `405` method not allowed response + is returned instead. + """ + explorer_html = explorer.html(request) + if isawaitable(explorer_html): + explorer_html = await explorer_html + if explorer_html: + return Response(body=explorer_html, headers={"Content-Type": "text/html"}) + + return self.handle_not_allowed_method(request) + + async def graphql_http_server(self, request: Request) -> Response: + """Executes GraphQL queries or mutations based on the POST request's body. + + Parses the request, executes the GraphQL query, and formats the response as JSON. + """ + try: + data = await self.extract_data_from_request(request) + except HttpError as error: + return Response( + status_code=400, + body=error.message or error.status, + headers={"Content-Type": "text/plain"}, + ) + + success, result = await self.execute_graphql_query(request, data) + return await self.create_json_response(request, result, success) + + async def extract_data_from_request(self, request: Request): + """ + Executes a GraphQL query or mutation based on the parsed request data. + + This method processes the GraphQL request, executes the query or mutation, + and returns the results in a JSON-formatted response suitable for + AWS Lambda's HTTP response format. + + Args: + request: A `Request` object containing the parsed HTTP request from API Gateway. + + Returns: + A `Response` object containing the JSON-formatted result of the GraphQL operation. + """ + content_type = request.headers.get("content-type", "") + content_type = content_type.split(";")[0] + + if content_type == DATA_TYPE_JSON: + return await self.extract_data_from_json_request(request) + if content_type == DATA_TYPE_MULTIPART: + return await self.extract_data_from_multipart_request(request) + if ( + request.method == "GET" + and self.execute_get_queries + and request.params + and request.params.get("query") + ): + return self.extract_data_from_get_request(request) + + raise HttpBadRequestError( + "Posted content must be of type {} or {}".format( # noqa: UP032 + DATA_TYPE_JSON, DATA_TYPE_MULTIPART + ) + ) + + async def extract_data_from_json_request(self, request: Request) -> dict: + """ + Parses the JSON body of an HTTP request to extract the GraphQL query data. + + Args: + request: A `Request` object containing the parsed HTTP request from API Gateway. + + Returns: + A dictionary containing the extracted GraphQL query data. + + Raises: + HttpBadRequestError: If the request body is not valid JSON. + """ + try: + return json.loads(request.body) + except (TypeError, ValueError) as ex: + raise HttpBadRequestError("Request body is not a valid JSON") from ex + + async def extract_data_from_multipart_request(self, request: Request): + raise NotImplementedError("Multipart requests are not yet supported in AWS Lambda") + + def extract_data_from_get_request(self, request: Request) -> dict: + """ + Extracts the GraphQL query data from the query string parameters of a GET request. + + Args: + request: A `Request` object containing the parsed HTTP request from API Gateway. + + Returns: + A dictionary containing the GraphQL query, operation name, and variables. + + Raises: + HttpBadRequestError: If the query parameters are missing or invalid. + """ + if not request.params: + raise HttpBadRequestError("Query variables are not valid") + query = request.params["query"].strip() + operation_name = request.params.get("operationName", "").strip() + variables = request.params.get("variables", "").strip() + + clean_variables = None + + if variables: + try: + clean_variables = json.loads(variables) + except (TypeError, ValueError) as ex: + raise HttpBadRequestError("Variables query arg is not a valid JSON") from ex + + return { + "query": query, + "operationName": operation_name or None, + "variables": clean_variables, + } + + async def execute_graphql_query( + self, + request: Any, + data: Any, + *, + context_value: Any = None, + query_document: DocumentNode | None = None, + ) -> GraphQLResult: + """ + Executes the GraphQL query using the provided data and optional context. + + Args: + request: The request object, typically containing metadata and headers. + data: A dictionary containing the query, variables, and operation name. + context_value: Optional context passed to the GraphQL execution. + query_document: An optional pre-parsed GraphQL query document. + + Returns: + A `GraphQLResult` object containing the results of the query execution. + """ + if context_value is None: + context_value = await self.get_context_for_request(request, data) + + extensions = await self.get_extensions_for_request(request, context_value) + # TODO: figure out how to mix those with powertools middleware + # middleware = await self.get_middleware_for_request(request, context_value) + middleware = None + + if self.schema is None: + raise TypeError("schema is not set, call configure method to initialize it") + + if isinstance(request, Request): + require_query = request.method == "GET" + else: + require_query = False + + return await graphql( + self.schema, + data, + context_value=context_value, + root_value=self.root_value, + query_parser=self.query_parser, + query_validator=self.query_validator, + query_document=query_document, + validation_rules=self.validation_rules, + require_query=require_query, + debug=self.debug, + introspection=self.introspection, + logger=self.logger, + error_formatter=self.error_formatter, + extensions=extensions, + middleware=middleware, + middleware_manager_class=self.middleware_manager_class, + execution_context_class=self.execution_context_class, + ) + + async def get_extensions_for_request( + self, request: Any, context: ContextValue | None + ) -> ExtensionList: + """ + Determines the extensions to be used for the current GraphQL request. + + Args: + request: The request object, providing access to request-specific data. + context: Optional context associated with the request. + + Returns: + A list of extensions to be used during the execution of the GraphQL query. + """ + if callable(self.extensions): + extensions = self.extensions(request, context) + if isawaitable(extensions): + extensions = await extensions # type: ignore + return extensions + return self.extensions + + async def create_json_response( + self, + request: Request, # pylint: disable=unused-argument + result: dict, + success: bool, + ) -> Response: + """ + Formats the GraphQL execution result into a JSON response suitable for AWS Lambda. + + Args: + request: The original request object. + result: The result dictionary from GraphQL execution. + success: A boolean flag indicating if the query execution was successful. + + Returns: + A `Response` object containing the JSON-formatted GraphQL result. + """ + status_code = 200 if success else 400 + return Response( + status_code=status_code, + body=json.dumps(result), + headers={"Content-Type": "application/json"}, + ) + + def handle_not_allowed_method(self, request: Request): + """ + Generates a response for HTTP methods not supported by the GraphQL handler. + + This method is typically invoked for HTTP methods other than GET or POST, or + when the GraphQL endpoint is configured to reject certain types of requests. + + Args: + request: The original request object. + + Returns: + A `Response` object indicating the HTTP method is not allowed. + """ + allowed_methods = ["OPTIONS", "POST"] + if self.introspection: + allowed_methods.append("GET") + allow_header = {"Allow": ", ".join(allowed_methods)} + + if request.method == "OPTIONS": + return Response(headers=allow_header) + + return Response(status_code=405, headers=allow_header) diff --git a/ariadne_lambda/schema.py b/ariadne_lambda/schema.py new file mode 100644 index 0000000..a917b4b --- /dev/null +++ b/ariadne_lambda/schema.py @@ -0,0 +1,77 @@ +from typing import Any, Literal + +from pydantic import BaseModel + + +class Request(BaseModel): + event: dict[str, Any] + + path: str + method: Literal["DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] + + body: str + is_base64_encoded: bool + + headers: dict[str, str] + params: dict[str, str] + + @property + def route_key(self): + return f"{self.method} {self.path}" + + @classmethod + def create_from_event(cls, event: dict[str, Any]) -> "Request": + # this is needed for API Gateway V1 when header keys comes capitalized + # but on API Gateway V2 it comes as lowered + lowered_key_headers = {key.lower(): value for key, value in event["headers"].items()} + request_data = { + "event": event, + "body": "", + "is_base64_encoded": event["isBase64Encoded"], + "headers": lowered_key_headers, + "params": event["queryStringParameters"], + } + + if http_context := event["requestContext"].get("http"): + # Api Gateway V2 + request_data["path"] = http_context["path"] + request_data["method"] = http_context["method"].upper() + + else: + # API Gateway V1 + # Application Load Balancer + request_data["path"] = event["path"] + request_data["method"] = event["httpMethod"].upper() + + if body := event["body"]: + request_data["body"] = body + + if not request_data["params"]: + request_data["params"] = {} + + return cls(**request_data) + + +class Response: + status_code: int + body: str + headers: dict + + def __init__(self, status_code: int = 200, body: str = "", headers: dict | None = None): + self.status_code = status_code + self.body = body + if not headers: + headers = {} + self.headers = headers + + def __iter__(self): + yield "statusCode", self.status_code + yield "body", self.body + yield "headers", self.headers + + def render(self) -> dict: + return { + "statusCode": self.status_code, + "body": self.body, + "headers": self.headers, + } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5ecfd44 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,109 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "ariadne-lambda" +version = "0.3.0" +description = 'This package extends the Ariadne library by adding a GraphQL HTTP handler designed for use in AWS Lambda environments.' +readme = "README.md" +requires-python = ">=3.8" +license = "MIT" +keywords = [] +authors = [ + { name = "Mirumee Software", email = "it@mirumee.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "ariadne>=0.23.0,<0.24.0", + "aws-lambda-powertools>=2.35.1,<3.0.0", + "jmespath", + "pydantic>=2.4.0,<3.0.0", +] + +[project.urls] +Documentation = "https://github.com/mirumee/ariadne-lambda#readme" +Issues = "https://github.com/mirumee/ariadne-lambda/issues" +Source = "https://github.com/mirumee/ariadne-lambda" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-asyncio", + "ruff", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] + +[tool.hatch.envs.types] +dependencies = [ + "mypy>=1.0.0", +] +[tool.hatch.envs.types.scripts] +check = "mypy --install-types --non-interactive {args:ariadne_lambda tests}" + +[tool.coverage.run] +source_pkgs = ["ariadne_lambda", "tests"] +branch = true +parallel = true + + +[tool.coverage.paths] +ariadne_lambda = ["ariadne_lambda"] +tests = ["tests", "*/ariadne-lambda/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[tool.ruff] +line-length = 99 +target-version = "py312" + +# rules: https://beta.ruff.rs/docs/rules +# F - pyflakes +# E - pycodestyle +# G - flake8-logging-format +# I - isort +# N - pep8-naming +# Q - flake8-quotes +# UP - pyupgrade +# C90 - mccabe (complexity) +# T20 - flake8-print +# TID - flake8-tidy-imports + +[tool.ruff.lint] +select = ["E", "F", "G", "I", "N", "Q", "UP", "C90", "T20", "TID"] + +[tool.ruff.lint.mccabe] +max-complexity = 10 + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ca469f0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import json +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +TESTS_DIR = Path(__file__).parent.absolute() + + +def load_data_file(fine_name: str): + with open(TESTS_DIR / "data" / fine_name) as data_file: + return json.load(data_file) + + +@pytest.fixture +def api_gateway_v1_event_payload(): + return load_data_file("api_gateway_v1_event.json") + + +@pytest.fixture +def api_gateway_v2_event_payload(): + return load_data_file("api_gateway_v2_event.json") + + +@pytest.fixture +def lambda_context(): + return MagicMock() diff --git a/tests/data/api_gateway_v1_event.json b/tests/data/api_gateway_v1_event.json new file mode 100644 index 0000000..d5edeca --- /dev/null +++ b/tests/data/api_gateway_v1_event.json @@ -0,0 +1,42 @@ +{ + + "path": "/my-resource", + "resource": "/my-resource", + "httpMethod": "GET", + + "body": null, + "isBase64Encoded": false, + + "headers": { + "Host": "api.example.com", + "User-Agent": "Mozilla/5.0", + "Accept": "application/json" + }, + + "queryStringParameters": { + "param1": "value1", + "param2": "value2" + }, + + "requestContext": { + "resourceId": "resource-id", + "apiId": "api-id", + "resourcePath": "/my-resource", + "httpMethod": "GET", + "requestId": "request-id", + "accountId": "account-id", + "identity": { + "apiKey": null, + "userAgent": "Mozilla/5.0", + "user": null, + "cognitoAuthenticationType": null, + "cognitoAuthenticationProvider": null, + "sourceIp": "127.0.0.1", + "accountId": null + }, + "stage": "prod" + }, + + "pathParameters": null, + "stageVariables": null + } \ No newline at end of file diff --git a/tests/data/api_gateway_v2_event.json b/tests/data/api_gateway_v2_event.json new file mode 100644 index 0000000..42b5b7c --- /dev/null +++ b/tests/data/api_gateway_v2_event.json @@ -0,0 +1,28 @@ +{ + "version": "2.0", + "rawPath": "/my-resource", + "body": "", + "isBase64Encoded": false, + "headers": { + "host": "api.example.com", + "user-agent": "Mozilla/5.0", + "accept": "application/json" + }, + "queryStringParameters": { + "param1": "value1", + "param2": "value2" + }, + "requestContext": { + "http": { + "method": "GET", + "path": "/my-resource", + "protocol": "HTTP/1.1", + "sourceIp": "127.0.0.1" + }, + "routeKey": "GET /my-resource", + "accountId": "account-id", + "stage": "prod" + }, + "routeKey": "GET /my-resource", + "rawQueryString": "param1=value1¶m2=value2" + } \ No newline at end of file diff --git a/tests/test_graphql.py b/tests/test_graphql.py new file mode 100644 index 0000000..edc8025 --- /dev/null +++ b/tests/test_graphql.py @@ -0,0 +1,61 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from graphql import GraphQLSchema + +from ariadne_lambda.graphql import GraphQLLambda +from ariadne_lambda.http_handler import GraphQLAWSAPIHTTPGatewayHandler + + +@pytest.fixture +def schema(): + return GraphQLSchema() + + +@pytest.fixture +def http_handler_mock(): + return MagicMock(spec=GraphQLAWSAPIHTTPGatewayHandler) + + +@pytest.fixture +def event(): + return { + "httpMethod": "POST", + "body": '{"query": "{ testQuery }"}', + "headers": {"Content-Type": "application/json"}, + } + + +@pytest.fixture +def context(): + return {} + + +@pytest.mark.asyncio +@patch("ariadne_lambda.graphql.GraphQLAWSAPIHTTPGatewayHandler") +async def test_graphql_lambda_initialization(http_handler_class_mock, schema): + # Given + http_handler_instance_mock = http_handler_class_mock.return_value + + # When + GraphQLLambda(schema) + + # Then + http_handler_class_mock.assert_called_once() + http_handler_instance_mock.configure.assert_called_once() + + +@pytest.mark.asyncio +async def test_graphql_lambda_call(schema, event, context, http_handler_mock): + # Given + http_handler_mock.handle = AsyncMock( + return_value={"statusCode": 200, "body": '{"data": {"test": "value"}}'} + ) + graphql_lambda = GraphQLLambda(schema, http_handler=http_handler_mock) + + # When + response = await graphql_lambda(event, context) + + # Then + http_handler_mock.handle.assert_called_once_with(event, context) + assert response == {"statusCode": 200, "body": '{"data": {"test": "value"}}'} diff --git a/tests/test_http_handler.py b/tests/test_http_handler.py new file mode 100644 index 0000000..fc1fa8a --- /dev/null +++ b/tests/test_http_handler.py @@ -0,0 +1,136 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ariadne_lambda.http_handler import ( + GraphQLAWSAPIHTTPGatewayHandler, + Request, + Response, +) + + +@pytest.fixture +def handler(): + handler = GraphQLAWSAPIHTTPGatewayHandler() + handler.schema = MagicMock() + handler.execute_graphql_query = AsyncMock() + return handler + + +@pytest.mark.asyncio +async def test_handle(handler, api_gateway_v1_event_payload, lambda_context): + # Given + mocked_handler_request = AsyncMock(return_value=Response(body="response", status_code=200)) + handler.handle_request = mocked_handler_request + + # When + result = await handler.handle(api_gateway_v1_event_payload, lambda_context) + + # Then + mocked_handler_request.assert_called_once() + assert result["statusCode"] == 200 + assert result["body"] == "response" + + +@pytest.mark.asyncio +async def test_handle_request_post_graphql(handler, api_gateway_v1_event_payload): + # Given + request = Request.create_from_event(api_gateway_v1_event_payload) + request.method = "POST" + mocked_graphql_http_server = AsyncMock(return_value=Response(body="response", status_code=200)) + handler.graphql_http_server = mocked_graphql_http_server + + # When + response = await handler.handle_request(request) + + # Then + mocked_graphql_http_server.assert_called_once() + assert response.status_code == 200 + assert "response" == response.body + + +@pytest.mark.asyncio +async def test_handle_request_get_explorer(handler, api_gateway_v1_event_payload): + # Given + handler.introspection = True + handler.explorer = True + mocked_render_explorer = AsyncMock(return_value=Response(body="response", status_code=200)) + handler.render_explorer = mocked_render_explorer + request = Request.create_from_event(api_gateway_v1_event_payload) + + # When + response = await handler.handle_request(request) + + # Then + mocked_render_explorer.assert_called_once() + assert response.status_code == 200 + assert "response" == response.body + + +@pytest.mark.asyncio +async def test_handle_request_method_not_allowed(handler, api_gateway_v1_event_payload): + # Given + request = Request.create_from_event(api_gateway_v1_event_payload) + request.method = "PUT" + mocked_handle_not_allowed_method = MagicMock( + return_value=Response(body="Method Not Allowed", status_code=405) + ) + handler.handle_not_allowed_method = mocked_handle_not_allowed_method + + # Then + response = await handler.handle_request(request) + + # Then + mocked_handle_not_allowed_method.assert_called_once() + assert response.status_code == 405 + assert "Method Not Allowed" in response.body + + +@pytest.mark.asyncio +async def test_handle_request_get_graphql_http_server(handler, api_gateway_v1_event_payload): + # Given + handler.execute_get_queries = True + mocked_graphql_http_server = AsyncMock(return_value=Response(body="response", status_code=200)) + handler.graphql_http_server = mocked_graphql_http_server + request = Request.create_from_event(api_gateway_v1_event_payload) + request.params = {"query": "hello"} + + # When + response = await handler.handle_request(request) + + # Then + mocked_graphql_http_server.assert_called_once() + assert response.status_code == 200 + assert "response" == response.body + + +@pytest.mark.asyncio +async def test_render_explorer_enabled(handler, api_gateway_v1_event_payload): + # Given + request = Request.create_from_event(api_gateway_v1_event_payload) + mocked_explorer = MagicMock() + mocked_explorer.html.return_value = "
GraphQL Explorer
" + handler.explorer = mocked_explorer + + # When + response = await handler.render_explorer(request, handler.explorer) + + # Then + assert response.status_code == 200 + assert "
GraphQL Explorer
" in response.body + assert response.headers["Content-Type"] == "text/html" + + +@pytest.mark.asyncio +async def test_render_explorer_no_explorer(handler, api_gateway_v1_event_payload): + # Given + request = Request.create_from_event(api_gateway_v1_event_payload) + mocked_explorer = MagicMock() + mocked_explorer.html.return_value = None + handler.explorer = mocked_explorer + + # When + response = await handler.render_explorer(request, handler.explorer) + + # Then + assert response.status_code == 405 diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..181fbe9 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,70 @@ +from ariadne_lambda.schema import Request, Response + + +def test_api_v1_event(api_gateway_v1_event_payload): + lowered_keys_headers = { + key.lower(): value for key, value in api_gateway_v1_event_payload["headers"].items() + } + request = Request.create_from_event(api_gateway_v1_event_payload) + assert request.method == api_gateway_v1_event_payload["httpMethod"] + assert request.path == api_gateway_v1_event_payload["path"] + assert request.body == "" + assert request.is_base64_encoded is False + assert request.headers == lowered_keys_headers + assert request.params == api_gateway_v1_event_payload["queryStringParameters"] + + +def test_api_v2_event(api_gateway_v2_event_payload): + request = Request.create_from_event(api_gateway_v2_event_payload) + assert request.method == api_gateway_v2_event_payload["requestContext"]["http"]["method"] + assert request.path == api_gateway_v2_event_payload["requestContext"]["http"]["path"] + assert request.body == "" + assert request.is_base64_encoded is False + assert request.headers == api_gateway_v2_event_payload["headers"] + assert request.params == api_gateway_v2_event_payload["queryStringParameters"] + + +def test_response_initialization(): + # When + response = Response(status_code=200, body="OK", headers={"Content-Type": "application/json"}) + + # Then + assert response.status_code == 200 + assert response.body == "OK" + assert response.headers == {"Content-Type": "application/json"} + + +def test_response_default_values(): + # When + response = Response() + + # Then + assert response.status_code == 200 + assert response.body == "" + assert response.headers == {} + + +def test_response_iter(): + # When + response = Response(status_code=404, body="Not Found", headers={"X-Custom-Header": "value"}) + response_dict = dict(response) + + # Then + assert response_dict == { + "statusCode": 404, + "body": "Not Found", + "headers": {"X-Custom-Header": "value"}, + } + + +def test_response_render(): + # When + response = Response(status_code=500, body="Error", headers={"Content-Type": "text/plain"}) + response_rendered = response.render() + + # Then + assert response_rendered == { + "statusCode": 500, + "body": "Error", + "headers": {"Content-Type": "text/plain"}, + }