diff --git a/example/__init__.py b/example/__init__.py index 77a7988..370eaff 100644 --- a/example/__init__.py +++ b/example/__init__.py @@ -1,6 +1,15 @@ from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLTransportWSHandler +from starlette.middleware.cors import CORSMiddleware from .schema import schema -app = GraphQL(schema, debug=True) +app = CORSMiddleware( + GraphQL( + schema, + debug=True, + websocket_handler=GraphQLTransportWSHandler(), + ), + allow_origins="*", +) diff --git a/example/database.py b/example/database.py index 7e2b55f..ab3f38e 100644 --- a/example/database.py +++ b/example/database.py @@ -5,11 +5,14 @@ class DataBase: _data: dict[str, dict[int, Any]] - _id: int + _id: dict[str, int] - def __init__(self, data: dict[str, dict[int, Any]], counter: int = 0): + def __init__(self, data: dict[str, dict[int, Any]]): self._data = data - self._id = counter + self._id = {} + + for table_name, table_data in data.items(): + self._id[table_name] = max(list(table_data)) async def get_row(self, table: str, **kwargs) -> Any: assert kwargs, "use kwargs to filter" @@ -40,8 +43,8 @@ async def get_all(self, table: str, **kwargs) -> list[Any]: async def insert(self, table: str, obj: Any): assert obj.id is None, "obj.id attr must be None" - self._id += 1 - obj.id = self._id + self._id[table] += 1 + obj.id = self._id[table] self._data[table][obj.id] = obj @@ -52,4 +55,4 @@ async def delete(self, table: str, id: int): self._data[table].pop(id, None) -db = DataBase(get_data(), 1000) +db = DataBase(get_data()) diff --git a/example/fixture.py b/example/fixture.py index 6958161..001534d 100644 --- a/example/fixture.py +++ b/example/fixture.py @@ -2,33 +2,12 @@ from .models.category import Category from .models.group import Group +from .models.post import Post from .models.user import User def get_data() -> dict[str, dict[int, Any]]: return { - "categories": { - 1: Category( - id=1, - name="First category", - parent_id=None, - ), - 2: Category( - id=2, - name="Second category", - parent_id=None, - ), - 3: Category( - id=3, - name="Child category", - parent_id=1, - ), - 4: Category( - id=4, - name="Other child category", - parent_id=1, - ), - }, "groups": { 1: Group( id=1, @@ -45,26 +24,70 @@ def get_data() -> dict[str, dict[int, Any]]: 1: User( id=1, username="JohnDoe", - email="johndoe@example.com", group_id=1, ), 2: User( id=2, username="Alice", - email="alice@example.com", group_id=1, ), 3: User( id=3, username="Bob", - email="b0b@example.com", group_id=2, ), 4: User( id=4, username="Mia", - email="mia@example.com", group_id=2, ), }, + "categories": { + 1: Category( + id=1, + name="First category", + parent_id=None, + ), + 2: Category( + id=2, + name="Second category", + parent_id=None, + ), + 3: Category( + id=3, + name="Child category", + parent_id=1, + ), + 4: Category( + id=4, + name="Other child category", + parent_id=1, + ), + }, + "posts": { + 1: Post( + id=1, + message="Lorem ipsum", + category_id=1, + poster_id=1, + ), + 2: Post( + id=2, + message="Dolor met", + category_id=2, + poster_id=2, + ), + 3: Post( + id=3, + message="Sit amet", + category_id=3, + poster_id=3, + ), + 4: Post( + id=4, + message="Elit", + category_id=4, + poster_id=4, + ), + }, } diff --git a/example/models/post.py b/example/models/post.py new file mode 100644 index 0000000..f09ecae --- /dev/null +++ b/example/models/post.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class Post: + id: int + message: str + category_id: int | None + poster_id: int | None diff --git a/example/models/user.py b/example/models/user.py index 44381cf..5bcfbf4 100644 --- a/example/models/user.py +++ b/example/models/user.py @@ -5,5 +5,4 @@ class User: id: int username: str - email: str group_id: int diff --git a/example/queries/__init__.py b/example/queries/__init__.py index 546fb50..1366e10 100644 --- a/example/queries/__init__.py +++ b/example/queries/__init__.py @@ -1,10 +1,12 @@ from typing import Any -from . import categories, groups, hello, users +from . import calendar, categories, groups, hello, posts, users queries: Any = [ + calendar.Query, categories.Query, groups.Query, hello.Query, + posts.Query, users.Query, ] diff --git a/example/queries/calendar.py b/example/queries/calendar.py new file mode 100644 index 0000000..e1e8638 --- /dev/null +++ b/example/queries/calendar.py @@ -0,0 +1,20 @@ +from datetime import datetime + +from ariadne_graphql_modules import GraphQLObject +from graphql import GraphQLResolveInfo + +from ..database import db +from ..scalars.date import DateScalar +from ..scalars.datetime import DateTimeScalar + + +class Query(GraphQLObject): + @GraphQLObject.field() + @staticmethod + async def date(obj, info: GraphQLResolveInfo) -> DateScalar: + return DateScalar(datetime.now().date()) + + @GraphQLObject.field() + @staticmethod + async def datetime(obj, info: GraphQLResolveInfo) -> DateTimeScalar: + return DateTimeScalar(datetime.now()) diff --git a/example/queries/categories.py b/example/queries/categories.py index bbff01f..8af3e8a 100644 --- a/example/queries/categories.py +++ b/example/queries/categories.py @@ -1,4 +1,4 @@ -from ariadne_graphql_modules import GraphQLObject +from ariadne_graphql_modules import GraphQLID, GraphQLObject from graphql import GraphQLResolveInfo from ..database import db @@ -13,7 +13,9 @@ async def categories(obj, info: GraphQLResolveInfo) -> list[CategoryType]: @GraphQLObject.field() @staticmethod - async def category(obj, info: GraphQLResolveInfo, id: str) -> CategoryType | None: + async def category( + obj, info: GraphQLResolveInfo, id: GraphQLID + ) -> CategoryType | None: try: id_int = int(id) except (TypeError, ValueError): diff --git a/example/queries/groups.py b/example/queries/groups.py index bd28d4d..c76d51c 100644 --- a/example/queries/groups.py +++ b/example/queries/groups.py @@ -1,4 +1,4 @@ -from ariadne_graphql_modules import GraphQLObject +from ariadne_graphql_modules import GraphQLID, GraphQLObject from graphql import GraphQLResolveInfo from ..database import db @@ -22,7 +22,7 @@ async def groups( @GraphQLObject.field() @staticmethod - async def group(obj, info: GraphQLResolveInfo, id: str) -> GroupType | None: + async def group(obj, info: GraphQLResolveInfo, id: GraphQLID) -> GroupType | None: try: id_int = int(id) except (TypeError, ValueError): diff --git a/example/queries/posts.py b/example/queries/posts.py new file mode 100644 index 0000000..f57de2a --- /dev/null +++ b/example/queries/posts.py @@ -0,0 +1,25 @@ +from typing import Optional + +from ariadne_graphql_modules import GraphQLID, GraphQLObject +from graphql import GraphQLResolveInfo + +from ..database import db +from ..models.post import Post +from ..types.post import PostType + + +class Query(GraphQLObject): + @GraphQLObject.field(graphql_type=list[PostType]) + @staticmethod + async def posts(obj, info: GraphQLResolveInfo) -> list[Post]: + return await db.get_all("posts") + + @GraphQLObject.field(graphql_type=Optional[PostType]) + @staticmethod + async def post(obj, info: GraphQLResolveInfo, id: GraphQLID) -> Post | None: + try: + id_int = int(id) + except (TypeError, ValueError): + return None + + return await db.get_row("posts", id=id_int) diff --git a/example/queries/users.py b/example/queries/users.py index e917a30..efb3816 100644 --- a/example/queries/users.py +++ b/example/queries/users.py @@ -1,4 +1,4 @@ -from ariadne_graphql_modules import GraphQLObject +from ariadne_graphql_modules import GraphQLID, GraphQLObject from graphql import GraphQLResolveInfo from ..database import db @@ -10,3 +10,13 @@ class Query(GraphQLObject): @staticmethod async def users(obj, info: GraphQLResolveInfo) -> list[UserType]: return await db.get_all("users") + + @GraphQLObject.field() + @staticmethod + async def user(obj, info: GraphQLResolveInfo, id: GraphQLID) -> UserType | None: + try: + id_int = int(id) + except (TypeError, ValueError): + return None + + return await db.get_row("users", id=id_int) diff --git a/example/scalars/__init__.py b/example/scalars/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/example/scalars/date.py b/example/scalars/date.py new file mode 100644 index 0000000..28ff68f --- /dev/null +++ b/example/scalars/date.py @@ -0,0 +1,17 @@ +from datetime import date, datetime +from typing import Union, cast + +from ariadne_graphql_modules import GraphQLScalar + + +class DateScalar(GraphQLScalar): + @classmethod + def serialize(cls, value: Union["DateScalar", date]) -> str: + if isinstance(value, cls): + value = cast(date, value.unwrap()) + + return value.strftime("%Y-%m-%d") + + @classmethod + def parse_value(cls, value: str) -> date: + return datetime.strptime("%Y-%m-%d").date() diff --git a/example/scalars/datetime.py b/example/scalars/datetime.py new file mode 100644 index 0000000..0ef14a7 --- /dev/null +++ b/example/scalars/datetime.py @@ -0,0 +1,17 @@ +from datetime import datetime +from typing import Union, cast + +from ariadne_graphql_modules import GraphQLScalar + + +class DateTimeScalar(GraphQLScalar): + @classmethod + def serialize(cls, value: Union["DateTimeScalar", datetime]) -> str: + if isinstance(value, cls): + value = cast(datetime, value.unwrap()) + + return value.isoformat() + + @classmethod + def parse_value(cls, value: str) -> datetime: + return datetime.fromisoformat(value) diff --git a/example/schema.py b/example/schema.py index e619b84..6d7a5ff 100644 --- a/example/schema.py +++ b/example/schema.py @@ -1,6 +1,11 @@ from ariadne_graphql_modules import make_executable_schema from .queries import queries +from .subscriptions import subscriptions -schema = make_executable_schema(queries, convert_names_case=True) +schema = make_executable_schema( + queries, + subscriptions, + convert_names_case=True, +) diff --git a/example/subscriptions/__init__.py b/example/subscriptions/__init__.py new file mode 100644 index 0000000..d3988a1 --- /dev/null +++ b/example/subscriptions/__init__.py @@ -0,0 +1,5 @@ +from typing import Any + +from . import events + +subscriptions: Any = [events.Subscription] diff --git a/example/subscriptions/events.py b/example/subscriptions/events.py new file mode 100644 index 0000000..910c309 --- /dev/null +++ b/example/subscriptions/events.py @@ -0,0 +1,26 @@ +import random +from asyncio import sleep +from typing import AsyncGenerator +from datetime import datetime + +from ariadne_graphql_modules import GraphQLSubscription +from graphql import GraphQLResolveInfo + +from ..types.event import EventType + + +class Subscription(GraphQLSubscription): + event: EventType + + @GraphQLSubscription.source("event") + async def source_event(obj, info: GraphQLResolveInfo) -> AsyncGenerator[int, None]: + i = 0 + + while True: + i += 1 + yield i + await sleep(float(random.randint(1, 50)) / 10) + + @GraphQLSubscription.resolver("event") + async def resolve_event(obj: int, info: GraphQLResolveInfo) -> dict: + return {"id": obj, "payload": datetime.now()} diff --git a/example/types/category.py b/example/types/category.py index 24d2731..f9d8128 100644 --- a/example/types/category.py +++ b/example/types/category.py @@ -1,9 +1,15 @@ +from typing import TYPE_CHECKING + from ariadne_graphql_modules import GraphQLObject from ariadne import gql from graphql import GraphQLResolveInfo from ..database import db from ..models.category import Category +from ..models.post import Post + +if TYPE_CHECKING: + from .post import PostType class CategoryType(GraphQLObject): @@ -14,6 +20,7 @@ class CategoryType(GraphQLObject): name: String! parent: Category children: [Category!]! + posts: [Post!]! } """ ) @@ -34,3 +41,8 @@ async def resolve_children( obj: Category, info: GraphQLResolveInfo ) -> list[Category]: return await db.get_all("categories", parent_id=obj.id) + + @GraphQLObject.resolver("posts", list["PostType"]) + @staticmethod + async def resolve_posts(obj: Category, info: GraphQLResolveInfo) -> list[Post]: + return await db.get_all("posts", category_id=obj.id) diff --git a/example/types/event.py b/example/types/event.py new file mode 100644 index 0000000..e3a4c7a --- /dev/null +++ b/example/types/event.py @@ -0,0 +1,10 @@ +from ariadne_graphql_modules import GraphQLID, GraphQLObject + +from ..scalars.datetime import DateTimeScalar + + +class EventType(GraphQLObject): + id: GraphQLID + message: DateTimeScalar + + __aliases__ = {"message": "payload"} diff --git a/example/types/post.py b/example/types/post.py new file mode 100644 index 0000000..52ef02a --- /dev/null +++ b/example/types/post.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Annotated, Optional + +from ariadne_graphql_modules import GraphQLObject, deferred +from ariadne import gql +from graphql import GraphQLResolveInfo + +from ..database import db +from ..models.category import Category +from ..models.post import Post +from ..models.user import User + +from .category import CategoryType + +if TYPE_CHECKING: + from .user import UserType + + +class PostType(GraphQLObject): + __schema__ = gql( + """ + type Post { + id: ID! + content: String! + category: Category + poster: User + } + """ + ) + __aliases__ = {"content": "message"} + + @GraphQLObject.resolver("category", CategoryType) + @staticmethod + async def resolve_category(obj: Post, info: GraphQLResolveInfo) -> Category: + return await db.get_row("categories", id=obj.category_id) + + @GraphQLObject.resolver( + "poster", Optional[Annotated["UserType", deferred(".user")]] + ) + @staticmethod + async def resolve_poster(obj: Post, info: GraphQLResolveInfo) -> User | None: + if not obj.poster_id: + return None + + return await db.get_row("users", id=obj.poster_id) diff --git a/example/types/user.py b/example/types/user.py index b719fee..0976d42 100644 --- a/example/types/user.py +++ b/example/types/user.py @@ -5,7 +5,9 @@ from ..database import db from ..models.group import Group +from ..models.post import Post from ..models.user import User +from .post import PostType if TYPE_CHECKING: from .group import GroupType @@ -14,10 +16,15 @@ class UserType(GraphQLObject): id: GraphQLID username: str - email: str group: Annotated["GroupType", deferred(".group")] + posts: list[PostType] @GraphQLObject.resolver("group") @staticmethod async def resolve_group(user: User, info: GraphQLResolveInfo) -> Group: return await db.get_row("groups", id=user.group_id) + + @GraphQLObject.resolver("posts") + @staticmethod + async def resolve_posts(user: User, info: GraphQLResolveInfo) -> list[Post]: + return await db.get_all("posts", poster_id=user.id) diff --git a/tests/test_query.py b/tests/test_query.py index d7172f3..d348f34 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -67,6 +67,20 @@ async def test_query_groups_field_member_arg(exec_query): } +@pytest.mark.asyncio +async def test_query_user(exec_query): + result = await exec_query('{ user(id: "2") { id username group { name } } }') + assert result.data == { + "user": { + "id": "2", + "username": "Alice", + "group": { + "name": "Admins", + }, + }, + } + + @pytest.mark.asyncio async def test_query_users(exec_query): result = await exec_query("{ users { id username group { name } } }") @@ -102,3 +116,74 @@ async def test_query_users(exec_query): }, ], } + + +@pytest.mark.asyncio +async def test_query_categories_field(exec_query): + result = await exec_query( + "{ categories { id name children { id name } posts { id content } } }" + ) + assert result.data == { + "categories": [ + { + "id": "1", + "name": "First category", + "children": [ + { + "id": "3", + "name": "Child category", + }, + { + "id": "4", + "name": "Other child category", + }, + ], + "posts": [ + { + "id": "1", + "content": "Lorem ipsum", + }, + ], + }, + { + "id": "2", + "name": "Second category", + "children": [], + "posts": [ + { + "id": "2", + "content": "Dolor met", + }, + ], + }, + ], + } + + +@pytest.mark.asyncio +async def test_query_category_field(exec_query): + result = await exec_query( + '{ category(id: "1") { id name children { id name } posts { id content } } }' + ) + assert result.data == { + "category": { + "id": "1", + "name": "First category", + "children": [ + { + "id": "3", + "name": "Child category", + }, + { + "id": "4", + "name": "Other child category", + }, + ], + "posts": [ + { + "id": "1", + "content": "Lorem ipsum", + }, + ], + }, + }