Skip to content

Commit

Permalink
Adds session context manager (#4819)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevgliss authored Jun 11, 2024
1 parent 4beb74e commit b51000f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 30 deletions.
30 changes: 30 additions & 0 deletions src/dispatch/database/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import re
from contextlib import contextmanager
from typing import Annotated, Any

from fastapi import Depends
Expand All @@ -12,6 +13,7 @@
from sqlalchemy_utils import get_mapper
from starlette.requests import Request


from dispatch import config
from dispatch.exceptions import NotFoundError
from dispatch.search.fulltext import make_searchable
Expand Down Expand Up @@ -201,3 +203,31 @@ def refetch_db_session(organization_slug: str) -> Session:
)
db_session = sessionmaker(bind=schema_engine)()
return db_session


@contextmanager
def get_session():
"""Context manager to ensure the session is closed after use."""
session = SessionLocal()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()


@contextmanager
def get_organization_session(organization_slug: str):
"""Context manager to ensure the session is closed after use."""
session = refetch_db_session(organization_slug)
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
56 changes: 26 additions & 30 deletions src/dispatch/plugins/dispatch_slack/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dispatch.auth import service as user_service
from dispatch.auth.models import DispatchUser, UserRegister
from dispatch.conversation import service as conversation_service
from dispatch.database.core import SessionLocal, refetch_db_session
from dispatch.database.core import get_session, get_organization_session, refetch_db_session
from dispatch.decorators import timer
from dispatch.enums import SubjectNames
from dispatch.organization import service as organization_service
Expand Down Expand Up @@ -38,35 +38,32 @@
@timer
def resolve_context_from_conversation(channel_id: str, thread_id: str = None) -> Optional[Subject]:
"""Attempts to resolve a conversation based on the channel id and thread_id."""
db_session = SessionLocal()
organization_slugs = [o.slug for o in organization_service.get_all(db_session=db_session)]
db_session.close()
organization_slugs = []
with get_session() as db_session:
organization_slugs = [o.slug for o in organization_service.get_all(db_session=db_session)]

for slug in organization_slugs:
scoped_db_session = refetch_db_session(slug)

conversation = conversation_service.get_by_channel_id_ignoring_channel_type(
db_session=scoped_db_session, channel_id=channel_id, thread_id=thread_id
)

if conversation:
if conversation.incident:
subject = SubjectMetadata(
type=IncidentSubjects.incident,
id=conversation.incident_id,
organization_slug=slug,
project_id=conversation.incident.project_id,
)
else:
subject = SubjectMetadata(
type=CaseSubjects.case,
id=conversation.case_id,
organization_slug=slug,
project_id=conversation.case.project_id,
)
return Subject(subject, db_session=scoped_db_session)
with get_organization_session(slug) as scoped_db_session:
conversation = conversation_service.get_by_channel_id_ignoring_channel_type(
db_session=scoped_db_session, channel_id=channel_id, thread_id=thread_id
)

scoped_db_session.close()
if conversation:
if conversation.incident:
subject = SubjectMetadata(
type=IncidentSubjects.incident,
id=conversation.incident_id,
organization_slug=slug,
project_id=conversation.incident.project_id,
)
else:
subject = SubjectMetadata(
type=CaseSubjects.case,
id=conversation.case_id,
organization_slug=slug,
project_id=conversation.case.project_id,
)
return Subject(subject, db_session=scoped_db_session)


def select_context_middleware(payload: dict, context: BoltContext, next: Callable) -> None:
Expand Down Expand Up @@ -416,7 +413,6 @@ def configuration_middleware(context: BoltContext, next: Callable):


def get_default_org_slug() -> str:
db_session = SessionLocal()
slug = organization_service.get_default(db_session=db_session).slug
db_session.close()
with get_session() as db_session:
slug = organization_service.get_default(db_session=db_session).slug
return slug

0 comments on commit b51000f

Please sign in to comment.