Skip to content

Commit

Permalink
Refactor client; add tests (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliseygusev authored Nov 16, 2022
1 parent ea32d93 commit fd9bcf6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
23 changes: 16 additions & 7 deletions dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://duneanalytics.notion.site/API-Documentation-1b93d16e0fa941398e15047f643e003a
"""
import asyncio
from typing import Any
from typing import Any, Optional

from aiohttp import (
ClientSession,
Expand Down Expand Up @@ -43,25 +43,30 @@ def __init__(self, api_key: str, connection_limit: int = 3):
"""
super().__init__(api_key=api_key)
self._connection_limit = connection_limit
self._session = self._create_session()
self._session: Optional[ClientSession] = None

def _create_session(self) -> ClientSession:
async def _create_session(self) -> ClientSession:
conn = TCPConnector(limit=self._connection_limit)
return ClientSession(
connector=conn,
base_url=self.BASE_URL,
timeout=ClientTimeout(total=self.DEFAULT_TIMEOUT),
)

async def close_session(self) -> None:
async def connect(self) -> None:
"""Opens a client session (can be used instead of async with)"""
self._session = await self._create_session()

async def disconnect(self) -> None:
"""Closes client session"""
await self._session.close()
if self._session:
await self._session.close()

async def __aenter__(self) -> None:
self._session = self._create_session()
self._session = await self._create_session()

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close_session()
await self.disconnect()

async def _handle_response(
self,
Expand All @@ -78,6 +83,8 @@ async def _handle_response(
raise ValueError("Unreachable since previous line raises") from err

async def _get(self, url: str) -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"GET received input url={url}")
response = await self._session.get(
url=f"{self.API_PATH}{url}",
Expand All @@ -86,6 +93,8 @@ async def _get(self, url: str) -> Any:
return await self._handle_response(response)

async def _post(self, url: str, params: Any) -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"POST received input url={url}, params={params}")
response = await self._session.post(
url=f"{self.API_PATH}{url}",
Expand Down
28 changes: 28 additions & 0 deletions tests/e2e/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def setUp(self) -> None:
async def test_get_status(self):
query = Query(name="No Name", query_id=1276442, params=[])
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
job_id = (await dune.execute(query)).execution_id
status = await dune.get_status(job_id)
self.assertTrue(
Expand All @@ -45,6 +46,7 @@ async def test_get_status(self):

async def test_refresh(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
results = (await dune.refresh(self.query)).get_rows()
self.assertGreater(len(results), 0)
await dune.close_session()
Expand All @@ -62,6 +64,7 @@ async def test_parameters_recognized(self):
self.assertEqual(query.parameters(), new_params)

dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
results = await dune.refresh(query)
self.assertEqual(
results.get_rows(),
Expand All @@ -78,6 +81,7 @@ async def test_parameters_recognized(self):

async def test_endpoints(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
execution_response = await dune.execute(self.query)
self.assertIsInstance(execution_response, ExecutionResponse)
job_id = execution_response.execution_id
Expand All @@ -93,6 +97,7 @@ async def test_endpoints(self):

async def test_cancel_execution(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
query = Query(
name="Long Running Query",
query_id=1229120,
Expand All @@ -109,6 +114,7 @@ async def test_cancel_execution(self):

async def test_invalid_api_key_error(self):
dune = AsyncDuneClient(api_key="Invalid Key")
await dune.connect()
with self.assertRaises(DuneError) as err:
await dune.execute(self.query)
self.assertEqual(
Expand All @@ -131,6 +137,7 @@ async def test_invalid_api_key_error(self):

async def test_query_not_found_error(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
query = copy.copy(self.query)
query.query_id = 99999999 # Invalid Query Id.

Expand All @@ -144,6 +151,7 @@ async def test_query_not_found_error(self):

async def test_internal_error(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
query = copy.copy(self.query)
# This query ID is too large!
query.query_id = 9999999999999
Expand All @@ -158,6 +166,7 @@ async def test_internal_error(self):

async def test_invalid_job_id_error(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()

with self.assertRaises(DuneError) as err:
await dune.get_status("Wonky Job ID")
Expand All @@ -168,6 +177,25 @@ async def test_invalid_job_id_error(self):
)
await dune.close_session()

async def test_disconnect(self):
dune = AsyncDuneClient(self.valid_api_key)
await dune.connect()
results = (await dune.refresh(self.query)).get_rows()
self.assertGreater(len(results), 0)
await dune.close_session()
self.assertTrue(cl._session.closed)

async def test_refresh_context_manager_singleton(self):
dune = AsyncDuneClient(self.valid_api_key)
async with dune as cl:
results = (await cl.refresh(self.query)).get_rows()
self.assertGreater(len(results), 0)

async def test_refresh_context_manager(self):
async with AsyncDuneClient(self.valid_api_key) as cl:
results = (await cl.refresh(self.query)).get_rows()
self.assertGreater(len(results), 0)


if __name__ == "__main__":
unittest.main()

0 comments on commit fd9bcf6

Please sign in to comment.