From a10cbfd72a808c3f054dab2486c4904f9f120ccc Mon Sep 17 00:00:00 2001 From: Guillaume Boutry Date: Mon, 25 Sep 2023 18:17:30 +0200 Subject: [PATCH] [Fix] Implement schema validation with pydantic Use schema defined in charm-relation-interface/vault-kv to validate schema. Rename property-like methods on Requirer object to getters. --- lib/charms/vault_k8s/v0/vault_kv.py | 124 +++++++++++++++++++++++----- requirements.txt | 2 + 2 files changed, 107 insertions(+), 19 deletions(-) diff --git a/lib/charms/vault_k8s/v0/vault_kv.py b/lib/charms/vault_k8s/v0/vault_kv.py index 12c7f7f4..f1b48fa8 100644 --- a/lib/charms/vault_k8s/v0/vault_kv.py +++ b/lib/charms/vault_k8s/v0/vault_kv.py @@ -64,7 +64,7 @@ def _on_ready(self, event: vault_kv.VaultKvReadyEvent): unit_credentials = self.interface.unit_credentials(relation) # unit_credentials is a juju secret id secret = self.model.get_secret(id=unit_credentials) - secret_content = secret.get_content(refresh=True) + secret_content = secret.get_content() role_id = secret_content["role-id"] role_secret_id = secret_content["role-secret-id"] @@ -106,9 +106,11 @@ def _on_update_status(self, event): import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import ops +from interface_tester.schema_base import DataBagSchema +from pydantic import BaseModel, Field, Json, ValidationError logger = logging.getLogger(__name__) @@ -123,6 +125,82 @@ def _on_update_status(self, event): # to 0 if you are raising the major API version LIBPATCH = 1 +PYDEPS = ["pydantic", "pytest-interface-tester"] + + +class VaultKvProviderSchema(BaseModel): + """Provider side of the vault-kv interface.""" + + vault_url: str = Field(description="The URL of the Vault server to connect to.") + mount: str = Field( + description=( + "The KV mount available for the requirer application, " + "respecting the pattern 'charm--'." + ) + ) + ca_certificate: str = Field( + description="The CA certificate to use when validating the Vault server's certificate." + ) + credentials: Json[Mapping[str, str]] = Field( + description=( + "Mapping of unit name and credentials for that unit." + " Credentials are a juju secret containing a 'role-id' and a 'role-secret-id'." + ) + ) + + +class AppVaultKvRequirerSchema(BaseModel): + """App schema of the requirer side of the vault-kv interface.""" + + mount_suffix: str = Field( + description="Suffix to append to the mount name to get the KV mount." + ) + + +class UnitVaultKvRequirerSchema(BaseModel): + """Unit schema of the requirer side of the vault-kv interface.""" + + egress_subnet: str = Field(description="Egress subnet to use, in CIDR notation.") + nonce: str = Field( + description="Uniquely identifying value for this unit. `secrets.token_hex(16)` is recommended." + ) + + +class ProviderSchema(DataBagSchema): + """The schema for the provider side of this interface.""" + + app: VaultKvProviderSchema + + +class RequirerSchema(DataBagSchema): + """The schema for the requirer side of this interface.""" + + app: AppVaultKvRequirerSchema + unit: UnitVaultKvRequirerSchema + + +def is_requirer_data_valid(app_data: dict, unit_data: dict) -> bool: + """Return whether the requirer data is valid.""" + try: + RequirerSchema( + app=AppVaultKvRequirerSchema(**app_data), + unit=UnitVaultKvRequirerSchema(**unit_data), + ) + return True + except ValidationError as e: + logger.debug("Invalid data: %s", e) + return False + + +def is_provider_data_valid(data: dict) -> bool: + """Return whether the provider data is valid.""" + try: + ProviderSchema(app=VaultKvProviderSchema(**data)) + return True + except ValidationError as e: + logger.debug("Invalid data: %s", e) + return False + class NewVaultKvClientAttachedEvent(ops.EventBase): """New vault kv client attached event.""" @@ -180,18 +258,30 @@ def __init__( ) def _on_relation_changed(self, event: ops.RelationChangedEvent): - """Handle client changed relation.""" + """Handle client changed relation. + + This handler will emit a new_vault_kv_client_attached event if at least one unit data is + valid. + """ if event.app is None: logger.debug("No remote application yet") return - mount_suffix = event.relation.data[event.app].get("mount_suffix") + app_data = dict(event.relation.data[event.app]) + + any_valid = False + for unit in event.relation.units: + if not is_requirer_data_valid(app_data, dict(event.relation.data[unit])): + logger.debug("Invalid data from unit %r", unit.name) + print(False) + continue + any_valid = True - if mount_suffix is not None: + if any_valid: self.on.new_vault_kv_client_attached.emit( event.relation.id, event.relation.name, - mount_suffix, + event.relation.data[event.app]["mount_suffix"], ) def set_vault_url(self, relation: ops.Relation, vault_url: str): @@ -394,18 +484,14 @@ def _on_vault_kv_relation_changed(self, event: ops.RelationChangedEvent): logger.debug("No remote application yet") return - vault_url = self.vault_url(event.relation) - ca_certificate = self.ca_certificate(event.relation) - mount = self.mount(event.relation) - unit_credentials_secret = self.unit_credentials(event.relation) - if all((vault_url, ca_certificate, mount, unit_credentials_secret)): + if is_provider_data_valid(dict(event.relation.data[event.app])): self.on.ready.emit( event.relation.id, event.relation.name, - vault_url, - ca_certificate, - mount, - unit_credentials_secret, + self.get_vault_url(event.relation), + self.get_ca_certificate(event.relation), + self.get_mount(event.relation), + self.get_unit_credentials(event.relation), ) def _on_vault_kv_relation_broken(self, event: ops.RelationBrokenEvent): @@ -424,25 +510,25 @@ def request_credentials(self, relation: ops.Relation, egress_subnet: str) -> Non self._set_unit_egress_subnet(relation, egress_subnet) self._set_unit_nonce(relation, self.nonce) - def vault_url(self, relation: ops.Relation) -> Optional[str]: + def get_vault_url(self, relation: ops.Relation) -> Optional[str]: """Return the vault_url from the relation.""" if relation.app is None: return None return relation.data[relation.app].get("vault_url") - def ca_certificate(self, relation: ops.Relation) -> Optional[str]: + def get_ca_certificate(self, relation: ops.Relation) -> Optional[str]: """Return the ca_certificate from the relation.""" if relation.app is None: return None return relation.data[relation.app].get("ca_certificate") - def mount(self, relation: ops.Relation) -> Optional[str]: + def get_mount(self, relation: ops.Relation) -> Optional[str]: """Return the mount from the relation.""" if relation.app is None: return None return relation.data[relation.app].get("mount") - def unit_credentials(self, relation: ops.Relation) -> Optional[str]: + def get_unit_credentials(self, relation: ops.Relation) -> Optional[str]: """Return the unit credentials from the relation. Unit credentials are stored in the relation data as a Juju secret id. diff --git a/requirements.txt b/requirements.txt index 5c66f4c8..963fe244 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ hvac jsonschema lightkube lightkube-models +pydantic +pytest-interface-tester requests jsonschema cryptography