diff --git a/jupyter_server_proxy/config.py b/jupyter_server_proxy/config.py index 4b21cf70..b816938b 100644 --- a/jupyter_server_proxy/config.py +++ b/jupyter_server_proxy/config.py @@ -2,6 +2,8 @@ Traitlets based configuration for jupyter_server_proxy """ +from __future__ import annotations + import sys from textwrap import dedent, indent from warnings import warn @@ -263,60 +265,83 @@ def cats_only(response, path): """, ).tag(config=True) + def get_proxy_base_class(self) -> tuple[type | None, dict]: + """ + Return the appropriate ProxyHandler Subclass and its kwargs + """ + if self.command: + return ( + SuperviseAndRawSocketHandler + if self.raw_socket_proxy + else SuperviseAndProxyHandler + ), dict(state={}) + + if not (self.port or isinstance(self.unix_socket, str)): + warn( + f"""Server proxy {self.name} does not have a command, port number or unix_socket path. + At least one of these is required.""" + ) + return None, dict() + + return ( + RawSocketHandler if self.raw_socket_proxy else NamedLocalProxyHandler + ), dict() -def _make_proxy_handler(sp: ServerProcess): - """ - Create an appropriate handler with given parameters - """ - if sp.command: - cls = ( - SuperviseAndRawSocketHandler - if sp.raw_socket_proxy - else SuperviseAndProxyHandler - ) - args = dict(state={}) - elif not (sp.port or isinstance(sp.unix_socket, str)): - warn( - f"Server proxy {sp.name} does not have a command, port " - f"number or unix_socket path. At least one of these is " - f"required." - ) - return - else: - cls = RawSocketHandler if sp.raw_socket_proxy else NamedLocalProxyHandler - args = {} - - # FIXME: Set 'name' properly - class _Proxy(cls): - kwargs = args - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = sp.name - self.command = sp.command - self.proxy_base = sp.name - self.absolute_url = sp.absolute_url - if sp.command: - self.requested_port = sp.port - self.requested_unix_socket = sp.unix_socket - else: - self.port = sp.port - self.unix_socket = sp.unix_socket - self.mappath = sp.mappath - self.rewrite_response = sp.rewrite_response - self.update_last_activity = sp.update_last_activity - - def get_request_headers_override(self): - return self._realize_rendered_template(sp.request_headers_override) - - # these two methods are only used in supervise classes, but do no harm otherwise - def get_env(self): - return self._realize_rendered_template(sp.environment) - - def get_timeout(self): - return sp.timeout - - return _Proxy + def get_proxy_attributes(self) -> dict: + """ + Return the required attributes, which will be set on the proxy handler + """ + attributes = { + "name": self.name, + "command": self.command, + "proxy_base": self.name, + "absolute_url": self.absolute_url, + "mappath": self.mappath, + "rewrite_response": self.rewrite_response, + "update_last_activity": self.update_last_activity, + "request_headers_override": self.request_headers_override, + } + + if self.command: + attributes["requested_port"] = self.port + attributes["requested_unix_socket"] = self.unix_socket + attributes["environment"] = self.environment + attributes["timeout"] = self.timeout + else: + attributes["port"] = self.port + attributes["unix_socket"] = self.unix_socket + + return attributes + + def make_proxy_handler(self) -> tuple[type | None, dict]: + """ + Create an appropriate handler for this ServerProxy Configuration + """ + cls, proxy_kwargs = self.get_proxy_base_class() + if cls is None: + return None, proxy_kwargs + + # FIXME: Set 'name' properly + attributes = self.get_proxy_attributes() + + class _Proxy(cls): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for name, value in attributes.items(): + setattr(self, name, value) + + def get_request_headers_override(self): + return self._realize_rendered_template(self.request_headers_override) + + # these two methods are only used in supervise classes, but do no harm otherwise + def get_env(self): + return self._realize_rendered_template(self.environment) + + def get_timeout(self): + return self.timeout + + return _Proxy, proxy_kwargs def get_entrypoint_server_processes(serverproxy_config): @@ -332,21 +357,21 @@ def get_entrypoint_server_processes(serverproxy_config): return sps -def make_handlers(base_url, server_processes): +def make_handlers(base_url: str, server_processes: list[ServerProcess]): """ Get tornado handlers for registered server_processes """ handlers = [] - for sp in server_processes: - handler = _make_proxy_handler(sp) + for server in server_processes: + handler, kwargs = server.make_proxy_handler() if not handler: continue - handlers.append((ujoin(base_url, sp.name, r"(.*)"), handler, handler.kwargs)) - handlers.append((ujoin(base_url, sp.name), AddSlashHandler)) + handlers.append((ujoin(base_url, server.name, r"(.*)"), handler, kwargs)) + handlers.append((ujoin(base_url, server.name), AddSlashHandler)) return handlers -def make_server_process(name, server_process_config, serverproxy_config): +def make_server_process(name: str, server_process_config: dict, serverproxy_config): return ServerProcess(name=name, **server_process_config) diff --git a/jupyter_server_proxy/standalone/app.py b/jupyter_server_proxy/standalone/app.py index bae67252..8239ad94 100644 --- a/jupyter_server_proxy/standalone/app.py +++ b/jupyter_server_proxy/standalone/app.py @@ -16,7 +16,7 @@ from ..config import ServerProcess from .activity import start_activity_update -from .proxy import make_proxy +from .proxy import make_standalone_proxy class StandaloneProxyServer(TraitletsApplication, ServerProcess): @@ -128,8 +128,8 @@ def _default_command(self): # ToDo: Find a better way to do this return self.extra_args - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) # Flags for CLI self.flags = { @@ -174,7 +174,21 @@ def __init__(self): "websocket_max_message_size": "StandaloneProxyServer.websocket_max_message_size", } - def _create_app(self) -> web.Application: + def get_proxy_base_class(self) -> tuple[type | None, dict]: + cls, kwargs = super().get_proxy_base_class() + if cls is None: + return None, kwargs + + return make_standalone_proxy(cls, kwargs) + + def get_proxy_attributes(self) -> dict: + attributes = super().get_proxy_attributes() + attributes["requested_port"] = self.server_port + attributes["skip_authentication"] = self.skip_authentication + + return attributes + + def create_app(self) -> web.Application: self.log.debug(f"Process will use port = {self.port}") self.log.debug(f"Process will use unix_socket = {self.unix_socket}") self.log.debug(f"Process environment: {self.environment}") @@ -196,15 +210,7 @@ def _create_app(self) -> web.Application: settings["websocket_max_message_size"] = self.websocket_max_message_size # Create the proxy class with out arguments - proxy_handler, proxy_kwargs = make_proxy( - self.command, - self.server_port, - self.unix_socket, - self.environment, - self.mappath, - self.timeout, - self.skip_authentication, - ) + proxy_handler, proxy_kwargs = self.make_proxy_handler() base_url = re.escape(self.base_url) return web.Application( @@ -253,7 +259,7 @@ def start(self): if self.skip_authentication: self.log.warn("Disabling Authentication with JuypterHub Server!") - app = self._create_app() + app = self.create_app() ssl_options = self._configure_ssl() http_server = httpserver.HTTPServer(app, ssl_options=ssl_options, xheaders=True) diff --git a/jupyter_server_proxy/standalone/proxy.py b/jupyter_server_proxy/standalone/proxy.py index 645d3941..35c30991 100644 --- a/jupyter_server_proxy/standalone/proxy.py +++ b/jupyter_server_proxy/standalone/proxy.py @@ -13,83 +13,69 @@ from ..handlers import SuperviseAndProxyHandler -class StandaloneHubProxyHandler(HubOAuthenticated, SuperviseAndProxyHandler): - """ - Base class for standalone proxies. - Will restrict access to the application by authentication with the JupyterHub API. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.environment = {} - self.timeout = 60 - self.skip_authentication = False - - @property - def log(self) -> Logger: - return app_log - - @property - def hub_users(self): - if "hub_user" in self.settings: - return {self.settings["hub_user"]} - return set() - - @property - def hub_groups(self): - if "hub_group" in self.settings: - return {self.settings["hub_group"]} - return set() - - def set_default_headers(self): - self.set_header("X-JupyterHub-Version", __jh_version__) - - def prepare(self, *args, **kwargs): - pass - - def check_origin(self, origin: str = None): - # Skip JupyterHandler.check_origin - return WebSocketHandler.check_origin(self, origin) - - def check_xsrf_cookie(self): - # Skip HubAuthenticated.check_xsrf_cookie - pass - - def write_error(self, status_code: int, **kwargs): - # ToDo: Return proper error page, like in jupyter-server/JupyterHub - return RequestHandler.write_error(self, status_code, **kwargs) - - async def proxy(self, port, path): - if self.skip_authentication: - return await super().proxy(port, path) - else: - return await ensure_async(self.oauth_proxy(port, path)) - - @web.authenticated - async def oauth_proxy(self, port, path): - return await super().proxy(port, path) - - def get_env(self): - return self._render_template(self.environment) - - def get_timeout(self): - return self.timeout +def make_standalone_proxy( + base_proxy_class: type, proxy_kwargs: dict +) -> tuple[type | None, dict]: + if not issubclass(base_proxy_class, SuperviseAndProxyHandler): + app_log.error( + "Cannot create a 'StandaloneHubProxyHandler' from a class not inheriting from 'SuperviseAndProxyHandler'" + ) + return None, dict() + + class StandaloneHubProxyHandler(HubOAuthenticated, base_proxy_class): + """ + Base class for standalone proxies. + Will restrict access to the application by authentication with the JupyterHub API. + """ - -def make_proxy( - command, port, unix_socket, environment, mappath, timeout, skip_authentication -) -> tuple[type, dict]: - class Proxy(StandaloneHubProxyHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.name = f"{command[0]!r} Process" - self.proxy_base = command[0] - self.requested_port = port - self.requested_unix_socket = unix_socket - self.mappath = mappath - self.command = command - self.environment = environment - self.timeout = timeout - self.skip_authentication = skip_authentication - - return Proxy, dict(state={}) + self.environment = {} + self.timeout = 60 + self.skip_authentication = False + + @property + def log(self) -> Logger: + return app_log + + @property + def hub_users(self): + if "hub_user" in self.settings: + return {self.settings["hub_user"]} + return set() + + @property + def hub_groups(self): + if "hub_group" in self.settings: + return {self.settings["hub_group"]} + return set() + + def set_default_headers(self): + self.set_header("X-JupyterHub-Version", __jh_version__) + + def prepare(self, *args, **kwargs): + pass + + def check_origin(self, origin: str = None): + # Skip JupyterHandler.check_origin + return WebSocketHandler.check_origin(self, origin) + + def check_xsrf_cookie(self): + # Skip HubAuthenticated.check_xsrf_cookie + pass + + def write_error(self, status_code: int, **kwargs): + # ToDo: Return proper error page, like in jupyter-server/JupyterHub + return RequestHandler.write_error(self, status_code, **kwargs) + + async def proxy(self, port, path): + if self.skip_authentication: + return await super().proxy(port, path) + else: + return await ensure_async(self.oauth_proxy(port, path)) + + @web.authenticated + async def oauth_proxy(self, port, path): + return await super().proxy(port, path) + + return StandaloneHubProxyHandler, proxy_kwargs