Skip to content

Commit

Permalink
Client: Don’t validate API responses
Browse files Browse the repository at this point in the history
It’s better to break when an accessed field doesn’t exist or contains
unexpected contents than to break when an otherwise unused field can’t
be validated because the client and server use different versions of the
API model.

Fixes: #941

Signed-off-by: Nils Philippsen <[email protected]>
  • Loading branch information
nphilipp committed Sep 28, 2023
1 parent 23d14e6 commit 3b44216
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 160 deletions.
100 changes: 45 additions & 55 deletions duffy/client/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@
from typing import Generator

import yaml
from pydantic import BaseModel

from ..api_models import (
PoolModel,
PoolResult,
PoolResultCollection,
SessionModel,
SessionResult,
SessionResultCollection,
)
from .main import DuffyAPIErrorModel
from .main import DuffyAPIErrorModel, JSONValue


class DuffyFormatter:
Expand All @@ -26,31 +17,27 @@ def __init_subclass__(cls, format, **kwargs):
def new_for_format(cls, format, *args, **kwargs):
return cls._subclasses_for_format[format](*args, **kwargs)

@staticmethod
def result_as_compatible_dict(result: BaseModel) -> dict:
return json.loads(result.model_dump_json())

def format(self, result: BaseModel) -> str:
def format(self, result: JSONValue) -> str:
raise NotImplementedError()


class DuffyJSONFormatter(DuffyFormatter, format="json"):
def format(self, result: BaseModel) -> str:
return result.model_dump_json(indent=2)
def format(self, result: JSONValue) -> str:
return json.dumps(result)


class DuffyYAMLFormatter(DuffyFormatter, format="yaml"):
def format(self, result: BaseModel) -> str:
return yaml.dump(self.result_as_compatible_dict(result))
def format(self, result: JSONValue) -> str:
return yaml.dump(result)


class DuffyFlatFormatter(DuffyFormatter, format="flat"):
model_to_flattener = {
DuffyAPIErrorModel: "flatten_api_error",
PoolResult: "flatten_pool_result",
PoolResultCollection: "flatten_pools_result",
SessionResult: "flatten_session_result",
SessionResultCollection: "flatten_sessions_result",
field_name_to_flattener = {
"error": "flatten_api_error",
"pool": "flatten_pool_result",
"pools": "flatten_pools_result",
"session": "flatten_session_result",
"sessions": "flatten_sessions_result",
}

@staticmethod
Expand All @@ -71,52 +58,55 @@ def format_key_value(key, value):
def flatten_api_error(self, api_error: DuffyAPIErrorModel) -> Generator[str, None, None]:
yield self.format_key_value("error", api_error.error.detail)

def flatten_pool(self, pool: PoolModel) -> Generator[str, None, None]:
fields = {"pool_name": pool.name, "fill_level": pool.fill_level}
if hasattr(pool, "levels"):
def flatten_pool(self, pool: JSONValue) -> Generator[str, None, None]:
fields = {
"pool_name": pool["name"],
"fill_level": pool.get("fill-level", pool.get("fill_level")),
}
if "levels" in pool:
fields.update(
{
"levels_provisioning": pool.levels.provisioning,
"levels_ready": pool.levels.ready,
"levels_contextualizing": pool.levels.contextualizing,
"levels_deployed": pool.levels.deployed,
"levels_deprovisioning": pool.levels.deprovisioning,
"levels_provisioning": pool["levels"]["provisioning"],
"levels_ready": pool["levels"]["ready"],
"levels_contextualizing": pool["levels"]["contextualizing"],
"levels_deployed": pool["levels"]["deployed"],
"levels_deprovisioning": pool["levels"]["deprovisioning"],
}
)
yield " ".join(self.format_key_value(key, value) for key, value in fields.items())

def flatten_pool_result(self, result: PoolResult) -> Generator[str, None, None]:
yield from self.flatten_pool(result.pool)
def flatten_pool_result(self, result: JSONValue) -> Generator[str, None, None]:
yield from self.flatten_pool(result["pool"])

def flatten_pools_result(self, result: PoolResultCollection) -> Generator[str, None, None]:
for pool in result.pools:
def flatten_pools_result(self, result: JSONValue) -> Generator[str, None, None]:
for pool in result["pools"]:
yield from self.flatten_pool(pool)

def flatten_session(self, session: SessionModel) -> Generator[str, None, None]:
for node in sorted(session.nodes, key=lambda node: (node.pool, node.hostname, node.ipaddr)):
def flatten_session(self, session: JSONValue) -> Generator[str, None, None]:
for node in sorted(
session["nodes"], key=lambda node: (node["pool"], node["hostname"], node["ipaddr"])
):
fields = {
"session_id": session.id,
"active": session.active,
"created_at": session.created_at,
"retired_at": session.retired_at,
"pool": node.pool,
"hostname": node.hostname,
"ipaddr": node.ipaddr,
"session_id": session["id"],
"active": session["active"],
"created_at": session["created_at"],
"retired_at": session["retired_at"],
"pool": node["pool"],
"hostname": node["hostname"],
"ipaddr": node["ipaddr"],
}
yield " ".join(self.format_key_value(key, value) for key, value in fields.items())

def flatten_session_result(self, result: SessionResult) -> Generator[str, None, None]:
yield from self.flatten_session(result.session)
def flatten_session_result(self, result: JSONValue) -> Generator[str, None, None]:
yield from self.flatten_session(result["session"])

def flatten_sessions_result(
self, result: SessionResultCollection
) -> Generator[str, None, None]:
for session in result.sessions:
def flatten_sessions_result(self, result: JSONValue) -> Generator[str, None, None]:
for session in result["sessions"]:
yield from self.flatten_session(session)

def format(self, result: BaseModel) -> str:
for model, flattener in self.model_to_flattener.items():
if isinstance(result, model):
def format(self, result: JSONValue) -> str:
for field_name, flattener in self.field_name_to_flattener.items():
if field_name in result:
return "\n".join(getattr(self, flattener)(result))

raise TypeError("Can't flatten {result!r}")
56 changes: 16 additions & 40 deletions duffy/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
import httpx
from pydantic import BaseModel, ConfigDict

from ..api_models import (
PoolResult,
PoolResultCollection,
SessionCreateModel,
SessionResult,
SessionResultCollection,
SessionUpdateModel,
)
from ..api_models import SessionCreateModel, SessionUpdateModel
from ..configuration import config

JSONValue = Union[None, bool, str, float, int, List["JSONValue"], Dict[str, "JSONValue"]]


class _MethodEnum(str, Enum):
get = "get"
Expand Down Expand Up @@ -80,9 +75,8 @@ def _query_method(
*,
in_dict: Optional[Dict[str, Any]] = None,
in_model: Optional[BaseModel] = None,
out_model: BaseModel,
expected_status: Union[HTTPStatus, Sequence[HTTPStatus]] = HTTPStatus.OK,
) -> BaseModel:
) -> JSONValue:
add_kwargs = {}
if in_dict is not None:
add_kwargs["json"] = in_model(**in_dict).model_dump()
Expand All @@ -96,56 +90,38 @@ def _query_method(

if response.status_code not in expected_status:
try:
return DuffyAPIErrorModel(error=response.json())
return DuffyAPIErrorModel(error=response.json()).model_dump(by_alias=True)
except Exception as exc:
response.raise_for_status()
raise RuntimeError(f"Can't process response: {response}") from exc

return out_model(**response.json())
return response.json()

def list_sessions(self) -> SessionResultCollection:
return self._query_method(
_MethodEnum.get,
"/sessions",
out_model=SessionResultCollection,
)
def list_sessions(self) -> JSONValue:
return self._query_method(_MethodEnum.get, "/sessions")

def show_session(self, session_id: int) -> SessionResult:
return self._query_method(
_MethodEnum.get,
f"/sessions/{session_id}",
out_model=SessionResult,
)
def show_session(self, session_id: int) -> JSONValue:
return self._query_method(_MethodEnum.get, f"/sessions/{session_id}")

def request_session(self, nodes_specs: List[Dict[str, str]]) -> SessionResult:
def request_session(self, nodes_specs: List[Dict[str, str]]) -> JSONValue:
return self._query_method(
_MethodEnum.post,
"/sessions",
in_dict={"nodes_specs": nodes_specs},
in_model=SessionCreateModel,
out_model=SessionResult,
expected_status=HTTPStatus.CREATED,
)

def retire_session(self, session_id: int) -> SessionResult:
def retire_session(self, session_id: int) -> JSONValue:
return self._query_method(
_MethodEnum.put,
f"/sessions/{session_id}",
in_dict={"active": False},
in_model=SessionUpdateModel,
out_model=SessionResult,
)

def list_pools(self) -> PoolResultCollection:
return self._query_method(
_MethodEnum.get,
"/pools",
out_model=PoolResultCollection,
)
def list_pools(self) -> JSONValue:
return self._query_method(_MethodEnum.get, "/pools")

def show_pool(self, pool_name: str) -> PoolResult:
return self._query_method(
_MethodEnum.get,
f"/pools/{pool_name}",
out_model=PoolResult,
)
def show_pool(self, pool_name: str) -> JSONValue:
return self._query_method(_MethodEnum.get, f"/pools/{pool_name}")
63 changes: 31 additions & 32 deletions tests/client/test_formatters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import datetime as dt
import json
from contextlib import nullcontext
from enum import Enum

import pytest
from pydantic import BaseModel

from duffy.api_models import (
PoolConciseModel,
Expand All @@ -26,16 +24,7 @@
)
from duffy.client.main import DuffyAPIErrorModel


class _TestEnum(str, Enum):
bar = "bar"


class _TestModel(BaseModel):
test_enum: _TestEnum


TEST_MODEL_DICT = {"test_enum": _TestEnum.bar}
TEST_JSON_DICT = {"test_key": "test_value"}


class TestDuffyFormatter:
Expand All @@ -51,26 +40,21 @@ def test_new_for_format(self, format, formatter_cls):
fmtobj = DuffyFormatter.new_for_format(format)
assert isinstance(fmtobj, formatter_cls)

def test_result_as_compatible_dict(self):
result = _TestModel(test_enum=_TestEnum.bar)

assert DuffyFormatter.result_as_compatible_dict(result) == {"test_enum": "bar"}

def test_format(self):
with pytest.raises(NotImplementedError):
DuffyFormatter().format(_TestModel(test_enum=_TestEnum.bar))
DuffyFormatter().format(TEST_JSON_DICT)


class TestDuffyJSONFormatter:
def test_format(self):
formatted = DuffyJSONFormatter().format(_TestModel(**TEST_MODEL_DICT))
assert json.loads(formatted) == TEST_MODEL_DICT
formatted = DuffyJSONFormatter().format(TEST_JSON_DICT)
assert json.loads(formatted) == TEST_JSON_DICT


class TestDuffyYAMLFormatter:
def test_format(self):
formatted = DuffyYAMLFormatter().format(_TestModel(**TEST_MODEL_DICT))
assert formatted == "test_enum: bar\n"
formatted = DuffyYAMLFormatter().format(TEST_JSON_DICT)
assert formatted == "test_key: test_value\n"


class TestDuffyFlatFormatter:
Expand Down Expand Up @@ -132,7 +116,9 @@ def test_flatten_api_error(self):
assert node_line == "error='Hullo.'"

def test_flatten_pool(self):
pool_line = next(DuffyFlatFormatter().flatten_pool(pool=self.TEST_POOL_VERBOSE))
pool_line = next(
DuffyFlatFormatter().flatten_pool(pool=self.TEST_POOL_VERBOSE.model_dump(by_alias=True))
)

assert pool_line == (
"pool_name='pool' fill_level=15 levels_provisioning=0 levels_ready=15"
Expand All @@ -142,7 +128,7 @@ def test_flatten_pool(self):
def test_flatten_pool_result(self):
pool_line = next(
DuffyFlatFormatter().flatten_pool_result(
PoolResult(action="get", pool=self.TEST_POOL_VERBOSE)
PoolResult(action="get", pool=self.TEST_POOL_VERBOSE).model_dump(by_alias=True)
)
)

Expand All @@ -154,14 +140,20 @@ def test_flatten_pool_result(self):
def test_flatten_pools_result(self):
pool_line = next(
DuffyFlatFormatter().flatten_pools_result(
PoolResultCollection(action="get", pools=[self.TEST_POOL_CONCISE])
PoolResultCollection(action="get", pools=[self.TEST_POOL_CONCISE]).model_dump(
by_alias=True
)
)
)

assert pool_line == "pool_name='pool' fill_level=15"

def test_flatten_session(self):
node_line = next(DuffyFlatFormatter().flatten_session(session=self.TEST_SESSION))
node_line = next(
DuffyFlatFormatter().flatten_session(
session=self.TEST_SESSION.model_dump(by_alias=True)
)
)

assert node_line == (
"session_id=17 active=TRUE created_at='2022-05-31 12:00:00' retired_at= pool='pool'"
Expand All @@ -171,7 +163,7 @@ def test_flatten_session(self):
def test_flatten_session_result(self):
node_line = next(
DuffyFlatFormatter().flatten_session_result(
SessionResult(action="get", session=self.TEST_SESSION)
SessionResult(action="get", session=self.TEST_SESSION).model_dump(by_alias=True)
)
)

Expand All @@ -183,7 +175,9 @@ def test_flatten_session_result(self):
def test_flatten_sessions_result(self):
node_line = next(
DuffyFlatFormatter().flatten_sessions_result(
SessionResultCollection(action="get", sessions=[self.TEST_SESSION])
SessionResultCollection(action="get", sessions=[self.TEST_SESSION]).model_dump(
by_alias=True
)
)
)

Expand All @@ -198,18 +192,23 @@ def test_flatten_sessions_result(self):
)
def test_format(self, result_cls):
expectation = nullcontext()
api_result = None

if result_cls == PoolResult:
api_result = PoolResult(action="get", pool=self.TEST_POOL_VERBOSE)
model_result = PoolResult(action="get", pool=self.TEST_POOL_VERBOSE)
elif result_cls == PoolResultCollection:
api_result = PoolResultCollection(action="get", pools=[self.TEST_POOL_CONCISE])
model_result = PoolResultCollection(action="get", pools=[self.TEST_POOL_CONCISE])
elif result_cls == SessionResult:
api_result = SessionResult(action="get", session=self.TEST_SESSION)
model_result = SessionResult(action="get", session=self.TEST_SESSION)
elif result_cls == SessionResultCollection:
api_result = SessionResultCollection(action="get", sessions=[self.TEST_SESSION])
model_result = SessionResultCollection(action="get", sessions=[self.TEST_SESSION])
else:
api_result = {"a dict": "contents don't matter"}
expectation = pytest.raises(TypeError)

if not api_result:
api_result = model_result.model_dump(by_alias=True)

with expectation:
formatted = DuffyFlatFormatter().format(api_result)

Expand Down
Loading

0 comments on commit 3b44216

Please sign in to comment.