From 5ec1110dfb51dbd94486022618bf6b1355069079 Mon Sep 17 00:00:00 2001 From: Matt Pryor Date: Wed, 8 Nov 2023 12:34:05 +0000 Subject: [PATCH] Changes to SSHD for Pydantic v2 --- sshd/requirements.txt | 22 +++++----- sshd/zenith/sshd/config.py | 12 +++--- sshd/zenith/sshd/tunnel.py | 82 ++++++++++++++++++++++++-------------- 3 files changed, 70 insertions(+), 46 deletions(-) diff --git a/sshd/requirements.txt b/sshd/requirements.txt index dad47c4b..f5c8ca02 100644 --- a/sshd/requirements.txt +++ b/sshd/requirements.txt @@ -1,13 +1,15 @@ -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -click==8.1.3 -configomatic @ git+https://github.com/stackhpc/configomatic.git@8b81e8f216762b2e1664f60681c4a8e618ab151d -cryptography==41.0.1 +annotated-types==0.6.0 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +configomatic @ git+https://github.com/stackhpc/configomatic.git@c485afefb9850430012e1526e5339c31b1ecee33 +cryptography==41.0.5 idna==3.4 pycparser==2.21 -pydantic==1.10.8 -PyYAML==6.0 +pydantic==2.4.2 +pydantic_core==2.10.1 +PyYAML==6.0.1 requests==2.31.0 -typing_extensions==4.6.3 -urllib3==2.0.2 +typing_extensions==4.8.0 +urllib3==2.0.7 diff --git a/sshd/zenith/sshd/config.py b/sshd/zenith/sshd/config.py index 40af6a49..a2037229 100644 --- a/sshd/zenith/sshd/config.py +++ b/sshd/zenith/sshd/config.py @@ -13,15 +13,15 @@ def default_service_host(): return socket.gethostbyname(socket.gethostname()) -class SSHDConfig(Configuration): +class SSHDConfig( + Configuration, + default_path = "/etc/zenith/sshd.yaml", + path_env_var = "ZENITH_SSHD_CONFIG", + env_prefix = "ZENITH_SSHD" +): """ Configuration model for the zenith-sshd package. """ - class Config: - default_path = "/etc/zenith/sshd.yaml" - path_env_var = "ZENITH_SSHD_CONFIG" - env_prefix = "ZENITH_SSHD" - #: The logging configuration logging: LoggingConfiguration = Field(default_factory = LoggingConfiguration) diff --git a/sshd/zenith/sshd/tunnel.py b/sshd/zenith/sshd/tunnel.py index bf3a04fb..9c908d28 100755 --- a/sshd/zenith/sshd/tunnel.py +++ b/sshd/zenith/sshd/tunnel.py @@ -14,12 +14,14 @@ from pydantic import ( BaseModel, - Extra, + TypeAdapter, Field, - AnyHttpUrl, + AfterValidator, + AnyHttpUrl as PyAnyHttpUrl, conint, constr, - validator + field_validator, + ValidationInfo ) import requests @@ -52,14 +54,17 @@ class TunnelExit(RuntimeError): #: Type for an RFC3986 compliant URL path component UrlPath = constr(pattern =r"/[a-zA-Z0-9._~!$&'()*+,;=:@%/-]*", min_length = 1) +#: Type for a string that is validated as a URL +AnyHttpUrl = typing.Annotated[ + str, + AfterValidator(lambda v: TypeAdapter(PyAnyHttpUrl).validate_python(v)) +] -class ClientConfig(BaseModel): + +class ClientConfig(BaseModel, extra = "forbid"): """ Object for validating the client configuration. """ - class Config: - extra = Extra.forbid - #: The port for the service (the tunnel port) allocated_port: int #: The backend protocol @@ -71,20 +76,29 @@ class Config: #: The URL of the OIDC issuer to use auth_oidc_issuer: typing.Optional[AnyHttpUrl] = None #: The OIDC client ID to use - auth_oidc_client_id: typing.Optional[constr(min_length = 1)] = None + auth_oidc_client_id: typing.Optional[constr(min_length = 1)] = Field( + None, + validate_default = True + ) #: The OIDC client secret to use - auth_oidc_client_secret: typing.Optional[constr(min_length = 1)] = None + auth_oidc_client_secret: typing.Optional[constr(min_length = 1)] = Field( + None, + validate_default = True + ) #: The OIDC groups that are allowed access to the the service #: The user must have at least one of these groups in their groups claim auth_oidc_allowed_groups: typing.List[AllowedGroup] = Field(default_factory = list) #: Parameters for the external authentication service (deprecated name) auth_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field(default_factory = dict) #: Parameters for the external authentication service - auth_external_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field(default_factory = dict) + auth_external_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field( + default_factory = dict, + validate_default = True + ) #: Base64-encoded TLS certificate to use tls_cert: typing.Optional[str] = None #: Base64-encoded TLS private key to use (corresponds to TLS cert) - tls_key: typing.Optional[str] = None + tls_key: typing.Optional[str] = Field(None, validate_default = True) #: Base64-encoded CA for validating TLS client certificates, if required tls_client_ca: typing.Optional[str] = None #: An optional liveness path @@ -94,7 +108,8 @@ class Config: #: The number of liveness checks that can fail before the tunnel is considered unhealthy liveness_failures: conint(gt = 0) = 3 - @validator("allocated_port", always = True) + @field_validator("allocated_port") + @classmethod def validate_port(cls, v): """ Validate the given input as a port. @@ -115,15 +130,17 @@ def validate_port(cls, v): else: raise ValueError("Given port is not in use") - @validator("auth_external_params", pre = True, always = True) - def validate_auth_external_params(cls, v, values, **kwargs): + @field_validator("auth_external_params", mode = "before") + @classmethod + def validate_auth_external_params(cls, v, info: ValidationInfo): """ Makes sure that the old name for external auth params is respected. """ - return v or values.get("auth_params", {}) + return v or info.data.get("auth_params", {}) - @validator("auth_oidc_issuer") - def validate_auth_oidc_issuer(cls, v, values, **kwargs): + @field_validator("auth_oidc_issuer") + @classmethod + def validate_auth_oidc_issuer(cls, v): """ Validates that the OIDC issuer supports discovery. """ @@ -134,29 +151,32 @@ def validate_auth_oidc_issuer(cls, v, values, **kwargs): else: raise ValueError("OIDC issuer does not support discovery") - @validator("auth_oidc_client_id", always = True) - def validate_auth_oidc_client_id(cls, v, values, **kwargs): + @field_validator("auth_oidc_client_id") + @classmethod + def validate_auth_oidc_client_id(cls, v, info: ValidationInfo): """ Validates that an OIDC client id is given when an OIDC issuer is present. """ - skip_auth = values.get("skip_auth", False) - oidc_issuer = values.get("auth_oidc_issuer") + skip_auth = info.data.get("skip_auth", False) + oidc_issuer = info.data.get("auth_oidc_issuer") if not skip_auth and oidc_issuer and not v: raise ValueError("required for OIDC authentication") return v - @validator("auth_oidc_client_secret", always = True) - def validate_auth_oidc_client_secret(cls, v, values, **kwargs): + @field_validator("auth_oidc_client_secret") + @classmethod + def validate_auth_oidc_client_secret(cls, v, info: ValidationInfo): """ Validates that a client secret is given when a client ID is present. """ - skip_auth = values.get("skip_auth", False) - oidc_issuer = values.get("auth_oidc_issuer") + skip_auth = info.data.get("skip_auth", False) + oidc_issuer = info.data.get("auth_oidc_issuer") if not skip_auth and oidc_issuer and not v: raise ValueError("required for OIDC authentication") return v - @validator("tls_cert") + @field_validator("tls_cert") + @classmethod def validate_tls_cert(cls, v): """ Validate the given value decoding it and trying to load it as a @@ -165,20 +185,22 @@ def validate_tls_cert(cls, v): _ = load_pem_x509_certificate(base64.b64decode(v)) return v - @validator("tls_key", always = True) - def validate_tls_key(cls, v, values, **kwargs): + @field_validator("tls_key") + @classmethod + def validate_tls_key(cls, v, info: ValidationInfo): """ Validate the given value by decoding it and trying to load it as a PEM-encoded private key. """ - tls_cert = values.get("tls_cert") + tls_cert = info.data.get("tls_cert") if tls_cert and not v: raise ValueError("required if TLS cert is specified") if v: _ = load_pem_private_key(base64.b64decode(v), None) return v - @validator("tls_client_ca") + @field_validator("tls_client_ca") + @classmethod def validate_tls_client_ca(cls, v): """ Validate the given value by decoding it and trying to load it as a