Skip to content

Commit

Permalink
Merge branch 'feat/mlops-2456' into stable-release
Browse files Browse the repository at this point in the history
  • Loading branch information
albjoaov committed Dec 4, 2024
2 parents e1ace97 + be44b15 commit 212fe59
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 3 deletions.
46 changes: 44 additions & 2 deletions butterfree/clients/cassandra_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CassandraClient entity."""

from ssl import CERT_REQUIRED, PROTOCOL_TLSv1
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import (
Expand All @@ -16,6 +16,12 @@
from typing_extensions import TypedDict

from butterfree.clients import AbstractClient
from butterfree.configs.logger import __logger

logger = __logger("cassandra_client")

EMPTY_STRING_HOST_ERROR = "The value of Cassandra host is empty. Please fill correctly with your endpoints" # noqa: E501
GENERIC_INVALID_HOST_ERROR = "The Cassandra host must be a valid string, a string that represents a list or list of strings" # noqa: E501


class CassandraColumn(TypedDict):
Expand Down Expand Up @@ -53,12 +59,48 @@ def __init__(
user: Optional[str] = None,
password: Optional[str] = None,
) -> None:
self.host = host
self.host = self._validate_and_format_cassandra_host(host)
logger.info(f"The host setted is {self.host}")
self.keyspace = keyspace
self.user = user
self.password = password
self._session: Optional[Session] = None

def _validate_and_format_cassandra_host(self, host: Union[List, str]):
"""
Validate and format the provided Cassandra host input.
This method checks if the input `host` is either a string, a list of strings, or
a list containing a single string with comma-separated values. It splits the string
by commas and trims whitespace, returning a list of hosts. If the input is already
a list of strings, it returns that list. If the input is empty or invalid, a
ValueError is raised.
Args:
host (str | list): The Cassandra host input, which can be a comma-separated
string or a list of string endpoints.
Returns:
list: A list of formatted Cassandra host strings.
Raises:
ValueError: If the input is an empty list/string or if it is not a string
(or a representation of a list) or a list of strings.
""" # noqa: E501
if isinstance(host, str):
if host:
return [item.strip() for item in host.split(",")]
else:
raise ValueError(EMPTY_STRING_HOST_ERROR)

if isinstance(host, list):
if len(host) == 1 and isinstance(host[0], str):
return [item.strip() for item in host[0].split(",")]
elif all(isinstance(item, str) for item in host):
return host

raise ValueError(GENERIC_INVALID_HOST_ERROR)

@property
def conn(self, *, ssl_path: str = None) -> Session: # type: ignore
"""Establishes a Cassandra connection."""
Expand Down
52 changes: 51 additions & 1 deletion tests/unit/butterfree/clients/test_cassandra_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Any, Dict, List
from unittest.mock import MagicMock

import pytest

from butterfree.clients import CassandraClient
from butterfree.clients.cassandra_client import CassandraColumn
from butterfree.clients.cassandra_client import (
EMPTY_STRING_HOST_ERROR,
GENERIC_INVALID_HOST_ERROR,
CassandraColumn,
)


def sanitize_string(query: str) -> str:
Expand Down Expand Up @@ -86,3 +92,47 @@ def test_cassandra_create_table(
query = cassandra_client.sql.call_args[0][0]

assert sanitize_string(query) == sanitize_string(expected_query)

def test_initialize_with_string_host(self):
client = CassandraClient(host="127.0.0.0, 127.0.0.1", keyspace="dummy_keyspace")
assert client.host == ["127.0.0.0", "127.0.0.1"]

def test_initialize_with_list_host(self):
client = CassandraClient(
host=["127.0.0.0", "127.0.0.1"], keyspace="test_keyspace"
)
assert client.host == ["127.0.0.0", "127.0.0.1"]

def test_initialize_with_empty_string_host(self):
with pytest.raises(
ValueError,
match=EMPTY_STRING_HOST_ERROR,
):
CassandraClient(host="", keyspace="test_keyspace")

def test_initialize_with_none_host(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=None, keyspace="test_keyspace")

def test_initialize_with_invalid_host_type(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=123, keyspace="test_keyspace")

def test_initialize_with_invalid_list_host(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=["127.0.0.0", 123], keyspace="test_keyspace")

def test_initialize_with_list_of_string_hosts(self):
client = CassandraClient(
host=["127.0.0.0, 127.0.0.1"], keyspace="test_keyspace"
)
assert client.host == ["127.0.0.0", "127.0.0.1"]

0 comments on commit 212fe59

Please sign in to comment.