Skip to content

Commit

Permalink
Refactor and reuse Proxy generation in standalone
Browse files Browse the repository at this point in the history
  • Loading branch information
jwindgassen committed Dec 7, 2024
1 parent 0228e86 commit 7ed8974
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 150 deletions.
143 changes: 84 additions & 59 deletions jupyter_server_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Traitlets based configuration for jupyter_server_proxy
"""

from __future__ import annotations

import sys
from textwrap import dedent, indent
from warnings import warn
Expand Down Expand Up @@ -263,60 +265,83 @@ def cats_only(response, path):
""",
).tag(config=True)

def get_proxy_base_class(self) -> tuple[type | None, dict]:
"""
Return the appropriate ProxyHandler Subclass and its kwargs
"""
if self.command:
return (
SuperviseAndRawSocketHandler
if self.raw_socket_proxy
else SuperviseAndProxyHandler
), dict(state={})

if not (self.port or isinstance(self.unix_socket, str)):
warn(
f"""Server proxy {self.name} does not have a command, port number or unix_socket path.
At least one of these is required."""
)
return None, dict()

return (
RawSocketHandler if self.raw_socket_proxy else NamedLocalProxyHandler
), dict()

def _make_proxy_handler(sp: ServerProcess):
"""
Create an appropriate handler with given parameters
"""
if sp.command:
cls = (
SuperviseAndRawSocketHandler
if sp.raw_socket_proxy
else SuperviseAndProxyHandler
)
args = dict(state={})
elif not (sp.port or isinstance(sp.unix_socket, str)):
warn(
f"Server proxy {sp.name} does not have a command, port "
f"number or unix_socket path. At least one of these is "
f"required."
)
return
else:
cls = RawSocketHandler if sp.raw_socket_proxy else NamedLocalProxyHandler
args = {}

# FIXME: Set 'name' properly
class _Proxy(cls):
kwargs = args

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = sp.name
self.command = sp.command
self.proxy_base = sp.name
self.absolute_url = sp.absolute_url
if sp.command:
self.requested_port = sp.port
self.requested_unix_socket = sp.unix_socket
else:
self.port = sp.port
self.unix_socket = sp.unix_socket
self.mappath = sp.mappath
self.rewrite_response = sp.rewrite_response
self.update_last_activity = sp.update_last_activity

def get_request_headers_override(self):
return self._realize_rendered_template(sp.request_headers_override)

# these two methods are only used in supervise classes, but do no harm otherwise
def get_env(self):
return self._realize_rendered_template(sp.environment)

def get_timeout(self):
return sp.timeout

return _Proxy
def get_proxy_attributes(self) -> dict:
"""
Return the required attributes, which will be set on the proxy handler
"""
attributes = {
"name": self.name,
"command": self.command,
"proxy_base": self.name,
"absolute_url": self.absolute_url,
"mappath": self.mappath,
"rewrite_response": self.rewrite_response,
"update_last_activity": self.update_last_activity,
"request_headers_override": self.request_headers_override,
}

if self.command:
attributes["requested_port"] = self.port
attributes["requested_unix_socket"] = self.unix_socket
attributes["environment"] = self.environment
attributes["timeout"] = self.timeout
else:
attributes["port"] = self.port
attributes["unix_socket"] = self.unix_socket

return attributes

def make_proxy_handler(self) -> tuple[type | None, dict]:
"""
Create an appropriate handler for this ServerProxy Configuration
"""
cls, proxy_kwargs = self.get_proxy_base_class()
if cls is None:
return None, proxy_kwargs

# FIXME: Set 'name' properly
attributes = self.get_proxy_attributes()

class _Proxy(cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

for name, value in attributes.items():
setattr(self, name, value)

def get_request_headers_override(self):
return self._realize_rendered_template(self.request_headers_override)

# these two methods are only used in supervise classes, but do no harm otherwise
def get_env(self):
return self._realize_rendered_template(self.environment)

def get_timeout(self):
return self.timeout

return _Proxy, proxy_kwargs


def get_entrypoint_server_processes(serverproxy_config):
Expand All @@ -332,21 +357,21 @@ def get_entrypoint_server_processes(serverproxy_config):
return sps


def make_handlers(base_url, server_processes):
def make_handlers(base_url: str, server_processes: list[ServerProcess]):
"""
Get tornado handlers for registered server_processes
"""
handlers = []
for sp in server_processes:
handler = _make_proxy_handler(sp)
for server in server_processes:
handler, kwargs = server.make_proxy_handler()
if not handler:
continue
handlers.append((ujoin(base_url, sp.name, r"(.*)"), handler, handler.kwargs))
handlers.append((ujoin(base_url, sp.name), AddSlashHandler))
handlers.append((ujoin(base_url, server.name, r"(.*)"), handler, kwargs))
handlers.append((ujoin(base_url, server.name), AddSlashHandler))
return handlers


def make_server_process(name, server_process_config, serverproxy_config):
def make_server_process(name: str, server_process_config: dict, serverproxy_config):
return ServerProcess(name=name, **server_process_config)


Expand Down
34 changes: 20 additions & 14 deletions jupyter_server_proxy/standalone/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ..config import ServerProcess
from .activity import start_activity_update
from .proxy import make_proxy
from .proxy import make_standalone_proxy


class StandaloneProxyServer(TraitletsApplication, ServerProcess):
Expand Down Expand Up @@ -128,8 +128,8 @@ def _default_command(self):
# ToDo: Find a better way to do this
return self.extra_args

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)

# Flags for CLI
self.flags = {
Expand Down Expand Up @@ -174,7 +174,21 @@ def __init__(self):
"websocket_max_message_size": "StandaloneProxyServer.websocket_max_message_size",
}

def _create_app(self) -> web.Application:
def get_proxy_base_class(self) -> tuple[type | None, dict]:
cls, kwargs = super().get_proxy_base_class()
if cls is None:
return None, kwargs

return make_standalone_proxy(cls, kwargs)

def get_proxy_attributes(self) -> dict:
attributes = super().get_proxy_attributes()
attributes["requested_port"] = self.server_port
attributes["skip_authentication"] = self.skip_authentication

return attributes

def create_app(self) -> web.Application:
self.log.debug(f"Process will use port = {self.port}")
self.log.debug(f"Process will use unix_socket = {self.unix_socket}")
self.log.debug(f"Process environment: {self.environment}")
Expand All @@ -196,15 +210,7 @@ def _create_app(self) -> web.Application:
settings["websocket_max_message_size"] = self.websocket_max_message_size

# Create the proxy class with out arguments
proxy_handler, proxy_kwargs = make_proxy(
self.command,
self.server_port,
self.unix_socket,
self.environment,
self.mappath,
self.timeout,
self.skip_authentication,
)
proxy_handler, proxy_kwargs = self.make_proxy_handler()

base_url = re.escape(self.base_url)
return web.Application(
Expand Down Expand Up @@ -253,7 +259,7 @@ def start(self):
if self.skip_authentication:
self.log.warn("Disabling Authentication with JuypterHub Server!")

app = self._create_app()
app = self.create_app()

ssl_options = self._configure_ssl()
http_server = httpserver.HTTPServer(app, ssl_options=ssl_options, xheaders=True)
Expand Down
140 changes: 63 additions & 77 deletions jupyter_server_proxy/standalone/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,83 +13,69 @@
from ..handlers import SuperviseAndProxyHandler


class StandaloneHubProxyHandler(HubOAuthenticated, SuperviseAndProxyHandler):
"""
Base class for standalone proxies.
Will restrict access to the application by authentication with the JupyterHub API.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.environment = {}
self.timeout = 60
self.skip_authentication = False

@property
def log(self) -> Logger:
return app_log

@property
def hub_users(self):
if "hub_user" in self.settings:
return {self.settings["hub_user"]}
return set()

@property
def hub_groups(self):
if "hub_group" in self.settings:
return {self.settings["hub_group"]}
return set()

def set_default_headers(self):
self.set_header("X-JupyterHub-Version", __jh_version__)

def prepare(self, *args, **kwargs):
pass

def check_origin(self, origin: str = None):
# Skip JupyterHandler.check_origin
return WebSocketHandler.check_origin(self, origin)

def check_xsrf_cookie(self):
# Skip HubAuthenticated.check_xsrf_cookie
pass

def write_error(self, status_code: int, **kwargs):
# ToDo: Return proper error page, like in jupyter-server/JupyterHub
return RequestHandler.write_error(self, status_code, **kwargs)

async def proxy(self, port, path):
if self.skip_authentication:
return await super().proxy(port, path)
else:
return await ensure_async(self.oauth_proxy(port, path))

@web.authenticated
async def oauth_proxy(self, port, path):
return await super().proxy(port, path)

def get_env(self):
return self._render_template(self.environment)

def get_timeout(self):
return self.timeout
def make_standalone_proxy(
base_proxy_class: type, proxy_kwargs: dict
) -> tuple[type | None, dict]:
if not issubclass(base_proxy_class, SuperviseAndProxyHandler):
app_log.error(
"Cannot create a 'StandaloneHubProxyHandler' from a class not inheriting from 'SuperviseAndProxyHandler'"
)
return None, dict()

class StandaloneHubProxyHandler(HubOAuthenticated, base_proxy_class):
"""
Base class for standalone proxies.
Will restrict access to the application by authentication with the JupyterHub API.
"""


def make_proxy(
command, port, unix_socket, environment, mappath, timeout, skip_authentication
) -> tuple[type, dict]:
class Proxy(StandaloneHubProxyHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = f"{command[0]!r} Process"
self.proxy_base = command[0]
self.requested_port = port
self.requested_unix_socket = unix_socket
self.mappath = mappath
self.command = command
self.environment = environment
self.timeout = timeout
self.skip_authentication = skip_authentication

return Proxy, dict(state={})
self.environment = {}
self.timeout = 60
self.skip_authentication = False

@property
def log(self) -> Logger:
return app_log

@property
def hub_users(self):
if "hub_user" in self.settings:
return {self.settings["hub_user"]}
return set()

@property
def hub_groups(self):
if "hub_group" in self.settings:
return {self.settings["hub_group"]}
return set()

def set_default_headers(self):
self.set_header("X-JupyterHub-Version", __jh_version__)

def prepare(self, *args, **kwargs):
pass

def check_origin(self, origin: str = None):
# Skip JupyterHandler.check_origin
return WebSocketHandler.check_origin(self, origin)

def check_xsrf_cookie(self):
# Skip HubAuthenticated.check_xsrf_cookie
pass

def write_error(self, status_code: int, **kwargs):
# ToDo: Return proper error page, like in jupyter-server/JupyterHub
return RequestHandler.write_error(self, status_code, **kwargs)

async def proxy(self, port, path):
if self.skip_authentication:
return await super().proxy(port, path)
else:
return await ensure_async(self.oauth_proxy(port, path))

@web.authenticated
async def oauth_proxy(self, port, path):
return await super().proxy(port, path)

return StandaloneHubProxyHandler, proxy_kwargs

0 comments on commit 7ed8974

Please sign in to comment.