Skip to content

Commit

Permalink
[DPE-5202] Test TLS through Manual TLS Certificates (#409)
Browse files Browse the repository at this point in the history
Add a test for an OpenSearch deployment with TLS certificates issued by
the `manual-tls-certificates` operator.
  • Loading branch information
skourta authored Aug 28, 2024
1 parent d411b11 commit 12ecb8b
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 0 deletions.
191 changes: 191 additions & 0 deletions tests/integration/tls/helpers_manual_tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#!/usr/bin/env python3
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.
from __future__ import annotations

import asyncio
import base64
import json
import logging
from typing import TYPE_CHECKING, NamedTuple

from charms.tls_certificates_interface.v3.tls_certificates import (
generate_ca,
generate_certificate,
generate_private_key,
)
from juju.model import Model
from juju.unit import Unit
from tenacity import retry, stop_after_attempt, wait_exponential

if TYPE_CHECKING:
from juju.action import Action

logger = logging.getLogger(__name__)

MANUAL_TLS_CERTIFICATES_APP_NAME = "manual-tls-certificates"


class GetOutstandingCertificateRequestsError(Exception):
"""Exception raised when getting outstanding certificate requests fails."""


class CSRsMissingError(Exception):
"""Exception raised when the number of CSRs in the queue is less than the expected number."""


class ProvidingCertificateFailedError(Exception):
"""Exception raised when providing a certificate fails."""


class CSR(NamedTuple):
"""CSR represents the information about a certificate signing request."""

relation_id: str
application_name: str
unit_name: str
csr: bytes
is_ca: bool

@classmethod
def from_dict(cls, csr: dict[str, str]) -> CSR:
"""Create a CSR object from a dictionary.
Arguments:
---------
csr : dict
The dictionary containing the information about
the certificate signing request gotten from the charm.
Returns:
-------
CSR: The CSR object.
"""
return cls(
relation_id=csr["relation_id"],
application_name=csr["application_name"],
unit_name=csr["unit_name"],
csr=csr["csr"].encode(),
is_ca=csr["is_ca"],
)


class ManualTLSAgent:
"""An agent that processes certificate signing requests from the TLS operator."""

def __init__(self, tls_unit: Unit) -> None:
"""Initialise the agent."""
self.tls_unit = tls_unit
self.ca_key = generate_private_key()
self.ca = generate_ca(self.ca_key, "CN_CA")
self.csr_queue: list[CSR] = []

async def get_outstanding_certificate_requests(self) -> None:
"""Get the outstanding certificate requests from the TLS operator.
Raises
------
GetOutstandingCertificateRequestsError:
If getting the outstanding certificate requests fails.
"""
logging.info("Getting outstanding certificate requests")
action = await self.tls_unit.run_action("get-outstanding-certificate-requests")
action: Action = await action.wait()
if action.status != "completed":
message = action.safe_data.get(
"message",
"Failed to get outstanding certificate requests",
)
raise GetOutstandingCertificateRequestsError(message)
csrs = json.loads(action.results["result"])
self.csr_queue = [CSR.from_dict(csr) for csr in csrs]

@retry(
wait=wait_exponential(multiplier=1, min=5, max=20),
stop=stop_after_attempt(100),
)
async def wait_for_csrs_in_queue(self, csrs_count: int = 1) -> None:
"""Wait for the number of csrs in the queue to be equal to the number of csrs specified.
Arguments:
---------
csrs_count : int
The number of csrs to wait for in the queue.
Raises:
------
CSRsMissingError: If the number of csrs in the queue is less than the expected number.
"""
await self.get_outstanding_certificate_requests()
if len(self.csr_queue) < csrs_count:
message = f"{csrs_count - len(self.csr_queue)} CSRs missing in queue"
raise CSRsMissingError(message)

async def process_csr(self, csr: CSR) -> None:
"""Process the certificate signing request.
Arguments:
---------
csr : CSR
The certificate signing request to process.
Raises:
------
ProvidingCertificateFailedError: If providing the certificate fails.
"""
# Generate a certificate
certificate = generate_certificate(
csr=csr.csr,
ca=self.ca,
ca_key=self.ca_key,
is_ca=csr.is_ca,
)
logger.info("Generated certificate for %s", csr.unit_name)
# Send the certificate back to the charm
action = await self.tls_unit.run_action(
"provide-certificate",
relation_id=csr.relation_id,
**{
"certificate": base64.b64encode(certificate).decode(),
"ca-certificate": base64.b64encode(self.ca).decode(),
"certificate-signing-request": base64.b64encode(
csr.csr,
).decode(),
},
)
action = await action.wait()
if action.status != "completed":
message = f"Failed to provide certificate for {csr.unit_name}"
logging.error(message)
raise ProvidingCertificateFailedError(message)
logger.info("Provided certificate to %s", csr.unit_name)

async def process_queue(self) -> None:
"""Process the certificate signing requests in the queue."""
while self.csr_queue:
csr = self.csr_queue.pop()
await self.process_csr(csr)


async def main() -> None:
"""Run the ManualTLSAgent."""
logging.info("Starting ManualTLSAgent")
model = Model()
await model.connect()

tls_unit = model.applications[MANUAL_TLS_CERTIFICATES_APP_NAME].units[0]

agent = ManualTLSAgent(tls_unit)

while True:
await agent.wait_for_csrs_in_queue()
await agent.process_queue()
await asyncio.sleep(5)


if __name__ == "__main__":
asyncio.run(main())
100 changes: 100 additions & 0 deletions tests/integration/tls/test_manual_tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.

import logging

import pytest
from juju.application import Application
from pytest_operator.plugin import OpsTest

from ..helpers import APP_NAME, MODEL_CONFIG, SERIES, UNIT_IDS
from ..helpers_deployments import wait_until
from .helpers_manual_tls import MANUAL_TLS_CERTIFICATES_APP_NAME, ManualTLSAgent

logger = logging.getLogger(__name__)


@pytest.mark.runner(["self-hosted", "linux", "X64", "jammy", "large"])
@pytest.mark.group(1)
@pytest.mark.abort_on_fail
@pytest.mark.skip_if_deployed
async def test_build_and_deploy_with_manual_tls(ops_test: OpsTest) -> None:
"""Build and deploy prod cluster of OpenSearch with Manual TLS Operator integration."""
my_charm = await ops_test.build_charm(".")
await ops_test.model.set_config(MODEL_CONFIG)

os_app: Application = await ops_test.model.deploy(
my_charm,
num_units=len(UNIT_IDS),
series=SERIES,
application_name=APP_NAME,
)

# Deploy TLS Certificates operator.
tls_app: Application = await ops_test.model.deploy(
MANUAL_TLS_CERTIFICATES_APP_NAME,
channel="stable",
)
await wait_until(
ops_test,
apps=[MANUAL_TLS_CERTIFICATES_APP_NAME],
apps_statuses=["active"],
)
logger.info("Deployed %s application", MANUAL_TLS_CERTIFICATES_APP_NAME)

# Integrate it to OpenSearch to set up TLS.
await ops_test.model.integrate(APP_NAME, MANUAL_TLS_CERTIFICATES_APP_NAME)
logger.info("Integrated %s with %s", APP_NAME, MANUAL_TLS_CERTIFICATES_APP_NAME)

# Initialize the ManualTLSAgent to process the CSRs
manual_tls_daemon = ManualTLSAgent(tls_app.units[0])
# Wait for len(UNIT_IDS)*2+1 CSRs to be created.
# 1 for each unit for http and transport and 1 for the admin cert.
logger.info("Waiting for CSRs to be created")
await manual_tls_daemon.wait_for_csrs_in_queue(len(UNIT_IDS) * 2 + 1)

# Sign all CSRs
logger.info("Signing CSRs")
await manual_tls_daemon.process_queue()

await wait_until(
ops_test,
apps=[APP_NAME],
apps_statuses=["active"],
units_statuses=["active"],
wait_for_exact_units=len(UNIT_IDS),
timeout=2000,
)
assert len(ops_test.model.applications[APP_NAME].units) == len(UNIT_IDS)

# Scale up the application by adding a new unit
logger.info("Scaling up the application by adding a new unit")
await os_app.add_unit(1)

# Wait for the new unit to be in maintenance
logger.info("Waiting for the new unit to be in maintenance waiting for certificates")
await wait_until(
ops_test,
apps=[APP_NAME],
units_statuses=["active", "maintenance"],
wait_for_exact_units=len(UNIT_IDS) + 1,
)

# Wait for the new unit request certificates
logger.info("Waiting for the new unit to request certificates")
await manual_tls_daemon.wait_for_csrs_in_queue(2)

# Sign all CSRs
logger.info("Signing CSRs")
await manual_tls_daemon.process_queue()

# Wait for the new unit to be active
logger.info("Waiting for the new unit to be active")
await wait_until(
ops_test,
apps=[APP_NAME],
units_statuses=["active"],
wait_for_exact_units=len(UNIT_IDS) + 1,
)
assert len(ops_test.model.applications[APP_NAME].units) == len(UNIT_IDS) + 1

0 comments on commit 12ecb8b

Please sign in to comment.