diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 714eace46..45d57fefd 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -320,7 +320,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 24 +LIBPATCH = 25 PYDEPS = ["ops>=2.0.0"] @@ -347,16 +347,6 @@ class SecretGroup(Enum): EXTRA = "extra" -# Local map to associate mappings with secrets potentially as a group -SECRET_LABEL_MAP = { - "username": SecretGroup.USER, - "password": SecretGroup.USER, - "uris": SecretGroup.USER, - "tls": SecretGroup.TLS, - "tls-ca": SecretGroup.TLS, -} - - class DataInterfacesError(Exception): """Common ancestor for DataInterfaces related exceptions.""" @@ -453,7 +443,7 @@ def leader_only(f): """Decorator to ensure that only leader can perform given operation.""" def wrapper(self, *args, **kwargs): - if not self.local_unit.is_leader(): + if self.component == self.local_app and not self.local_unit.is_leader(): logger.error( "This operation (%s()) can only be performed by the leader unit", f.__name__ ) @@ -502,7 +492,9 @@ def add_secret(self, content: Dict[str, str], relation: Relation) -> Secret: ) secret = self.charm.app.add_secret(content, label=self.label) - secret.grant(relation) + if relation.app != self.charm.app: + # If it's not a peer relation, grant is to be applied + secret.grant(relation) self._secret_uri = secret.id self._secret_meta = secret return self._secret_meta @@ -587,6 +579,15 @@ def add(self, label: str, content: Dict[str, str], relation: Relation) -> Cached class DataRelation(Object, ABC): """Base relation data mainpulation (abstract) class.""" + # Local map to associate mappings with secrets potentially as a group + SECRET_LABEL_MAP = { + "username": SecretGroup.USER, + "password": SecretGroup.USER, + "uris": SecretGroup.USER, + "tls": SecretGroup.TLS, + "tls-ca": SecretGroup.TLS, + } + def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) self.charm = charm @@ -599,6 +600,7 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: ) self._jujuversion = None self.secrets = SecretCache(self.charm) + self.component = self.local_app @property def relations(self) -> List[Relation]: @@ -677,8 +679,7 @@ def _generate_secret_label( """Generate unique group_mappings for secrets within a relation context.""" return f"{relation_name}.{relation_id}.{group_mapping.value}.secret" - @staticmethod - def _generate_secret_field_name(group_mapping: SecretGroup) -> str: + def _generate_secret_field_name(self, group_mapping: SecretGroup) -> str: """Generate unique group_mappings for secrets within a relation context.""" return f"{PROV_SECRET_PREFIX}{group_mapping.value}" @@ -705,8 +706,8 @@ def _relation_from_secret_label(self, secret_label: str) -> Optional[Relation]: except ModelError: return - @staticmethod - def _group_secret_fields(secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + @classmethod + def _group_secret_fields(cls, secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: """Helper function to arrange secret mappings under their group. NOTE: All unrecognized items end up in the 'extra' secret bucket. @@ -714,7 +715,7 @@ def _group_secret_fields(secret_fields: List[str]) -> Dict[SecretGroup, List[str """ secret_fieldnames_grouped = {} for key in secret_fields: - if group := SECRET_LABEL_MAP.get(key): + if group := cls.SECRET_LABEL_MAP.get(key): secret_fieldnames_grouped.setdefault(group, []).append(key) else: secret_fieldnames_grouped.setdefault(SecretGroup.EXTRA, []).append(key) @@ -736,22 +737,22 @@ def _get_group_secret_contents( return {k: v for k, v in secret_data.items() if k in secret_fields} return {} - @staticmethod + @classmethod def _content_for_secret_group( - content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + cls, content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup ) -> Dict[str, str]: """Select : pairs from input, that belong to this particular Secret group.""" if group_mapping == SecretGroup.EXTRA: return { k: v for k, v in content.items() - if k in secret_fields and k not in SECRET_LABEL_MAP.keys() + if k in secret_fields and k not in cls.SECRET_LABEL_MAP.keys() } return { k: v for k, v in content.items() - if k in secret_fields and SECRET_LABEL_MAP.get(k) == group_mapping + if k in secret_fields and cls.SECRET_LABEL_MAP.get(k) == group_mapping } @juju_secrets_only @@ -784,7 +785,7 @@ def _process_secret_fields( fallback_to_databag = ( req_secret_fields and self.local_unit.is_leader() - and set(req_secret_fields) & set(relation.data[self.local_app]) + and set(req_secret_fields) & set(relation.data[self.component]) ) normal_fields = set(impacted_rel_fields) @@ -807,7 +808,7 @@ def _process_secret_fields( return (result, normal_fields) def _fetch_relation_data_without_secrets( - self, app: Application, relation: Relation, fields: Optional[List[str]] + self, app: Union[Application, Unit], relation: Relation, fields: Optional[List[str]] ) -> Dict[str, str]: """Fetching databag contents when no secrets are involved. @@ -826,7 +827,7 @@ def _fetch_relation_data_without_secrets( def _fetch_relation_data_with_secrets( self, - app: Application, + app: Union[Application, Unit], req_secret_fields: Optional[List[str]], relation: Relation, fields: Optional[List[str]] = None, @@ -867,33 +868,30 @@ def _fetch_relation_data_with_secrets( return result def _update_relation_data_without_secrets( - self, app: Application, relation: Relation, data: Dict[str, str] + self, app: Union[Application, Unit], relation: Relation, data: Dict[str, str] ) -> None: """Updating databag contents when no secrets are involved.""" if app not in relation.data or relation.data[app] is None: return - if any(self._is_secret_field(key) for key in data.keys()): - raise SecretsIllegalUpdateError("Can't update secret {key}.") - if relation: relation.data[app].update(data) def _delete_relation_data_without_secrets( - self, app: Application, relation: Relation, fields: List[str] + self, app: Union[Application, Unit], relation: Relation, fields: List[str] ) -> None: """Remove databag fields 'fields' from Relation.""" - if app not in relation.data or not relation.data[app]: + if app not in relation.data or relation.data[app] is None: return for field in fields: try: relation.data[app].pop(field) except KeyError: - logger.debug( - "Non-existing field was attempted to be removed from the databag %s, %s", - str(relation.id), + logger.error( + "Non-existing field '%s' was attempted to be removed from the databag (relation ID: %s)", str(field), + str(relation.id), ) pass @@ -954,7 +952,6 @@ def fetch_relation_field( .get(field) ) - @leader_only def fetch_my_relation_data( self, relation_ids: Optional[List[int]] = None, @@ -983,7 +980,6 @@ def fetch_my_relation_data( data[relation.id] = self._fetch_my_specific_relation_data(relation, fields) return data - @leader_only def fetch_my_relation_field( self, relation_id: int, field: str, relation_name: Optional[str] = None ) -> Optional[str]: @@ -1035,27 +1031,38 @@ def _diff(self, event: RelationChangedEvent) -> Diff: @juju_secrets_only def _add_relation_secret( - self, relation: Relation, content: Dict[str, str], group_mapping: SecretGroup + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, ) -> bool: """Add a new Juju Secret that will be registered in the relation databag.""" secret_field = self._generate_secret_field_name(group_mapping) - if relation.data[self.local_app].get(secret_field): + if uri_to_databag and relation.data[self.component].get(secret_field): logging.error("Secret for relation %s already exists, not adding again", relation.id) return False + content = self._content_for_secret_group(data, secret_fields, group_mapping) + label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) secret = self.secrets.add(label, content, relation) # According to lint we may not have a Secret ID - if secret.meta and secret.meta.id: - relation.data[self.local_app][secret_field] = secret.meta.id + if uri_to_databag and secret.meta and secret.meta.id: + relation.data[self.component][secret_field] = secret.meta.id # Return the content that was added return True @juju_secrets_only def _update_relation_secret( - self, relation: Relation, content: Dict[str, str], group_mapping: SecretGroup + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], ) -> bool: """Update the contents of an existing Juju Secret, referred in the relation databag.""" secret = self._get_relation_secret(relation.id, group_mapping) @@ -1064,6 +1071,8 @@ def _update_relation_secret( logging.error("Can't update secret for relation %s", relation.id) return False + content = self._content_for_secret_group(data, secret_fields, group_mapping) + old_content = secret.get_content() full_content = copy.deepcopy(old_content) full_content.update(content) @@ -1078,13 +1087,13 @@ def _add_or_update_relation_secrets( group: SecretGroup, secret_fields: Set[str], data: Dict[str, str], + uri_to_databag=True, ) -> bool: """Update contents for Secret group. If the Secret doesn't exist, create it.""" - secret_content = self._content_for_secret_group(data, secret_fields, group) if self._get_relation_secret(relation.id, group): - return self._update_relation_secret(relation, secret_content, group) + return self._update_relation_secret(relation, group, secret_fields, data) else: - return self._add_relation_secret(relation, secret_content, group) + return self._add_relation_secret(relation, group, secret_fields, data, uri_to_databag) @juju_secrets_only def _delete_relation_secret( @@ -1116,7 +1125,7 @@ def _delete_relation_secret( if not new_content: field = self._generate_secret_field_name(group) try: - relation.data[self.local_app].pop(field) + relation.data[self.component].pop(field) except KeyError: pass @@ -1233,6 +1242,11 @@ def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: """ self.update_relation_data(relation_id, {"tls-ca": tls_ca}) + # Public functions -- inherited + + fetch_my_relation_data = leader_only(DataRelation.fetch_my_relation_data) + fetch_my_relation_field = leader_only(DataRelation.fetch_my_relation_field) + class DataRequires(DataRelation): """Requires-side of the relation.""" @@ -1426,6 +1440,218 @@ def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """ return self._delete_relation_data_without_secrets(self.local_app, relation, fields) + # Public functions -- inherited + + fetch_my_relation_data = leader_only(DataRelation.fetch_my_relation_data) + fetch_my_relation_field = leader_only(DataRelation.fetch_my_relation_field) + + +# Base DataPeer + + +class DataPeer(DataRequires, DataProvides): + """Represents peer relations.""" + + SECRET_FIELDS = ["operator-password"] + SECRET_FIELD_NAME = "internal_secret" + SECRET_LABEL_MAP = {} + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + ): + """Manager of base client relations.""" + DataRequires.__init__( + self, charm, relation_name, extra_user_roles, additional_secret_fields + ) + self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME + self.deleted_label = deleted_label + + @property + def scope(self) -> Optional[Scope]: + """Turn component information into Scope.""" + if isinstance(self.component, Application): + return Scope.APP + if isinstance(self.component, Unit): + return Scope.UNIT + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + def _generate_secret_label( + self, relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + members = [self.charm.app.name] + if self.scope: + members.append(self.scope.value) + return f"{'.'.join(members)}" + + def _generate_secret_field_name(self, group_mapping: SecretGroup = SecretGroup.EXTRA) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{self.secret_field_name}" + + @juju_secrets_only + def _get_relation_secret( + self, + relation_id: int, + group_mapping: SecretGroup = SecretGroup.EXTRA, + relation_name: Optional[str] = None, + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret specifically for peer relations. + + In case this code may be executed within a rolling upgrade, and we may need to + migrate secrets from the databag to labels, we make sure to stick the correct + label on the secret, and clean up the local databag. + """ + if not relation_name: + relation_name = self.relation_name + + relation = self.charm.model.get_relation(relation_name, relation_id) + if not relation: + return + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) + + # Fetching the secret with fallback to URI (in case label is not yet known) + # Label would we "stuck" on the secret in case it is found + secret = self.secrets.get(label, secret_uri) + + # Either app scope secret with leader executing, or unit scope secret + leader_or_unit_scope = self.component != self.local_app or self.local_unit.is_leader() + if secret_uri and secret and leader_or_unit_scope: + # Databag reference to the secret URI can be removed, now that it's labelled + relation.data[self.component].pop(self._generate_secret_field_name(), None) + return secret + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Optional[Union[Set[str], List[str]]] = None, + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + result = super()._get_group_secret_contents(relation, group, secret_fields) + if not self.deleted_label: + return result + return {key: result[key] for key in result if result[key] != self.deleted_label} + + def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + """For Rolling Upgrades -- when moving from databag to secrets usage. + + Practically what happens here is to remove stuff from the databag that is + to be stored in secrets. + """ + if not self.secret_fields: + return + + secret_fields_passed = set(self.secret_fields) & set(fields) + for field in secret_fields_passed: + if self._fetch_relation_data_without_secrets(self.component, relation, [field]): + self._delete_relation_data_without_secrets(self.component, relation, [field]) + + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" + return self._fetch_relation_data_with_secrets( + self.component, self.secret_fields, relation, fields + ) + + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + return self._fetch_relation_data_with_secrets( + self.component, self.secret_fields, relation, fields + ) + + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + self._remove_secret_from_databag(relation, list(data.keys())) + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + uri_to_databag=False, + ) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.component, relation, normal_content) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + if self.secret_fields and self.deleted_label: + current_data = self.fetch_my_relation_data([relation.id], fields) + if current_data is not None: + # Check if the secret we wanna delete actually exists + # Given the "deleted label", here we can't rely on the default mechanism (i.e. 'key not found') + if non_existent := (set(fields) & set(self.secret_fields)) - set( + current_data.get(relation.id, []) + ): + logger.error( + "Non-existing secret %s was attempted to be removed.", non_existent + ) + + _, normal_fields = self._process_secret_fields( + relation, + self.secret_fields, + fields, + self._update_relation_secret, + data={field: self.deleted_label for field in fields}, + ) + else: + _, normal_fields = self._process_secret_fields( + relation, self.secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.component, relation, list(normal_fields)) + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """This method makes no sense for a Peer Relation.""" + raise NotImplementedError( + "Peer Relation only supports 'self-side' fetch methods: " + "fetch_my_relation_data() and fetch_my_relation_field()" + ) + + # Public functions -- inherited + + fetch_my_relation_data = DataRelation.fetch_my_relation_data + fetch_my_relation_field = DataRelation.fetch_my_relation_field + + +class DataPeerUnit(DataPeer): + """Unit databag representation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.component = self.local_unit + # General events diff --git a/lib/charms/data_platform_libs/v0/s3.py b/lib/charms/data_platform_libs/v0/s3.py index 9fb518a56..7beb113b6 100644 --- a/lib/charms/data_platform_libs/v0/s3.py +++ b/lib/charms/data_platform_libs/v0/s3.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A library for communicating with the S3 credentials providers and consumers. +r"""A library for communicating with the S3 credentials providers and consumers. This library provides the relevant interface code implementing the communication specification for fetching, retrieving, triggering, and responding to events related to @@ -113,7 +113,7 @@ def _on_credential_gone(self, event: CredentialsGoneEvent): import json import logging from collections import namedtuple -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import ops.charm import ops.framework @@ -121,15 +121,13 @@ def _on_credential_gone(self, event: CredentialsGoneEvent): from ops.charm import ( CharmBase, CharmEvents, - EventSource, - Object, - ObjectEvents, RelationBrokenEvent, RelationChangedEvent, RelationEvent, RelationJoinedEvent, ) -from ops.model import Relation +from ops.framework import EventSource, Object, ObjectEvents +from ops.model import Application, Relation, RelationDataContent, Unit # The unique Charmhub library identifier, never change it LIBID = "fca396f6254246c9bfa565b1f85ab528" @@ -139,7 +137,7 @@ def _on_credential_gone(self, event: CredentialsGoneEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 4 logger = logging.getLogger(__name__) @@ -152,7 +150,7 @@ def _on_credential_gone(self, event: CredentialsGoneEvent): deleted - key that were deleted""" -def diff(event: RelationChangedEvent, bucket: str) -> Diff: +def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff: """Retrieves the diff of the data in the relation changed databag. Args: @@ -166,9 +164,11 @@ def diff(event: RelationChangedEvent, bucket: str) -> Diff: # Retrieve the old data from the data key in the application relation databag. old_data = json.loads(event.relation.data[bucket].get("data", "{}")) # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } + new_data = ( + {key: value for key, value in event.relation.data[event.app].items() if key != "data"} + if event.app + else {} + ) # These are the keys that were added to the databag and triggered this event. added = new_data.keys() - old_data.keys() @@ -193,7 +193,10 @@ class BucketEvent(RelationEvent): @property def bucket(self) -> Optional[str]: """Returns the bucket was requested.""" - return self.relation.data[self.relation.app].get("bucket") + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("bucket", "") class CredentialRequestedEvent(BucketEvent): @@ -209,7 +212,7 @@ class S3CredentialEvents(CharmEvents): class S3Provider(Object): """A provider handler for communicating S3 credentials to consumers.""" - on = S3CredentialEvents() + on = S3CredentialEvents() # pyright: ignore [reportGeneralTypeIssues] def __init__( self, @@ -232,7 +235,9 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: diff = self._diff(event) # emit on credential requested if bucket is provided by the requirer application if "bucket" in diff.added: - self.on.credentials_requested.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "credentials_requested").emit( + event.relation, app=event.app, unit=event.unit + ) def _load_relation_data(self, raw_relation_data: dict) -> dict: """Loads relation data from the relation data bag. @@ -242,7 +247,7 @@ def _load_relation_data(self, raw_relation_data: dict) -> dict: Returns: dict: Relation data in dict format. """ - connection_data = dict() + connection_data = {} for key in raw_relation_data: try: connection_data[key] = json.loads(raw_relation_data[key]) @@ -309,9 +314,11 @@ def fetch_relation_data(self) -> dict: """ data = {} for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } + data[relation.id] = ( + {key: value for key, value in relation.data[relation.app].items() if key != "data"} + if relation.app + else {} + ) return data def update_connection_info(self, relation_id: int, connection_data: dict) -> None: @@ -493,46 +500,73 @@ class S3Event(RelationEvent): @property def bucket(self) -> Optional[str]: """Returns the bucket name.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("bucket") @property def access_key(self) -> Optional[str]: """Returns the access key.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("access-key") @property def secret_key(self) -> Optional[str]: """Returns the secret key.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("secret-key") @property def path(self) -> Optional[str]: """Returns the path where data can be stored.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("path") @property def endpoint(self) -> Optional[str]: """Returns the endpoint address.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoint") @property def region(self) -> Optional[str]: """Returns the region.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("region") @property def s3_uri_style(self) -> Optional[str]: """Returns the s3 uri style.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("s3-uri-style") @property def storage_class(self) -> Optional[str]: """Returns the storage class name.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("storage-class") @property def tls_ca_chain(self) -> Optional[List[str]]: """Returns the TLS CA chain.""" + if not self.relation.app: + return None + tls_ca_chain = self.relation.data[self.relation.app].get("tls-ca-chain") if tls_ca_chain is not None: return json.loads(tls_ca_chain) @@ -541,11 +575,17 @@ def tls_ca_chain(self) -> Optional[List[str]]: @property def s3_api_version(self) -> Optional[str]: """Returns the S3 API version.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("s3-api-version") @property def attributes(self) -> Optional[List[str]]: """Returns the attributes.""" + if not self.relation.app: + return None + attributes = self.relation.data[self.relation.app].get("attributes") if attributes is not None: return json.loads(attributes) @@ -573,9 +613,11 @@ class S3CredentialRequiresEvents(ObjectEvents): class S3Requirer(Object): """Requires-side of the s3 relation.""" - on = S3CredentialRequiresEvents() + on = S3CredentialRequiresEvents() # pyright: ignore[reportGeneralTypeIssues] - def __init__(self, charm: ops.charm.CharmBase, relation_name: str, bucket_name: str = None): + def __init__( + self, charm: ops.charm.CharmBase, relation_name: str, bucket_name: Optional[str] = None + ): """Manager of the s3 client relations.""" super().__init__(charm, relation_name) @@ -658,7 +700,7 @@ def update_connection_info(self, relation_id: int, connection_data: dict) -> Non relation.data[self.local_app].update(updated_connection_data) logger.debug(f"Updated S3 credentials: {updated_connection_data}") - def _load_relation_data(self, raw_relation_data: dict) -> dict: + def _load_relation_data(self, raw_relation_data: RelationDataContent) -> Dict[str, str]: """Loads relation data from the relation data bag. Args: @@ -666,7 +708,7 @@ def _load_relation_data(self, raw_relation_data: dict) -> dict: Returns: dict: Relation data in dict format. """ - connection_data = dict() + connection_data = {} for key in raw_relation_data: try: connection_data[key] = json.loads(raw_relation_data[key]) @@ -700,22 +742,25 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: missing_options.append(configuration_option) # emit credential change event only if all mandatory fields are present if contains_required_options: - self.on.credentials_changed.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "credentials_changed").emit( + event.relation, app=event.app, unit=event.unit + ) else: logger.warning( f"Some mandatory fields: {missing_options} are not present, do not emit credential change event!" ) - def get_s3_connection_info(self) -> Dict: + def get_s3_connection_info(self) -> Dict[str, str]: """Return the s3 credentials as a dictionary.""" - relation = self.charm.model.get_relation(self.relation_name) - if not relation: - return {} - return self._load_relation_data(relation.data[relation.app]) + for relation in self.relations: + if relation and relation.app: + return self._load_relation_data(relation.data[relation.app]) + + return {} def _on_relation_broken(self, event: RelationBrokenEvent) -> None: """Notify the charm about a broken S3 credential store relation.""" - self.on.credentials_gone.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "credentials_gone").emit(event.relation, app=event.app, unit=event.unit) @property def relations(self) -> List[Relation]: diff --git a/lib/charms/grafana_agent/v0/cos_agent.py b/lib/charms/grafana_agent/v0/cos_agent.py index 0acaed361..259a90170 100644 --- a/lib/charms/grafana_agent/v0/cos_agent.py +++ b/lib/charms/grafana_agent/v0/cos_agent.py @@ -22,7 +22,7 @@ Using the `COSAgentProvider` object only requires instantiating it, typically in the `__init__` method of your charm (the one which sends telemetry). -The constructor of `COSAgentProvider` has only one required and eight optional parameters: +The constructor of `COSAgentProvider` has only one required and nine optional parameters: ```python def __init__( @@ -36,6 +36,7 @@ def __init__( log_slots: Optional[List[str]] = None, dashboard_dirs: Optional[List[str]] = None, refresh_events: Optional[List] = None, + scrape_configs: Optional[Union[List[Dict], Callable]] = None, ): ``` @@ -47,7 +48,8 @@ def __init__( the `cos_agent` interface, this is where you have to specify that. - `metrics_endpoints`: In this parameter you can specify the metrics endpoints that Grafana Agent - machine Charmed Operator will scrape. + machine Charmed Operator will scrape. The configs of this list will be merged with the configs + from `scrape_configs`. - `metrics_rules_dir`: The directory in which the Charmed Operator stores its metrics alert rules files. @@ -63,6 +65,10 @@ def __init__( - `refresh_events`: List of events on which to refresh relation data. +- `scrape_configs`: List of standard scrape_configs dicts or a callable that returns the list in + case the configs need to be generated dynamically. The contents of this list will be merged + with the configs from `metrics_endpoints`. + ### Example 1 - Minimal instrumentation: @@ -91,6 +97,7 @@ def __init__(self, *args): self, relation_name="custom-cos-agent", metrics_endpoints=[ + # specify "path" and "port" to scrape from localhost {"path": "/metrics", "port": 9000}, {"path": "/metrics", "port": 9001}, {"path": "/metrics", "port": 9002}, @@ -101,6 +108,46 @@ def __init__(self, *args): log_slots=["my-app:slot"], dashboard_dirs=["./src/dashboards_1", "./src/dashboards_2"], refresh_events=["update-status", "upgrade-charm"], + scrape_configs=[ + { + "job_name": "custom_job", + "metrics_path": "/metrics", + "authorization": {"credentials": "bearer-token"}, + "static_configs": [ + { + "targets": ["localhost:9003"]}, + "labels": {"key": "value"}, + }, + ], + }, + ] + ) +``` + +### Example 3 - Dynamic scrape configs generation: + +Pass a function to the `scrape_configs` to decouple the generation of the configs +from the instantiation of the COSAgentProvider object. + +```python +from charms.grafana_agent.v0.cos_agent import COSAgentProvider +... + +class TelemetryProviderCharm(CharmBase): + def generate_scrape_configs(self): + return [ + { + "job_name": "custom", + "metrics_path": "/metrics", + "static_configs": [{"targets": ["localhost:9000"]}], + }, + ] + + def __init__(self, *args): + ... + self._grafana_agent = COSAgentProvider( + self, + scrape_configs=self.generate_scrape_configs, ) ``` @@ -159,19 +206,17 @@ def __init__(self, *args): ``` """ -import base64 import json import logging -import lzma from collections import namedtuple from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Union import pydantic -from cosl import JujuTopology +from cosl import GrafanaDashboard, JujuTopology from cosl.rules import AlertRules -from ops.charm import RelationChangedEvent, RelationEvent +from ops.charm import RelationChangedEvent from ops.framework import EventBase, EventSource, Object, ObjectEvents from ops.model import Relation, Unit from ops.testing import CharmType @@ -185,46 +230,25 @@ class _MetricsEndpointDict(TypedDict): port: int except ModuleNotFoundError: - _MetricsEndpointDict = dict + _MetricsEndpointDict = Dict # pyright: ignore LIBID = "dc15fa84cef84ce58155fb84f6c6213a" LIBAPI = 0 -LIBPATCH = 3 +LIBPATCH = 7 -PYDEPS = ["cosl", "pydantic"] +PYDEPS = ["cosl", "pydantic < 2"] DEFAULT_RELATION_NAME = "cos-agent" DEFAULT_PEER_RELATION_NAME = "peers" -DEFAULT_METRICS_ENDPOINT = { - "path": "/metrics", - "port": 80, +DEFAULT_SCRAPE_CONFIG = { + "static_configs": [{"targets": ["localhost:80"]}], + "metrics_path": "/metrics", } logger = logging.getLogger(__name__) SnapEndpoint = namedtuple("SnapEndpoint", "owner, name") -class GrafanaDashboard(str): - """Grafana Dashboard encoded json; lzma-compressed.""" - - # TODO Replace this with a custom type when pydantic v2 released (end of 2023 Q1?) - # https://github.com/pydantic/pydantic/issues/4887 - @staticmethod - def _serialize(raw_json: Union[str, bytes]) -> "GrafanaDashboard": - if not isinstance(raw_json, bytes): - raw_json = raw_json.encode("utf-8") - encoded = base64.b64encode(lzma.compress(raw_json)).decode("utf-8") - return GrafanaDashboard(encoded) - - def _deserialize(self) -> Dict: - raw = lzma.decompress(base64.b64decode(self.encode("utf-8"))).decode() - return json.loads(raw) - - def __repr__(self): - """Return string representation of self.""" - return "" - - class CosAgentProviderUnitData(pydantic.BaseModel): """Unit databag model for `cos-agent` relation.""" @@ -234,6 +258,7 @@ class CosAgentProviderUnitData(pydantic.BaseModel): metrics_alert_rules: dict log_alert_rules: dict dashboards: List[GrafanaDashboard] + subordinate: Optional[bool] # The following entries may vary across units of the same principal app. # this data does not need to be forwarded to the gagent leader @@ -247,7 +272,7 @@ class CosAgentProviderUnitData(pydantic.BaseModel): class CosAgentPeersUnitData(pydantic.BaseModel): - """Unit databag model for `cluster` cos-agent machine charm peer relation.""" + """Unit databag model for `peers` cos-agent machine charm peer relation.""" # We need the principal unit name and relation metadata to be able to render identifiers # (e.g. topology) on the leader side, after all the data moves into peer data (the grafana @@ -291,6 +316,8 @@ def __init__( log_slots: Optional[List[str]] = None, dashboard_dirs: Optional[List[str]] = None, refresh_events: Optional[List] = None, + *, + scrape_configs: Optional[Union[List[dict], Callable]] = None, ): """Create a COSAgentProvider instance. @@ -298,6 +325,8 @@ def __init__( charm: The `CharmBase` instance that is instantiating this object. relation_name: The name of the relation to communicate over. metrics_endpoints: List of endpoints in the form [{"path": path, "port": port}, ...]. + This argument is a simplified form of the `scrape_configs`. + The contents of this list will be merged with the contents of `scrape_configs`. metrics_rules_dir: Directory where the metrics rules are stored. logs_rules_dir: Directory where the logs rules are stored. recurse_rules_dirs: Whether to recurse into rule paths. @@ -305,14 +334,17 @@ def __init__( in the form ["snap-name:slot", ...]. dashboard_dirs: Directory where the dashboards are stored. refresh_events: List of events on which to refresh relation data. + scrape_configs: List of standard scrape_configs dicts or a callable + that returns the list in case the configs need to be generated dynamically. + The contents of this list will be merged with the contents of `metrics_endpoints`. """ super().__init__(charm, relation_name) - metrics_endpoints = metrics_endpoints or [DEFAULT_METRICS_ENDPOINT] dashboard_dirs = dashboard_dirs or ["./src/grafana_dashboards"] self._charm = charm self._relation_name = relation_name - self._metrics_endpoints = metrics_endpoints + self._metrics_endpoints = metrics_endpoints or [] + self._scrape_configs = scrape_configs or [] self._metrics_rules = metrics_rules_dir self._logs_rules = logs_rules_dir self._recursive = recurse_rules_dirs @@ -328,10 +360,7 @@ def __init__( def _on_refresh(self, event): """Trigger the class to update relation data.""" - if isinstance(event, RelationEvent): - relations = [event.relation] - else: - relations = self._charm.model.relations[self._relation_name] + relations = self._charm.model.relations[self._relation_name] for relation in relations: # Before a principal is related to the grafana-agent subordinate, we'd get @@ -339,23 +368,52 @@ def _on_refresh(self, event): # Add a guard to make sure it doesn't happen. if relation.data and self._charm.unit in relation.data: # Subordinate relations can communicate only over unit data. - data = CosAgentProviderUnitData( - metrics_alert_rules=self._metrics_alert_rules, - log_alert_rules=self._log_alert_rules, - dashboards=self._dashboards, - metrics_scrape_jobs=self._scrape_jobs, - log_slots=self._log_slots, - ) - relation.data[self._charm.unit][data.KEY] = data.json() + try: + data = CosAgentProviderUnitData( + metrics_alert_rules=self._metrics_alert_rules, + log_alert_rules=self._log_alert_rules, + dashboards=self._dashboards, + metrics_scrape_jobs=self._scrape_jobs, + log_slots=self._log_slots, + subordinate=self._charm.meta.subordinate, + ) + relation.data[self._charm.unit][data.KEY] = data.json() + except ( + pydantic.ValidationError, + json.decoder.JSONDecodeError, + ) as e: + logger.error("Invalid relation data provided: %s", e) @property def _scrape_jobs(self) -> List[Dict]: - """Return a prometheus_scrape-like data structure for jobs.""" - job_name_prefix = self._charm.app.name - return [ - {"job_name": f"{job_name_prefix}_{key}", **endpoint} - for key, endpoint in enumerate(self._metrics_endpoints) - ] + """Return a prometheus_scrape-like data structure for jobs. + + https://prometheus.io/docs/prometheus/latest/configuration/configuration/#scrape_config + """ + if callable(self._scrape_configs): + scrape_configs = self._scrape_configs() + else: + # Create a copy of the user scrape_configs, since we will mutate this object + scrape_configs = self._scrape_configs.copy() + + # Convert "metrics_endpoints" to standard scrape_configs, and add them in + for endpoint in self._metrics_endpoints: + scrape_configs.append( + { + "metrics_path": endpoint["path"], + "static_configs": [{"targets": [f"localhost:{endpoint['port']}"]}], + } + ) + + scrape_configs = scrape_configs or [DEFAULT_SCRAPE_CONFIG] + + # Augment job name to include the app name and a unique id (index) + for idx, scrape_config in enumerate(scrape_configs): + scrape_config["job_name"] = "_".join( + [self._charm.app.name, str(idx), scrape_config.get("job_name", "default")] + ) + + return scrape_configs @property def _metrics_alert_rules(self) -> Dict: @@ -387,16 +445,39 @@ class COSAgentDataChanged(EventBase): """Event emitted by `COSAgentRequirer` when relation data changes.""" +class COSAgentValidationError(EventBase): + """Event emitted by `COSAgentRequirer` when there is an error in the relation data.""" + + def __init__(self, handle, message: str = ""): + super().__init__(handle) + self.message = message + + def snapshot(self) -> Dict: + """Save COSAgentValidationError source information.""" + return {"message": self.message} + + def restore(self, snapshot): + """Restore COSAgentValidationError source information.""" + self.message = snapshot["message"] + + class COSAgentRequirerEvents(ObjectEvents): """`COSAgentRequirer` events.""" data_changed = EventSource(COSAgentDataChanged) + validation_error = EventSource(COSAgentValidationError) + + +class MultiplePrincipalsError(Exception): + """Custom exception for when there are multiple principal applications.""" + + pass class COSAgentRequirer(Object): """Integration endpoint wrapper for the Requirer side of the cos_agent interface.""" - on = COSAgentRequirerEvents() + on = COSAgentRequirerEvents() # pyright: ignore def __init__( self, @@ -426,7 +507,7 @@ def __init__( ) # TODO: do we need this? self.framework.observe(events.relation_changed, self._on_relation_data_changed) for event in self._refresh_events: - self.framework.observe(event, self.trigger_refresh) + self.framework.observe(event, self.trigger_refresh) # pyright: ignore # Peer relation events # A peer relation is needed as it is the only mechanism for exchanging data across @@ -450,7 +531,7 @@ def _on_peer_relation_changed(self, _): # Peer data is used for forwarding data from principal units to the grafana agent # subordinate leader, for updating the app data of the outgoing o11y relations. if self._charm.unit.is_leader(): - self.on.data_changed.emit() + self.on.data_changed.emit() # pyright: ignore def _on_relation_data_changed(self, event: RelationChangedEvent): # Peer data is the only means of communication between subordinate units. @@ -474,7 +555,9 @@ def _on_relation_data_changed(self, event: RelationChangedEvent): if not (raw := cos_agent_relation.data[principal_unit].get(CosAgentProviderUnitData.KEY)): return - provider_data = CosAgentProviderUnitData(**json.loads(raw)) + + if not (provider_data := self._validated_provider_data(raw)): + return # Copy data from the principal relation to the peer relation, so the leader could # follow up. @@ -487,17 +570,26 @@ def _on_relation_data_changed(self, event: RelationChangedEvent): log_alert_rules=provider_data.log_alert_rules, dashboards=provider_data.dashboards, ) - self.peer_relation.data[self._charm.unit][data.KEY] = data.json() + self.peer_relation.data[self._charm.unit][ + f"{CosAgentPeersUnitData.KEY}-{event.unit.name}" + ] = data.json() # We can't easily tell if the data that was changed is limited to only the data # that goes into peer relation (in which case, if this is not a leader unit, we wouldn't # need to emit `on.data_changed`), so we're emitting `on.data_changed` either way. - self.on.data_changed.emit() + self.on.data_changed.emit() # pyright: ignore + + def _validated_provider_data(self, raw) -> Optional[CosAgentProviderUnitData]: + try: + return CosAgentProviderUnitData(**json.loads(raw)) + except (pydantic.ValidationError, json.decoder.JSONDecodeError) as e: + self.on.validation_error.emit(message=str(e)) # pyright: ignore + return None def trigger_refresh(self, _): """Trigger a refresh of relation data.""" # FIXME: Figure out what we should do here - self.on.data_changed.emit() + self.on.data_changed.emit() # pyright: ignore @property def _principal_unit(self) -> Optional[Unit]: @@ -518,28 +610,40 @@ def _principal_unit(self) -> Optional[Unit]: @property def _principal_relations(self): - # Technically it's a list, but for subordinates there can only be one. - return self._charm.model.relations[self._relation_name] + relations = [] + for relation in self._charm.model.relations[self._relation_name]: + if not json.loads(relation.data[next(iter(relation.units))]["config"]).get( + ["subordinate"], False + ): + relations.append(relation) + if len(relations) > 1: + logger.error( + "Multiple applications claiming to be principal. Update the cos-agent library in the client application charms." + ) + raise MultiplePrincipalsError("Multiple principal applications.") + return relations @property - def _principal_unit_data(self) -> Optional[CosAgentProviderUnitData]: - """Return the principal unit's data. + def _remote_data(self) -> List[CosAgentProviderUnitData]: + """Return a list of remote data from each of the related units. Assumes that the relation is of type subordinate. Relies on the fact that, for subordinate relations, the only remote unit visible to *this unit* is the principal unit that this unit is attached to. """ - if relations := self._principal_relations: - # Technically it's a list, but for subordinates there can only be one relation - principal_relation = next(iter(relations)) - if units := principal_relation.units: - # Technically it's a list, but for subordinates there can only be one - unit = next(iter(units)) - raw = principal_relation.data[unit].get(CosAgentProviderUnitData.KEY) - if raw: - return CosAgentProviderUnitData(**json.loads(raw)) + all_data = [] - return None + for relation in self._charm.model.relations[self._relation_name]: + if not relation.units: + continue + unit = next(iter(relation.units)) + if not (raw := relation.data[unit].get(CosAgentProviderUnitData.KEY)): + continue + if not (provider_data := self._validated_provider_data(raw)): + continue + all_data.append(provider_data) + + return all_data def _gather_peer_data(self) -> List[CosAgentPeersUnitData]: """Collect data from the peers. @@ -557,18 +661,21 @@ def _gather_peer_data(self) -> List[CosAgentPeersUnitData]: app_names: Set[str] = set() for unit in chain((self._charm.unit,), relation.units): - if not relation.data.get(unit) or not ( - raw := relation.data[unit].get(CosAgentPeersUnitData.KEY) - ): - logger.info(f"peer {unit} has not set its primary data yet; skipping for now...") + if not relation.data.get(unit): continue - data = CosAgentPeersUnitData(**json.loads(raw)) - app_name = data.app_name - # Have we already seen this principal app? - if app_name in app_names: - continue - peer_data.append(data) + for unit_name in relation.data.get(unit): # pyright: ignore + if not unit_name.startswith(CosAgentPeersUnitData.KEY): + continue + raw = relation.data[unit].get(unit_name) + if raw is None: + continue + data = CosAgentPeersUnitData(**json.loads(raw)) + # Have we already seen this principal app? + if (app_name := data.app_name) in app_names: + continue + peer_data.append(data) + app_names.add(app_name) return peer_data @@ -578,7 +685,7 @@ def metrics_alerts(self) -> Dict[str, Any]: alert_rules = {} seen_apps: List[str] = [] - for data in self._gather_peer_data(): # type: CosAgentPeersUnitData + for data in self._gather_peer_data(): if rules := data.metrics_alert_rules: app_name = data.app_name if app_name in seen_apps: @@ -604,16 +711,23 @@ def metrics_alerts(self) -> Dict[str, Any]: def metrics_jobs(self) -> List[Dict]: """Parse the relation data contents and extract the metrics jobs.""" scrape_jobs = [] - if data := self._principal_unit_data: - jobs = data.metrics_scrape_jobs - if jobs: - for job in jobs: - job_config = { + for data in self._remote_data: + for job in data.metrics_scrape_jobs: + # In #220, relation schema changed from a simplified dict to the standard + # `scrape_configs`. + # This is to ensure backwards compatibility with Providers older than v0.5. + if "path" in job and "port" in job and "job_name" in job: + job = { "job_name": job["job_name"], "metrics_path": job["path"], "static_configs": [{"targets": [f"localhost:{job['port']}"]}], + # We include insecure_skip_verify because we are always scraping localhost. + # Even if we have the certs for the scrape targets, we'd rather specify the scrape + # jobs with localhost rather than the SAN DNS the cert was issued for. + "tls_config": {"insecure_skip_verify": True}, } - scrape_jobs.append(job_config) + + scrape_jobs.append(job) return scrape_jobs @@ -621,7 +735,7 @@ def metrics_jobs(self) -> List[Dict]: def snap_log_endpoints(self) -> List[SnapEndpoint]: """Fetch logging endpoints exposed by related snaps.""" plugs = [] - if data := self._principal_unit_data: + for data in self._remote_data: targets = data.log_slots if targets: for target in targets: @@ -649,7 +763,7 @@ def logs_alerts(self) -> Dict[str, Any]: alert_rules = {} seen_apps: List[str] = [] - for data in self._gather_peer_data(): # type: CosAgentPeersUnitData + for data in self._gather_peer_data(): if rules := data.log_alert_rules: # This is only used for naming the file, so be as specific as we can be app_name = data.app_name @@ -678,10 +792,10 @@ def dashboards(self) -> List[Dict[str, str]]: Dashboards are assumed not to vary across units of the same primary. """ - dashboards: List[Dict[str, str]] = [] + dashboards: List[Dict[str, Any]] = [] seen_apps: List[str] = [] - for data in self._gather_peer_data(): # type: CosAgentPeersUnitData + for data in self._gather_peer_data(): app_name = data.app_name if app_name in seen_apps: continue # dedup! diff --git a/lib/charms/mongodb/v0/config_server_interface.py b/lib/charms/mongodb/v0/config_server_interface.py index eeb1ee299..65242d9b3 100644 --- a/lib/charms/mongodb/v0/config_server_interface.py +++ b/lib/charms/mongodb/v0/config_server_interface.py @@ -35,7 +35,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 2 class ClusterProvider(Object): @@ -139,6 +139,7 @@ def __init__( self.database_requires = DatabaseRequires( self.charm, relation_name=self.relation_name, + relations_aliases=[self.relation_name], database_name=self.charm.database, extra_user_roles=self.charm.extra_user_roles, additional_secret_fields=[KEYFILE_KEY], @@ -149,11 +150,24 @@ def __init__( charm.on[self.relation_name].relation_created, self.database_requires._on_relation_created_event, ) + + self.framework.observe( + self.database_requires.on.database_created, self._on_database_created + ) self.framework.observe( charm.on[self.relation_name].relation_changed, self._on_relation_changed ) # TODO Future PRs handle scale down + def _on_database_created(self, event) -> None: + if not self.charm.unit.is_leader(): + return + + logger.info("Database and user created for mongos application") + self.charm.set_secret(Config.Relations.APP_SCOPE, Config.Secrets.USERNAME, event.username) + self.charm.set_secret(Config.Relations.APP_SCOPE, Config.Secrets.PASSWORD, event.password) + self.charm.share_connection_info() + def _on_relation_changed(self, event) -> None: """Starts/restarts monogs with config server information.""" key_file_contents = self.database_requires.fetch_relation_field( @@ -186,7 +200,6 @@ def _on_relation_changed(self, event) -> None: event.defer() return - self.charm.share_uri() self.charm.unit.status = ActiveStatus() # BEGIN: helper functions diff --git a/lib/charms/mongodb/v1/mongos.py b/lib/charms/mongodb/v1/mongos.py index ee2c11678..90711652b 100644 --- a/lib/charms/mongodb/v1/mongos.py +++ b/lib/charms/mongodb/v1/mongos.py @@ -52,8 +52,11 @@ class MongosConfiguration: @property def uri(self): """Return URI concatenated from fields.""" - hosts = [f"{host}:{self.port}" for host in self.hosts] - hosts = ",".join(hosts) + # mongos using Unix Domain Socket to communicate do not use port + if self.port: + self.hosts = [f"{host}:{self.port}" for host in self.hosts] + + hosts = ",".join(self.hosts) # Auth DB should be specified while user connects to application DB. auth_source = "" if self.database != "admin": diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py index d75ade184..cdcbad6a9 100644 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ b/lib/charms/operator_libs_linux/v1/systemd.py @@ -23,6 +23,7 @@ service_resume with run the mask/unmask and enable/disable invocations. Example usage: + ```python from charms.operator_libs_linux.v0.systemd import service_running, service_reload @@ -33,13 +34,14 @@ # Attempt to reload a service, restarting if necessary success = service_reload("nginx", restart_on_failure=True) ``` - """ -import logging -import subprocess - __all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) + "SystemdError", + "daemon_reload", + "service_disable", + "service_enable", + "service_failed", "service_pause", "service_reload", "service_restart", @@ -47,9 +49,11 @@ "service_running", "service_start", "service_stop", - "daemon_reload", ] +import logging +import subprocess + logger = logging.getLogger(__name__) # The unique Charmhub library identifier, never change it @@ -60,133 +64,168 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 3 +LIBPATCH = 4 class SystemdError(Exception): """Custom exception for SystemD related errors.""" - pass +def _systemctl(*args: str, check: bool = False) -> int: + """Control a system service using systemctl. -def _popen_kwargs(): - return { - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - "bufsize": 1, - "universal_newlines": True, - "encoding": "utf-8", - } - + Args: + *args: Arguments to pass to systemctl. + check: Check the output of the systemctl command. Default: False. -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. + Returns: + Returncode of systemctl command execution. - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. + Raises: + SystemdError: Raised if calling systemctl returns a non-zero returncode and check is True. """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if proc.returncode < 1: - return True - - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if sub_cmd == "is-active" and proc.returncode == 3: # Code returned when service not active. - return False - - if sub_cmd == "is-failed": - return False - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line + cmd = ["systemctl", *args] + logger.debug(f"Executing command: {cmd}") + try: + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + encoding="utf-8", + check=check, + ) + logger.debug( + f"Command {cmd} exit code: {proc.returncode}. systemctl output:\n{proc.stdout}" + ) + return proc.returncode + except subprocess.CalledProcessError as e: + raise SystemdError( + f"Command {cmd} failed with returncode {e.returncode}. systemctl output:\n{e.stdout}" ) - ) def service_running(service_name: str) -> bool: - """Determine whether a system service is running. + """Report whether a system service is running. Args: - service_name: the name of the service to check + service_name: The name of the service to check. + + Return: + True if service is running/active; False if not. """ - return _systemctl("is-active", service_name, quiet=True) + # If returncode is 0, this means that is service is active. + return _systemctl("--quiet", "is-active", service_name) == 0 def service_failed(service_name: str) -> bool: - """Determine whether a system service has failed. + """Report whether a system service has failed. Args: - service_name: the name of the service to check + service_name: The name of the service to check. + + Returns: + True if service is marked as failed; False if not. """ - return _systemctl("is-failed", service_name, quiet=True) + # If returncode is 0, this means that the service has failed. + return _systemctl("--quiet", "is-failed", service_name) == 0 -def service_start(service_name: str) -> bool: +def service_start(*args: str) -> bool: """Start a system service. Args: - service_name: the name of the service to start + *args: Arguments to pass to `systemctl start` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl start ...` returns a non-zero returncode. """ - return _systemctl("start", service_name) + return _systemctl("start", *args, check=True) == 0 -def service_stop(service_name: str) -> bool: +def service_stop(*args: str) -> bool: """Stop a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl stop` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl stop ...` returns a non-zero returncode. """ - return _systemctl("stop", service_name) + return _systemctl("stop", *args, check=True) == 0 -def service_restart(service_name: str) -> bool: +def service_restart(*args: str) -> bool: """Restart a system service. Args: - service_name: the name of the service to restart + *args: Arguments to pass to `systemctl restart` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl restart ...` returns a non-zero returncode. """ - return _systemctl("restart", service_name) + return _systemctl("restart", *args, check=True) == 0 + + +def service_enable(*args: str) -> bool: + """Enable a system service. + + Args: + *args: Arguments to pass to `systemctl enable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl enable ...` returns a non-zero returncode. + """ + return _systemctl("enable", *args, check=True) == 0 + + +def service_disable(*args: str) -> bool: + """Disable a system service. + + Args: + *args: Arguments to pass to `systemctl disable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl disable ...` returns a non-zero returncode. + """ + return _systemctl("disable", *args, check=True) == 0 def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: """Reload a system service, optionally falling back to restart if reload fails. Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. + service_name: The name of the service to reload. + restart_on_failure: + Boolean indicating whether to fall back to a restart if the reload fails. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl reload|restart ...` returns a non-zero returncode. """ try: - return _systemctl("reload", service_name) + return _systemctl("reload", service_name, check=True) == 0 except SystemdError: if restart_on_failure: - return _systemctl("restart", service_name) + return service_restart(service_name) else: raise @@ -194,37 +233,56 @@ def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: def service_pause(service_name: str) -> bool: """Pause a system service. - Stop it, and prevent it from starting again at boot. + Stops the service and prevents the service from starting again at boot. Args: - service_name: the name of the service to pause + service_name: The name of the service to pause. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is still running after being paused by systemctl. """ - _systemctl("disable", service_name, now=True) + _systemctl("disable", "--now", service_name) _systemctl("mask", service_name) - if not service_running(service_name): - return True + if service_running(service_name): + raise SystemdError(f"Attempted to pause {service_name!r}, but it is still running.") - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) + return True def service_resume(service_name: str) -> bool: """Resume a system service. - Re-enable starting again at boot. Start the service. + Re-enable starting the service again at boot. Start the service. Args: - service_name: the name of the service to resume + service_name: The name of the service to resume. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is not running after being resumed by systemctl. """ _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) + _systemctl("enable", "--now", service_name) - if service_running(service_name): - return True + if not service_running(service_name): + raise SystemdError(f"Attempted to resume {service_name!r}, but it is not running.") - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) + return True def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") + """Reload systemd manager configuration. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl daemon-reload` returns a non-zero returncode. + """ + return _systemctl("daemon-reload", check=True) == 0 diff --git a/requirements.txt b/requirements.txt index e892ba369..a4a2f819c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ attrs==22.2.0 cffi==1.15.1 -cosl==0.0.5 +cosl==0.0.7 importlib-resources==5.10.2 tenacity==8.1.0 pymongo==4.3.3 @@ -16,4 +16,4 @@ pyyaml==6.0.1 zipp==3.11.0 pyOpenSSL==22.1.0 typing-extensions==4.5.0 -parameterized==0.9.0 +parameterized==0.9.0 \ No newline at end of file