Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump SQLAlchemy to 2.0 #626

Merged
merged 14 commits into from
Oct 5, 2024
1 change: 0 additions & 1 deletion cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ def deserialize(serialized: str) -> Dict[int, PublicKey]:
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

return cls(
id=row["id"],
unit=row["unit"],
Expand Down
25 changes: 14 additions & 11 deletions cashu/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlalchemy.sql.expression import TextClause

from cashu.core.settings import settings

Expand Down Expand Up @@ -64,7 +65,7 @@ def big_int(self) -> str:
def table_with_schema(self, table: str):
return f"{self.references_schema if self.schema else ''}{table}"


# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
class Connection(Compat):
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.conn = conn
Expand All @@ -73,19 +74,20 @@ def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.name = name
self.schema = schema

def rewrite_query(self, query) -> str:
def rewrite_query(self, query) -> TextClause:
if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%")
query = query.replace("?", "%s")
return text(query)

async def fetchall(self, query: str, values: dict = {}) -> list:
async def fetchall(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.all()
return [r._mapping for r in result.all()] # will return [] if result list is empty

async def fetchone(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None

async def execute(self, query: str, values: dict = {}):
return await self.conn.execute(self.rewrite_query(query), values)
Expand Down Expand Up @@ -132,9 +134,9 @@ def __init__(self, db_name: str, db_location: str):
if not settings.db_connection_pool:
kwargs["poolclass"] = NullPool
elif self.type == POSTGRES:
kwargs["poolclass"] = QueuePool
kwargs["pool_size"] = 50
kwargs["max_overflow"] = 100
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
kwargs["pool_size"] = 50 # type: ignore[assignment]
kwargs["max_overflow"] = 100 # type: ignore[assignment]

self.engine = create_async_engine(database_uri, **kwargs)
self.async_session = sessionmaker(
Expand Down Expand Up @@ -281,12 +283,13 @@ async def acquire_lock(
async def fetchall(self, query: str, values: dict = {}) -> list:
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.all()
return [r._mapping for r in result.all()]

async def fetchone(self, query: str, values: dict = {}):
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None

async def execute(self, query: str, values: dict = {}):
async with self.connect() as conn:
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ async def run_migration(db, migrations_module):
f"SELECT * FROM {db.table_with_schema('dbversions')}"
)
rows = result.all()
current_versions = {row["db"]: row["version"] for row in rows}
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
await run_migration(db, migrations_module)
5 changes: 4 additions & 1 deletion cashu/mint/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,10 @@ async def get_balance(
"""
)
assert row, "Balance not found"
return int(row[0])

# sqlalchemy index of first element
key = next(iter(row))
return int(row[key])

async def get_keyset(
self,
Expand Down
6 changes: 3 additions & 3 deletions cashu/wallet/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ async def bump_secret_derivation(
)
counter = 0
else:
counter = int(rows[0])
counter = int(rows["counter"])

if not skip:
await (conn or db).execute(
Expand Down Expand Up @@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
)
return (
(
row[0],
row[1],
row["seed"],
row["mnemonic"],
)
if row
else None
Expand Down
Loading
Loading