Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(slurm_ops): implement initial jwt key manager #34

Merged
merged 6 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ slurmutils ~= 0.7.0
python-dotenv ~= 1.0.1
pyyaml >= 6.0.2
distro ~=1.9.0
cryptography ~= 43.0.1

# tests deps
coverage[toml] ~= 7.6
Expand Down
67 changes: 64 additions & 3 deletions lib/charms/hpc_libs/v0/slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _on_install(self, _) -> None:
import distro
import dotenv
import yaml
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from slurmutils.editors import cgroupconfig, slurmconfig, slurmdbdconfig
from slurmutils.models import CgroupConfig, SlurmConfig, SlurmdbdConfig

Expand All @@ -96,7 +98,13 @@ def _on_install(self, _) -> None:
LIBPATCH = 7

# Charm library dependencies to fetch during `charmcraft pack`.
PYDEPS = ["pyyaml>=6.0.2", "python-dotenv~=1.0.1", "slurmutils~=0.7.0", "distro~=1.9.0"]
PYDEPS = [
"cryptography~=43.0.1",
"pyyaml>=6.0.2",
"python-dotenv~=1.0.1",
"slurmutils~=0.7.0",
"distro~=1.9.0",
]

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -391,6 +399,11 @@ def version(self) -> str:
def etc_path(self) -> Path:
"""Get the path to the Slurm configuration directory."""

@property
@abstractmethod
def var_lib_path(self) -> Path:
"""Get the path to the Slurm variable state data directory."""

@abstractmethod
def service_manager_for(self, type: _ServiceType) -> _ServiceManager:
"""Return the `ServiceManager` for the specified `ServiceType`."""
Expand All @@ -405,9 +418,13 @@ class _SnapManager(_OpsManager):

def install(self) -> None:
"""Install Slurm using the `slurm` snap."""
# FIXME: Pin slurm to the stable channel
# TODO: https://github.com/charmed-hpc/hpc-libs/issues/35 -
# Pin Slurm snap to stable channel.
_snap("install", "slurm", "--channel", "latest/candidate", "--classic")
# FIXME: Request automatic alias for `mungectl` so that we don't need to do this manually
# TODO: https://github.com/charmed-hpc/slurm-snap/issues/49 -
# Request automatic alias for the Slurm snap so we don't need to do it here.
# We will possibly need to account for a third-party Slurm snap installation
# where aliasing is not automatically performed.
_snap("alias", "slurm.mungectl", "mungectl")

def version(self) -> str:
Expand All @@ -424,6 +441,11 @@ def etc_path(self) -> Path:
"""Get the path to the Slurm configuration directory."""
return Path("/var/snap/slurm/common/etc/slurm")

@property
def var_lib_path(self) -> Path:
"""Get the path to the Slurm variable state data directory."""
return Path("/var/snap/slurm/common/var/lib/slurm")

def service_manager_for(self, type: _ServiceType) -> _ServiceManager:
"""Return the `ServiceManager` for the specified `ServiceType`."""
return _SnapServiceManager(type)
Expand Down Expand Up @@ -549,6 +571,9 @@ def install(self) -> None:
raise SlurmOpsError(f"failed to install {self._service_name}. reason: {e}")

self._env_file.touch(exist_ok=True)
# Debian package postinst hook does not create a `StateSaveLocation` directory
# so we make one here that is only r/w by owner.
Path("/var/lib/slurm/slurm.state").mkdir(mode=0o600, exist_ok=True)

if self._service_name == "slurmd":
override = Path("/etc/systemd/system/slurmd.service.d/10-slurmd-conf-server.conf")
Expand All @@ -575,6 +600,11 @@ def etc_path(self) -> Path:
"""Get the path to the Slurm configuration directory."""
return Path("/etc/slurm")

@property
def var_lib_path(self) -> Path:
"""Get the path to the Slurm variable state data directory."""
return Path("/var/lib/slurm")

def service_manager_for(self, type: _ServiceType) -> _ServiceManager:
"""Return the `ServiceManager` for the specified `ServiceType`."""
return _SystemctlServiceManager(type)
Expand All @@ -584,6 +614,36 @@ def _env_manager_for(self, type: _ServiceType) -> _EnvManager:
return _EnvManager(file=self._env_file, prefix=type.value)


# TODO: https://github.com/charmed-hpc/hpc-libs/issues/36 -
# Use `jwtctl` to provide backend for generating, setting, and getting
# jwt signing key used by `slurmctld` and `slurmdbd`. This way we also
# won't need to pass the keyfile path to the `__init__` constructor.
class _JWTKeyManager:
"""Control the jwt signing key used by Slurm."""

def __init__(self, ops_manager: _OpsManager) -> None:
self._keyfile = ops_manager.var_lib_path / "slurm.state/jwt_hs256.key"

def get(self) -> str:
"""Get the current jwt key."""
return self._keyfile.read_text()

def set(self, key: str) -> None:
"""Set a new jwt key."""
self._keyfile.write_text(key)

def generate(self) -> None:
"""Generate a new, cryptographically secure jwt key."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
self.set(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode()
)


class _MungeKeyManager:
"""Control the munge key via `mungectl ...` commands."""

Expand Down Expand Up @@ -633,6 +693,7 @@ def __init__(self, service: _ServiceType, snap: bool = False) -> None:
self._ops_manager = _SnapManager() if snap else _AptManager(service)
self.service = self._ops_manager.service_manager_for(service)
self.munge = _MungeManager(self._ops_manager)
self.jwt = _JWTKeyManager(self._ops_manager)
self.exporter = _PrometheusExporterManager(self._ops_manager)
self.install = self._ops_manager.install
self.version = self._ops_manager.version
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import subprocess
import textwrap
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

Expand All @@ -25,6 +26,34 @@

MUNGEKEY = b"1234567890"
MUNGEKEY_BASE64 = base64.b64encode(MUNGEKEY)
JWT_KEY = """-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAt3PLWkwUOeckDwyMpHgGqmOZhitC8KfOQY/zPWfo+up5RQXz
gVWqsTIt1RWynxIwCGeKYfVlhoKNDEDL1ZjYPcrrGBgMEC8ifqxkN4RC8bwwaGrJ
9Zf0kknPHI5AJ9Fkv6EjgAZW1lwV0uEE5kf0wmlgfThXfzwwGVHVwemE1EgUzdI/
rVxFP5Oe+mRM7kWdtXQrfizGhfmr8laCs+dgExpPa37mk7u/3LZfNXXSWYiaNtie
vax5BxmI4bnTIXxdTT4VP9rMxG8nSspVj5NSWcplKUANlIkMKiO7k/CCD/YzRzM0
0yZttiTvECG+rKy+KJd97dbtj6wSvbJ7cjfq2wIDAQABAoIBACNTfPkqZUqxI9Ry
CjMxmbb97vZTJlTJO4KMgb51X/vRYwDToIxrPq9YhlLeFsNi8TTtG0y5wI8iXJ7b
a2T6RcnAZX0CRHBpYy8Za0L1iR6bqoaw6asNU99Hr0ZEbj48qDXuhbOFhPtKSDmP
cy4U9SDqwdXbH540rN5zT8JDgXyPAVJpwgsShk7rhgOFGIPIZqQoxEjPV3jr1sbk
k7c39fJR6Kxywppn7flSmNX3v1LDu4NDIp0Llt1NlcKlbdy5XWEW9IbiIYi3JTpB
kMpkFQFIuUyledeFyVFPsP8O7Da2rZS6Fb1dYNWzh3WkDRiAwYgTspiYiSf4AAi4
TgrOmiECgYEA312O5bXqXOapU+S2yAFRTa8wkZ1iRR2E66NypZKVsv/vfe0bO+WQ
kI6MRmTluvOKsKe3JulJZpjbl167gge45CHnFPZxEODAJN6OYp+Z4aOvTYBWQPpO
A75AGSheL66PWe4d+ZGvxYCZB5vf4THAs8BsGlFK04RKL1vHADkUjHUCgYEA0kFh
2ei/NP8ODrwygjrpjYSc2OSH9tBUoB7y5zIfLsXshb3Fn4pViF9vl01YkJJ57kki
KQm7rgqCsFnKS4oUFbjDDFbo351m1e3XRbPAATIiqtJmtLoLoSWuhXpsCbneM5bB
xLhFmm8RcFC6ORPBE2WMTGYzTEKydhImvUo+8A8CgYEAssWpyjaoRgSjP68Nj9Rm
Izv1LoZ9kX3H1eUyrEw/Hk3ze6EbK/xXkStWID0/FTs5JJyHXVBX3BK5plQ+1Rqj
I4vy7Hc2FWEcyCWMZmkA+3RLqUbvQgBUEnDh0oDZqWYX+802FnpA6V08nbdnH1D3
v6Zhn0qzDcmSqobVJluJE8UCgYB93FO1/QSQtel1WqUlnhx28Z5um4bkcVtnKn+f
dDqEZkiq2qn1UfrXksGbIdrVWEmTIcZIKKJnkbUf2fAl/fb99ccUmOX4DiIkB6co
+2wBi0CDX0XKA+C4S3VIQ7tuqwvfd+xwVRqdUsVupXSEfFXExbIRfdBRY0+vLDhy
cYJxcwKBgQCK+dW+F0UJTQq1rDxfI0rt6yuRnhtSdAq2+HbXNx/0nwdLQg7SubWe
1QnLcdjnBNxg0m3a7S15nyO2xehvB3rhGeWSfOrHYKJNX7IUqluVLJ+lIwgE2eAz
94qOCvkFCP3pnm/MKN6/rezyOzrVJn7GbyDhcjElu+DD+WRLjfxiSw==
-----END RSA PRIVATE KEY-----
"""
SLURM_INFO = """
name: slurm
summary: "Slurm: A Highly Scalable Workload Manager"
Expand Down Expand Up @@ -132,6 +161,11 @@ class SlurmOpsBase:
def setUp(self):
self.setUpPyfakefs()
self.fs.create_file("/var/snap/slurm/common/.env")
self.fs.create_file("/var/snap/slurm/common/var/lib/slurm/slurm.state/jwt_hs256.key")
self.manager.jwt._keyfile = Path(
"/var/snap/slurm/common/var/lib/slurm/slurm.state/jwt_hs256.key"
)
self.manager.jwt._keyfile.write_text(JWT_KEY)

def test_config_name(self, *_) -> None:
"""Test that the config name is correctly set."""
Expand Down Expand Up @@ -203,6 +237,20 @@ def test_configure_munge(self, *_) -> None:
self.manager.munge.max_thread_count = 24
self.assertEqual(self.manager.munge.max_thread_count, 24)

def test_get_jwt_key(self, *_) -> None:
"""Test that the jwt key is properly retrieved."""
self.assertEqual(self.manager.jwt.get(), JWT_KEY)

def test_set_jwt_key(self, *_) -> None:
"""Test that the jwt key is set correctly."""
self.manager.jwt.set(JWT_KEY)
self.assertEqual(self.manager.jwt.get(), JWT_KEY)

def test_generate_jwt_key(self, *_) -> None:
"""Test that the jwt key is properly generated."""
self.manager.jwt.generate()
self.assertNotEqual(self.manager.jwt.get(), JWT_KEY)

@patch("charms.hpc_libs.v0.slurm_ops.socket.gethostname")
def test_hostname(self, gethostname, *_) -> None:
"""Test that manager is able to correctly get the host name."""
Expand Down