Skip to content

Commit

Permalink
Merge pull request #58 from kizill/main
Browse files Browse the repository at this point in the history
  • Loading branch information
ponytailer authored Jul 29, 2024
2 parents ae4a084 + d16fb84 commit 1bb52c8
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 169 deletions.
9 changes: 6 additions & 3 deletions pydantic_client/clients/abstract_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Dict, Tuple
from typing import Any, Dict, Tuple, Callable

from pydantic_client.schema.http_request import HttpRequest


class AbstractClient:
@staticmethod
def data_encoder(x):
return x

def do_request(self, request: HttpRequest) -> Any:
raise NotImplementedError
Expand All @@ -12,8 +15,8 @@ def do_request(self, request: HttpRequest) -> Any:
def parse_request(request: HttpRequest) -> Tuple[Dict, Dict]:
if request.data:
data = request.data
json = {}
json = None
else:
data = {}
data = None
json = request.json_body
return data, json
4 changes: 2 additions & 2 deletions pydantic_client/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class AIOHttpClient(AbstractClient):
runner_class: Type[Proxy] = AsyncClientProxy
runner_class: Proxy = AsyncClientProxy

def __init__(self, base_url: str, headers: Dict[str, Any] = None):
self.base_url = base_url.rstrip("/")
Expand All @@ -31,6 +31,6 @@ async def do_request(self, request: HttpRequest) -> Any:
async with req as resp:
resp.raise_for_status()
if resp.status == 200:
return resp.json()
return await resp.json()
except BaseException as e:
raise e
9 changes: 4 additions & 5 deletions pydantic_client/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ def _get_url(self, args) -> str:
)
return url_result.path + "?" + querystring if querystring else url_result.path

@staticmethod
def dict_to_body(func_args: Dict[str, Any]) -> Dict:
def dict_to_body(self, func_args: Dict[str, Any]) -> Dict:
keys = list(func_args.keys())
if len(keys) == 1:
return func_args[keys[0]]
return self.instance.data_encoder(func_args[keys[0]])
return {}

def get_request(self, *args, **kwargs):
Expand Down Expand Up @@ -81,7 +80,7 @@ def __call__(self, *args, **kwargs):
request = self.get_request(*args, **kwargs)
raw_response = self.instance.do_request(request)
if self.method_info.response_type:
return self.method_info.response_type(**raw_response)
return self.method_info.response_type(val=raw_response).val
return raw_response


Expand All @@ -91,5 +90,5 @@ async def __call__(self, *args, **kwargs):
request = self.get_request(*args, **kwargs)
raw_response = await self.instance.do_request(request)
if self.method_info.response_type:
return self.method_info.response_type(**raw_response)
return self.method_info.response_type(val=raw_response).val
return raw_response
16 changes: 12 additions & 4 deletions pydantic_client/utility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import inspect
from typing import Any, Callable, Dict, Type
from typing import Any, Callable, Dict, Optional, Type

from pydantic import BaseModel

from pydantic_client.schema.method_info import MethodInfo


def create_response_type(annotations: Dict[str, Any]) -> Type:
return annotations.pop("return", None)
def create_response_type(annotations: Dict[str, Any]) -> Optional[Type]:
response_type = annotations.pop("return", None)
if response_type is None:
return response_type

class T(BaseModel):
val: response_type

return T


def parse_func(
Expand All @@ -15,7 +24,6 @@ def parse_func(
method: str,
form_body: bool,
):

spec = inspect.getfullargspec(func)
annotations = spec.annotations.copy()
return MethodInfo(
Expand Down
12 changes: 9 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

autoflake==2.3.1
black==24.4.0
fastapi==0.110.3
flake8==6.0.0
isort==5.13.2
mypy==1.8.0
pre-commit==3.5.0
pytest==8.1.1
pytest==8.2.0
pytest-cov==5.0.0
pytest-asyncio
aiohttp
httpx[http2]
aiohttp==3.9.5
httpx[http2]==0.27.0

pydantic==2.5.2
python-multipart==0.0.9
requests==2.31.0
uvicorn==0.29.0
4 changes: 4 additions & 0 deletions tests/book.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
class Book(BaseModel):
name: str
age: int


def get_the_book() -> Book:
return Book(name="name", age=1)
52 changes: 41 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@


class R(RequestsClient):
def __init__(self):
super().__init__("http://localhost")

@get("/books/{book_id}?query={query}")
def get_book(self, book_id: int, query: str) -> Book:
Expand All @@ -19,6 +17,10 @@ def get_book(self, book_id: int, query: str) -> Book:
def get_raw_book(self, book_id: int):
...

@get("/books/{book_id}/num_pages")
def get_book_num_pages(self, book_id: int) -> int:
...

@post("/books", form_body=True)
def create_book_form(self, book: Book) -> Book:
""" will post the form with book"""
Expand All @@ -39,9 +41,6 @@ def patch_book(self, book_id: int, book: Book) -> Book:


class AsyncR(AIOHttpClient):
def __init__(self):
super().__init__("http://localhost")

@get("/books/{book_id}?query={query}")
async def get_book(self, book_id: int, query: str) -> Book:
...
Expand All @@ -50,6 +49,10 @@ async def get_book(self, book_id: int, query: str) -> Book:
async def get_raw_book(self, book_id: int):
...

@get("/books/{book_id}/num_pages")
def get_book_num_pages(self, book_id: int) -> int:
...

@post("/books", form_body=True)
async def create_book_form(self, book: Book) -> Book:
""" will post the form with book"""
Expand All @@ -70,14 +73,41 @@ async def patch_book(self, book_id: int, book: Book) -> Book:


class HttpxR(HttpxClient, AsyncR):
def __init__(self):
super().__init__("http://localhost")
...


@pytest.fixture(scope="session")
def fastapi_server_url() -> str:
from uvicorn import run
from .fastapi_service import app
from threading import Thread
host = "localhost"
port = 12098 # TODO: add port availability check
def start_server():
run(app, host=host, port=port)

thread = Thread(target=start_server, daemon=True)
thread.start()

for _ in range(10):
assert thread.is_alive(), "Fastapi thread died"
try:
url = f"http://{host}:{port}"
import requests
book = requests.get(f"{url}/books/5")
book.raise_for_status()
return "http://localhost:12098/"
except Exception:
from time import sleep
sleep(1)

raise Exception("Can't start fastapi server in 10 seconds")


@pytest.fixture
def clients():
def clients(fastapi_server_url):
yield (
R(),
AsyncR(),
HttpxR()
R(fastapi_server_url),
AsyncR(fastapi_server_url),
HttpxR(fastapi_server_url)
)
42 changes: 42 additions & 0 deletions tests/fastapi_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing_extensions import Annotated

from fastapi import FastAPI, Form

from tests.book import Book, get_the_book

app = FastAPI()


@app.get("/books/{book_id}?query={query}")
async def get() -> Book:
return get_the_book()


@app.get("/books/{book_id}")
def get_raw_book(book_id: int):
return get_the_book()


@app.get("/books/{book_id}/num_pages")
def get_book_num_pages(book_id: int) -> int:
return 42


@app.post("/books")
def create_book_form(name: Annotated[str, Form()], age: Annotated[int, Form()]) -> Book:
return Book(name=name, age=age)


@app.put("/books/{book_id}")
def change_book(book_id: int, book: Book) -> Book:
return book


@app.delete("/books/{book_id}")
def delete_book(book_id: int) -> Book:
return get_the_book()


@app.patch("/books/{book_id}")
def patch_book(book_id: int, book: Book) -> Book:
return book
100 changes: 0 additions & 100 deletions tests/helpers.py

This file was deleted.

Loading

0 comments on commit 1bb52c8

Please sign in to comment.