From 9d8e9f891febf15aadc282200c689ddd1e2e3579 Mon Sep 17 00:00:00 2001 From: Shubha Rajan Date: Tue, 28 Jun 2022 17:29:50 -0700 Subject: [PATCH] refactor DataAccessLayer class --- api/openapi_server/__main__.py | 2 ++ .../service_provider_controller.py | 13 +++++----- api/openapi_server/models/database.py | 24 ++++++++++++------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/api/openapi_server/__main__.py b/api/openapi_server/__main__.py index 4dd755c5..0477888a 100644 --- a/api/openapi_server/__main__.py +++ b/api/openapi_server/__main__.py @@ -4,6 +4,7 @@ from os import environ as env from openapi_server import encoder +from openapi_server.models.database import DataAccessLayer from openapi_server.exceptions import AuthError, handle_auth_error from dotenv import load_dotenv, find_dotenv @@ -13,6 +14,7 @@ load_dotenv(ENV_FILE) SECRET_KEY=env.get('SECRET_KEY') +DataAccessLayer.db_init() def main(): app = connexion.App(__name__, specification_dir='./_spec/') diff --git a/api/openapi_server/controllers/service_provider_controller.py b/api/openapi_server/controllers/service_provider_controller.py index 40c2e37e..3d5b26fe 100644 --- a/api/openapi_server/controllers/service_provider_controller.py +++ b/api/openapi_server/controllers/service_provider_controller.py @@ -9,8 +9,7 @@ from openapi_server.models import database as db from sqlalchemy.orm import Session -dal = db.DataAccessLayer() -dal.db_init() +db_engine = db.DataAccessLayer.get_engine() def create_service_provider(): # noqa: E501 """Create a housing program service provider @@ -27,7 +26,7 @@ def create_service_provider(): # noqa: E501 connexion.request.get_json()).to_dict() except ValueError: return traceback.format_exc(ValueError), 400 - with Session(dal.engine) as session: + with Session(db_engine) as session: row = db.HousingProgramServiceProvider( provider_name=provider["provider_name"] ) @@ -51,7 +50,7 @@ def delete_service_provider(provider_id): # noqa: E501 :rtype: None """ - with Session(dal.engine) as session: + with Session(db_engine) as session: query = session.query( db.HousingProgramServiceProvider).filter( db.HousingProgramServiceProvider.id == provider_id) @@ -71,7 +70,7 @@ def get_service_provider_by_id(provider_id): # noqa: E501 :rtype: ServiceProviderWithId """ - with Session(dal.engine) as session: + with Session(db_engine) as session: row = session.get( db.HousingProgramServiceProvider, provider_id) if row != None: @@ -93,7 +92,7 @@ def get_service_providers(): # noqa: E501 :rtype: List[ServiceProviderWithId] """ resp = [] - with Session(dal.engine) as session: + with Session(db_engine) as session: table = session.query(db.HousingProgramServiceProvider).all() for row in table: provider = ServiceProvider( @@ -122,7 +121,7 @@ def update_service_provider(provider_id): # noqa: E501 connexion.request.get_json()).to_dict() except ValueError: return traceback.format_exc(ValueError), 400 - with Session(dal.engine) as session: + with Session(db_engine) as session: query = session.query( db.HousingProgramServiceProvider).filter( db.HousingProgramServiceProvider.id == provider_id) diff --git a/api/openapi_server/models/database.py b/api/openapi_server/models/database.py index 5077ff1c..becb01ff 100644 --- a/api/openapi_server/models/database.py +++ b/api/openapi_server/models/database.py @@ -233,16 +233,22 @@ class ProgramCaseStatusLog(Base): src_status = Column(Integer, ForeignKey('case_status.id'), nullable=False) dest_status = Column(Integer, ForeignKey('case_status.id'), nullable=False) - - class DataAccessLayer: - connection = None - engine = None + _engine = None # temporary local sqlite DB, replace with conn str for postgres container port for real e2e - conn_string = "sqlite:///./homeuniteus.db" + _conn_string = "sqlite:///./homeuniteus.db" - def db_init(self, conn_string=None): - self.engine = create_engine(conn_string or self.conn_string, echo=True, future=True) - Base.metadata.create_all(bind=self.engine) - self.connection = self.engine.connect() + @classmethod + def db_init(cls, conn_string=None): + Base.metadata.create_all(bind=cls.get_engine(conn_string)) + + @classmethod + def connect(cls): + return cls.get_engine().connect() + + @classmethod + def get_engine(cls, conn_string=None): + if cls._engine == None: + cls._engine = create_engine(conn_string or cls._conn_string, echo=True, future=True) + return cls._engine