Skip to content

Commit

Permalink
Bump SQLAlchemy to 2.0 (#626)
Browse files Browse the repository at this point in the history
* `SQLALCHEMY_WARN_20=1` fixed all removed warnings.

* fix some mypy errors

* fix fetchone

* make format

* ignore annotations

* let's try like this?

* remove

* make format

* Update pyproject.toml

Co-authored-by: Pavol Rusnak <[email protected]>

* extract _mapping in fetchone() and fetchall() + fix poetry lock

* fix

* make format

* fix integer indexing of row fields

* Update cashu/mint/crud.py

---------

Co-authored-by: Pavol Rusnak <[email protected]>
Co-authored-by: callebtc <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2024
1 parent 7fdca3b commit c5ccf65
Show file tree
Hide file tree
Showing 7 changed files with 720 additions and 614 deletions.
1 change: 0 additions & 1 deletion cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ def deserialize(serialized: str) -> Dict[int, PublicKey]:
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

return cls(
id=row["id"],
unit=row["unit"],
Expand Down
25 changes: 14 additions & 11 deletions cashu/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlalchemy.sql.expression import TextClause

from cashu.core.settings import settings

Expand Down Expand Up @@ -64,7 +65,7 @@ def big_int(self) -> str:
def table_with_schema(self, table: str):
return f"{self.references_schema if self.schema else ''}{table}"


# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.CursorResult
class Connection(Compat):
def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.conn = conn
Expand All @@ -73,19 +74,20 @@ def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.name = name
self.schema = schema

def rewrite_query(self, query) -> str:
def rewrite_query(self, query) -> TextClause:
if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%")
query = query.replace("?", "%s")
return text(query)

async def fetchall(self, query: str, values: dict = {}) -> list:
async def fetchall(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.all()
return [r._mapping for r in result.all()] # will return [] if result list is empty

async def fetchone(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None

async def execute(self, query: str, values: dict = {}):
return await self.conn.execute(self.rewrite_query(query), values)
Expand Down Expand Up @@ -132,9 +134,9 @@ def __init__(self, db_name: str, db_location: str):
if not settings.db_connection_pool:
kwargs["poolclass"] = NullPool
elif self.type == POSTGRES:
kwargs["poolclass"] = QueuePool
kwargs["pool_size"] = 50
kwargs["max_overflow"] = 100
kwargs["poolclass"] = AsyncAdaptedQueuePool # type: ignore[assignment]
kwargs["pool_size"] = 50 # type: ignore[assignment]
kwargs["max_overflow"] = 100 # type: ignore[assignment]

self.engine = create_async_engine(database_uri, **kwargs)
self.async_session = sessionmaker(
Expand Down Expand Up @@ -281,12 +283,13 @@ async def acquire_lock(
async def fetchall(self, query: str, values: dict = {}) -> list:
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.all()
return [r._mapping for r in result.all()]

async def fetchone(self, query: str, values: dict = {}):
async with self.connect() as conn:
result = await conn.execute(query, values)
return result.fetchone()
r = result.fetchone()
return r._mapping if r is not None else None

async def execute(self, query: str, values: dict = {}):
async with self.connect() as conn:
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ async def run_migration(db, migrations_module):
f"SELECT * FROM {db.table_with_schema('dbversions')}"
)
rows = result.all()
current_versions = {row["db"]: row["version"] for row in rows}
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
await run_migration(db, migrations_module)
5 changes: 4 additions & 1 deletion cashu/mint/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,10 @@ async def get_balance(
"""
)
assert row, "Balance not found"
return int(row[0])

# sqlalchemy index of first element
key = next(iter(row))
return int(row[key])

async def get_keyset(
self,
Expand Down
6 changes: 3 additions & 3 deletions cashu/wallet/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ async def bump_secret_derivation(
)
counter = 0
else:
counter = int(rows[0])
counter = int(rows["counter"])

if not skip:
await (conn or db).execute(
Expand Down Expand Up @@ -437,8 +437,8 @@ async def get_seed_and_mnemonic(
)
return (
(
row[0],
row[1],
row["seed"],
row["mnemonic"],
)
if row
else None
Expand Down
Loading

0 comments on commit c5ccf65

Please sign in to comment.