Skip to content

Commit

Permalink
Changes to SSHD for Pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mkjpryor committed Nov 8, 2023
1 parent ffd8a6e commit 5ec1110
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 46 deletions.
22 changes: 12 additions & 10 deletions sshd/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
certifi==2023.5.7
cffi==1.15.1
charset-normalizer==3.1.0
click==8.1.3
configomatic @ git+https://github.com/stackhpc/configomatic.git@8b81e8f216762b2e1664f60681c4a8e618ab151d
cryptography==41.0.1
annotated-types==0.6.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
configomatic @ git+https://github.com/stackhpc/configomatic.git@c485afefb9850430012e1526e5339c31b1ecee33
cryptography==41.0.5
idna==3.4
pycparser==2.21
pydantic==1.10.8
PyYAML==6.0
pydantic==2.4.2
pydantic_core==2.10.1
PyYAML==6.0.1
requests==2.31.0
typing_extensions==4.6.3
urllib3==2.0.2
typing_extensions==4.8.0
urllib3==2.0.7
12 changes: 6 additions & 6 deletions sshd/zenith/sshd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def default_service_host():
return socket.gethostbyname(socket.gethostname())


class SSHDConfig(Configuration):
class SSHDConfig(
Configuration,
default_path = "/etc/zenith/sshd.yaml",
path_env_var = "ZENITH_SSHD_CONFIG",
env_prefix = "ZENITH_SSHD"
):
"""
Configuration model for the zenith-sshd package.
"""
class Config:
default_path = "/etc/zenith/sshd.yaml"
path_env_var = "ZENITH_SSHD_CONFIG"
env_prefix = "ZENITH_SSHD"

#: The logging configuration
logging: LoggingConfiguration = Field(default_factory = LoggingConfiguration)

Expand Down
82 changes: 52 additions & 30 deletions sshd/zenith/sshd/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from pydantic import (
BaseModel,
Extra,
TypeAdapter,
Field,
AnyHttpUrl,
AfterValidator,
AnyHttpUrl as PyAnyHttpUrl,
conint,
constr,
validator
field_validator,
ValidationInfo
)

import requests
Expand Down Expand Up @@ -52,14 +54,17 @@ class TunnelExit(RuntimeError):
#: Type for an RFC3986 compliant URL path component
UrlPath = constr(pattern =r"/[a-zA-Z0-9._~!$&'()*+,;=:@%/-]*", min_length = 1)

#: Type for a string that is validated as a URL
AnyHttpUrl = typing.Annotated[
str,
AfterValidator(lambda v: TypeAdapter(PyAnyHttpUrl).validate_python(v))
]

class ClientConfig(BaseModel):

class ClientConfig(BaseModel, extra = "forbid"):
"""
Object for validating the client configuration.
"""
class Config:
extra = Extra.forbid

#: The port for the service (the tunnel port)
allocated_port: int
#: The backend protocol
Expand All @@ -71,20 +76,29 @@ class Config:
#: The URL of the OIDC issuer to use
auth_oidc_issuer: typing.Optional[AnyHttpUrl] = None
#: The OIDC client ID to use
auth_oidc_client_id: typing.Optional[constr(min_length = 1)] = None
auth_oidc_client_id: typing.Optional[constr(min_length = 1)] = Field(
None,
validate_default = True
)
#: The OIDC client secret to use
auth_oidc_client_secret: typing.Optional[constr(min_length = 1)] = None
auth_oidc_client_secret: typing.Optional[constr(min_length = 1)] = Field(
None,
validate_default = True
)
#: The OIDC groups that are allowed access to the the service
#: The user must have at least one of these groups in their groups claim
auth_oidc_allowed_groups: typing.List[AllowedGroup] = Field(default_factory = list)
#: Parameters for the external authentication service (deprecated name)
auth_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field(default_factory = dict)
#: Parameters for the external authentication service
auth_external_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field(default_factory = dict)
auth_external_params: typing.Dict[AuthParamsKey, AuthParamsValue] = Field(
default_factory = dict,
validate_default = True
)
#: Base64-encoded TLS certificate to use
tls_cert: typing.Optional[str] = None
#: Base64-encoded TLS private key to use (corresponds to TLS cert)
tls_key: typing.Optional[str] = None
tls_key: typing.Optional[str] = Field(None, validate_default = True)
#: Base64-encoded CA for validating TLS client certificates, if required
tls_client_ca: typing.Optional[str] = None
#: An optional liveness path
Expand All @@ -94,7 +108,8 @@ class Config:
#: The number of liveness checks that can fail before the tunnel is considered unhealthy
liveness_failures: conint(gt = 0) = 3

@validator("allocated_port", always = True)
@field_validator("allocated_port")
@classmethod
def validate_port(cls, v):
"""
Validate the given input as a port.
Expand All @@ -115,15 +130,17 @@ def validate_port(cls, v):
else:
raise ValueError("Given port is not in use")

@validator("auth_external_params", pre = True, always = True)
def validate_auth_external_params(cls, v, values, **kwargs):
@field_validator("auth_external_params", mode = "before")
@classmethod
def validate_auth_external_params(cls, v, info: ValidationInfo):
"""
Makes sure that the old name for external auth params is respected.
"""
return v or values.get("auth_params", {})
return v or info.data.get("auth_params", {})

@validator("auth_oidc_issuer")
def validate_auth_oidc_issuer(cls, v, values, **kwargs):
@field_validator("auth_oidc_issuer")
@classmethod
def validate_auth_oidc_issuer(cls, v):
"""
Validates that the OIDC issuer supports discovery.
"""
Expand All @@ -134,29 +151,32 @@ def validate_auth_oidc_issuer(cls, v, values, **kwargs):
else:
raise ValueError("OIDC issuer does not support discovery")

@validator("auth_oidc_client_id", always = True)
def validate_auth_oidc_client_id(cls, v, values, **kwargs):
@field_validator("auth_oidc_client_id")
@classmethod
def validate_auth_oidc_client_id(cls, v, info: ValidationInfo):
"""
Validates that an OIDC client id is given when an OIDC issuer is present.
"""
skip_auth = values.get("skip_auth", False)
oidc_issuer = values.get("auth_oidc_issuer")
skip_auth = info.data.get("skip_auth", False)
oidc_issuer = info.data.get("auth_oidc_issuer")
if not skip_auth and oidc_issuer and not v:
raise ValueError("required for OIDC authentication")
return v

@validator("auth_oidc_client_secret", always = True)
def validate_auth_oidc_client_secret(cls, v, values, **kwargs):
@field_validator("auth_oidc_client_secret")
@classmethod
def validate_auth_oidc_client_secret(cls, v, info: ValidationInfo):
"""
Validates that a client secret is given when a client ID is present.
"""
skip_auth = values.get("skip_auth", False)
oidc_issuer = values.get("auth_oidc_issuer")
skip_auth = info.data.get("skip_auth", False)
oidc_issuer = info.data.get("auth_oidc_issuer")
if not skip_auth and oidc_issuer and not v:
raise ValueError("required for OIDC authentication")
return v

@validator("tls_cert")
@field_validator("tls_cert")
@classmethod
def validate_tls_cert(cls, v):
"""
Validate the given value decoding it and trying to load it as a
Expand All @@ -165,20 +185,22 @@ def validate_tls_cert(cls, v):
_ = load_pem_x509_certificate(base64.b64decode(v))
return v

@validator("tls_key", always = True)
def validate_tls_key(cls, v, values, **kwargs):
@field_validator("tls_key")
@classmethod
def validate_tls_key(cls, v, info: ValidationInfo):
"""
Validate the given value by decoding it and trying to load it as a
PEM-encoded private key.
"""
tls_cert = values.get("tls_cert")
tls_cert = info.data.get("tls_cert")
if tls_cert and not v:
raise ValueError("required if TLS cert is specified")
if v:
_ = load_pem_private_key(base64.b64decode(v), None)
return v

@validator("tls_client_ca")
@field_validator("tls_client_ca")
@classmethod
def validate_tls_client_ca(cls, v):
"""
Validate the given value by decoding it and trying to load it as a
Expand Down

0 comments on commit 5ec1110

Please sign in to comment.