diff --git a/README.md b/README.md index 8181adf..f9b98a7 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ dwh = Dwh(SQL_SERVER, SQL_DATABASE, SQL_USER, SQL_PASSWORD, driver_index) ## Example usage ``` -from pyprediktorutilities.dwh import Dwh +from pyprediktorutilities.dwh.dwh import Dwh dwh = Dwh("localhost", "mydatabase", "myusername", "mypassword") results = dwh.fetch("SELECT * FROM mytable") diff --git a/src/pyprediktorutilities/dwh/__init__.py b/src/pyprediktorutilities/dwh/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyprediktorutilities/dwh.py b/src/pyprediktorutilities/dwh/dwh.py similarity index 97% rename from src/pyprediktorutilities/dwh.py rename to src/pyprediktorutilities/dwh/dwh.py index a353746..9ae0006 100644 --- a/src/pyprediktorutilities/dwh.py +++ b/src/pyprediktorutilities/dwh/dwh.py @@ -4,8 +4,6 @@ from typing import List, Any from pydantic import validate_call -from pyprediktorutilities import singleton_class - logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -304,11 +302,3 @@ def __disconnect(self) -> None: def __commit(self) -> None: """Commits any changes to the database.""" self.connection.commit() - - -class DwhSingleton(Dwh, metaclass=singleton_class.SingletonMeta): - pass - - -def get_dwh_instance(*args, **kwargs): - return DwhSingleton(*args, **kwargs) diff --git a/src/pyprediktorutilities/dwh/dwh_singleton.py b/src/pyprediktorutilities/dwh/dwh_singleton.py new file mode 100644 index 0000000..16e7744 --- /dev/null +++ b/src/pyprediktorutilities/dwh/dwh_singleton.py @@ -0,0 +1,6 @@ +from pyprediktorutilities import singleton +from pyprediktorutilities.dwh import dwh + + +class DwhSingleton(dwh.Dwh, metaclass=singleton.SingletonMeta): + pass diff --git a/src/pyprediktorutilities/singleton_class.py b/src/pyprediktorutilities/singleton.py similarity index 100% rename from src/pyprediktorutilities/singleton_class.py rename to src/pyprediktorutilities/singleton.py diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..51b0a2c --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,25 @@ +import random +import string + +import pyodbc + + +class mock_pyodbc_connection: + def __init__(self, connection_string): + pass + + def cursor(self): + return + + +def mock_pyodbc_connection_throws_error_not_tolerant_to_attempts(connection_string): + raise pyodbc.DataError("Error code", "Error message") + + +def mock_pyodbc_connection_throws_error_tolerant_to_attempts(connection_string): + raise pyodbc.DatabaseError("Error code", "Error message") + + +def grs(): + """Generate a random string.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) diff --git a/tests/test_dwh.py b/tests/test_dwh.py index 8630d83..e0d337d 100644 --- a/tests/test_dwh.py +++ b/tests/test_dwh.py @@ -1,43 +1,13 @@ -import pytest -import random -import string -import pyodbc import logging -import pandas as pd - from unittest.mock import Mock -from pyprediktorutilities.dwh import Dwh, get_dwh_instance -from pandas.testing import assert_frame_equal - -""" -Helpers -""" - - -class mock_pyodbc_connection: - def __init__(self, connection_string): - pass - - def cursor(self): - return - - -def mock_pyodbc_connection_throws_error_not_tolerant_to_attempts(connection_string): - raise pyodbc.DataError("Error code", "Error message") - - -def mock_pyodbc_connection_throws_error_tolerant_to_attempts(connection_string): - raise pyodbc.DatabaseError("Error code", "Error message") - - -def grs(): - """Generate a random string.""" - return "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) +import pandas as pd +import pyodbc +import pytest +from pandas.testing import assert_frame_equal -""" -__init__ -""" +import helpers +from pyprediktorutilities.dwh.dwh import Dwh def test_init_when_instantiate_db_then_instance_is_created(monkeypatch): @@ -45,10 +15,10 @@ def test_init_when_instantiate_db_then_instance_is_created(monkeypatch): # Mock the database connection monkeypatch.setattr( - "pyprediktorutilities.dwh.pyodbc.connect", mock_pyodbc_connection + "pyprediktorutilities.dwh.dwh.pyodbc.connect", helpers.mock_pyodbc_connection ) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) assert db is not None @@ -58,10 +28,10 @@ def test_init_when_instantiate_db_but_no_pyodbc_drivers_available_then_throw_exc driver_index = 0 # Mock the absence of ODBC drivers - monkeypatch.setattr("pyprediktorutilities.dwh.pyodbc.drivers", lambda: []) + monkeypatch.setattr("pyprediktorutilities.dwh.dwh.pyodbc.drivers", lambda: []) with pytest.raises(ValueError) as excinfo: - Dwh(grs(), grs(), grs(), grs(), driver_index) + Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) assert "Driver index 0 is out of range." in str(excinfo.value) @@ -72,12 +42,12 @@ def test_init_when_instantiate_db_but_pyodbc_throws_error_with_tolerance_to_atte # Mock the database connection monkeypatch.setattr( - "pyprediktorutilities.dwh.pyodbc.connect", - mock_pyodbc_connection_throws_error_not_tolerant_to_attempts, + "pyprediktorutilities.dwh.dwh.pyodbc.connect", + helpers.mock_pyodbc_connection_throws_error_not_tolerant_to_attempts, ) with pytest.raises(pyodbc.DataError): - Dwh(grs(), grs(), grs(), grs(), driver_index) + Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) def test_init_when_instantiate_db_but_pyodbc_throws_error_tolerant_to_attempts_then_retry_connecting_and_throw_exception( @@ -87,13 +57,13 @@ def test_init_when_instantiate_db_but_pyodbc_throws_error_tolerant_to_attempts_t # Mock the database connection monkeypatch.setattr( - "pyprediktorutilities.dwh.pyodbc.connect", - mock_pyodbc_connection_throws_error_tolerant_to_attempts, + "pyprediktorutilities.dwh.dwh.pyodbc.connect", + helpers.mock_pyodbc_connection_throws_error_tolerant_to_attempts, ) with caplog.at_level(logging.ERROR): with pytest.raises(pyodbc.DatabaseError): - Dwh(grs(), grs(), grs(), grs(), driver_index) + Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) assert any( "Failed to connect to the DataWarehouse after 3 attempts." in message @@ -111,7 +81,7 @@ def test_init_when_instantiate_dwh_but_driver_index_is_not_passed_then_instance_ monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2"]) - db = Dwh(grs(), grs(), grs(), grs()) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs()) assert db is not None assert db.driver == "Driver1" @@ -145,7 +115,7 @@ def test_fetch_when_init_db_connection_is_successfull_but_fails_when_calling_fet ) with pytest.raises(pyodbc.DataError): - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) db.connection = False db.fetch(query) @@ -178,7 +148,7 @@ def test_fetch_when_to_dataframe_is_false_and_no_data_is_returned_then_return_em monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query) mock_cursor.execute.assert_called_once_with(query) @@ -270,7 +240,7 @@ def test_fetch_when_to_dataframe_is_false_and_single_data_set_is_returned_then_r monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query) mock_cursor.execute.assert_called_once_with(query) @@ -407,7 +377,7 @@ def test_fetch_when_to_dataframe_is_false_and_multiple_data_sets_are_returned_th monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query) mock_cursor.execute.assert_called_once_with(query) @@ -440,7 +410,7 @@ def test_fetch_when_to_dataframe_is_true_and_no_data_is_returned_then_return_emp monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query, True) mock_cursor.execute.assert_called_once_with(query) @@ -533,7 +503,7 @@ def test_fetch_when_to_dataframe_is_true_and_single_data_set_is_returned_then_re monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query, True) mock_cursor.execute.assert_called_once_with(query) @@ -674,7 +644,7 @@ def test_fetch_when_to_dataframe_is_true_and_multiple_data_sets_are_returned_the monkeypatch.setattr("pyodbc.connect", lambda *args, **kwargs: mock_connection) monkeypatch.setattr("pyodbc.drivers", lambda: ["Driver1", "Driver2", "Driver3"]) - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.fetch(query, True) mock_cursor.execute.assert_called_once_with(query) @@ -719,7 +689,7 @@ def test_execute_when_init_db_connection_is_successfull_but_fails_when_calling_e ) with pytest.raises(pyodbc.DataError): - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) db.connection = False db.execute(query) @@ -749,7 +719,7 @@ def test_execute_when_parameter_passed_then_fetch_results_and_return_data(monkey mock_cursor.execute = mock_execute mock_cursor.fetchall = mock_fetch - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.execute(query, param_one, param_two) mock_execute.assert_called_once_with(query, param_one, param_two) @@ -781,26 +751,9 @@ def test_execute_when_fetchall_throws_error_then_return_empty_list(monkeypatch): mock_cursor.execute = mock_execute mock_cursor.fetchall = mock_fetchall - db = Dwh(grs(), grs(), grs(), grs(), driver_index) + db = Dwh(helpers.grs(), helpers.grs(), helpers.grs(), helpers.grs(), driver_index) actual_result = db.execute(query, param_one, param_two) mock_execute.assert_called_once_with(query, param_one, param_two) mock_fetchall.assert_called_once() assert actual_result == [] - - -def test_dwh_singleton_can_be_created_only_once(monkeypatch): - driver_index = 0 - - # Mock the database connection - monkeypatch.setattr( - "pyprediktorutilities.dwh.pyodbc.connect", mock_pyodbc_connection - ) - - db = get_dwh_instance(grs(), grs(), grs(), grs(), driver_index) - db_instance_address = id(db) - - db_2 = get_dwh_instance(grs(), grs(), grs(), grs(), driver_index) - db_2_instance_address = id(db_2) - - assert db_instance_address == db_2_instance_address diff --git a/tests/test_dwh_singleton.py b/tests/test_dwh_singleton.py new file mode 100644 index 0000000..6f13b66 --- /dev/null +++ b/tests/test_dwh_singleton.py @@ -0,0 +1,19 @@ +from helpers import grs, mock_pyodbc_connection +from pyprediktorutilities.dwh import dwh_singleton + + +def test_dwh_singleton_can_be_created_only_once(monkeypatch): + driver_index = 0 + + # Mock the database connection + monkeypatch.setattr( + "pyprediktorutilities.dwh.dwh.pyodbc.connect", mock_pyodbc_connection + ) + + db = dwh_singleton.DwhSingleton(grs(), grs(), grs(), grs(), driver_index) + db_instance_address = id(db) + + db_2 = dwh_singleton.DwhSingleton(grs(), grs(), grs(), grs(), driver_index) + db_2_instance_address = id(db_2) + + assert db_instance_address == db_2_instance_address diff --git a/tests/test_shared.py b/tests/test_shared.py index d453342..5c9e53b 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -1,6 +1,5 @@ import unittest from unittest import mock -import requests import pytest from pydantic import ValidationError @@ -17,6 +16,7 @@ } ] + # This method will be used by the mock to replace requests def mocked_requests(*args, **kwargs): class MockResponse: