Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump SQLAlchemy to 2.0 #626

Merged
merged 14 commits into from
Oct 5, 2024
13 changes: 8 additions & 5 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ class BlindedSignature(BaseModel):
dleq: Optional[DLEQ] = None # DLEQ proof

@classmethod
def from_row(cls, row: Row):
def from_row(cls, r: Row):
row = r._mapping # type: ignore[attr-defined]
return cls(
id=row["id"],
amount=row["amount"],
Expand Down Expand Up @@ -311,7 +312,8 @@ class MeltQuote(LedgerEvent):
change: Optional[List[BlindedSignature]] = None

@classmethod
def from_row(cls, row: Row):
def from_row(cls, r: Row):
row = r._mapping # type: ignore[attr-defined]
try:
created_time = int(row["created_time"]) if row["created_time"] else None
paid_time = int(row["paid_time"]) if row["paid_time"] else None
Expand Down Expand Up @@ -408,7 +410,8 @@ class MintQuote(LedgerEvent):
expiry: Optional[int] = None

@classmethod
def from_row(cls, row: Row):
def from_row(cls, r: Row):
row = r._mapping # type: ignore[attr-defined]
try:
# SQLITE: row is timestamp (string)
created_time = int(row["created_time"]) if row["created_time"] else None
Expand Down Expand Up @@ -641,13 +644,13 @@ def serialize(self):
)

@classmethod
def from_row(cls, row: Row):
def from_row(cls, r: Row):
def deserialize(serialized: str) -> Dict[int, PublicKey]:
return {
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

row = r._mapping # type: ignore[attr-defined]
return cls(
id=row["id"],
unit=row["unit"],
Expand Down
13 changes: 7 additions & 6 deletions cashu/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.pool import NullPool
from sqlalchemy.sql.expression import TextClause

from cashu.core.settings import settings

Expand Down Expand Up @@ -73,13 +74,13 @@ def __init__(self, conn: AsyncSession, txn, typ, name, schema):
self.name = name
self.schema = schema

def rewrite_query(self, query) -> str:
def rewrite_query(self, query) -> TextClause:
if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%")
query = query.replace("?", "%s")
return text(query)

async def fetchall(self, query: str, values: dict = {}) -> list:
async def fetchall(self, query: str, values: dict = {}):
result = await self.conn.execute(self.rewrite_query(query), values)
return result.all()

Expand Down Expand Up @@ -132,9 +133,9 @@ def __init__(self, db_name: str, db_location: str):
if not settings.db_connection_pool:
kwargs["poolclass"] = NullPool
elif self.type == POSTGRES:
kwargs["poolclass"] = QueuePool
kwargs["pool_size"] = 50
kwargs["max_overflow"] = 100
#kwargs["poolclass"] = AsyncQueuePool # type: ignore[assignment]
kwargs["pool_size"] = 50 # type: ignore[assignment]
kwargs["max_overflow"] = 100 # type: ignore[assignment]

self.engine = create_async_engine(database_uri, **kwargs)
self.async_session = sessionmaker(
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ async def run_migration(db, migrations_module):
f"SELECT * FROM {db.table_with_schema('dbversions')}"
)
rows = result.all()
current_versions = {row["db"]: row["version"] for row in rows}
current_versions = {row._mapping["db"]: row._mapping["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
await run_migration(db, migrations_module)
8 changes: 4 additions & 4 deletions cashu/mint/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ async def get_pending_proofs_for_quote(
""",
{"quote_id": quote_id},
)
return [Proof(**r) for r in rows]
return [Proof(**r._mapping) for r in rows]

async def get_proofs_pending(
self,
Expand All @@ -380,7 +380,7 @@ async def get_proofs_pending(
"""
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
rows = await (conn or db).fetchall(query, values)
return [Proof(**r) for r in rows]
return [Proof(**r._mapping) for r in rows]

async def set_proof_pending(
self,
Expand Down Expand Up @@ -718,7 +718,7 @@ async def get_keyset(
""",
values,
)
return [MintKeyset(**row) for row in rows]
return [MintKeyset(**row._mapping) for row in rows]

async def get_proofs_used(
self,
Expand All @@ -733,4 +733,4 @@ async def get_proofs_used(
"""
values = {f"y_{i}": Ys[i] for i in range(len(Ys))}
rows = await (conn or db).fetchall(query, values)
return [Proof(**r) for r in rows] if rows else []
return [Proof(**r._mapping) for r in rows] if rows else []
8 changes: 4 additions & 4 deletions cashu/wallet/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def get_proofs(
""",
values,
)
return [Proof.from_dict(dict(r)) for r in rows] if rows else []
return [Proof.from_dict(dict(r._mapping)) for r in rows] if rows else []


async def get_reserved_proofs(
Expand All @@ -75,7 +75,7 @@ async def get_reserved_proofs(
WHERE reserved
"""
)
return [Proof.from_dict(dict(r)) for r in rows]
return [Proof.from_dict(dict(r._mapping)) for r in rows]


async def invalidate_proof(
Expand Down Expand Up @@ -294,7 +294,7 @@ async def get_lightning_invoice(
query,
values,
)
return Invoice(**row) if row else None
return Invoice(**row._mapping) if row else None


async def get_lightning_invoices(
Expand Down Expand Up @@ -327,7 +327,7 @@ async def get_lightning_invoices(
""",
values,
)
return [Invoice(**r) for r in rows]
return [Invoice(**r._mapping) for r in rows]


async def update_lightning_invoice(
Expand Down
Loading
Loading