diff --git a/README.md b/README.md index 544759c..3b04ed3 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,24 @@ $ juju integrate slurmdbd:database slurmdbd-mysql-router:database $ juju integrate slurmctld:slurmdbd slurmdbd:slurmdbd ``` +### Operations +This charm hardens and simplifies operations by codifying common administration operations as charm actions. + +#### Node Config +You can get and set the node configuration using the `node-config` action. + +##### Use the `node-config` action to get the node configuration for the unit. +```bash +$ juju run --quiet slurmd/0 node-config --format json | jq ".[].results.node.config" +"NodeName=juju-462521-4 NodeAddr=10.240.222.28 State=UNKNOWN RealMemory=64012 CPUs=12 ThreadsPerCore=2 CoresPerSocket=6 SocketsPerBoard=1" +``` + +##### Use the `node-config` action to set a custom weight value for the node. +```bash +$ juju run --quiet slurmd/0 node-config node-config="Weight=5000" --format json | jq ".[].results.node.config" +"NodeName=juju-462521-4 NodeAddr=10.240.222.28 State=UNKNOWN RealMemory=64012 CPUs=12 ThreadsPerCore=2 CoresPerSocket=6 SocketsPerBoard=1 Weight=5000" +``` + ## Project & Community The slurmd operator is a project of the [Ubuntu HPC](https://discourse.ubuntu.com/t/high-performance-computing-team/35988) diff --git a/actions.yaml b/actions.yaml deleted file mode 100644 index 5be9cf4..0000000 --- a/actions.yaml +++ /dev/null @@ -1,15 +0,0 @@ -version: - description: Return version of installed software. -node-configured: - description: Remove a node from DownNodes when the reason is `New node`. -get-node-inventory: - description: Return node inventory. -set-node-inventory: - description: Modify node inventory. - params: - real-memory: - type: integer - description: Total amount of memory of the node, in MB. - -show-nhc-config: - description: Display the currently used `nhc.conf`. diff --git a/charmcraft.yaml b/charmcraft.yaml index 19539d7..36b309e 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -1,7 +1,29 @@ -# Copyright 2020 Omnivector, LLC -# See LICENSE file for licensing details. - +name: slurmd type: charm + +summary: | + Slurmd, the compute node daemon of Slurm. + +description: | + This charm provides slurmd, munged, and the bindings to other utilities + that make lifecycle operations a breeze. + + slurmd is the compute node daemon of SLURM. It monitors all tasks running + on the compute node, accepts work (tasks), launches tasks, and kills + running tasks upon request. + +links: + contact: https://matrix.to/#/#hpc:ubuntu.com + + issues: + - https://github.com/charmed-hpc/slurmd-operator/issues + + source: + - https://github.com/charmed-hpc/slurmd-operator + +assumes: + - juju + bases: - build-on: - name: ubuntu @@ -10,25 +32,73 @@ bases: - name: ubuntu channel: "22.04" architectures: [amd64] - - name: centos - channel: "7" - architectures: [amd64] parts: charm: - build-packages: [git] - charm-python-packages: [setuptools] - - # Create a version file and pack it into the charm. This is dynamically generated - # as part of the build process for a charm to ensure that the git revision of the - # charm is always recorded in this version file. - version-file: - plugin: nil build-packages: - - git + - wget override-build: | - VERSION=$(git -C $CRAFT_PART_SRC/../../charm/src describe --dirty --always) - echo "Setting version to $VERSION" - echo $VERSION > $CRAFT_PART_INSTALL/version - stage: - - version + wget https://github.com/mej/nhc/releases/download/1.4.3/lbnl-nhc-1.4.3.tar.gz + craftctl default + +requires: + fluentbit: + interface: fluentbit + +provides: + slurmd: + interface: slurmd + limit: 1 + +config: + options: + partition-config: + type: string + default: "" + description: > + Extra partition configuration, specified as a space separated `key=value` + in a single line. + + Example usage: + ```bash + $ juju config slurmd partition-config="DefaultTime=45:00 MaxTime=1:00:00" + ``` + + nhc-conf: + default: "" + type: string + description: > + Multiline string. + These lines are appended to the `nhc.conf` maintained by the charm. + + Example usage: + ```bash + $ juju config slurmd nhc-conf="$(cat extra-nhc.conf)" + ``` + +actions: + node-configured: + description: Remove a node from DownNodes when the reason is `New node`. + + node-config: + description: > + Set or return node configuration parameters. + + To get the current node configuration for this unit: + ``bash + $ juju run slurmd node-config + ``` + + To set node level configuration parameters for this unit: + ``bash + $ juju run slurmd node-config="Weight=200 Gres=gpu:tesla:1,gpu:kepler:1,bandwidth:lustre:no_consume:4G" + ``` + + params: + node-config: + type: string + description: > + Node configuration as defined [here](https://slurm.schedmd.com/slurm.conf.html#SECTION_NODE-CONFIGURATION). + + show-nhc-config: + description: Display `nhc.conf`. diff --git a/config.yaml b/config.yaml deleted file mode 100644 index b2ebfea..0000000 --- a/config.yaml +++ /dev/null @@ -1,40 +0,0 @@ -options: - custom-slurm-repo: - type: string - default: "" - description: > - Use a custom repository for Slurm installation. - - This can be set to the Organization's local mirror/cache of packages and - supersedes the Omnivector repositories. Alternatively, it can be used to - track a `testing` Slurm version, e.g. by setting to - `ppa:omnivector/osd-testing` (on Ubuntu), or - `https://omnivector-solutions.github.io/repo/centos7/stable/$basearch` - (on CentOS). - - Note: The configuration `custom-slurm-repo` must be set *before* - deploying the units. Changing this value after deploying the units will - not reinstall Slurm. - partition-config: - type: string - default: "" - description: > - Extra partition configuration, specified as a space separated `key=value` - in a single line. - - Example usage: - $ juju config slurmd partition-config="DefaultTime=45:00 MaxTime=1:00:00" - partition-state: - type: string - default: "UP" - description: > - State of partition or availability for use. Possible values are `UP`, - `DOWN`, `DRAIN` and `INACTIVE`. The default value is `UP`. See also the - related `Alternate` keyword. - nhc-conf: - default: "" - type: string - description: > - Custom extra configuration to use for Node Health Check. - - These lines are appended to a basic `nhc.conf` provided by the charm. diff --git a/dispatch b/dispatch index 7f58019..2b6f3b8 100755 --- a/dispatch +++ b/dispatch @@ -1,44 +1,11 @@ #!/bin/bash -# This hook installs the dependencies needed to run the charm, -# creates the dispatch executable, regenerates the symlinks for start and -# upgrade-charm, and kicks off the operator framework. - set -e -# Source the os-release information into the env -. /etc/os-release - if ! [[ -f '.installed' ]] then - if [[ $ID == 'centos' ]] - then - # Install dependencies and build custom python - yum -y install epel-release - yum -y install wget gcc make tar bzip2-devel zlib-devel xz-devel openssl-devel libffi-devel sqlite-devel ncurses-devel - - export PYTHON_VERSION=3.8.16 - wget https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tar.xz -P /tmp - tar xvf /tmp/Python-${PYTHON_VERSION}.tar.xz -C /tmp - cd /tmp/Python-${PYTHON_VERSION} - ./configure --enable-optimizations - make -C /tmp/Python-${PYTHON_VERSION} -j $(nproc) altinstall - cd $OLDPWD - rm -rf /tmp/Python* - - elif [[ $ID == 'ubuntu' ]] - then - # Necessary to compile and install NHC - apt-get install --assume-yes make - fi - touch .installed -fi - -# set the correct python bin path -if [[ $ID == "centos" ]] -then - PYTHON_BIN="/usr/bin/env python3.8" -else - PYTHON_BIN="/usr/bin/env python3" + # Necessary to compile and install NHC + apt-get install --assume-yes make + touch .installed fi -JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv $PYTHON_BIN ./src/charm.py \ No newline at end of file +JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv /usr/bin/env python3 ./src/charm.py diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py new file mode 100644 index 0000000..1400df7 --- /dev/null +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -0,0 +1,1361 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstractions for the system's Debian/Ubuntu package information and repositories. + +This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and +packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or +repositories to systems for use in machine charms. + +A sane default configuration is attainable through nothing more than instantiation of the +appropriate classes. `DebianPackage` objects provide information about the architecture, version, +name, and status of a package. + +`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when +provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` +will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error +message if the package is not known is desirable. + +To install packages with convenience methods: + +```python +try: + # Run `apt-get update` + apt.update() + apt.add_package("zsh") + apt.add_package(["vim", "htop", "wget"]) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +```` + +To find details of a specific package: + +```python +try: + vim = apt.DebianPackage.from_system("vim") + + # To find from the apt cache only + # apt.DebianPackage.from_apt_cache("vim") + + # To find from installed packages only + # apt.DebianPackage.from_installed_package("vim") + + vim.ensure(PackageState.Latest) + logger.info("updated vim to version: %s", vim.fullversion) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +``` + + +`RepositoryMapping` will return a dict-like object containing enabled system repositories +and their properties (available groups, baseuri. gpg key). This class can add, disable, or +manipulate repositories. Items can be retrieved as `DebianRepository` objects. + +In order add a new repository with explicit details for fields, a new `DebianRepository` can +be added to `RepositoryMapping` + +`RepositoryMapping` provides an abstraction around the existing repositories on the system, +and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, +iterate, or perform other operations. + +Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. + +Repositories can be added with explicit values through a Python constructor. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-example.com-focal" not in repositories: + repositories.add(DebianRepository(enabled=True, repotype="deb", + uri="https://example.com", release="focal", groups=["universe"])) +``` + +Alternatively, any valid `sources.list` line may be used to construct a new +`DebianRepository`. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-us.archive.ubuntu.com-xenial" not in repositories: + line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" + repo = DebianRepository.from_repo_line(line) + repositories.add(repo) +``` +""" + +import fileinput +import glob +import logging +import os +import re +import subprocess +from collections.abc import Mapping +from enum import Enum +from subprocess import PIPE, CalledProcessError, check_output +from typing import Iterable, List, Optional, Tuple, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 13 + + +VALID_SOURCE_TYPES = ("deb", "deb-src") +OPTIONS_MATCHER = re.compile(r"\[.*?\]") + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class PackageError(Error): + """Raised when there's an error installing or removing a package.""" + + +class PackageNotFoundError(Error): + """Raised when a requested package is not known to the system.""" + + +class PackageState(Enum): + """A class to represent possible package states.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class DebianPackage: + """Represents a traditional Debian package and its utility functions. + + `DebianPackage` wraps information and functionality around a known package, whether installed + or available. The version, epoch, name, and architecture can be easily queried and compared + against other `DebianPackage` objects to determine the latest version or to install a specific + version. + + The representation of this object as a string mimics the output from `dpkg` for familiarity. + + Installation and removal of packages is handled through the `state` property or `ensure` + method, with the following options: + + apt.PackageState.Absent + apt.PackageState.Available + apt.PackageState.Present + apt.PackageState.Latest + + When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to + `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal + (though it operates essentially the same as `Available`). + """ + + def __init__( + self, name: str, version: str, epoch: str, arch: str, state: PackageState + ) -> None: + self._name = name + self._arch = arch + self._state = state + self._version = Version(version, epoch) + + def __eq__(self, other) -> bool: + """Equality for comparison. + + Args: + other: a `DebianPackage` object for comparison + + Returns: + A boolean reflecting equality + """ + return isinstance(other, self.__class__) and ( + self._name, + self._version.number, + ) == (other._name, other._version.number) + + def __hash__(self): + """Return a hash of this package.""" + return hash((self._name, self._version.number)) + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return a human-readable representation of the package.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._version, + self._arch, + str(self._state), + ) + + @staticmethod + def _apt( + command: str, + package_names: Union[str, List], + optargs: Optional[List[str]] = None, + ) -> None: + """Wrap package management commands for Debian/Ubuntu systems. + + Args: + command: the command given to `apt-get` + package_names: a package name or list of package names to operate on + optargs: an (Optional) list of additioanl arguments + + Raises: + PackageError if an error is encountered + """ + optargs = optargs if optargs is not None else [] + if isinstance(package_names, str): + package_names = [package_names] + _cmd = ["apt-get", "-y", *optargs, command, *package_names] + try: + env = os.environ.copy() + env["DEBIAN_FRONTEND"] = "noninteractive" + subprocess.run(_cmd, capture_output=True, check=True, text=True, env=env) + except CalledProcessError as e: + raise PackageError( + "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.stderr) + ) from None + + def _add(self) -> None: + """Add a package to the system.""" + self._apt( + "install", + "{}={}".format(self.name, self.version), + optargs=["--option=Dpkg::Options::=--force-confold"], + ) + + def _remove(self) -> None: + """Remove a package from the system. Implementation-specific.""" + return self._apt("remove", "{}={}".format(self.name, self.version)) + + @property + def name(self) -> str: + """Returns the name of the package.""" + return self._name + + def ensure(self, state: PackageState): + """Ensure that a package is in a given state. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if self._state is not state: + if state not in (PackageState.Present, PackageState.Latest): + self._remove() + else: + self._add() + self._state = state + + @property + def present(self) -> bool: + """Returns whether or not a package is present.""" + return self._state in (PackageState.Present, PackageState.Latest) + + @property + def latest(self) -> bool: + """Returns whether the package is the most recent version.""" + return self._state is PackageState.Latest + + @property + def state(self) -> PackageState: + """Returns the current package state.""" + return self._state + + @state.setter + def state(self, state: PackageState) -> None: + """Set the package state to a given value. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if state in (PackageState.Latest, PackageState.Present): + self._add() + else: + self._remove() + self._state = state + + @property + def version(self) -> "Version": + """Returns the version for a package.""" + return self._version + + @property + def epoch(self) -> str: + """Returns the epoch for a package. May be unset.""" + return self._version.epoch + + @property + def arch(self) -> str: + """Returns the architecture for a package.""" + return self._arch + + @property + def fullversion(self) -> str: + """Returns the name+epoch for a package.""" + return "{}.{}".format(self._version, self._arch) + + @staticmethod + def _get_epoch_from_version(version: str) -> Tuple[str, str]: + """Pull the epoch, if any, out of a version string.""" + epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") + matches = epoch_matcher.search(version).groupdict() + return matches.get("epoch", ""), matches.get("version") + + @classmethod + def from_system( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Locates a package, either on the system or known to apt, and serializes the information. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an + architecture is not specified, this will be used for selection. + + """ + try: + return DebianPackage.from_installed_package(package, version, arch) + except PackageNotFoundError: + logger.debug( + "package '%s' is not currently installed or has the wrong architecture.", package + ) + + # Ok, try `apt-cache ...` + try: + return DebianPackage.from_apt_cache(package, version, arch) + except (PackageNotFoundError, PackageError): + # If we get here, it's not known to the systems. + # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, + # and providing meaningful error messages without this is ugly. + raise PackageNotFoundError( + "Package '{}{}' could not be found on the system or in the apt cache!".format( + package, ".{}".format(arch) if arch else "" + ) + ) from None + + @classmethod + def from_installed_package( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + output = "" + try: + output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) + except CalledProcessError: + raise PackageNotFoundError("Package is not installed: {}".format(package)) from None + + # Pop off the output from `dpkg -l' because there's no flag to + # omit it` + lines = str(output).splitlines()[5:] + + dpkg_matcher = re.compile( + r""" + ^(?P\w+?)\s+ + (?P.*?)(?P:\w+?)?\s+ + (?P.*?)\s+ + (?P\w+?)\s+ + (?P.*) + """, + re.VERBOSE, + ) + + for line in lines: + try: + matches = dpkg_matcher.search(line).groupdict() + package_status = matches["package_status"] + + if not package_status.endswith("i"): + logger.debug( + "package '%s' in dpkg output but not installed, status: '%s'", + package, + package_status, + ) + break + + epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) + pkg = DebianPackage( + matches["package_name"], + split_version, + epoch, + matches["arch"], + PackageState.Present, + ) + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + except AttributeError: + logger.warning("dpkg matcher could not parse line: %s", line) + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) + + @classmethod + def from_apt_cache( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + keys = ("Package", "Architecture", "Version") + + try: + output = check_output( + ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True + ) + except CalledProcessError as e: + raise PackageError( + "Could not list packages in apt-cache: {}".format(e.stderr) + ) from None + + pkg_groups = output.strip().split("\n\n") + keys = ("Package", "Architecture", "Version") + + for pkg_raw in pkg_groups: + lines = str(pkg_raw).splitlines() + vals = {} + for line in lines: + if line.startswith(keys): + items = line.split(":", 1) + vals[items[0]] = items[1].strip() + else: + continue + + epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) + pkg = DebianPackage( + vals["Package"], + split_version, + epoch, + vals["Architecture"], + PackageState.Available, + ) + + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) + + +class Version: + """An abstraction around package versions. + + This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a + venv, and wedging version comparisons into `DebianPackage` would overcomplicate it. + + This class implements the algorithm found here: + https://www.debian.org/doc/debian-policy/ch-controlfields.html#version + """ + + def __init__(self, version: str, epoch: str): + self._version = version + self._epoch = epoch or "" + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return human-readable representation of the package.""" + return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) + + @property + def epoch(self): + """Returns the epoch for a package. May be empty.""" + return self._epoch + + @property + def number(self) -> str: + """Returns the version number for a package.""" + return self._version + + def _get_parts(self, version: str) -> Tuple[str, str]: + """Separate the version into component upstream and Debian pieces.""" + try: + version.rindex("-") + except ValueError: + # No hyphens means no Debian version + return version, "0" + + upstream, debian = version.rsplit("-", 1) + return upstream, debian + + def _listify(self, revision: str) -> List[str]: + """Split a revision string into a listself. + + This list is comprised of alternating between strings and numbers, + padded on either end to always be "str, int, str, int..." and + always be of even length. This allows us to trivially implement the + comparison algorithm described. + """ + result = [] + while revision: + rev_1, remains = self._get_alphas(revision) + rev_2, remains = self._get_digits(remains) + result.extend([rev_1, rev_2]) + revision = remains + return result + + def _get_alphas(self, revision: str) -> Tuple[str, str]: + """Return a tuple of the first non-digit characters of a revision.""" + # get the index of the first digit + for i, char in enumerate(revision): + if char.isdigit(): + if i == 0: + return "", revision + return revision[0:i], revision[i:] + # string is entirely alphas + return revision, "" + + def _get_digits(self, revision: str) -> Tuple[int, str]: + """Return a tuple of the first integer characters of a revision.""" + # If the string is empty, return (0,'') + if not revision: + return 0, "" + # get the index of the first non-digit + for i, char in enumerate(revision): + if not char.isdigit(): + if i == 0: + return 0, revision + return int(revision[0:i]), revision[i:] + # string is entirely digits + return int(revision), "" + + def _dstringcmp(self, a, b): # noqa: C901 + """Debian package version string section lexical sort algorithm. + + The lexical comparison is a comparison of ASCII values modified so + that all the letters sort earlier than all the non-letters and so that + a tilde sorts before anything, even the end of a part. + """ + if a == b: + return 0 + try: + for i, char in enumerate(a): + if char == b[i]: + continue + # "a tilde sorts before anything, even the end of a part" + # (emptyness) + if char == "~": + return -1 + if b[i] == "~": + return 1 + # "all the letters sort earlier than all the non-letters" + if char.isalpha() and not b[i].isalpha(): + return -1 + if not char.isalpha() and b[i].isalpha(): + return 1 + # otherwise lexical sort + if ord(char) > ord(b[i]): + return 1 + if ord(char) < ord(b[i]): + return -1 + except IndexError: + # a is longer than b but otherwise equal, greater unless there are tildes + if char == "~": + return -1 + return 1 + # if we get here, a is shorter than b but otherwise equal, so check for tildes... + if b[len(a)] == "~": + return 1 + return -1 + + def _compare_revision_strings(self, first: str, second: str): # noqa: C901 + """Compare two debian revision strings.""" + if first == second: + return 0 + + # listify pads results so that we will always be comparing ints to ints + # and strings to strings (at least until we fall off the end of a list) + first_list = self._listify(first) + second_list = self._listify(second) + if first_list == second_list: + return 0 + try: + for i, item in enumerate(first_list): + # explicitly raise IndexError if we've fallen off the edge of list2 + if i >= len(second_list): + raise IndexError + # if the items are equal, next + if item == second_list[i]: + continue + # numeric comparison + if isinstance(item, int): + if item > second_list[i]: + return 1 + if item < second_list[i]: + return -1 + else: + # string comparison + return self._dstringcmp(item, second_list[i]) + except IndexError: + # rev1 is longer than rev2 but otherwise equal, hence greater + # ...except for goddamn tildes + if first_list[len(second_list)][0][0] == "~": + return 1 + return 1 + # rev1 is shorter than rev2 but otherwise equal, hence lesser + # ...except for goddamn tildes + if second_list[len(first_list)][0][0] == "~": + return -1 + return -1 + + def _compare_version(self, other) -> int: + if (self.number, self.epoch) == (other.number, other.epoch): + return 0 + + if self.epoch < other.epoch: + return -1 + if self.epoch > other.epoch: + return 1 + + # If none of these are true, follow the algorithm + upstream_version, debian_version = self._get_parts(self.number) + other_upstream_version, other_debian_version = self._get_parts(other.number) + + upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) + if upstream_cmp != 0: + return upstream_cmp + + debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) + if debian_cmp != 0: + return debian_cmp + + return 0 + + def __lt__(self, other) -> bool: + """Less than magic method impl.""" + return self._compare_version(other) < 0 + + def __eq__(self, other) -> bool: + """Equality magic method impl.""" + return self._compare_version(other) == 0 + + def __gt__(self, other) -> bool: + """Greater than magic method impl.""" + return self._compare_version(other) > 0 + + def __le__(self, other) -> bool: + """Less than or equal to magic method impl.""" + return self.__eq__(other) or self.__lt__(other) + + def __ge__(self, other) -> bool: + """Greater than or equal to magic method impl.""" + return self.__gt__(other) or self.__eq__(other) + + def __ne__(self, other) -> bool: + """Not equal to magic method impl.""" + return not self.__eq__(other) + + +def add_package( + package_names: Union[str, List[str]], + version: Optional[str] = "", + arch: Optional[str] = "", + update_cache: Optional[bool] = False, +) -> Union[DebianPackage, List[DebianPackage]]: + """Add a package or list of packages to the system. + + Args: + package_names: single package name, or list of package names + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + update_cache: whether or not to run `apt-get update` prior to operating + + Raises: + TypeError if no package name is given, or explicit version is set for multiple packages + PackageNotFoundError if the package is not in the cache. + PackageError if packages fail to install + """ + cache_refreshed = False + if update_cache: + update() + cache_refreshed = True + + packages = {"success": [], "retry": [], "failed": []} + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + if len(package_names) != 1 and version: + raise TypeError( + "Explicit version should not be set if more than one package is being added!" + ) + + for p in package_names: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + logger.warning("failed to locate and install/update '%s'", pkg) + packages["retry"].append(p) + + if packages["retry"] and not cache_refreshed: + logger.info("updating the apt-cache and retrying installation of failed packages.") + update() + + for p in packages["retry"]: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + packages["failed"].append(p) + + if packages["failed"]: + raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) + + return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] + + +def _add( + name: str, + version: Optional[str] = "", + arch: Optional[str] = "", +) -> Tuple[Union[DebianPackage, str], bool]: + """Add a package to the system. + + Args: + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + + Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and + a boolean indicating success + """ + try: + pkg = DebianPackage.from_system(name, version, arch) + pkg.ensure(state=PackageState.Present) + return pkg, True + except PackageNotFoundError: + return name, False + + +def remove_package( + package_names: Union[str, List[str]] +) -> Union[DebianPackage, List[DebianPackage]]: + """Remove package(s) from the system. + + Args: + package_names: the name of a package + + Raises: + PackageNotFoundError if the package is not found. + """ + packages = [] + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + for p in package_names: + try: + pkg = DebianPackage.from_installed_package(p) + pkg.ensure(state=PackageState.Absent) + packages.append(pkg) + except PackageNotFoundError: + logger.info("package '%s' was requested for removal, but it was not installed.", p) + + # the list of packages will be empty when no package is removed + logger.debug("packages: '%s'", packages) + return packages[0] if len(packages) == 1 else packages + + +def update() -> None: + """Update the apt cache via `apt-get update`.""" + subprocess.run(["apt-get", "update"], capture_output=True, check=True) + + +def import_key(key: str) -> str: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, including BEGIN + and END markers or a keyid. + + Returns: + The GPG key filename written. + + Raises: + GPGKeyError if the key could not be imported + """ + key = key.strip() + if "-" in key or "\n" in key: + # Send everything not obviously a keyid to GPG to import, as + # we trust its validation better than our own. eg. handling + # comments before the key. + logger.debug("PGP key found (looks like ASCII Armor format)") + if ( + "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key + and "-----END PGP PUBLIC KEY BLOCK-----" in key + ): + logger.debug("Writing provided PGP key in the binary format") + key_bytes = key.encode("utf-8") + key_name = DebianRepository._get_keyid_by_gpg_key(key_bytes) + key_gpg = DebianRepository._dearmor_gpg_key(key_bytes) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) + DebianRepository._write_apt_gpg_keyfile( + key_name=gpg_key_filename, key_material=key_gpg + ) + return gpg_key_filename + else: + raise GPGKeyError("ASCII armor markers missing from GPG key") + else: + logger.warning( + "PGP key found (looks like Radix64 format). " + "SECURELY importing PGP key from keyserver; " + "full key not provided." + ) + # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL + # to retrieve GPG keys. `apt-key adv` command is deprecated as is + # apt-key in general as noted in its manpage. See lp:1433761 for more + # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop + # gpg + key_asc = DebianRepository._get_key_by_keyid(key) + # write the key in GPG format so that apt-key list shows it + key_gpg = DebianRepository._dearmor_gpg_key(key_asc.encode("utf-8")) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) + DebianRepository._write_apt_gpg_keyfile(key_name=gpg_key_filename, key_material=key_gpg) + return gpg_key_filename + + +class InvalidSourceError(Error): + """Exceptions for invalid source entries.""" + + +class GPGKeyError(Error): + """Exceptions for GPG keys.""" + + +class DebianRepository: + """An abstraction to represent a repository.""" + + def __init__( + self, + enabled: bool, + repotype: str, + uri: str, + release: str, + groups: List[str], + filename: Optional[str] = "", + gpg_key_filename: Optional[str] = "", + options: Optional[dict] = None, + ): + self._enabled = enabled + self._repotype = repotype + self._uri = uri + self._release = release + self._groups = groups + self._filename = filename + self._gpg_key_filename = gpg_key_filename + self._options = options + + @property + def enabled(self): + """Return whether or not the repository is enabled.""" + return self._enabled + + @property + def repotype(self): + """Return whether it is binary or source.""" + return self._repotype + + @property + def uri(self): + """Return the URI.""" + return self._uri + + @property + def release(self): + """Return which Debian/Ubuntu releases it is valid for.""" + return self._release + + @property + def groups(self): + """Return the enabled package groups.""" + return self._groups + + @property + def filename(self): + """Returns the filename for a repository.""" + return self._filename + + @filename.setter + def filename(self, fname: str) -> None: + """Set the filename used when a repo is written back to disk. + + Args: + fname: a filename to write the repository information to. + """ + if not fname.endswith(".list"): + raise InvalidSourceError("apt source filenames should end in .list!") + + self._filename = fname + + @property + def gpg_key(self): + """Returns the path to the GPG key for this repository.""" + return self._gpg_key_filename + + @property + def options(self): + """Returns any additional repo options which are set.""" + return self._options + + def make_options_string(self) -> str: + """Generate the complete options string for a a repository. + + Combining `gpg_key`, if set, and the rest of the options to find + a complex repo string. + """ + options = self._options if self._options else {} + if self._gpg_key_filename: + options["signed-by"] = self._gpg_key_filename + + return ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) + if options + else "" + ) + + @staticmethod + def prefix_from_uri(uri: str) -> str: + """Get a repo list prefix from the uri, depending on whether a path is set.""" + uridetails = urlparse(uri) + path = ( + uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc + ) + return "/etc/apt/sources.list.d/{}".format(path) + + @staticmethod + def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": + """Instantiate a new `DebianRepository` a `sources.list` entry line. + + Args: + repo_line: a string representing a repository entry + write_file: boolean to enable writing the new repo to disk + """ + repo = RepositoryMapping._parse(repo_line, "UserInput") + fname = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + repo.filename = fname + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + # For Python 3.5 it's required to use sorted in the options dict in order to not have + # different results in the order of the options between executions. + options_str = ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) + if options + else "" + ) + + if write_file: + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, options_str, repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + return repo + + def disable(self) -> None: + """Remove this repository from consideration. + + Disable it instead of removing from the repository file. + """ + searcher = "{} {}{} {}".format( + self.repotype, self.make_options_string(), self.uri, self.release + ) + for line in fileinput.input(self._filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + def import_key(self, key: str) -> None: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, + including BEGIN and END markers or a keyid. + + Raises: + GPGKeyError if the key could not be imported + """ + self._gpg_key_filename = import_key(key) + + @staticmethod + def _get_keyid_by_gpg_key(key_material: bytes) -> str: + """Get a GPG key fingerprint by GPG key material. + + Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded + or binary GPG key material. Can be used, for example, to generate file + names for keys passed via charm options. + """ + # Use the same gpg command for both Xenial and Bionic + cmd = ["gpg", "--with-colons", "--with-fingerprint"] + ps = subprocess.run( + cmd, + stdout=PIPE, + stderr=PIPE, + input=key_material, + ) + out, err = ps.stdout.decode(), ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError("Invalid GPG key material provided") + # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) + return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) + + @staticmethod + def _get_key_by_keyid(keyid: str) -> str: + """Get a key via HTTPS from the Ubuntu keyserver. + + Different key ID formats are supported by SKS keyservers (the longer ones + are more secure, see "dead beef attack" and https://evil32.com/). Since + HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will + impersonate keyserver.ubuntu.com and generate a certificate with + keyserver.ubuntu.com in the CN field or in SubjAltName fields of a + certificate. If such proxy behavior is expected it is necessary to add the + CA certificate chain containing the intermediate CA of the SSLBump proxy to + every machine that this code runs on via ca-certs cloud-init directive (via + cloudinit-userdata model-config) or via other means (such as through a + custom charm option). Also note that DNS resolution for the hostname in a + URL is done at a proxy server - not at the client side. + 8-digit (32 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 + 16-digit (64 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 + 40-digit key ID: + https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 + + Args: + keyid: An 8, 16 or 40 hex digit keyid to find a key for + + Returns: + A string contining key material for the specified GPG key id + + + Raises: + subprocess.CalledProcessError + """ + # options=mr - machine-readable output (disables html wrappers) + keyserver_url = ( + "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" + ) + curl_cmd = ["curl", keyserver_url.format(keyid)] + # use proxy server settings in order to retrieve the key + return check_output(curl_cmd).decode() + + @staticmethod + def _dearmor_gpg_key(key_asc: bytes) -> bytes: + """Convert a GPG key in the ASCII armor format to the binary format. + + Args: + key_asc: A GPG key in ASCII armor format. + + Returns: + A GPG key in binary format as a string + + Raises: + GPGKeyError + """ + ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) + out, err = ps.stdout, ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError( + "Invalid GPG key material. Check your network setup" + " (MTU, routing, DNS) and/or proxy server settings" + " as well as destination keyserver status." + ) + else: + return out + + @staticmethod + def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: + """Write GPG key material into a file at a provided path. + + Args: + key_name: A key name to use for a key file (could be a fingerprint) + key_material: A GPG key material (binary) + """ + with open(key_name, "wb") as keyf: + keyf.write(key_material) + + +class RepositoryMapping(Mapping): + """An representation of known repositories. + + Instantiation of `RepositoryMapping` will iterate through the + filesystem, parse out repository files in `/etc/apt/...`, and create + `DebianRepository` objects in this list. + + Typical usage: + + repositories = apt.RepositoryMapping() + repositories.add(DebianRepository( + enabled=True, repotype="deb", uri="https://example.com", release="focal", + groups=["universe"] + )) + """ + + def __init__(self): + self._repository_map = {} + # Repositories that we're adding -- used to implement mode param + self.default_file = "/etc/apt/sources.list" + + # read sources.list if it exists + if os.path.isfile(self.default_file): + self.load(self.default_file) + + # read sources.list.d + for file in glob.iglob("/etc/apt/sources.list.d/*.list"): + self.load(file) + + def __contains__(self, key: str) -> bool: + """Magic method for checking presence of repo in mapping.""" + return key in self._repository_map + + def __len__(self) -> int: + """Return number of repositories in map.""" + return len(self._repository_map) + + def __iter__(self) -> Iterable[DebianRepository]: + """Return iterator for RepositoryMapping.""" + return iter(self._repository_map.values()) + + def __getitem__(self, repository_uri: str) -> DebianRepository: + """Return a given `DebianRepository`.""" + return self._repository_map[repository_uri] + + def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: + """Add a `DebianRepository` to the cache.""" + self._repository_map[repository_uri] = repository + + def load(self, filename: str): + """Load a repository source file into the cache. + + Args: + filename: the path to the repository file + """ + parsed = [] + skipped = [] + with open(filename, "r") as f: + for n, line in enumerate(f): + try: + repo = self._parse(line, filename) + except InvalidSourceError: + skipped.append(n) + else: + repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) + self._repository_map[repo_identifier] = repo + parsed.append(n) + logger.debug("parsed repo: '%s'", repo_identifier) + + if skipped: + skip_list = ", ".join(str(s) for s in skipped) + logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) + + if parsed: + logger.info("parsed %d apt package repositories", len(parsed)) + else: + raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) + + @staticmethod + def _parse(line: str, filename: str) -> DebianRepository: + """Parse a line in a sources.list file. + + Args: + line: a single line from `load` to parse + filename: the filename being read + + Raises: + InvalidSourceError if the source type is unknown + """ + enabled = True + repotype = uri = release = gpg_key = "" + options = {} + groups = [] + + line = line.strip() + if line.startswith("#"): + enabled = False + line = line[1:] + + # Check for "#" in the line and treat a part after it as a comment then strip it off. + i = line.find("#") + if i > 0: + line = line[:i] + + # Split a source into substrings to initialize a new repo. + source = line.strip() + if source: + # Match any repo options, and get a dict representation. + for v in re.findall(OPTIONS_MATCHER, source): + opts = dict(o.split("=") for o in v.strip("[]").split()) + # Extract the 'signed-by' option for the gpg_key + gpg_key = opts.pop("signed-by", "") + options = opts + + # Remove any options from the source string and split the string into chunks + source = re.sub(OPTIONS_MATCHER, "", source) + chunks = source.split() + + # Check we've got a valid list of chunks + if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + repotype = chunks[0] + uri = chunks[1] + release = chunks[2] + groups = chunks[3:] + + return DebianRepository( + enabled, repotype, uri, release, groups, filename, gpg_key, options + ) + else: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: + """Add a new repository to the system. + + Args: + repo: a `DebianRepository` object + default_filename: an (Optional) filename if the default is not desirable + """ + new_filename = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + + fname = repo.filename or new_filename + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo + + def disable(self, repo: DebianRepository) -> None: + """Remove a repository. Disable by default. + + Args: + repo: a `DebianRepository` to disable + """ + searcher = "{} {}{} {}".format( + repo.repotype, repo.make_options_string(), repo.uri, repo.release + ) + + for line in fileinput.input(repo.filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v0/juju_systemd_notices.py b/lib/charms/operator_libs_linux/v0/juju_systemd_notices.py index 08157c9..72cbdb9 100644 --- a/lib/charms/operator_libs_linux/v0/juju_systemd_notices.py +++ b/lib/charms/operator_libs_linux/v0/juju_systemd_notices.py @@ -198,7 +198,7 @@ def subscribe(self) -> None: Type=simple Restart=always WorkingDirectory={self._charm.framework.charm_dir} - Environment="PYTHONPATH={self._charm.framework.charm_dir / "venv"}" + Environment="PYTHONPATH={self._charm.framework.charm_dir / "venv"}:{self._charm.framework.charm_dir / "lib"}" ExecStart=/usr/bin/python3 {__file__} {self._charm.unit.name} [Install] diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py index d75ade1..cdcbad6 100644 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ b/lib/charms/operator_libs_linux/v1/systemd.py @@ -23,6 +23,7 @@ service_resume with run the mask/unmask and enable/disable invocations. Example usage: + ```python from charms.operator_libs_linux.v0.systemd import service_running, service_reload @@ -33,13 +34,14 @@ # Attempt to reload a service, restarting if necessary success = service_reload("nginx", restart_on_failure=True) ``` - """ -import logging -import subprocess - __all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) + "SystemdError", + "daemon_reload", + "service_disable", + "service_enable", + "service_failed", "service_pause", "service_reload", "service_restart", @@ -47,9 +49,11 @@ "service_running", "service_start", "service_stop", - "daemon_reload", ] +import logging +import subprocess + logger = logging.getLogger(__name__) # The unique Charmhub library identifier, never change it @@ -60,133 +64,168 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 3 +LIBPATCH = 4 class SystemdError(Exception): """Custom exception for SystemD related errors.""" - pass +def _systemctl(*args: str, check: bool = False) -> int: + """Control a system service using systemctl. -def _popen_kwargs(): - return { - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - "bufsize": 1, - "universal_newlines": True, - "encoding": "utf-8", - } - + Args: + *args: Arguments to pass to systemctl. + check: Check the output of the systemctl command. Default: False. -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. + Returns: + Returncode of systemctl command execution. - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. + Raises: + SystemdError: Raised if calling systemctl returns a non-zero returncode and check is True. """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if proc.returncode < 1: - return True - - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if sub_cmd == "is-active" and proc.returncode == 3: # Code returned when service not active. - return False - - if sub_cmd == "is-failed": - return False - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line + cmd = ["systemctl", *args] + logger.debug(f"Executing command: {cmd}") + try: + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + encoding="utf-8", + check=check, + ) + logger.debug( + f"Command {cmd} exit code: {proc.returncode}. systemctl output:\n{proc.stdout}" + ) + return proc.returncode + except subprocess.CalledProcessError as e: + raise SystemdError( + f"Command {cmd} failed with returncode {e.returncode}. systemctl output:\n{e.stdout}" ) - ) def service_running(service_name: str) -> bool: - """Determine whether a system service is running. + """Report whether a system service is running. Args: - service_name: the name of the service to check + service_name: The name of the service to check. + + Return: + True if service is running/active; False if not. """ - return _systemctl("is-active", service_name, quiet=True) + # If returncode is 0, this means that is service is active. + return _systemctl("--quiet", "is-active", service_name) == 0 def service_failed(service_name: str) -> bool: - """Determine whether a system service has failed. + """Report whether a system service has failed. Args: - service_name: the name of the service to check + service_name: The name of the service to check. + + Returns: + True if service is marked as failed; False if not. """ - return _systemctl("is-failed", service_name, quiet=True) + # If returncode is 0, this means that the service has failed. + return _systemctl("--quiet", "is-failed", service_name) == 0 -def service_start(service_name: str) -> bool: +def service_start(*args: str) -> bool: """Start a system service. Args: - service_name: the name of the service to start + *args: Arguments to pass to `systemctl start` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl start ...` returns a non-zero returncode. """ - return _systemctl("start", service_name) + return _systemctl("start", *args, check=True) == 0 -def service_stop(service_name: str) -> bool: +def service_stop(*args: str) -> bool: """Stop a system service. Args: - service_name: the name of the service to stop + *args: Arguments to pass to `systemctl stop` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl stop ...` returns a non-zero returncode. """ - return _systemctl("stop", service_name) + return _systemctl("stop", *args, check=True) == 0 -def service_restart(service_name: str) -> bool: +def service_restart(*args: str) -> bool: """Restart a system service. Args: - service_name: the name of the service to restart + *args: Arguments to pass to `systemctl restart` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl restart ...` returns a non-zero returncode. """ - return _systemctl("restart", service_name) + return _systemctl("restart", *args, check=True) == 0 + + +def service_enable(*args: str) -> bool: + """Enable a system service. + + Args: + *args: Arguments to pass to `systemctl enable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl enable ...` returns a non-zero returncode. + """ + return _systemctl("enable", *args, check=True) == 0 + + +def service_disable(*args: str) -> bool: + """Disable a system service. + + Args: + *args: Arguments to pass to `systemctl disable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl disable ...` returns a non-zero returncode. + """ + return _systemctl("disable", *args, check=True) == 0 def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: """Reload a system service, optionally falling back to restart if reload fails. Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. + service_name: The name of the service to reload. + restart_on_failure: + Boolean indicating whether to fall back to a restart if the reload fails. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl reload|restart ...` returns a non-zero returncode. """ try: - return _systemctl("reload", service_name) + return _systemctl("reload", service_name, check=True) == 0 except SystemdError: if restart_on_failure: - return _systemctl("restart", service_name) + return service_restart(service_name) else: raise @@ -194,37 +233,56 @@ def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: def service_pause(service_name: str) -> bool: """Pause a system service. - Stop it, and prevent it from starting again at boot. + Stops the service and prevents the service from starting again at boot. Args: - service_name: the name of the service to pause + service_name: The name of the service to pause. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is still running after being paused by systemctl. """ - _systemctl("disable", service_name, now=True) + _systemctl("disable", "--now", service_name) _systemctl("mask", service_name) - if not service_running(service_name): - return True + if service_running(service_name): + raise SystemdError(f"Attempted to pause {service_name!r}, but it is still running.") - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) + return True def service_resume(service_name: str) -> bool: """Resume a system service. - Re-enable starting again at boot. Start the service. + Re-enable starting the service again at boot. Start the service. Args: - service_name: the name of the service to resume + service_name: The name of the service to resume. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is not running after being resumed by systemctl. """ _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) + _systemctl("enable", "--now", service_name) - if service_running(service_name): - return True + if not service_running(service_name): + raise SystemdError(f"Attempted to resume {service_name!r}, but it is not running.") - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) + return True def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") + """Reload systemd manager configuration. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl daemon-reload` returns a non-zero returncode. + """ + return _systemctl("daemon-reload", check=True) == 0 diff --git a/metadata.yaml b/metadata.yaml deleted file mode 100644 index f056e7e..0000000 --- a/metadata.yaml +++ /dev/null @@ -1,33 +0,0 @@ -name: slurmd -summary: | - Slurmd, the compute node daemon of Slurm. -description: | - This charm provides slurmd, munged, and the bindings to other utilities - that make lifecycle operations a breeze. - - slurmd is the compute node daemon of SLURM. It monitors all tasks running - on the compute node, accepts work (tasks), launches tasks, and kills - running tasks upon request. -source: https://github.com/omnivector-solutions/slurmd-operator -issues: https://github.com/omnivector-solutions/slurmd-operator/issues -maintainers: - - OmniVector Solutions - - Jason C. Nucciarone - - David Gomez - -requires: - fluentbit: - interface: fluentbit -provides: - slurmd: - interface: slurmd - -resources: - nhc: - type: file - filename: lbnl-nhc-1.4.3.tar.gz - description: | - Official tarball containing NHC. Retrieved from Github Releases. - -assumes: - - juju diff --git a/pyproject.toml b/pyproject.toml index ca3e80b..3f5a323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ target-version = ["py38"] # Linting tools configuration [tool.ruff] line-length = 99 -select = ["E", "W", "F", "C", "N", "D", "I001"] -extend-ignore = [ +lint.select = ["E", "W", "F", "C", "N", "D", "I001"] +lint.extend-ignore = [ "D203", "D204", "D213", @@ -47,9 +47,9 @@ extend-ignore = [ "D409", "D413", ] -ignore = ["E501", "D107"] +lint.ignore = ["E501", "D107"] extend-exclude = ["__pycache__", "*.egg_info"] -per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} +lint.per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/requirements.txt b/requirements.txt index 77d33a8..e974a11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ -ops==2.* distro -git+https://github.com/omnivector-solutions/slurm-ops-manager.git@0.8.16 +ops diff --git a/src/charm.py b/src/charm.py index 6a802d3..869cda1 100755 --- a/src/charm.py +++ b/src/charm.py @@ -1,33 +1,42 @@ #!/usr/bin/env python3 -# Copyright 2020 Omnivector Solutions, LLC. +# Copyright 2024 Omnivector, LLC. # See LICENSE file for licensing details. -"""SlurmdCharm.""" +"""Slurmd Operator Charm.""" import logging -from pathlib import Path +import socket +from dataclasses import fields +from typing import Any, Dict -import distro -from charms.fluentbit.v0.fluentbit import FluentbitClient -from charms.operator_libs_linux.v0.juju_systemd_notices import ( +from charms.fluentbit.v0.fluentbit import FluentbitClient # type: ignore[import-untyped] +from charms.operator_libs_linux.v0.juju_systemd_notices import ( # type: ignore[import-untyped] ServiceStartedEvent, ServiceStoppedEvent, SystemdNotices, ) -from interface_slurmd import Slurmd -from ops.charm import ActionEvent, CharmBase -from ops.framework import StoredState -from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, WaitingStatus -from slurm_ops_manager import SlurmManager -from utils import monkeypatch, slurmd +from interface_slurmd import ( + SlurmctldAvailableEvent, + Slurmd, +) +from ops import ( + ActionEvent, + ActiveStatus, + BlockedStatus, + CharmBase, + ConfigChangedEvent, + InstallEvent, + RelationCreatedEvent, + StoredState, + UpdateStatusEvent, + WaitingStatus, + main, +) +from slurm_conf_editor import Node, Partition +from slurmd_ops import SlurmdManager +from utils import machine, slurmd logger = logging.getLogger(__name__) -if distro.id() == "centos": - logger.debug("Monkeypatching slurmd operator to support CentOS base") - SystemdNotices = monkeypatch.juju_systemd_notices(SystemdNotices) - slurmd = monkeypatch.slurmd_override_default(slurmd) - slurmd = monkeypatch.slurmd_override_service(slurmd) class SlurmdCharm(CharmBase): @@ -40,138 +49,166 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._stored.set_default( + cluster_name=str(), + munge_key=str(), + new_node=True, nhc_conf=str(), + nhc_params=str(), slurm_installed=False, slurmctld_available=False, - slurmctld_started=False, - cluster_name=str(), + slurmctld_host=str(), + user_supplied_node_config={}, + user_supplied_partition_params={}, ) - self._slurm_manager = SlurmManager(self, "slurmd") self._fluentbit = FluentbitClient(self, "fluentbit") - # interface to slurmctld, should only have one slurmctld per slurmd app + + self._slurm_manager = SlurmdManager() + self._slurmd = Slurmd(self, "slurmd") + self._systemd_notices = SystemdNotices(self, ["slurmd"]) event_handler_bindings = { self.on.install: self._on_install, - self.on.upgrade_charm: self._on_upgrade, self.on.update_status: self._on_update_status, self.on.config_changed: self._on_config_changed, - self.on.service_slurmd_started: self._on_slurmd_started, - self.on.service_slurmd_stopped: self._on_slurmd_stopped, + # slurmd self._slurmd.on.slurmctld_available: self._on_slurmctld_available, self._slurmd.on.slurmctld_unavailable: self._on_slurmctld_unavailable, + # juju systemd services + self.on.service_slurmd_started: self._on_slurmd_started, + self.on.service_slurmd_stopped: self._on_slurmd_stopped, # fluentbit self.on["fluentbit"].relation_created: self._on_configure_fluentbit, # actions - self.on.version_action: self._on_version_action, self.on.node_configured_action: self._on_node_configured_action, - self.on.get_node_inventory_action: self._on_get_node_inventory_action, - self.on.set_node_inventory_action: self._on_set_node_inventory_action, self.on.show_nhc_config_action: self._on_show_nhc_config, + self.on.node_config_action: self._on_node_config_action_event, } for event, handler in event_handler_bindings.items(): self.framework.observe(event, handler) - def _on_install(self, event): + def _on_install(self, event: InstallEvent) -> None: """Perform installation operations for slurmd.""" - try: - nhc_path = self.model.resources.fetch("nhc") - logger.debug(f"## Found nhc resource: {nhc_path}") - except Exception as e: - logger.error(f"## Missing nhc resource: {e}") - self.unit.status = BlockedStatus("Missing nhc resource") - event.defer() - return - - self.unit.set_workload_version(Path("version").read_text().strip()) self.unit.status = WaitingStatus("Installing slurmd") - successful_installation = self._slurm_manager.install( - self.config.get("custom-slurm-repo"), nhc_path - ) - slurmd.override_service() - logger.debug(f"### slurmd installed: {successful_installation}") - if successful_installation: - self._stored.slurm_installed = True + if self._slurm_manager.install(): + self.unit.set_workload_version(self._slurm_manager.version()) + slurmd.override_service() self._systemd_notices.subscribe() + + self._stored.slurm_installed = True else: self.unit.status = BlockedStatus("Error installing slurmd") event.defer() self._check_status() - def _on_configure_fluentbit(self, event): + def _on_config_changed(self, event: ConfigChangedEvent) -> None: + """Handle charm configuration changes.""" + if nhc_conf := self.model.config.get("nhc-conf"): + if nhc_conf != self._stored.nhc_conf: + self._stored.nhc_conf = nhc_conf + self._slurm_manager.render_nhc_config(nhc_conf) + + if self.model.unit.is_leader(): + if ( + user_supplied_partition_params := self.model.config.get("partition-config") + ) is not None and self.model.config.get("partition-config") != "": + if user_supplied_partition_params != self._stored.user_supplied_partition_params: + # Parse the user supplied partition configuration. + tmp_params = {} + try: + tmp_params = { + item.split("=")[0]: item.split("=")[1] + for item in str(user_supplied_partition_params).split() + } + except IndexError: + logger.error( + "Error parsing partition-config. Please use KEY1=VALUE KEY2=VALUE." + ) + return + + # Validate the user supplied params are valid params. + for param in tmp_params: + if param not in [ + partition_param.name for partition_param in fields(Partition) + ]: + logger.error( + f"Invalid user supplied partition configuration parameter: {param}." + ) + return + + logger.debug(f"tmp_params={tmp_params}") + self._stored.user_supplied_partition_params = tmp_params + + if self._slurmd.is_joined: + self._slurmd.set_partition() + + def _on_configure_fluentbit(self, event: RelationCreatedEvent) -> None: """Set up Fluentbit log forwarding.""" self._configure_fluentbit() - def _configure_fluentbit(self): + def _configure_fluentbit(self) -> None: logger.debug("## Configuring fluentbit") - cfg = [] - cfg.extend(self._slurm_manager.fluentbit_config_nhc) - cfg.extend(self._slurm_manager.fluentbit_config_slurm) - self._fluentbit.configure(cfg) - - def _on_upgrade(self, event): - """Perform upgrade operations.""" - self.unit.set_workload_version(Path("version").read_text().strip()) + self._fluentbit.configure( + [ + self._slurm_manager.fluentbit_config_nhc(self.cluster_name, self.model.app), + self._slurm_manager.fluentbit_config_slurm(self.cluster_name, self.model.app), + ] + ) - def _on_update_status(self, event): + def _on_update_status(self, event: UpdateStatusEvent) -> None: """Handle update status.""" self._check_status() - def _check_status(self) -> bool: - """Check if we have all needed components. - - - partition name - - slurm installed - - slurmctld available and working - - munge key configured and working - """ - if not self._stored.slurm_installed: - self.unit.status = BlockedStatus("Error installing slurmd") - return False - - if not self._slurmd.is_joined: - self.unit.status = BlockedStatus("Need relations: slurmctld") - return False - - if not self._stored.slurmctld_available: - self.unit.status = WaitingStatus("Waiting on: slurmctld") - return False + def _on_slurmctld_available(self, event: SlurmctldAvailableEvent) -> None: + """Get data from slurmctld and send inventory.""" + if self._stored.slurm_installed is not True: + event.defer() + return - if not self._slurm_manager.check_munged(): - self.unit.status = BlockedStatus("Error configuring munge key") - return False + if (slurmctld_host := event.slurmctld_host) != self._stored.slurmctld_host: + slurmd.override_default(slurmctld_host) + self._stored.slurmctld_host = slurmctld_host + logger.debug(f"slurmctld_host={slurmctld_host}") - return True + if (munge_key := event.munge_key) != self._stored.munge_key: + self._stored.munge_key = munge_key + self._slurm_manager.configure_munge_key(munge_key) + logger.debug(f"munge_key={munge_key}") - def _set_slurmctld_available(self, flag: bool): - """Change stored value for slurmctld availability.""" - self._stored.slurmctld_available = flag + if (nhc_params := event.nhc_params) != self._stored.nhc_params: + self._stored.nhc_params = nhc_params + self._slurm_manager.render_nhc_wrapper(nhc_params) + logger.debug(f"nhc_params={nhc_params}") - def _on_slurmctld_available(self, event): - """Get data from slurmctld and send inventory.""" - if not self._stored.slurm_installed: - event.defer() - return + logger.debug( + "#### Storing slurmctld_available event relation data in charm StoredState." "" + ) + self._stored.slurmctld_available = True + self.cluster_name = event.cluster_name - logger.debug("#### Slurmctld available - setting overrides for configless") - self._set_slurmctld_available(True) - # Get slurmctld host:port from relation and override systemd services. - slurmd.override_default(self._slurmd.slurmctld_hostname, self._slurmd.slurmctld_port) - self._on_set_partition_info_on_app_relation_data(event) - self._write_munge_key_and_restart_munge() # Only set up fluentbit if we have a relation to it. if self._fluentbit._relation is not None: self._configure_fluentbit() + + if self._slurm_manager.restart_munged(): + logger.debug("## Munge restarted successfully") + else: + logger.error("## Unable to restart munge") + slurmd.restart() self._check_status() - def _on_slurmctld_unavailable(self, event): + def _on_slurmctld_unavailable(self, event) -> None: + """Stop slurmd and set slurmctld_available = False when we lose slurmctld.""" logger.debug("## Slurmctld unavailable") - self._set_slurmctld_available(False) + self._stored.slurmctld_available = False + self._stored.nhc_params = "" + self._stored.munge_key = "" + self._stored.slurmctld_host = "" slurmd.stop() self._check_status() @@ -183,109 +220,149 @@ def _on_slurmd_stopped(self, _: ServiceStoppedEvent) -> None: """Handle event emitted by systemd after slurmd daemon is stopped.""" self.unit.status = BlockedStatus("slurmd not running") - def _on_config_changed(self, event): - """Handle charm configuration changes.""" - if self.model.unit.is_leader(): - logger.debug("## slurmd config changed - leader") - self._on_set_partition_info_on_app_relation_data(event) - - nhc_conf = self.model.config.get("nhc-conf") - if nhc_conf: - if nhc_conf != self._stored.nhc_conf: - self._stored.nhc_conf = nhc_conf - self._slurm_manager.render_nhc_config(nhc_conf) - - def _write_munge_key_and_restart_munge(self): - logger.debug("#### slurmd charm - writing munge key") - - self._slurm_manager.configure_munge_key(self._slurmd.get_stored_munge_key()) - - if self._slurm_manager.restart_munged(): - logger.debug("## Munge restarted successfully") - else: - logger.error("## Unable to restart munge") - - def _on_version_action(self, event): - """Return version of installed components. - - - Slurm - - munge - """ - version = {} - version["slurm"] = self._slurm_manager.slurm_version() - version["munge"] = self._slurm_manager.munge_version() - - event.set_results(version) - def _on_node_configured_action(self, _: ActionEvent) -> None: """Remove node from DownNodes and mark as active.""" # Trigger reconfiguration of slurmd node. - self._slurmd.new_node = False + self._new_node = False + self._slurmd.set_node() slurmd.restart() logger.debug("### This node is not new anymore") - def _on_get_node_inventory_action(self, event): - """Return node inventory.""" - inventory = self._slurmd.node_inventory - logger.debug(f"### Node inventory: {inventory}") - - # Juju does not like underscores in dictionaries - inv = {k.replace("_", "-"): v for k, v in inventory.items()} - event.set_results(inv) - - def _on_set_node_inventory_action(self, event): - """Overwrite the node inventory.""" - inventory = self._slurmd.node_inventory - - # update local copy of inventory - memory = event.params.get("real-memory", inventory["real_memory"]) - inventory["real_memory"] = memory - - # send it to slurmctld - self._slurmd.node_inventory = inventory - - event.set_results({"real-memory": memory}) - - def _on_show_nhc_config(self, event): + def _on_show_nhc_config(self, event: ActionEvent) -> None: """Show current nhc.conf.""" nhc_conf = self._slurm_manager.get_nhc_config() event.set_results({"nhc.conf": nhc_conf}) - def _on_set_partition_info_on_app_relation_data(self, event): - """Set the slurm partition info on the application relation data.""" - # Only the leader can set data on the relation. - if self.model.unit.is_leader(): - # If the relation with slurmctld exists then set our - # partition info on the application relation data. - # This handler shouldn't fire if the relation isn't made, - # but add this extra check here just in case. - if self._slurmd.is_joined: - if partition := { - "partition_name": self.app.name, - "partition_config": self.config.get("partition-config"), - "partition_state": self.config.get("partition-state"), - }: - self._slurmd.set_partition_info_on_app_relation_data(partition) - else: - event.defer() - else: - event.defer() + def _on_node_config_action_event(self, event: ActionEvent) -> None: + """Get or set the user_supplied_node_conifg. + + Return the node config if the `node-config` parameter is not specified, otherwise + parse, validate, and store the input of the `node-config` parameter in stored state. + Lastly, update slurmctld if there are updates to the node config. + """ + valid_config = True + + if (node_config_input := event.params.get("node-config")) is not None: + + # Parse the user supplied node-config. + try: + node_config_tmp = { + item.split("=")[0]: item.split("=")[1] for item in node_config_input.split() + } + except IndexError: + logger.error("Incorrect node-config. Please use KEY1=VAL KEY2=VAL format.") + valid_config = False + + # Validate the user supplied params are valid params. + for param in node_config_tmp: + if param not in [node_param.name for node_param in fields(Node)]: + logger.error(f"Invalid user supplied node parameter: {param}.") + valid_config = False + + # Validate the user supplied params have valid keys. + for k, v in node_config_tmp.items(): + if v == "": + logger.error(f"Invalid user supplied node parameter: {k}={v}.") + valid_config = False + + if valid_config: + if (node_config := node_config_tmp) != self._user_supplied_node_config: + self._user_supplied_node_config = node_config + self._slurmd.set_node() + + event.set_results( + { + "node.config": " ".join( + [f"{k}={v}" for k, v in self.get_node()["node_config"].items()] + ), + "user-supplied-configuration-accepted": f"{valid_config}", + } + ) + # Charm class properties @property def hostname(self) -> str: """Return the hostname.""" - return self._slurm_manager.hostname + return socket.gethostname().split(".")[0] @property def cluster_name(self) -> str: """Return the cluster-name.""" - return self._stored.cluster_name + return f"{self._stored.cluster_name}" @cluster_name.setter - def cluster_name(self, name: str): + def cluster_name(self, name: str) -> None: """Set the cluster-name.""" self._stored.cluster_name = name + @property + def _user_supplied_node_config(self) -> dict[Any, Any]: + """Return the user_supplied_node_config from stored state.""" + return self._stored.user_supplied_node_config # type: ignore[return-value] + + @_user_supplied_node_config.setter + def _user_supplied_node_config(self, node_config: dict) -> None: + """Set the node_config in stored state.""" + self._stored.user_supplied_node_config = node_config + + @property + def _new_node(self) -> bool: + """Get the new_node from stored state.""" + return True if self._stored.new_node is True else False + + @_new_node.setter + def _new_node(self, new_node: bool) -> None: + """Set the new_node in stored state.""" + self._stored.new_node = new_node + + # Charm methods + def _check_status(self) -> bool: + """Check if we have all needed components. + + - slurmd installed + - slurmctld available and working + - munge key configured and working + """ + if self._stored.slurm_installed is not True: + self.unit.status = BlockedStatus("Error installing slurmd") + return False + + if self._slurmd.is_joined is not True: + self.unit.status = BlockedStatus("Need relations: slurmctld") + return False + + if self._stored.slurmctld_available is not True: + self.unit.status = WaitingStatus("Waiting on: slurmctld") + return False + + if not self._slurm_manager.check_munged(): + self.unit.status = BlockedStatus("Error configuring munge key") + return False + + logger.debug(f"Slurmctld joined: {self._slurmd.is_joined}") + logger.debug(f"Slurmctld available: {self._stored.slurmctld_available}") + return True + + def get_node(self) -> Dict[Any, Any]: + """Get the node from stored state.""" + node = {} + if binding := self.model.get_binding("slurmd"): + node = { + "node_config": { + **machine.get_inventory(self.hostname, f"{binding.network.ingress_address}"), + **self._user_supplied_node_config, + }, + "new_node": self._new_node, + } + logger.debug(f"Node Configuration: {node}") + return node + + def get_partition(self) -> Dict[Any, Any]: + """Return the partition.""" + partition = {self.app.name: {**{"State": "UP"}, **self._stored.user_supplied_partition_params}} # type: ignore[dict-item] + logger.debug(f"partition={partition}") + return partition + if __name__ == "__main__": # pragma: nocover - main(SlurmdCharm) + main.main(SlurmdCharm) diff --git a/src/interface_slurmd.py b/src/interface_slurmd.py index 3509c17..380a24e 100644 --- a/src/interface_slurmd.py +++ b/src/interface_slurmd.py @@ -2,16 +2,18 @@ """Slurmd.""" import json import logging +from typing import Union -from ops.framework import ( +from ops import ( EventBase, EventSource, Object, ObjectEvents, - StoredState, + Relation, + RelationBrokenEvent, + RelationChangedEvent, + RelationCreatedEvent, ) -from ops.model import Relation -from utils import machine logger = logging.getLogger(__name__) @@ -19,6 +21,37 @@ class SlurmctldAvailableEvent(EventBase): """Emitted when slurmctld is available.""" + def __init__( + self, + handle, + cluster_name, + munge_key, + nhc_params, + slurmctld_host, + ): + super().__init__(handle) + + self.cluster_name = cluster_name + self.munge_key = munge_key + self.nhc_params = nhc_params + self.slurmctld_host = slurmctld_host + + def snapshot(self): + """Snapshot the event data.""" + return { + "cluster_name": self.cluster_name, + "munge_key": self.munge_key, + "nhc_params": self.nhc_params, + "slurmctld_host": self.slurmctld_host, + } + + def restore(self, snapshot): + """Restore the snapshot of the event data.""" + self.cluster_name = snapshot.get("cluster_name") + self.munge_key = snapshot.get("munge_key") + self.nhc_params = snapshot.get("nhc_params") + self.slurmctld_host = snapshot.get("slurmctld_host") + class SlurmctldUnavailableEvent(EventBase): """Emit when the relation to slurmctld is broken.""" @@ -34,7 +67,6 @@ class SlurmdEvents(ObjectEvents): class Slurmd(Object): """Slurmd.""" - _stored = StoredState() on = SlurmdEvents() def __init__(self, charm, relation_name): @@ -43,24 +75,11 @@ def __init__(self, charm, relation_name): self._charm = charm self._relation_name = relation_name - self._stored.set_default( - munge_key=str(), - slurmctld_hostname=str(), - slurmctld_addr=str(), - slurmctld_port=str(), - nhc_params=str(), - ) - self.framework.observe( self._charm.on[self._relation_name].relation_created, self._on_relation_created, ) - self.framework.observe( - self._charm.on[self._relation_name].relation_joined, - self._on_relation_joined, - ) - self.framework.observe( self._charm.on[self._relation_name].relation_changed, self._on_relation_changed, @@ -71,152 +90,78 @@ def __init__(self, charm, relation_name): self._on_relation_broken, ) - def _on_relation_created(self, event): + def _on_relation_created(self, event: RelationCreatedEvent) -> None: """Handle the relation-created event. - Set the node inventory on the relation data. - """ - # Generate the inventory and set it on the relation data. - node_name = self._charm.hostname - node_addr = event.relation.data[self.model.unit]["ingress-address"] - - inv = machine.get_inventory(node_name, node_addr) - inv["new_node"] = True - self.node_inventory = inv - - def _on_relation_joined(self, event): - """Handle the relation-joined event. - - Get the munge_key, slurmctld_host and slurmctld_port, NHC params, - the cluster name from slurmctld and save it to the charm stored - state. + Set the node and partition config on the relation. """ - app_data = event.relation.data[event.app] - if not app_data.get("munge_key"): - event.defer() - return - - slurmctld_addr = event.relation.data[event.unit]["ingress-address"] - # slurmctld sets the munge_key on the relation-created event - # which happens before relation-joined. We can guarantee that - # the munge_key, slurmctld_host and slurmctld_port will exist - # at this point so retrieve them from the relation data and store - # them in the charm's stored state. - self._store_munge_key(app_data["munge_key"]) - self._store_slurmctld_host_port( - app_data["slurmctld_host"], app_data["slurmctld_port"], slurmctld_addr - ) - - self._charm.cluster_name = app_data.get("cluster_name") + self.set_node() - self._store_nhc_params(app_data.get("nhc_params")) + if self.model.unit.is_leader(): + self.set_partition() - self.on.slurmctld_available.emit() + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Handle the relation-changed event. - def _on_relation_changed(self, event): - """Perform relation changed operations. + Get the cluster_info from slurmctld and emit the slurmctld_available event. - Possible scenarios: - - nhc parameters changed - - tls parameters changed + Ensure all cases are accounted for: + - no application in event + - no application data in relation + - no cluster_info in application relation data + - application exists in event, and application data exists on relation, cluster_info + exists in application relation data """ - app_data = event.relation.data[event.app] - self._store_nhc_params(app_data.get("nhc_params")) + if app := event.app: + if app_data := event.relation.data.get(app): + if cluster_info_json := app_data.get("cluster_info"): + try: + cluster_info = json.loads(cluster_info_json) + except json.JSONDecodeError as e: + logger.error(e) + raise (e) + + logger.debug(f"cluster_info: {cluster_info}") + self.on.slurmctld_available.emit(**cluster_info) + else: + logger.debug( + f"No cluster_info in application data, deferring {self._relation_name}" + ) + event.defer() + else: + logger.debug("No application data on relation.") + else: + logger.debug("No application on the event.") - def _on_relation_broken(self, event): - """Perform relation broken operations.""" + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Emit slurmctld_unavailable when the relation-broken event occurs.""" self.on.slurmctld_unavailable.emit() - @property - def _relation(self) -> Relation: - """Return the relation.""" - return self.framework.model.get_relation(self._relation_name) - - @property - def is_joined(self) -> bool: - """Return True if relation is joined.""" - if self._charm.framework.model.relations.get(self._relation_name): - return True + def set_node(self) -> None: + """Set the node on the unit data.""" + if relation := self._relation: + relation.data[self.model.unit]["node"] = json.dumps(self._charm.get_node()) else: - return False + logger.debug("No relation, cannot set 'node'.") - @property - def slurmctld_hostname(self) -> str: - """Get slurmctld hostname.""" - return self._stored.slurmctld_hostname - - @property - def slurmctld_port(self) -> str: - """Get slurmctld port.""" - return self._stored.slurmctld_port - - @property - def node_inventory(self) -> dict: - """Return unit inventory.""" - return json.loads(self._relation.data[self.model.unit]["inventory"]) - - @node_inventory.setter - def node_inventory(self, inventory: dict): - """Set unit inventory.""" - self._relation.data[self.model.unit]["inventory"] = json.dumps(inventory) - - @property - def new_node(self) -> bool: - """Get `new_node` value in integration data.""" - return self.node_inventory["new_node"] - - @new_node.setter - def new_node(self, value: bool) -> None: - """Update `new_node` field in integration data.""" - inv = self.node_inventory - inv["new_node"] = value - self.node_inventory = inv - - def set_partition_info_on_app_relation_data(self, partition_info): + def set_partition(self) -> None: """Set the slurmd partition on the app relation data. Setting data on the application relation forces the units of related slurmctld application(s) to observe the relation-changed event so they can acquire and redistribute the updated slurm config. """ - # there is only one slurmctld, so there should be only one relation here - relations = self._charm.framework.model.relations["slurmd"] - for relation in relations: - relation.data[self.model.app]["partition_info"] = json.dumps(partition_info) - - def _store_munge_key(self, munge_key: str): - """Store the munge_key in the StoredState.""" - self._stored.munge_key = munge_key - - def _store_nhc_params(self, params: str): - """Store the NHC params.""" - if params != self._stored.nhc_params: - self._stored.nhc_params = params - - logger.debug(f"## rendering /usr/sbin/omni-nhc-wrapper: {params}") - self._charm._slurm_manager.render_nhc_wrapper(params) + if relation := self._relation: + relation.data[self.model.app]["partition"] = json.dumps(self._charm.get_partition()) + else: + logger.debug("No relation, cannot set 'partition'.") @property - def slurmctld_address(self) -> str: - """Get slurmctld IP address.""" - return self._stored.slurmctld_addr - - @slurmctld_address.setter - def slurmctld_address(self, addr: str): - """Set slurmctld IP address.""" - self._stored.slurmctld_addr = addr - - def _store_slurmctld_host_port(self, host: str, port: str, addr: str): - """Store the hostname, port and IP of slurmctld in StoredState.""" - if host != self._stored.slurmctld_hostname: - self._stored.slurmctld_hostname = host - - if port != self._stored.slurmctld_port: - self._stored.slurmctld_port = port - - if addr != self.slurmctld_address: - self.slurmctld_address = addr + def _relation(self) -> Union[Relation, None]: + """Return the relation.""" + return self.model.get_relation(self._relation_name) - def get_stored_munge_key(self) -> str: - """Retrieve the munge_key from the StoredState.""" - return self._stored.munge_key + @property + def is_joined(self) -> bool: + """Return True if relation is joined.""" + return True if self.model.relations.get(self._relation_name) else False diff --git a/src/slurm_conf_editor.py b/src/slurm_conf_editor.py new file mode 100644 index 0000000..ae8d6c8 --- /dev/null +++ b/src/slurm_conf_editor.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +"""Slurm Node Options.""" +from dataclasses import dataclass + + +@dataclass +class Partition: + """A slurm partition.""" + + PartitionName: str + AllocNodes: str + AllowAccounts: str + AllowGroups: str + AllowQos: str + Alternate: str + CpuBind: str + Default: str + DefaultTime: str + DefCpuPerGPU: str + DefMemPerCPU: str + DefMemPerGPU: str + DefMemPerNode: str + DenyAccounts: str + DenyQos: str + DisableRootJobs: str + ExclusiveUser: str + GraceTime: str + Hidden: str + LLN: str + MaxCPUsPerNode: str + MaxCPUsPerSocket: str + MaxMemPerCPU: str + MaxMemPerNode: str + MaxNodes: str + MaxTime: str + MinNodes: str + Nodes: str + OverSubscribe: str + OverTimeLimit: str + PowerDownOnIdle: str + PreemptMode: str + PriorityJobFactor: str + PriorityTier: str + QOS: str + ReqResv: str + ResumeTimeout: str + RootOnly: str + SelectTypeParameters: str + State: str + SuspendTime: str + SuspendTimeout: str + TRESBillingWeights: str + + +@dataclass +class Node: + """A slurm node entry.""" + + NodeName: str + NodeHostname: str + NodeAddr: str + BcastAddr: str + Boards: str + CoreSpecCount: str + CoresPerSocket: str + CpuBind: str + CPUs: str + CpuSpecList: str + Features: str + Gres: str + MemSpecLimit: str + Port: str + Procs: str + RealMemory: str + Reason: str + Sockets: str + SocketsPerBoard: str + State: str + ThreadsPerCore: str + TmpDisk: str + Weight: str diff --git a/src/slurmd_ops.py b/src/slurmd_ops.py new file mode 100644 index 0000000..f373198 --- /dev/null +++ b/src/slurmd_ops.py @@ -0,0 +1,422 @@ +# Copyright 2024 Omnivector, LLC. +# See LICENSE file for licensing details. +"""SlurmManager.""" + +import logging +import os +import shlex +import subprocess +import textwrap +from base64 import b64decode +from pathlib import Path +from shutil import rmtree + +import charms.operator_libs_linux.v0.apt as apt # type: ignore [import-untyped] +import charms.operator_libs_linux.v1.systemd as systemd # type: ignore [import-untyped] +import distro + +logger = logging.getLogger() + + +TEMPLATE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) / "templates" + + +class SlurmdException(BaseException): + """SlurmdException.""" + + def __init__(self, msg): + pass + + +SLURM_PPA_KEY: str = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Comment: Hostname: +Version: Hockeypuck 2.1.1-10-gec3b0e7 + +xsFNBGTuZb8BEACtJ1CnZe6/hv84DceHv+a54y3Pqq0gqED0xhTKnbj/E2ByJpmT +NlDNkpeITwPAAN1e3824Me76Qn31RkogTMoPJ2o2XfG253RXd67MPxYhfKTJcnM3 +CEkmeI4u2Lynh3O6RQ08nAFS2AGTeFVFH2GPNWrfOsGZW03Jas85TZ0k7LXVHiBs +W6qonbsFJhshvwC3SryG4XYT+z/+35x5fus4rPtMrrEOD65hij7EtQNaE8owuAju +Kcd0m2b+crMXNcllWFWmYMV0VjksQvYD7jwGrWeKs+EeHgU8ZuqaIP4pYHvoQjag +umqnH9Qsaq5NAXiuAIAGDIIV4RdAfQIR4opGaVgIFJdvoSwYe3oh2JlrLPBlyxyY +dayDifd3X8jxq6/oAuyH1h5K/QLs46jLSR8fUbG98SCHlRmvozTuWGk+e07ALtGe +sGv78ToHKwoM2buXaTTHMwYwu7Rx8LZ4bZPHdersN1VW/m9yn1n5hMzwbFKy2s6/ +D4Q2ZBsqlN+5aW2q0IUmO+m0GhcdaDv8U7RVto1cWWPr50HhiCi7Yvei1qZiD9jq +57oYZVqTUNCTPxi6NeTOdEc+YqNynWNArx4PHh38LT0bqKtlZCGHNfoAJLPVYhbB +b2AHj9edYtHU9AAFSIy+HstET6P0UDxy02IeyE2yxoUBqdlXyv6FL44E+wARAQAB +zRxMYXVuY2hwYWQgUFBBIGZvciBVYnVudHUgSFBDwsGOBBMBCgA4FiEErocSHcPk +oLD4H/Aj9tDF1ca+s3sFAmTuZb8CGwMFCwkIBwIGFQoJCAsCBBYCAwECHgECF4AA +CgkQ9tDF1ca+s3sz3w//RNawsgydrutcbKf0yphDhzWS53wgfrs2KF1KgB0u/H+u +6Kn2C6jrVM0vuY4NKpbEPCduOj21pTCepL6PoCLv++tICOLVok5wY7Zn3WQFq0js +Iy1wO5t3kA1cTD/05v/qQVBGZ2j4DsJo33iMcQS5AjHvSr0nu7XSvDDEE3cQE55D +87vL7lgGjuTOikPh5FpCoS1gpemBfwm2Lbm4P8vGOA4/witRjGgfC1fv1idUnZLM +TbGrDlhVie8pX2kgB6yTYbJ3P3kpC1ZPpXSRWO/cQ8xoYpLBTXOOtqwZZUnxyzHh +gM+hv42vPTOnCo+apD97/VArsp59pDqEVoAtMTk72fdBqR+BB77g2hBkKESgQIEq +EiE1/TOISioMkE0AuUdaJ2ebyQXugSHHuBaqbEC47v8t5DVN5Qr9OriuzCuSDNFn +6SBHpahN9ZNi9w0A/Yh1+lFfpkVw2t04Q2LNuupqOpW+h3/62AeUqjUIAIrmfeML +IDRE2VdquYdIXKuhNvfpJYGdyvx/wAbiAeBWg0uPSepwTfTG59VPQmj0FtalkMnN +ya2212K5q68O5eXOfCnGeMvqIXxqzpdukxSZnLkgk40uFJnJVESd/CxHquqHPUDE +fy6i2AnB3kUI27D4HY2YSlXLSRbjiSxTfVwNCzDsIh7Czefsm6ITK2+cVWs0hNQ= +=cs1s +-----END PGP PUBLIC KEY BLOCK----- +""" + + +class Slurmd: + """Facilitate slurmd package lifecycle ops.""" + + _package_name: str = "slurmd" + _keyring_path: Path = Path("/usr/share/keyrings/slurm-wlm.asc") + + def _repo(self) -> apt.DebianRepository: + """Return the ubuntu-hpc slurm-wlm repo.""" + ppa_url: str = "https://ppa.launchpadcontent.net/ubuntu-hpc/slurm-wlm-23.02/ubuntu" + sources_list: str = ( + f"deb [signed-by={self._keyring_path}] {ppa_url} {distro.codename()} main" + ) + return apt.DebianRepository.from_repo_line(sources_list) + + def install(self) -> bool: + """Install the slurmd package using lib apt.""" + # Install the key. + if self._keyring_path.exists(): + self._keyring_path.unlink() + self._keyring_path.write_text(SLURM_PPA_KEY) + + # Add the repo. + repositories = apt.RepositoryMapping() + repositories.add(self._repo()) + + # Install the slurmd, slurm-client packages. + slurmd_installed = False + try: + # Run `apt-get update` + apt.update() + apt.add_package(["mailutils", "logrotate"]) + apt.add_package([self._package_name, "slurm-client"]) + slurmd_installed = True + except apt.PackageNotFoundError: + logger.error(f"{self._package_name} not found in package cache or on system") + except apt.PackageError as e: + logger.error(f"Could not install {self._package_name}. Reason: %s", e.message) + return slurmd_installed + + def uninstall(self) -> None: + """Uninstall the slurmd package using libapt.""" + # Uninstall the slurmd package. + if apt.remove_package(self._package_name): + logger.info(f"{self._package_name} removed from system.") + else: + logger.error(f"{self._package_name} not found on system") + + # Disable the ubuntu-hpc repo. + repositories = apt.RepositoryMapping() + repositories.disable(self._repo()) + + # Remove the key. + if self._keyring_path.exists(): + self._keyring_path.unlink() + + def version(self) -> str: + """Return the slurmd version.""" + slurmd_version = "" + try: + slurmd_version = apt.DebianPackage.from_installed_package( + self._package_name + ).version.number + except apt.PackageNotFoundError: + logger.error(f"{self._package_name} not found on system") + return slurmd_version + + +class SlurmdManager: + """SlurmdManager.""" + + @property + def _munge_key_path(self) -> Path: + """Return the full path to the munge key.""" + return Path("/etc/munge/munge.key") + + @property + def _munged_systemd_service(self) -> str: + """Return the name of the Munge Systemd unit file.""" + return "munge.service" + + @property + def _slurmd_user(self) -> str: + """Return the slurmd user.""" + return "root" + + @property + def _slurmd_group(self) -> str: + """Return the slurmd group.""" + return "root" + + def install(self) -> bool: + """Install slurmd to the system. + + Returns: + bool: True on success, False otherwise. + """ + slurmd_installed = False + nhc_installed = False + if installed_slurmd := Slurmd().install(): + slurmd_installed = installed_slurmd + systemd.service_stop("slurmd") + systemd.service_stop("munge") + + if installed_nhc := self._install_nhc_from_tarball(): + nhc_installed = installed_nhc + self.render_nhc_config() + + logger.debug(f"NHC installed: {nhc_installed}") + logger.debug(f"slurmd installed: {slurmd_installed}") + return nhc_installed and slurmd_installed + + def write_munge_key(self, munge_key: str) -> None: + """Base64 decode and write the munge key.""" + key = b64decode(munge_key.encode()) + self._munge_key_path.write_bytes(key) + + def version(self) -> str: + """Return slurm version.""" + return Slurmd().version() + + # Fluentbit + def fluentbit_config_nhc(self, cluster_name: str, app_name: str) -> list: + """Return Fluentbit configuration parameters to forward NHC logs.""" + cfg = [ + { + "input": [ + ("name", "tail"), + ("path", "/var/log/nhc.log"), + ("path_key", "filename"), + ("tag", "nhc"), + ("multiline.parser", "nhc"), + ] + }, + { + "multiline_parser": [ + ("name", "nhc"), + ("type", "regex"), + ("flush_timeout", "1000"), + ("rule", '"start_state"', '"/^([\d]{8} [\d:]*) (.*)/"', '"cont"'), # noqa + ("rule", '"cont"', r'"/^([^\d].*)/"', '"cont"'), + ] + }, # noqa + { + "filter": [ + ("name", "record_modifier"), + ("match", "nhc"), + ("record", "hostname ${HOSTNAME}"), + ("record", f"cluster-name {cluster_name}"), + ("record", "service nhc"), + ("record", f"partition-name {app_name}"), + ] + }, + ] + return cfg + + def fluentbit_config_slurm(self, cluster_name: str, app_name: str) -> list: + """Return Fluentbit configuration parameters to forward Slurm logs.""" + log_file = self._slurmd_log_file + + cfg = [ + { + "input": [ + ("name", "tail"), + ("path", log_file.as_posix()), + ("path_key", "filename"), + ("tag", "slurmd"), + ("parser", "slurm"), + ] + }, + { + "parser": [ + ("name", "slurm"), + ("format", "regex"), + ("regex", r"^\[(?