Skip to content

Commit

Permalink
Merge pull request #172 from Aiven-Open/kmichel-preloaded-keys
Browse files Browse the repository at this point in the history
Accept preloaded RSA keys in encryptor/rohmufile
  • Loading branch information
Khatskevich authored Feb 29, 2024
2 parents 66ef7db + 2805cc7 commit 6bef80f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 24 deletions.
44 changes: 26 additions & 18 deletions rohmu/encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,15 +428,19 @@ def write(self, data: BinaryData) -> int:


class Encryptor(BaseEncryptor):
def __init__(self, public_key_pem: Union[str, bytes]):
if not isinstance(public_key_pem, bytes):
public_key_pem = public_key_pem.encode("ascii")
public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend())
if not isinstance(public_key, RSAPublicKey):
raise ValueError("Key must be RSA")
def __init__(self, public_key_pem: Union[str, bytes, RSAPublicKey]):
if isinstance(public_key_pem, RSAPublicKey):
rsa_public_key = public_key_pem
else:
if not isinstance(public_key_pem, bytes):
public_key_pem = public_key_pem.encode("ascii")
public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend())
if not isinstance(public_key, RSAPublicKey):
raise ValueError("Key must be RSA")
rsa_public_key = public_key

super().__init__()
self.rsa_public_key = public_key
self.rsa_public_key = rsa_public_key

def init_cipher(self) -> bytes:
cipher_key = os.urandom(16)
Expand All @@ -450,27 +454,31 @@ def init_cipher(self) -> bytes:


class EncryptorFile(BaseEncryptorFile):
def __init__(self, next_fp: FileLike, public_key_pem: Union[str, bytes]) -> None:
def __init__(self, next_fp: FileLike, public_key_pem: Union[str, bytes, RSAPublicKey]) -> None:
super().__init__(next_fp, Encryptor(public_key_pem))


class EncryptorStream(BaseEncryptorStream):
"""Non-seekable stream of data that adds encryption on top of given source stream"""

def __init__(self, src_fp: HasRead, public_key_pem: Union[str, bytes]) -> None:
def __init__(self, src_fp: HasRead, public_key_pem: Union[str, bytes, RSAPublicKey]) -> None:
super().__init__(src_fp, Encryptor(public_key_pem))


class Decryptor(BaseDecryptor):
def __init__(self, private_key_pem: Union[str, bytes]) -> None:
if not isinstance(private_key_pem, bytes):
private_key_pem = private_key_pem.encode("ascii")
private_key = serialization.load_pem_private_key(data=private_key_pem, password=None, backend=default_backend())
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("Key must be RSA")
def __init__(self, private_key_pem: Union[str, bytes, RSAPrivateKey]) -> None:
if isinstance(private_key_pem, RSAPrivateKey):
rsa_private_key = private_key_pem
else:
if not isinstance(private_key_pem, bytes):
private_key_pem = private_key_pem.encode("ascii")
private_key = serialization.load_pem_private_key(data=private_key_pem, password=None, backend=default_backend())
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("Key must be RSA")
rsa_private_key = private_key

super().__init__()
self.rsa_private_key = private_key
self.rsa_private_key = rsa_private_key
self._key_size = None
self._header_size = None

Expand Down Expand Up @@ -513,12 +521,12 @@ def process_header(self, data: bytes) -> None:


class DecryptorFile(BaseDecryptorFile):
def __init__(self, next_fp: FileLike, private_key_pem: Union[bytes, str]):
def __init__(self, next_fp: FileLike, private_key_pem: Union[bytes, str, RSAPrivateKey]):
super().__init__(next_fp, lambda: Decryptor(private_key_pem))


class DecryptSink(BaseDecryptSink):
def __init__(self, next_sink: HasWrite, file_size: int, private_key_pem: Union[bytes, str]):
def __init__(self, next_sink: HasWrite, file_size: int, private_key_pem: Union[bytes, str, RSAPrivateKey]):
super().__init__(next_sink, file_size, Decryptor(private_key_pem))


Expand Down
13 changes: 7 additions & 6 deletions rohmu/rohmufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .filewrap import ThrottleSink
from .typing import FileLike, HasRead, HasWrite, Metadata
from contextlib import suppress
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from inspect import signature
from rohmu.object_storage.base import IncrementalProgressCallbackType
from typing import Any, Callable, Optional, Union
Expand All @@ -27,8 +28,8 @@ def _obj_name(input_obj: Any) -> str:


def _get_encryption_key_data(
metadata: Optional[Metadata], key_lookup: Optional[Callable[[str], Optional[str]]]
) -> Optional[str]:
metadata: Optional[Metadata], key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]]
) -> Optional[str | bytes | RSAPrivateKey]:
if not metadata or not metadata.get("encryption-key-id"):
return None

Expand All @@ -47,7 +48,7 @@ def file_reader(
*,
fileobj: FileLike,
metadata: Optional[Metadata] = None,
key_lookup: Optional[Callable[[str], Optional[str]]] = None,
key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]] = None,
) -> FileLike:
if not metadata:
return fileobj
Expand All @@ -68,7 +69,7 @@ def create_sink_pipeline(
output: HasWrite,
file_size: int = 0,
metadata: Optional[Metadata] = None,
key_lookup: Optional[Callable[[str], Optional[str]]] = None,
key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]] = None,
throttle_time: float = 0.001,
) -> HasWrite:
if throttle_time:
Expand Down Expand Up @@ -143,7 +144,7 @@ def file_writer(
compression_algorithm: Optional[str] = None,
compression_level: int = 0,
compression_threads: int = 0,
rsa_public_key: Union[None, str, bytes] = None,
rsa_public_key: Union[None, str, bytes, RSAPublicKey] = None,
) -> FileLike:
if rsa_public_key:
fileobj = EncryptorFile(fileobj, rsa_public_key)
Expand All @@ -162,7 +163,7 @@ def write_file(
compression_algorithm: Optional[str] = None,
compression_level: int = 0,
compression_threads: int = 0,
rsa_public_key: Union[None, str, bytes] = None,
rsa_public_key: Union[None, str, bytes, RSAPublicKey] = None,
log_func: Optional[Callable[..., None]] = None,
header_func: Optional[Callable[[bytes], None]] = None,
data_callback: Optional[Callable[[bytes], None]] = None,
Expand Down
41 changes: 41 additions & 0 deletions test/test_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from __future__ import annotations

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from pathlib import Path
from rohmu.common.constants import IO_BLOCK_SIZE
from rohmu.encryptor import (
Expand Down Expand Up @@ -66,6 +69,12 @@
-----END PRIVATE KEY-----"""
)

LOADED_RSA_PUBLIC_KEY = load_pem_public_key(RSA_PUBLIC_KEY.encode(), backend=default_backend())
assert isinstance(LOADED_RSA_PUBLIC_KEY, RSAPublicKey)

LOADED_RSA_PRIVATE_KEY = load_pem_private_key(RSA_PRIVATE_KEY.encode(), password=None, backend=default_backend())
assert isinstance(LOADED_RSA_PRIVATE_KEY, RSAPrivateKey)


@pytest.mark.parametrize(
("plaintext"),
Expand All @@ -81,6 +90,14 @@
lambda: Encryptor(RSA_PUBLIC_KEY),
lambda: Decryptor(RSA_PRIVATE_KEY),
),
(
lambda: Encryptor(RSA_PUBLIC_KEY.encode()),
lambda: Decryptor(RSA_PRIVATE_KEY.encode()),
),
(
lambda: Encryptor(LOADED_RSA_PUBLIC_KEY),
lambda: Decryptor(LOADED_RSA_PRIVATE_KEY),
),
(
lambda: SymmetricEncryptor(SYMMETRIC_KEY),
lambda: SymmetricDecryptor(SYMMETRIC_KEY),
Expand Down Expand Up @@ -115,6 +132,14 @@ def test_encryptor_decryptor(
lambda x: EncryptorStream(x, RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: EncryptorStream(x, RSA_PUBLIC_KEY.encode()),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: EncryptorStream(x, LOADED_RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: SymmetricEncryptorStream(x, SYMMETRIC_KEY),
lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY),
Expand Down Expand Up @@ -160,6 +185,14 @@ def test_encryptor_stream(
lambda: Encryptor(RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda: Encryptor(RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY.encode()),
),
(
lambda: Encryptor(RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, LOADED_RSA_PRIVATE_KEY),
),
(
lambda: SymmetricEncryptor(SYMMETRIC_KEY),
lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY),
Expand Down Expand Up @@ -305,6 +338,14 @@ def test_decryptorfile_for_tarfile(
lambda x: EncryptorFile(x, RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: EncryptorFile(x, RSA_PUBLIC_KEY.encode()),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: EncryptorFile(x, LOADED_RSA_PUBLIC_KEY),
lambda x: DecryptorFile(x, RSA_PRIVATE_KEY),
),
(
lambda x: SymmetricEncryptorFile(x, SYMMETRIC_KEY),
lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY),
Expand Down

0 comments on commit 6bef80f

Please sign in to comment.