Skip to content

Commit

Permalink
feat: add custom Azure host/port to support custom blob endpoint
Browse files Browse the repository at this point in the history
 e.g. azurite
  • Loading branch information
jeqo committed Dec 20, 2023
1 parent cc2b427 commit 9a54a02
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
62 changes: 49 additions & 13 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
prefix: Optional[str] = None,
is_secure: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
azure_cloud: Optional[str] = None,
proxy_info: Optional[dict[str, Union[str, int]]] = None,
notifier: Optional[Notifier] = None,
Expand All @@ -78,16 +81,13 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={self.account_name};"
f"AccountKey={self.account_key};"
f"EndpointSuffix={endpoint_suffix}"
conn_str = self.conn_string(
account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure,
)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
Expand All @@ -97,13 +97,13 @@ def __init__(
auth = f"{username}:{password}@"
else:
auth = ""
host = proxy_info["host"]
port = proxy_info["port"]
proxy_host = proxy_info["host"]
proxy_port = proxy_info["port"]
if proxy_info.get("type") == "socks5":
schema = "socks5"
else:
schema = "http"
config["proxies"] = {"https": f"{schema}://{auth}{host}:{port}"}
config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}

self.conn: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=conn_str,
Expand All @@ -113,6 +113,42 @@ def __init__(
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

@staticmethod
def conn_string(
account_name: str,
account_key: Optional[str],
azure_cloud: Optional[str],
host: Optional[str],
port: Optional[int],
is_secure: bool,
) -> str:
if not host and not port:
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={account_name};"
f"AccountKey={account_key};"
f"EndpointSuffix={endpoint_suffix}"
)
else:
print(host)
print(port)
if not host or not port:
raise InvalidConfigurationError("Custom host and port must be specified together")

protocol = "https" if is_secure else "http"
conn_str = (
f"DefaultEndpointsProtocol={protocol};"
f"AccountName={account_name};"
f"AccountKey={account_key};"
f"BlobEndpoint={protocol}://{host}:{port}/{account_name};"
)
return conn_str

def copy_file(
self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions rohmu/object_storage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class AzureObjectStorageConfig(StorageModel):
account_key: Optional[str] = Field(None, repr=False)
sas_token: Optional[str] = Field(None, repr=False)
prefix: Optional[str] = None
is_secure: bool = True
host: Optional[str] = None
port: Optional[int] = None
azure_cloud: Optional[str] = None
proxy_info: Optional[ProxyInfo] = None
storage_type: Literal[StorageDriver.azure] = StorageDriver.azure
Expand Down
12 changes: 12 additions & 0 deletions test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module
fileobj_to_store_to=BytesIO(),
byte_range=(100, 10),
)


def test_conn_string_custom_host_port() -> None:
from rohmu.object_storage.azure import AzureTransfer

conn_string = AzureTransfer.conn_string(
account_name="test_name", account_key="test_key", azure_cloud=None, host="localhost", port=10000, is_secure=False
)
assert (
"DefaultEndpointsProtocol=http;AccountName=test_name;AccountKey=test_key;"
"BlobEndpoint=http://localhost:10000/test_name;"
) == conn_string

0 comments on commit 9a54a02

Please sign in to comment.