-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconnection.py
113 lines (84 loc) · 3.31 KB
/
connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations
import logging
import time
from contextlib import contextmanager
from typing import Generator, TYPE_CHECKING
from sqlalchemy import event
from sqlalchemy.engine import create_engine, Engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
if TYPE_CHECKING:
from aspen.config.config import Config
class SqlAlchemyInterface:
def __init__(self, engine: Engine, use_async: bool = False):
self._engine = engine
if use_async:
session = sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession, future=True
)
else:
session = sessionmaker(bind=engine)
self._session_maker = session
@property
def engine(self) -> Engine:
return self._engine
def make_session(self) -> Session:
return self._session_maker()
def init_async_db(db_uri: str, **kwargs) -> SqlAlchemyInterface:
engine = create_async_engine(
db_uri, echo=False, pool_size=5, max_overflow=5, future=True, **kwargs
)
return SqlAlchemyInterface(engine, use_async=True)
def init_db(db_uri: str) -> SqlAlchemyInterface:
engine = create_engine(db_uri)
return SqlAlchemyInterface(engine)
@contextmanager
def session_scope(interface: SqlAlchemyInterface) -> Generator[Session, None, None]:
"""Provide a transactional scope around a series of operations."""
session = interface.make_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def get_db_uri(runtime_config: Config, readonly: bool = False) -> str:
"""Provides a URI for the database based on a runtime environment.
Parameters
----------
runtime_config : Config
The runtime config that contains the database we want to access.
readonly : bool
Returns a read-only handle for the database if True. (default: False)
Returns
-------
string that can be used to connect to a postgres database
"""
if readonly:
try:
return runtime_config.DATABASE_READONLY_URI
except NotImplementedError:
raise ValueError(f"Config {runtime_config} does not have a read-only mode.")
return runtime_config.DATABASE_URI
sqltime_logger = logging.getLogger("sqltime")
sqltime_logger.setLevel(logging.DEBUG)
def _before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault("query_start_time", []).append(time.time())
sqltime_logger.debug("Start Query: %s", statement)
def _after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - conn.info["query_start_time"].pop(-1)
sqltime_logger.debug("Query Complete!")
sqltime_logger.debug("Total Time: %f", total)
def enable_profiling():
event.listen(Engine, "before_cursor_execute", _before_cursor_execute)
event.listen(Engine, "after_cursor_execute", _after_cursor_execute)
def disable_profiling():
event.remove(Engine, "before_cursor_execute", _before_cursor_execute)
event.remove(Engine, "after_cursor_execute", _after_cursor_execute)
@contextmanager
def enable_profiling_ctx():
enable_profiling()
yield
disable_profiling()