From a2c4226cfa68012829e2cc4eb19e25f9437b3f80 Mon Sep 17 00:00:00 2001 From: Dragomir Penev Date: Tue, 26 Sep 2023 13:37:31 +0300 Subject: [PATCH] Bump libs --- .../data_platform_libs/v0/data_interfaces.py | 899 ++++++++++++++++-- 1 file changed, 802 insertions(+), 97 deletions(-) diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index d894130e2b..9fa0021ec9 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -291,19 +291,23 @@ def _on_topic_requested(self, event: TopicRequestedEvent): exchanged in the relation databag. """ +import copy import json import logging from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime -from typing import List, Optional, Union +from enum import Enum +from typing import Dict, List, Optional, Set, Union +from ops import JujuVersion, Secret, SecretInfo, SecretNotFoundError from ops.charm import ( CharmBase, CharmEvents, RelationChangedEvent, RelationCreatedEvent, RelationEvent, + SecretChangedEvent, ) from ops.framework import EventSource, Object from ops.model import Application, ModelError, Relation, Unit @@ -316,7 +320,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 17 +LIBPATCH = 18 PYDEPS = ["ops>=2.0.0"] @@ -331,6 +335,58 @@ def _on_topic_requested(self, event: TopicRequestedEvent): deleted - key that were deleted""" +PROV_SECRET_PREFIX = "secret-" +REQ_SECRET_FIELDS = "requested-secrets" + + +class SecretGroup(Enum): + """Secret groups as constants.""" + + USER = "user" + TLS = "tls" + EXTRA = "extra" + + +# Local map to associate mappings with secrets potentially as a group +SECRET_LABEL_MAP = { + "username": SecretGroup.USER, + "password": SecretGroup.USER, + "uris": SecretGroup.USER, + "tls": SecretGroup.TLS, + "tls-ca": SecretGroup.TLS, +} + + +class DataInterfacesError(Exception): + """Common ancestor for DataInterfaces related exceptions.""" + + +class SecretError(Exception): + """Common ancestor for Secrets related exceptions.""" + + +class SecretAlreadyExistsError(SecretError): + """A secret that was to be added already exists.""" + + +class SecretsUnavailableError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class SecretsIllegalUpdateError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +def get_encoded_field(relation, member, field) -> Dict[str, str]: + """Retrieve and decode an encoded field from relation data.""" + return json.loads(relation.data[member].get(field, "{}")) + + +def set_encoded_field(relation, member, field, value) -> None: + """Set an encoded field from relation data.""" + relation.data[member].update({field: json.dumps(value)}) + + def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff: """Retrieves the diff of the data in the relation changed databag. @@ -343,7 +399,7 @@ def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff: keys from the event relation databag. """ # Retrieve the old data from the data key in the application relation databag. - old_data = json.loads(event.relation.data[bucket].get("data", "{}")) + old_data = get_encoded_field(event.relation, bucket, "data") # Retrieve the new data from the event relation databag. new_data = ( {key: value for key, value in event.relation.data[event.app].items() if key != "data"} @@ -359,17 +415,132 @@ def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff: # but had their values changed. changed = {key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key]} # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[bucket].update({"data": json.dumps(new_data)}) + set_encoded_field(event.relation, bucket, "data", new_data) # Return the diff with all possible changes. return Diff(added, changed, deleted) +def leader_only(f): + """Decorator to ensure that only leader can perform given operation.""" + + def wrapper(self, *args, **kwargs): + if not self.local_unit.is_leader(): + return + return f(self, *args, **kwargs) + + return wrapper + + +def juju_secrets_only(f): + """Decorator to ensure that certain operations would be only executed on Juju3.""" + + def wrapper(self, *args, **kwargs): + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + return f(self, *args, **kwargs) + + return wrapper + + +class Scope(Enum): + """Peer relations scope.""" + + APP = "app" + UNIT = "unit" + + +class CachedSecret: + """Locally cache a secret. + + The data structure is precisely re-using/simulating as in the actual Secret Storage + """ + + def __init__(self, charm: CharmBase, label: str, secret_uri: Optional[str] = None): + self._secret_meta = None + self._secret_content = {} + self._secret_uri = secret_uri + self.label = label + self.charm = charm + + def add_secret(self, content: Dict[str, str], relation: Relation) -> Secret: + """Create a new secret.""" + if self._secret_uri: + raise SecretAlreadyExistsError( + "Secret is already defined with uri %s", self._secret_uri + ) + + secret = self.charm.app.add_secret(content, label=self.label) + secret.grant(relation) + self._secret_uri = secret.id + self._secret_meta = secret + return self._secret_meta + + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + try: + self._secret_meta = self.charm.model.get_secret(label=self.label) + except SecretNotFoundError: + if self._secret_uri: + self._secret_meta = self.charm.model.get_secret( + id=self._secret_uri, label=self.label + ) + return self._secret_meta + + def get_content(self) -> Dict[str, str]: + """Getting cached secret content.""" + if not self._secret_content: + if self.meta: + self._secret_content = self.meta.get_content() + return self._secret_content + + def set_content(self, content: Dict[str, str]) -> None: + """Setting cached secret content.""" + if self.meta: + self.meta.set_content(content) + self._secret_content = content + + def get_info(self) -> Optional[SecretInfo]: + """Wrapper function to apply the corresponding call on the Secret object within CachedSecret if any.""" + if self.meta: + return self.meta.get_info() + + +class SecretCache: + """A data structure storing CachedSecret objects.""" + + def __init__(self, charm): + self.charm = charm + self._secrets: Dict[str, CachedSecret] = {} + + def get(self, label: str, uri: Optional[str] = None) -> Optional[CachedSecret]: + """Getting a secret from Juju Secret store or cache.""" + if not self._secrets.get(label): + secret = CachedSecret(self.charm, label, uri) + if secret.meta: + self._secrets[label] = secret + return self._secrets.get(label) + + def add(self, label: str, content: Dict[str, str], relation: Relation) -> CachedSecret: + """Adding a secret to Juju Secret.""" + if self._secrets.get(label): + raise SecretAlreadyExistsError(f"Secret {label} already exists") + + secret = CachedSecret(self.charm, label) + secret.add_secret(content, relation) + self._secrets[label] = secret + return self._secrets[label] + + # Base DataRelation class DataRelation(Object, ABC): - """Base relation data mainpulation class.""" + """Base relation data mainpulation (abstract) class.""" def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) @@ -381,13 +552,143 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: charm.on[relation_name].relation_changed, self._on_relation_changed_event, ) + self._jujuversion = None + self.secrets = SecretCache(self.charm) + + @property + def relations(self) -> List[Relation]: + """The list of Relation instances associated with this relation_name.""" + return [ + relation + for relation in self.charm.model.relations[self.relation_name] + if self._is_relation_active(relation) + ] + + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + if not self._jujuversion: + self._jujuversion = JujuVersion.from_environ() + return self._jujuversion.has_secrets + + # Mandatory overrides for internal/helper methods @abstractmethod def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation data has changed.""" raise NotImplementedError - def fetch_relation_data(self) -> dict: + @abstractmethod + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + raise NotImplementedError + + @abstractmethod + def _fetch_specific_relation_data( + self, relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" + raise NotImplementedError + + # Internal helper methods + + @staticmethod + def _is_relation_active(relation: Relation): + """Whether the relation is active based on contained data.""" + try: + _ = repr(relation.data) + return True + except (RuntimeError, ModelError): + return False + + @staticmethod + def _is_secret_field(field: str) -> bool: + """Is the field in question a secret reference (URI) field or not?""" + return field.startswith(PROV_SECRET_PREFIX) + + @staticmethod + def _generate_secret_label( + relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{relation_name}.{relation_id}.{group_mapping.value}.secret" + + @staticmethod + def _generate_secret_field_name(group_mapping: SecretGroup) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{PROV_SECRET_PREFIX}{group_mapping.value}" + + def _relation_from_secret_label(self, secret_label: str) -> Optional[Relation]: + """Retrieve the relation that belongs to a secret label.""" + contents = secret_label.split(".") + + if not (contents and len(contents) >= 3): + return + + contents.pop() # ".secret" at the end + contents.pop() # Group mapping + relation_id = contents.pop() + try: + relation_id = int(relation_id) + except ValueError: + return + + # In case '.' character appeared in relation name + relation_name = ".".join(contents) + + try: + return self.get_relation(relation_name, relation_id) + except ModelError: + return + + @staticmethod + def _group_secret_fields(secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. + + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + if group := SECRET_LABEL_MAP.get(key): + secret_fieldnames_grouped.setdefault(group, []).append(key) + else: + secret_fieldnames_grouped.setdefault(SecretGroup.EXTRA, []).append(key) + return secret_fieldnames_grouped + + @juju_secrets_only + def _get_relation_secret_data( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[Dict[str, str]]: + """Retrieve contents of a Juju Secret that's been stored in the relation databag.""" + secret = self._get_relation_secret(relation_id, group_mapping, relation_name) + if secret: + return secret.get_content() + + # Public methods + + def get_relation(self, relation_name, relation_id) -> Relation: + """Safe way of retrieving a relation.""" + relation = self.charm.model.get_relation(relation_name, relation_id) + + if not relation: + raise DataInterfacesError( + "Relation %s %s couldn't be retrieved", relation_name, relation_id + ) + + if not relation.app: + raise DataInterfacesError("Relation's application missing") + + return relation + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: """Retrieves data from relation. This function can be used to retrieve data from a relation @@ -398,48 +699,29 @@ def fetch_relation_data(self) -> dict: a dict of the values stored in the relation data bag for all relation instances (indexed by the relation ID). """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + data = {} - for relation in self.relations: - data[relation.id] = ( - {key: value for key, value in relation.data[relation.app].items() if key != "data"} - if relation.app - else {} - ) + for relation in relations: + if not relation_ids or (relation_ids and relation.id in relation_ids): + data[relation.id] = self._fetch_specific_relation_data(relation, fields) return data - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - if relation: - relation.data[self.local_app].update(data) - - @staticmethod - def _is_relation_active(relation: Relation): - """Whether the relation is active based on contained data.""" - try: - _ = repr(relation.data) - return True - except (RuntimeError, ModelError): - return False + # Public methods - mandatory override - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return [ - relation - for relation in self.charm.model.relations[self.relation_name] - if self._is_relation_active(relation) - ] + @abstractmethod + def update_relation_data(self, relation_id: int, data: dict) -> None: + """Update the data within the relation.""" + raise NotImplementedError # Base DataProvides and DataRequires @@ -463,6 +745,127 @@ def _diff(self, event: RelationChangedEvent) -> Diff: """ return diff(event, self.local_app) + # Private methods handling secrets + + @leader_only + @juju_secrets_only + def _add_relation_secret( + self, relation_id: int, content: Dict[str, str], group_mapping: SecretGroup + ) -> Optional[Secret]: + """Add a new Juju Secret that will be registered in the relation databag.""" + relation = self.get_relation(self.relation_name, relation_id) + + secret_field = self._generate_secret_field_name(group_mapping) + if relation.data[self.local_app].get(secret_field): + logging.error("Secret for relation %s already exists, not adding again", relation_id) + return + + label = self._generate_secret_label(self.relation_name, relation_id, group_mapping) + secret = self.secrets.add(label, content, relation) + + # According to lint we may not have a Secret ID + if secret.meta and secret.meta.id: + relation.data[self.local_app][secret_field] = secret.meta.id + + @leader_only + @juju_secrets_only + def _update_relation_secret( + self, relation_id: int, content: Dict[str, str], group_mapping: SecretGroup + ): + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation_id, group_mapping) + + if not secret: + logging.error("Can't update secret for relation %s", relation_id) + return + + old_content = secret.get_content() + full_content = copy.deepcopy(old_content) + full_content.update(content) + secret.set_content(full_content) + + @staticmethod + def _secret_content_grouped( + content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + if group_mapping == SecretGroup.EXTRA: + return { + k: v + for k, v in content.items() + if k in secret_fields and k not in SECRET_LABEL_MAP.keys() + } + + return { + k: v + for k, v in content.items() + if k in secret_fields and SECRET_LABEL_MAP.get(k) == group_mapping + } + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + if secret := self.secrets.get(label): + return secret + + relation = self.charm.model.get_relation(relation_name, relation_id) + if not relation: + return + + secret_field = self._generate_secret_field_name(group_mapping) + if secret_uri := relation.data[self.local_app].get(secret_field): + return self.secrets.get(label, secret_uri) + + def _fetch_specific_relation_data(self, relation, fields: Optional[List[str]]) -> dict: + """Fetching relation data for Provides. + + NOTE: Since all secret fields are in the Requires side of the databag, we don't need to worry about that + """ + if not relation.app: + return {} + + if fields: + return {k: relation.data[relation.app].get(k) for k in fields} + else: + return relation.data[relation.app] + + # Public methods -- mandatory overrides + + @leader_only + def update_relation_data(self, relation_id: int, fields: Dict[str, str]) -> None: + """Set values for fields not caring whether it's a secret or not.""" + relation = self.get_relation(self.relation_name, relation_id) + + relation_secret_fields = get_encoded_field(relation, relation.app, REQ_SECRET_FIELDS) + + normal_fields = list(fields) + if relation_secret_fields and self.secrets_enabled: + normal_fields = set(fields.keys()) - set(relation_secret_fields) + secret_fields = set(fields.keys()) - set(normal_fields) + + secret_fieldnames_grouped = self._group_secret_fields(list(secret_fields)) + + for group in secret_fieldnames_grouped: + secret_content = self._secret_content_grouped(fields, secret_fields, group) + if self._get_relation_secret(relation_id, group): + self._update_relation_secret(relation_id, secret_content, group) + else: + self._add_relation_secret(relation_id, secret_content, group) + + normal_content = {k: v for k, v in fields.items() if k in normal_fields} + relation.data[self.local_app].update( # pyright: ignore [reportGeneralTypeIssues] + normal_content + ) + + # Public methods - "native" + def set_credentials(self, relation_id: int, username: str, password: str) -> None: """Set credentials. @@ -474,13 +877,7 @@ def set_credentials(self, relation_id: int, username: str, password: str) -> Non username: user that was created. password: password of the created user. """ - self._update_relation_data( - relation_id, - { - "username": username, - "password": password, - }, - ) + self.update_relation_data(relation_id, {"username": username, "password": password}) def set_tls(self, relation_id: int, tls: str) -> None: """Set whether TLS is enabled. @@ -489,7 +886,7 @@ def set_tls(self, relation_id: int, tls: str) -> None: relation_id: the identifier for a particular relation. tls: whether tls is enabled (True or False). """ - self._update_relation_data(relation_id, {"tls": tls}) + self.update_relation_data(relation_id, {"tls": tls}) def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: """Set the TLS CA in the application relation databag. @@ -498,29 +895,41 @@ def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: relation_id: the identifier for a particular relation. tls_ca: TLS certification authority. """ - self._update_relation_data(relation_id, {"tls-ca": tls_ca}) + self.update_relation_data(relation_id, {"tls-ca": tls_ca}) class DataRequires(DataRelation): """Requires-side of the relation.""" + SECRET_FIELDS = ["username", "password", "tls", "tls-ca", "uris"] + def __init__( self, charm, relation_name: str, extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of base client relations.""" super().__init__(charm, relation_name) self.extra_user_roles = extra_user_roles + self._secret_fields = list(self.SECRET_FIELDS) + if additional_secret_fields: + self._secret_fields += additional_secret_fields + self.framework.observe( self.charm.on[relation_name].relation_created, self._on_relation_created_event ) + self.framework.observe( + charm.on.secret_changed, + self._on_secret_changed_event, + ) - @abstractmethod - def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: - """Event emitted when the relation is created.""" - raise NotImplementedError + @property + def secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._secret_fields def _diff(self, event: RelationChangedEvent) -> Diff: """Retrieves the diff of the data in the relation changed databag. @@ -534,14 +943,50 @@ def _diff(self, event: RelationChangedEvent) -> Diff: """ return diff(event, self.local_unit) - @staticmethod - def _is_resource_created_for_relation(relation: Relation) -> bool: + # Internal helper functions + + def _register_secret_to_relation( + self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup + ): + """Fetch secrets and apply local label on them. + + [MAGIC HERE] + If we fetch a secret using get_secret(id=, label=), + then will be "stuck" on the Secret object, whenever it may + appear (i.e. as an event attribute, or fetched manually) on future occasions. + + This will allow us to uniquely identify the secret on Provides side (typically on + 'secret-changed' events), and map it to the corresponding relation. + """ + label = self._generate_secret_label(relation_name, relation_id, group) + + # Fetchin the Secret's meta information ensuring that it's locally getting registered with + CachedSecret(self.charm, label, secret_id).meta + + def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): + """Make sure that secrets of the provided list are locally 'registered' from the databag. + + More on 'locally registered' magic is described in _register_secret_to_relation() method + """ + if not relation.app: + return + + for group in SecretGroup: + secret_field = self._generate_secret_field_name(group) + if secret_field in params_name_list: + if secret_uri := relation.data[relation.app].get(secret_field): + self._register_secret_to_relation( + relation.name, relation.id, secret_uri, group + ) + + def _is_resource_created_for_relation(self, relation: Relation) -> bool: if not relation.app: return False - return ( - "username" in relation.data[relation.app] and "password" in relation.data[relation.app] + data = self.fetch_relation_data([relation.id], ["username", "password"]).get( + relation.id, {} ) + return bool(data.get("username")) and bool(data.get("password")) def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -576,6 +1021,132 @@ def is_resource_created(self, relation_id: Optional[int] = None) -> bool: else False ) + def _retrieve_group_secret_contents( + self, + relation_id, + group: SecretGroup, + secret_fields: Optional[Union[Set[str], List[str]]] = None, + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + if not secret_fields: + secret_fields = [] + + if (secret := self._get_relation_secret(relation_id, group)) and ( + secret_data := secret.get_content() + ): + return {k: v for k, v in secret_data.items() if k in secret_fields} + return {} + + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + if not self.local_unit.is_leader(): + return + + if self.secret_fields: + set_encoded_field( + event.relation, self.charm.app, REQ_SECRET_FIELDS, self.secret_fields + ) + + @abstractmethod + def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group) + return self.secrets.get(label) + + def _fetch_specific_relation_data( + self, relation, fields: Optional[List[str]] = None + ) -> Dict[str, str]: + if not relation.app: + return {} + + result = {} + + normal_fields = fields + if not normal_fields: + normal_fields = list(relation.data[relation.app].keys()) + + if self.secret_fields and self.secrets_enabled: + if fields: + # Processing from what was requested + normal_fields = set(fields) - set(self.secret_fields) + secret_fields = set(fields) - set(normal_fields) + + secret_fieldnames_grouped = self._group_secret_fields(list(secret_fields)) + + for group in secret_fieldnames_grouped: + if contents := self._retrieve_group_secret_contents( + relation.id, group, secret_fields + ): + result.update(contents) + else: + # If it wasn't found as a secret, let's give it a 2nd chance as "normal" field + normal_fields |= set(secret_fieldnames_grouped[group]) + else: + # Processing from what is given, i.e. retrieving all + normal_fields = [ + f for f in relation.data[relation.app].keys() if not self._is_secret_field(f) + ] + secret_fields = [ + f for f in relation.data[relation.app].keys() if self._is_secret_field(f) + ] + for group in SecretGroup: + result.update( + self._retrieve_group_secret_contents( + relation.id, group, self.secret_fields + ) + ) + + # Processing "normal" fields. May include leftover from what we couldn't retrieve as a secret. + result.update({k: relation.data[relation.app].get(k) for k in normal_fields}) + return result + + # Public methods -- mandatory overrides + + @leader_only + def update_relation_data(self, relation_id: int, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + if any(self._is_secret_field(key) for key in data.keys()): + raise SecretsIllegalUpdateError("Requires side can't update secrets.") + + relation = self.charm.model.get_relation(self.relation_name, relation_id) + if relation: + relation.data[self.local_app].update(data) + + # "Native" public methods + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data.""" + return ( + self.fetch_relation_data([relation_id], [field], relation_name) + .get(relation_id, {}) + .get(field) + ) + # General events @@ -593,7 +1164,50 @@ def extra_user_roles(self) -> Optional[str]: class AuthenticationEvent(RelationEvent): - """Base class for authentication fields for events.""" + """Base class for authentication fields for events. + + The amount of logic added here is not ideal -- but this was the only way to preserve + the interface when moving to Juju Secrets + """ + + @property + def _secrets(self) -> dict: + """Caching secrets to avoid fetching them each time a field is referrd. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_secrets"): + self._cached_secrets = {} + return self._cached_secrets + + @property + def _jujuversion(self) -> JujuVersion: + """Caching jujuversion to avoid a Juju call on each field evaluation. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_jujuversion"): + self._cached_jujuversion = None + if not self._cached_jujuversion: + self._cached_jujuversion = JujuVersion.from_environ() + return self._cached_jujuversion + + def _get_secret(self, group) -> Optional[Dict[str, str]]: + """Retrieveing secrets.""" + if not self.app: + return + if not self._secrets.get(group): + self._secrets[group] = None + secret_field = f"{PROV_SECRET_PREFIX}{group}" + if secret_uri := self.relation.data[self.app].get(secret_field): + secret = self.framework.model.get_secret(id=secret_uri) + self._secrets[group] = secret.get_content() + return self._secrets[group] + + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + return self._jujuversion.has_secrets @property def username(self) -> Optional[str]: @@ -601,6 +1215,11 @@ def username(self) -> Optional[str]: if not self.relation.app: return None + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("username") + return self.relation.data[self.relation.app].get("username") @property @@ -609,6 +1228,11 @@ def password(self) -> Optional[str]: if not self.relation.app: return None + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("password") + return self.relation.data[self.relation.app].get("password") @property @@ -617,6 +1241,11 @@ def tls(self) -> Optional[str]: if not self.relation.app: return None + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls") + return self.relation.data[self.relation.app].get("tls") @property @@ -625,6 +1254,11 @@ def tls_ca(self) -> Optional[str]: if not self.relation.app: return None + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls-ca") + return self.relation.data[self.relation.app].get("tls-ca") @@ -761,10 +1395,9 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return - # Check which data has changed to emit customs events. diff = self._diff(event) @@ -785,7 +1418,7 @@ def set_database(self, relation_id: int, database_name: str) -> None: relation_id: the identifier for a particular relation. database_name: database name. """ - self._update_relation_data(relation_id, {"database": database_name}) + self.update_relation_data(relation_id, {"database": database_name}) def set_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database primary connections. @@ -801,7 +1434,7 @@ def set_endpoints(self, relation_id: int, connection_strings: str) -> None: relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"endpoints": connection_strings}) + self.update_relation_data(relation_id, {"endpoints": connection_strings}) def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database replicas connection strings. @@ -813,7 +1446,7 @@ def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) + self.update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) def set_replset(self, relation_id: int, replset: str) -> None: """Set replica set name in the application relation databag. @@ -824,7 +1457,7 @@ def set_replset(self, relation_id: int, replset: str) -> None: relation_id: the identifier for a particular relation. replset: replica set name. """ - self._update_relation_data(relation_id, {"replset": replset}) + self.update_relation_data(relation_id, {"replset": replset}) def set_uris(self, relation_id: int, uris: str) -> None: """Set the database connection URIs in the application relation databag. @@ -835,7 +1468,7 @@ def set_uris(self, relation_id: int, uris: str) -> None: relation_id: the identifier for a particular relation. uris: connection URIs. """ - self._update_relation_data(relation_id, {"uris": uris}) + self.update_relation_data(relation_id, {"uris": uris}) def set_version(self, relation_id: int, version: str) -> None: """Set the database version in the application relation databag. @@ -844,7 +1477,7 @@ def set_version(self, relation_id: int, version: str) -> None: relation_id: the identifier for a particular relation. version: database version. """ - self._update_relation_data(relation_id, {"version": version}) + self.update_relation_data(relation_id, {"version": version}) class DatabaseRequires(DataRequires): @@ -859,9 +1492,10 @@ def __init__( database_name: str, extra_user_roles: Optional[str] = None, relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of database client relations.""" - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.database = database_name self.relations_aliases = relations_aliases @@ -886,6 +1520,10 @@ def __init__( DatabaseReadOnlyEndpointsChangedEvent, ) + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass + def _assign_relation_alias(self, relation_id: int) -> None: """Assigns an alias to a relation. @@ -962,16 +1600,21 @@ def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> if len(self.relations) == 0: return False - relation_data = self.fetch_relation_data()[self.relations[relation_index].id] - host = relation_data.get("endpoints") + relation_id = self.relations[relation_index].id + host = self.fetch_relation_field(relation_id, "endpoints") # Return False if there is no endpoint available. if host is None: return False host = host.split(":")[0] - user = relation_data.get("username") - password = relation_data.get("password") + + content = self.fetch_relation_data([relation_id], ["username", "password"]).get( + relation_id, {} + ) + user = content.get("username") + password = content.get("password") + connection_string = ( f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" ) @@ -990,13 +1633,15 @@ def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the database relation is created.""" + super()._on_relation_created_event(event) + # If relations aliases were provided, assign one to the relation. self._assign_relation_alias(event.relation.id) # Sets both database and extra user roles in the relation # if the roles are provided. Otherwise, sets only the database. if self.extra_user_roles: - self._update_relation_data( + self.update_relation_data( event.relation.id, { "database": self.database, @@ -1004,16 +1649,23 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: }, ) else: - self._update_relation_data(event.relation.id, {"database": self.database}) + self.update_relation_data(event.relation.id, {"database": self.database}) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" # Check which data has changed to emit customs events. diff = self._diff(event) + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + # Check if the database is created # (the database charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("database created at %s", datetime.now()) getattr(self.on, "database_created").emit( @@ -1159,7 +1811,7 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return @@ -1180,7 +1832,7 @@ def set_topic(self, relation_id: int, topic: str) -> None: relation_id: the identifier for a particular relation. topic: the topic name. """ - self._update_relation_data(relation_id, {"topic": topic}) + self.update_relation_data(relation_id, {"topic": topic}) def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: """Set the bootstrap server in the application relation databag. @@ -1189,7 +1841,7 @@ def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: relation_id: the identifier for a particular relation. bootstrap_server: the bootstrap server address. """ - self._update_relation_data(relation_id, {"endpoints": bootstrap_server}) + self.update_relation_data(relation_id, {"endpoints": bootstrap_server}) def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str) -> None: """Set the consumer group prefix in the application relation databag. @@ -1198,7 +1850,7 @@ def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str relation_id: the identifier for a particular relation. consumer_group_prefix: the consumer group prefix string. """ - self._update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) + self.update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: """Set the zookeeper uris in the application relation databag. @@ -1207,7 +1859,7 @@ def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: relation_id: the identifier for a particular relation. zookeeper_uris: comma-separated list of ZooKeeper server uris. """ - self._update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) + self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) class KafkaRequires(DataRequires): @@ -1222,10 +1874,11 @@ def __init__( topic: str, extra_user_roles: Optional[str] = None, consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of Kafka client relations.""" # super().__init__(charm, relation_name) - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.charm = charm self.topic = topic self.consumer_group_prefix = consumer_group_prefix or "" @@ -1244,13 +1897,19 @@ def topic(self, value): def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the Kafka relation is created.""" + super()._on_relation_created_event(event) + # Sets topic, extra user roles, and "consumer-group-prefix" in the relation relation_data = { f: getattr(self, f.replace("-", "_"), "") for f in ["consumer-group-prefix", "extra-user-roles", "topic"] } - self._update_relation_data(event.relation.id, relation_data) + self.update_relation_data(event.relation.id, relation_data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the Kafka relation has changed.""" @@ -1259,7 +1918,15 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check if the topic is created # (the Kafka charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("topic created at %s", datetime.now()) getattr(self.on, "topic_created").emit(event.relation, app=event.app, unit=event.unit) @@ -1339,10 +2006,9 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return - # Check which data has changed to emit customs events. diff = self._diff(event) @@ -1362,7 +2028,7 @@ def set_index(self, relation_id: int, index: str) -> None: requested index, and can be used to present a different index name if, for example, the requested index is invalid. """ - self._update_relation_data(relation_id, {"index": index}) + self.update_relation_data(relation_id, {"index": index}) def set_endpoints(self, relation_id: int, endpoints: str) -> None: """Set the endpoints in the application relation databag. @@ -1371,7 +2037,7 @@ def set_endpoints(self, relation_id: int, endpoints: str) -> None: relation_id: the identifier for a particular relation. endpoints: the endpoint addresses for opensearch nodes. """ - self._update_relation_data(relation_id, {"endpoints": endpoints}) + self.update_relation_data(relation_id, {"endpoints": endpoints}) def set_version(self, relation_id: int, version: str) -> None: """Set the opensearch version in the application relation databag. @@ -1380,7 +2046,7 @@ def set_version(self, relation_id: int, version: str) -> None: relation_id: the identifier for a particular relation. version: database version. """ - self._update_relation_data(relation_id, {"version": version}) + self.update_relation_data(relation_id, {"version": version}) class OpenSearchRequires(DataRequires): @@ -1389,22 +2055,54 @@ class OpenSearchRequires(DataRequires): on = OpenSearchRequiresEvents() # pyright: ignore[reportGeneralTypeIssues] def __init__( - self, charm, relation_name: str, index: str, extra_user_roles: Optional[str] = None + self, + charm, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of OpenSearch client relations.""" - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.charm = charm self.index = index def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the OpenSearch relation is created.""" + super()._on_relation_created_event(event) + # Sets both index and extra user roles in the relation if the roles are provided. # Otherwise, sets only the index. data = {"index": self.index} if self.extra_user_roles: data["extra-user-roles"] = self.extra_user_roles - self._update_relation_data(event.relation.id, data) + self.update_relation_data(event.relation.id, data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + logger.info("authentication updated") + getattr(self.on, "authentication_updated").emit( + relation, app=relation.app, unit=remote_unit + ) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the OpenSearch relation has changed. @@ -1414,8 +2112,13 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check which data has changed to emit customs events. diff = self._diff(event) - # Check if authentication has updated, emit event if so - updates = {"username", "password", "tls", "tls-ca"} + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + secret_field_tls = self._generate_secret_field_name(SecretGroup.TLS) + updates = {"username", "password", "tls", "tls-ca", secret_field_user, secret_field_tls} if len(set(diff._asdict().keys()) - updates) < len(diff): logger.info("authentication updated at: %s", datetime.now()) getattr(self.on, "authentication_updated").emit( @@ -1424,7 +2127,9 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check if the index is created # (the OpenSearch charm shares the credentials). - if "username" in diff.added and "password" in diff.added: + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("index created at: %s", datetime.now()) getattr(self.on, "index_created").emit(event.relation, app=event.app, unit=event.unit)