Skip to content

Commit

Permalink
[Fix] Implement schema validation with pydantic
Browse files Browse the repository at this point in the history
Use schema defined in charm-relation-interface/vault-kv to validate
schema.

Rename property-like methods on Requirer object to getters.
  • Loading branch information
gboutry committed Sep 25, 2023
1 parent 87f9955 commit a10cbfd
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 19 deletions.
124 changes: 105 additions & 19 deletions lib/charms/vault_k8s/v0/vault_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _on_ready(self, event: vault_kv.VaultKvReadyEvent):
unit_credentials = self.interface.unit_credentials(relation)
# unit_credentials is a juju secret id
secret = self.model.get_secret(id=unit_credentials)
secret_content = secret.get_content(refresh=True)
secret_content = secret.get_content()
role_id = secret_content["role-id"]
role_secret_id = secret_content["role-secret-id"]
Expand Down Expand Up @@ -106,9 +106,11 @@ def _on_update_status(self, event):

import json
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import ops
from interface_tester.schema_base import DataBagSchema
from pydantic import BaseModel, Field, Json, ValidationError

logger = logging.getLogger(__name__)

Expand All @@ -123,6 +125,82 @@ def _on_update_status(self, event):
# to 0 if you are raising the major API version
LIBPATCH = 1

PYDEPS = ["pydantic", "pytest-interface-tester"]


class VaultKvProviderSchema(BaseModel):
"""Provider side of the vault-kv interface."""

vault_url: str = Field(description="The URL of the Vault server to connect to.")
mount: str = Field(
description=(
"The KV mount available for the requirer application, "
"respecting the pattern 'charm-<requirer app>-<user provided suffix>'."
)
)
ca_certificate: str = Field(
description="The CA certificate to use when validating the Vault server's certificate."
)
credentials: Json[Mapping[str, str]] = Field(
description=(
"Mapping of unit name and credentials for that unit."
" Credentials are a juju secret containing a 'role-id' and a 'role-secret-id'."
)
)


class AppVaultKvRequirerSchema(BaseModel):
"""App schema of the requirer side of the vault-kv interface."""

mount_suffix: str = Field(
description="Suffix to append to the mount name to get the KV mount."
)


class UnitVaultKvRequirerSchema(BaseModel):
"""Unit schema of the requirer side of the vault-kv interface."""

egress_subnet: str = Field(description="Egress subnet to use, in CIDR notation.")
nonce: str = Field(
description="Uniquely identifying value for this unit. `secrets.token_hex(16)` is recommended."
)


class ProviderSchema(DataBagSchema):
"""The schema for the provider side of this interface."""

app: VaultKvProviderSchema


class RequirerSchema(DataBagSchema):
"""The schema for the requirer side of this interface."""

app: AppVaultKvRequirerSchema
unit: UnitVaultKvRequirerSchema


def is_requirer_data_valid(app_data: dict, unit_data: dict) -> bool:
"""Return whether the requirer data is valid."""
try:
RequirerSchema(
app=AppVaultKvRequirerSchema(**app_data),
unit=UnitVaultKvRequirerSchema(**unit_data),
)
return True
except ValidationError as e:
logger.debug("Invalid data: %s", e)
return False


def is_provider_data_valid(data: dict) -> bool:
"""Return whether the provider data is valid."""
try:
ProviderSchema(app=VaultKvProviderSchema(**data))
return True
except ValidationError as e:
logger.debug("Invalid data: %s", e)
return False


class NewVaultKvClientAttachedEvent(ops.EventBase):
"""New vault kv client attached event."""
Expand Down Expand Up @@ -180,18 +258,30 @@ def __init__(
)

def _on_relation_changed(self, event: ops.RelationChangedEvent):
"""Handle client changed relation."""
"""Handle client changed relation.
This handler will emit a new_vault_kv_client_attached event if at least one unit data is
valid.
"""
if event.app is None:
logger.debug("No remote application yet")
return

mount_suffix = event.relation.data[event.app].get("mount_suffix")
app_data = dict(event.relation.data[event.app])

any_valid = False
for unit in event.relation.units:
if not is_requirer_data_valid(app_data, dict(event.relation.data[unit])):
logger.debug("Invalid data from unit %r", unit.name)
print(False)
continue
any_valid = True

if mount_suffix is not None:
if any_valid:
self.on.new_vault_kv_client_attached.emit(
event.relation.id,
event.relation.name,
mount_suffix,
event.relation.data[event.app]["mount_suffix"],
)

def set_vault_url(self, relation: ops.Relation, vault_url: str):
Expand Down Expand Up @@ -394,18 +484,14 @@ def _on_vault_kv_relation_changed(self, event: ops.RelationChangedEvent):
logger.debug("No remote application yet")
return

vault_url = self.vault_url(event.relation)
ca_certificate = self.ca_certificate(event.relation)
mount = self.mount(event.relation)
unit_credentials_secret = self.unit_credentials(event.relation)
if all((vault_url, ca_certificate, mount, unit_credentials_secret)):
if is_provider_data_valid(dict(event.relation.data[event.app])):
self.on.ready.emit(
event.relation.id,
event.relation.name,
vault_url,
ca_certificate,
mount,
unit_credentials_secret,
self.get_vault_url(event.relation),
self.get_ca_certificate(event.relation),
self.get_mount(event.relation),
self.get_unit_credentials(event.relation),
)

def _on_vault_kv_relation_broken(self, event: ops.RelationBrokenEvent):
Expand All @@ -424,25 +510,25 @@ def request_credentials(self, relation: ops.Relation, egress_subnet: str) -> Non
self._set_unit_egress_subnet(relation, egress_subnet)
self._set_unit_nonce(relation, self.nonce)

def vault_url(self, relation: ops.Relation) -> Optional[str]:
def get_vault_url(self, relation: ops.Relation) -> Optional[str]:
"""Return the vault_url from the relation."""
if relation.app is None:
return None
return relation.data[relation.app].get("vault_url")

def ca_certificate(self, relation: ops.Relation) -> Optional[str]:
def get_ca_certificate(self, relation: ops.Relation) -> Optional[str]:
"""Return the ca_certificate from the relation."""
if relation.app is None:
return None
return relation.data[relation.app].get("ca_certificate")

def mount(self, relation: ops.Relation) -> Optional[str]:
def get_mount(self, relation: ops.Relation) -> Optional[str]:
"""Return the mount from the relation."""
if relation.app is None:
return None
return relation.data[relation.app].get("mount")

def unit_credentials(self, relation: ops.Relation) -> Optional[str]:
def get_unit_credentials(self, relation: ops.Relation) -> Optional[str]:
"""Return the unit credentials from the relation.
Unit credentials are stored in the relation data as a Juju secret id.
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ hvac
jsonschema
lightkube
lightkube-models
pydantic
pytest-interface-tester
requests
jsonschema
cryptography

0 comments on commit a10cbfd

Please sign in to comment.