Skip to content

Commit

Permalink
feat: add custom Azure host/port to support custom blob endpoint
Browse files Browse the repository at this point in the history
 e.g. azurite
  • Loading branch information
jeqo committed Jan 3, 2024
1 parent cc2b427 commit 579d154
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 23 deletions.
57 changes: 36 additions & 21 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SourceStorageModelT,
)
from rohmu.object_storage.config import ( # pylint: disable=unused-import
AZURE_ENDPOINT_SUFFIXES,
AZURE_MAX_BLOCK_SIZE as MAX_BLOCK_SIZE,
AzureObjectStorageConfig as Config,
calculate_azure_max_block_size as calculate_max_block_size,
Expand All @@ -42,14 +43,6 @@
from azure.storage.blob._models import BlobPrefix, BlobType # type: ignore


ENDPOINT_SUFFIXES = {
None: "core.windows.net",
"germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud
"china": "core.chinacloudapi.cn",
"public": "core.windows.net",
}


# Reduce Azure logging verbocity of http requests and responses
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)

Expand All @@ -64,6 +57,9 @@ def __init__(
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
prefix: Optional[str] = None,
is_secure: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
azure_cloud: Optional[str] = None,
proxy_info: Optional[dict[str, Union[str, int]]] = None,
notifier: Optional[Notifier] = None,
Expand All @@ -78,16 +74,13 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={self.account_name};"
f"AccountKey={self.account_key};"
f"EndpointSuffix={endpoint_suffix}"
conn_str = self.conn_string(
account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure,
)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
Expand All @@ -97,13 +90,13 @@ def __init__(
auth = f"{username}:{password}@"
else:
auth = ""
host = proxy_info["host"]
port = proxy_info["port"]
proxy_host = proxy_info["host"]
proxy_port = proxy_info["port"]
if proxy_info.get("type") == "socks5":
schema = "socks5"
else:
schema = "http"
config["proxies"] = {"https": f"{schema}://{auth}{host}:{port}"}
config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}

self.conn: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=conn_str,
Expand All @@ -113,6 +106,28 @@ def __init__(
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

@staticmethod
def conn_string(
account_name: str,
account_key: Optional[str],
azure_cloud: Optional[str],
host: Optional[str],
port: Optional[int],
is_secure: bool,
) -> str:
protocol = "https" if is_secure else "http"
conn = [
f"DefaultEndpointsProtocol={protocol}",
f"AccountName={account_name}",
f"AccountKey={account_key}",
]
if not host and not port:
endpoint_suffix = AZURE_ENDPOINT_SUFFIXES[azure_cloud]
conn.append(f"EndpointSuffix={endpoint_suffix}")
else:
conn.append(f"BlobEndpoint={protocol}://{host}:{port}/{account_name}")
return ";".join(conn)

def copy_file(
self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any
) -> None:
Expand Down
25 changes: 24 additions & 1 deletion rohmu/object_storage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from enum import Enum, unique
from pathlib import Path
from pydantic import Field, root_validator
from pydantic import Field, root_validator, validator
from rohmu.common.models import ProxyInfo, StorageDriver, StorageModel
from typing import Any, Dict, Final, Literal, Optional, TypeVar

Expand Down Expand Up @@ -42,6 +42,12 @@ def calculate_azure_max_block_size() -> int:
return max(min(int(total_mem_mib / 1000), 100), 4) * 1024 * 1024


AZURE_ENDPOINT_SUFFIXES = {
None: "core.windows.net",
"germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud
"china": "core.chinacloudapi.cn",
"public": "core.windows.net",
}
# Increase block size based on host memory. Azure supports up to 50k blocks and up to 5 TiB individual
# files. Default block size is set to 4 MiB so only ~200 GB files can be uploaded. In order to get close
# to that 5 TiB increase the block size based on host memory; we don't want to use the max 100 for all
Expand Down Expand Up @@ -83,10 +89,27 @@ class AzureObjectStorageConfig(StorageModel):
account_key: Optional[str] = Field(None, repr=False)
sas_token: Optional[str] = Field(None, repr=False)
prefix: Optional[str] = None
is_secure: bool = True
host: Optional[str] = None
port: Optional[int] = None
azure_cloud: Optional[str] = None
proxy_info: Optional[ProxyInfo] = None
storage_type: Literal[StorageDriver.azure] = StorageDriver.azure

@root_validator
@classmethod
def host_and_port_must_be_set_together(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if (values["host"] is None) != (values["port"] is None):
raise ValueError("host and port must be set together")
return values

@validator("azure_cloud")
@classmethod
def valid_azure_cloud_endpoint(cls, v: str) -> str:
if v not in AZURE_ENDPOINT_SUFFIXES:
raise ValueError(f"azure_cloud must be one of {AZURE_ENDPOINT_SUFFIXES.keys()}")
return v


class GoogleObjectStorageConfig(StorageModel):
project_id: str
Expand Down
100 changes: 99 additions & 1 deletion test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from io import BytesIO
from rohmu.errors import InvalidByteRangeError
from rohmu.object_storage.config import AzureObjectStorageConfig
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import Any, Tuple
from typing import Any, Optional, Tuple
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -103,3 +104,100 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module
fileobj_to_store_to=BytesIO(),
byte_range=(100, 10),
)


def test_minimal_config() -> None:
config = AzureObjectStorageConfig(account_name="test")
assert config.account_name == "test"


def test_azure_config_host_port_set_together() -> None:
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", host="localhost")
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", port=10000)
config = AzureObjectStorageConfig(account_name="test", host="localhost", port=10000)
assert config.host == "localhost"
assert config.port == 10000


def test_valid_azure_cloud_endpoint() -> None:
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", azure_cloud="invalid")
config = AzureObjectStorageConfig(account_name="test", azure_cloud="public")
assert config.azure_cloud == "public"


@pytest.mark.parametrize(
"host,port,is_secured,expected",
[
(
None,
None,
True,
";".join(
[
"DefaultEndpointsProtocol=https",
"AccountName=test_name",
"AccountKey=test_key",
"EndpointSuffix=core.windows.net",
]
),
),
(
None,
None,
False,
";".join(
[
"DefaultEndpointsProtocol=http",
"AccountName=test_name",
"AccountKey=test_key",
"EndpointSuffix=core.windows.net",
]
),
),
(
"localhost",
10000,
True,
";".join(
[
"DefaultEndpointsProtocol=https",
"AccountName=test_name",
"AccountKey=test_key",
"BlobEndpoint=https://localhost:10000/test_name",
]
),
),
(
"localhost",
10000,
False,
";".join(
[
"DefaultEndpointsProtocol=http",
"AccountName=test_name",
"AccountKey=test_key",
"BlobEndpoint=http://localhost:10000/test_name",
]
),
),
],
)
def test_conn_string(host: Optional[str], port: Optional[int], is_secured: bool, expected: str) -> None:
get_blob_client_mock = MagicMock()
blob_client = MagicMock(get_blob_client=get_blob_client_mock)
service_client = MagicMock(from_connection_string=MagicMock(return_value=blob_client))
module_patches = {
"azure.common": MagicMock(),
"azure.core.exceptions": MagicMock(),
"azure.storage.blob": MagicMock(BlobServiceClient=service_client),
}
with patch.dict(sys.modules, module_patches):
from rohmu.object_storage.azure import AzureTransfer

conn_string = AzureTransfer.conn_string(
account_name="test_name", account_key="test_key", azure_cloud=None, host=host, port=port, is_secure=is_secured
)
assert expected == conn_string

0 comments on commit 579d154

Please sign in to comment.