-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add sqlserver database adapter #240
Changes from 3 commits
06d3410
b30319e
01a539e
3761240
e94cb94
cf51430
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,77 @@ | ||
from abc import ABCMeta, abstractmethod | ||
import logging | ||
|
||
import pymongo | ||
import pyodbc | ||
import threading | ||
from bson.objectid import ObjectId | ||
from pymongo.errors import ConnectionFailure, PyMongoError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InsertionResponse: | ||
def __init__(self, ok, exception=None, need_upsert=False): | ||
self.ok = ok | ||
self.need_upsert = need_upsert | ||
self.error = None if (ok or exception is None) else exception.__class__.__name__ | ||
self.error = None if (ok or exception is None) else f"{exception.__class__.__name__}: {str(exception)}" | ||
|
||
class DatabaseReaderInterface(metaclass=ABCMeta): | ||
"""Database adapter that just support read operations.""" | ||
|
||
class DatabaseInterface(metaclass=ABCMeta): | ||
@abstractmethod | ||
def get_connection(self): | ||
pass | ||
|
||
@abstractmethod | ||
def delete_collection_data(self): | ||
def get_all_dataset_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_all_collection_data(self): | ||
def get_chunked_dataset_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_chunked_collection_data(self): | ||
def get_paginated_dataset_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_paginated_collection_data(self): | ||
def get_estimated_item_count(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_estimated_document_count(self): | ||
def get_estimated_item_size(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_estimated_document_size(self): | ||
def get_database_size(self): | ||
pass | ||
|
||
class DatabaseWriterInterface(metaclass=ABCMeta): | ||
"""Database adapter that just support write operations.""" | ||
|
||
@abstractmethod | ||
def insert_one_to_unique_collection(self): | ||
def get_connection(self): | ||
pass | ||
|
||
@abstractmethod | ||
def insert_one_to_collection(self): | ||
def delete_dataset_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def insert_many_to_collection(self): | ||
def insert_one_to_unique_dataset(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_database_size(self): | ||
def insert_one_to_dataset(self): | ||
pass | ||
|
||
@abstractmethod | ||
def get_collection_size(self): | ||
def insert_many_to_dataset(self): | ||
pass | ||
|
||
|
||
|
||
class MongoAdapter(DatabaseInterface): | ||
class MongoAdapter(DatabaseWriterInterface, DatabaseReaderInterface): | ||
def __init__(self, mongo_connection, mongo_production, mongo_certificate_path): | ||
self.mongo_connection = mongo_connection | ||
self.mongo_production = mongo_production | ||
|
@@ -85,7 +94,7 @@ def get_connection(self): | |
self.client = client | ||
return True | ||
|
||
def delete_collection_data(self, database_name, collection_name): | ||
def delete_dataset_data(self, database_name, collection_name): | ||
collection = self.client[database_name][collection_name] | ||
try: | ||
collection.drop() | ||
|
@@ -94,17 +103,17 @@ def delete_collection_data(self, database_name, collection_name): | |
print(ex) | ||
return False | ||
|
||
def get_collection_data(self, database_name, collection_name, limit=10000): | ||
def get_dataset_data(self, database_name, collection_name, limit=10000): | ||
collection = self.client[database_name][collection_name] | ||
result = collection.find({}, {"_id": False}).limit(limit) | ||
return list(result) | ||
|
||
def get_all_collection_data(self, database_name, collection_name): | ||
def get_all_dataset_data(self, database_name, collection_name): | ||
collection = self.client[database_name][collection_name] | ||
result = collection.find({}, {"_id": False}) | ||
return list(result) | ||
|
||
def get_chunked_collection_data( | ||
def get_chunked_dataset_data( | ||
self, database_name, collection_name, chunk_size, current_chunk=None | ||
): | ||
collection = self.client[database_name][collection_name] | ||
|
@@ -131,7 +140,7 @@ def get_jobs_set_stats(self, database_name, jobs_ids): | |
) | ||
return list(result) | ||
|
||
def get_paginated_collection_data( | ||
def get_paginated_dataset_data( | ||
self, database_name, collection_name, page, page_size | ||
): | ||
collection = self.client[database_name][collection_name] | ||
|
@@ -147,16 +156,16 @@ def update_document(self, database_name, collection_name, document_id, new_field | |
result = collection.update_one({"_id": document_id}, {"$set": new_field}) | ||
return result.acknowledged | ||
|
||
def get_estimated_document_count(self, database_name, collection_name): | ||
def get_estimated_item_count(self, database_name, collection_name): | ||
collection = self.client[database_name][collection_name] | ||
return collection.estimated_document_count() | ||
|
||
def get_estimated_document_size(self, database_name, collection_name): | ||
def get_estimated_item_size(self, database_name, collection_name): | ||
database = self.client[database_name] | ||
document_size = database.command("collstats", collection_name)["avgObjSize"] | ||
return document_size | ||
|
||
def insert_one_to_unique_collection(self, database_name, collection_name, item): | ||
def insert_one_to_unique_dataset(self, database_name, collection_name, item): | ||
response = None | ||
try: | ||
self.client[database_name][collection_name].update_one( | ||
|
@@ -168,7 +177,7 @@ def insert_one_to_unique_collection(self, database_name, collection_name, item): | |
finally: | ||
return response | ||
|
||
def insert_one_to_collection(self, database_name, collection_name, item): | ||
def insert_one_to_dataset(self, database_name, collection_name, item): | ||
response = None | ||
try: | ||
self.client[database_name][collection_name].insert_one(item) | ||
|
@@ -178,7 +187,7 @@ def insert_one_to_collection(self, database_name, collection_name, item): | |
finally: | ||
return response | ||
|
||
def insert_many_to_collection( | ||
def insert_many_to_dataset( | ||
self, database_name, collection_name, items, ordered=False | ||
): | ||
response = None | ||
|
@@ -198,17 +207,86 @@ def get_database_size(self, database_name, data_type): | |
total_size_bytes = 0 | ||
for collection in collections: | ||
if data_type in collection: | ||
total_size_bytes += self.get_collection_size(database_name, collection) | ||
total_size_bytes += self.get_dataset_size(database_name, collection) | ||
return total_size_bytes | ||
|
||
def get_collection_size(self, database_name, collection_name): | ||
def get_dataset_size(self, database_name, collection_name): | ||
database = self.client[database_name] | ||
collection_size = database.command("collstats", collection_name)["size"] | ||
return collection_size | ||
|
||
class SqlServerWriterAdapter(DatabaseWriterInterface): | ||
|
||
def __init__(self, connection_string, production, certificate_path): | ||
self.connection_string = connection_string | ||
self.local_storage = threading.local() | ||
|
||
def get_connection(self): | ||
if not hasattr(self.local_storage, 'connection'): | ||
try: | ||
self.local_storage.connection = pyodbc.connect(self.connection_string) | ||
return True | ||
except Exception as e: | ||
print(f"Error connecting to SQL Server: {e}") | ||
return False | ||
return True | ||
|
||
def _execute_query(self, database_name, query, values=(), execute_many=False): | ||
if not self.get_connection(): | ||
return False, "Connection Error" | ||
|
||
try: | ||
with self.local_storage.connection.cursor() as cursor: | ||
logger.debug("Executing query: %s", query) | ||
if not execute_many: | ||
cursor.execute(f"USE {database_name}") | ||
cursor.execute(query, values) | ||
else: | ||
cursor.execute(f"USE {database_name}") | ||
cursor.executemany(query, values) | ||
self.local_storage.connection.commit() | ||
return True, None | ||
except pyodbc.Error as e: | ||
self.local_storage.connection.rollback() | ||
logger.debug("Error executing query: %s", query) | ||
return False, e | ||
|
||
def insert_one_to_dataset(self, database_name, table_name, item): | ||
# It should `transform`` the item into a valid sql item. | ||
columns = ', '.join(item.keys()) | ||
placeholders = ', '.join('?' * len(item)) | ||
query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" | ||
response, ex = self._execute_query(database_name, query, values=list(item.values())) | ||
return InsertionResponse(response, ex) | ||
|
||
def insert_many_to_dataset(self, database_name, table_name, items): | ||
columns = ', '.join(items[0].keys()) | ||
placeholders = ', '.join('?' * len(items[0])) | ||
logger.debug("items :%s", str(items)) | ||
query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" | ||
values_to_insert = [tuple(item.values()) for item in items] | ||
logger.debug("values to insert: %s", str(values_to_insert)) | ||
response, ex = self._execute_query(database_name, query, values_to_insert, execute_many=True) | ||
# no upsert needed as execute_many is atomic | ||
return InsertionResponse(response, ex) | ||
|
||
def delete_dataset_data(self, database_name, table_name): | ||
query = f"DELETE FROM {table_name}" | ||
response, ex= self._execute_query(database_name, query) | ||
return InsertionResponse(response, ex) | ||
|
||
def insert_one_to_unique_dataset(self, database_name, table_name, item): # Needs more discussion. | ||
return self.insert_one_to_dataset(database_name, table_name, item) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @joaquingx In what sense does this require more discussion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears that this function is not being utilized as intended. We will need to discuss to address the appropriate use of this function in future implementations. |
||
|
||
def get_database_interface(engine, connection, production, certificate_path): | ||
database_interfaces = { | ||
"mongodb": MongoAdapter(connection, production, certificate_path), | ||
} | ||
return database_interfaces[engine] | ||
|
||
def get_database_writer_interface(engine, connection, production, certificate_path): | ||
database_interfaces = { | ||
"mongodb": MongoAdapter(connection, production, certificate_path), | ||
"sqlserver": SqlServerWriterAdapter(connection, production, certificate_path), | ||
} | ||
return database_interfaces[engine] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pytest | ||
from unittest.mock import Mock, patch | ||
import pyodbc | ||
from database_adapters.db_adapters import SqlServerWriterAdapter, InsertionResponse | ||
|
||
class TestSqlServerWriterAdapter: | ||
|
||
@pytest.fixture | ||
def mock_pyodbc_connect(self): | ||
with patch('database_adapters.db_adapters.pyodbc.connect') as mock_connect: | ||
mock_connect.return_value.cursor.return_value.__enter__.return_value.execute = Mock() | ||
mock_connect.return_value.cursor.return_value.__enter__.return_value.executemany = Mock() | ||
yield mock_connect | ||
|
||
@pytest.fixture | ||
def writer_adapter(self, mock_pyodbc_connect): | ||
return SqlServerWriterAdapter("dummy_connection_string") | ||
|
||
def test_get_connection_success(self, writer_adapter): | ||
assert writer_adapter.get_connection() == True | ||
|
||
def test_get_connection_failure(self, mock_pyodbc_connect, writer_adapter): | ||
mock_pyodbc_connect.side_effect = Exception("Connection Error") | ||
assert writer_adapter.get_connection() == False | ||
|
||
def test_execute_query_success(self, writer_adapter): | ||
connection = writer_adapter.get_connection() | ||
assert writer_adapter._execute_query("test_db", "SELECT * FROM test_table") == (True, None) | ||
|
||
def test_execute_query_failure(self, mock_pyodbc_connect, writer_adapter): | ||
connection = writer_adapter.get_connection() | ||
mock_pyodbc_connect.return_value.cursor.side_effect = pyodbc.DatabaseError("Error") | ||
|
||
# Ejecutar la consulta y capturar el resultado | ||
result, error = writer_adapter._execute_query("test_db", "SELECT * FROM test_table") | ||
assert result == False | ||
assert isinstance(error, pyodbc.DatabaseError) | ||
assert str(error) == "Error" | ||
|
||
def test_execute_query_with_execute_many(self, writer_adapter): | ||
connection = writer_adapter.get_connection() | ||
assert writer_adapter._execute_query("test_db", "INSERT INTO test_table VALUES (?)", values=[(1,),(2,)], execute_many=True) == (True, None) | ||
|
||
def test_insert_one_to_dataset(self, writer_adapter): | ||
item = {"col1": "val1", "col2": "val2"} | ||
connection = writer_adapter.get_connection() | ||
response = writer_adapter.insert_one_to_dataset("test_db", "test_table", item) | ||
assert isinstance(response, InsertionResponse) | ||
assert response.ok == True | ||
|
||
def test_insert_many_to_dataset(self, writer_adapter): | ||
items = [{"col1": "val1", "col2": "val2"}, {"col1": "val3", "col2": "val4"}] | ||
connection = writer_adapter.get_connection() | ||
response = writer_adapter.insert_many_to_dataset("test_db", "test_table", items) | ||
assert isinstance(response, InsertionResponse) | ||
assert response.ok == True | ||
assert not response.error | ||
assert not response.need_upsert | ||
|
||
def test_delete_dataset_data(self, writer_adapter): | ||
connection = writer_adapter.get_connection() | ||
response = writer_adapter.delete_dataset_data("test_db", "test_table") | ||
assert response.ok == True | ||
assert not response.error | ||
assert not response.need_upsert | ||
|
||
def test_insert_one_to_unique_dataset(self, writer_adapter): | ||
connection = writer_adapter.get_connection() | ||
item = {"col1": "val1", "col2": "val2"} | ||
response = writer_adapter.insert_one_to_unique_dataset("test_db", "test_table", item) | ||
assert isinstance(response, InsertionResponse) | ||
assert response.ok == True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
production
andcertificate_path
variables are not being used here, although they may be required for some production environments. Or are these parameters also included in the connection string if needed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have been reviewing the Azure SQL documentation and it appears that using a certificate path is not a requirement, even for production environments.