Skip to content

Commit

Permalink
feat: add custom Azure host/port to support custom blob endpoint
Browse files Browse the repository at this point in the history
 e.g. azurite
  • Loading branch information
jeqo committed Dec 20, 2023
1 parent cc2b427 commit 3f1bb1f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
54 changes: 43 additions & 11 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
prefix: Optional[str] = None,
is_secure: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
azure_cloud: Optional[str] = None,
proxy_info: Optional[dict[str, Union[str, int]]] = None,
notifier: Optional[Notifier] = None,
Expand All @@ -78,17 +81,12 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={self.account_name};"
f"AccountKey={self.account_key};"
f"EndpointSuffix={endpoint_suffix}"
)
conn_str = self.conn_string(account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
username = proxy_info.get("user")
Expand All @@ -113,6 +111,40 @@ def __init__(
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

@staticmethod
def conn_string(account_name: str,
account_key: str,
azure_cloud: Optional[str],
host: Optional[str],
port: Optional[int],
is_secure: bool):
if not host and not port:
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={account_name};"
f"AccountKey={account_key};"
f"EndpointSuffix={endpoint_suffix}"
)
else:
print(host)
print(port)
if not host or not port:
raise InvalidConfigurationError("Custom host and port must be specified together")

protocol = "https" if is_secure else "http"
conn_str = (
f"DefaultEndpointsProtocol={protocol};"
f"AccountName={account_name};"
f"AccountKey={account_key};"
f"BlobEndpoint={protocol}://{host}:{port}/{account_name};"
)
return conn_str

def copy_file(
self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions rohmu/object_storage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class AzureObjectStorageConfig(StorageModel):
account_key: Optional[str] = Field(None, repr=False)
sas_token: Optional[str] = Field(None, repr=False)
prefix: Optional[str] = None
is_secure: bool = True
host: Optional[str] = None
port: Optional[int] = None
azure_cloud: Optional[str] = None
proxy_info: Optional[ProxyInfo] = None
storage_type: Literal[StorageDriver.azure] = StorageDriver.azure
Expand Down
14 changes: 14 additions & 0 deletions test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,17 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module
fileobj_to_store_to=BytesIO(),
byte_range=(100, 10),
)


def test_conn_string_custom_host_port() -> None:
from rohmu.object_storage.azure import AzureTransfer
conn_string = AzureTransfer.conn_string(
account_name="test_name",
account_key="test_key",
azure_cloud=None,
host="localhost",
port=10000,
is_secure=False)
assert ("DefaultEndpointsProtocol=http;AccountName=test_name;AccountKey=test_key;"
"BlobEndpoint=http://localhost:10000/test_name;") == conn_string

0 comments on commit 3f1bb1f

Please sign in to comment.