diff --git a/jupyter_server_proxy/config.py b/jupyter_server_proxy/config.py index 4b21cf70..a215bc9a 100644 --- a/jupyter_server_proxy/config.py +++ b/jupyter_server_proxy/config.py @@ -249,6 +249,27 @@ def cats_only(response, path): """, ).tag(config=True) + progressive = Union( + [Bool(), Callable()], + default_value=None, + allow_none=True, + help=""" + Makes the proxy progressive, meaning it won't buffer any requests from the server. + Useful for applications streaming their data, where the buffering of requests can lead + to a lagging, e.g. in video streams. + + Must be either None (default), a bool, or a function. Setting it to a boolean will enable/disable + progressive requests for all requests. Setting to None, jupyter-server-proxy will only enable progressive + for somespecial types, like videos, images and binary data. A function must be taking the "Accept" header of + the request from the client as input and returning a bool, whether this request should be made progressive. + + Note: `progressive` and `rewrite_response` are mutually exclusive on the same request. When rewrite_response + is given and progressive is None, the proxying will never be progressive. If progressive is a function, + rewrite_response will only be called on requests where it returns False. Progressive takes precedence over + rewrite_response when both are given! + """, + ).tag(config=True) + update_last_activity = Bool( True, help="Will cause the proxy to report activity back to jupyter server." ).tag(config=True) @@ -304,6 +325,7 @@ def __init__(self, *args, **kwargs): self.unix_socket = sp.unix_socket self.mappath = sp.mappath self.rewrite_response = sp.rewrite_response + self.progressive = sp.progressive self.update_last_activity = sp.update_last_activity def get_request_headers_override(self): diff --git a/jupyter_server_proxy/handlers.py b/jupyter_server_proxy/handlers.py index a5987fc8..5707b02b 100644 --- a/jupyter_server_proxy/handlers.py +++ b/jupyter_server_proxy/handlers.py @@ -22,7 +22,7 @@ from traitlets.traitlets import HasTraits from .unixsock import UnixResolver -from .utils import call_with_asked_args +from .utils import call_with_asked_args, mime_types_match from .websocket import WebSocketHandlerMixin, pingable_ws_connect @@ -95,6 +95,15 @@ def get(self, *args): self.redirect(urlunparse(dest)) +COMMON_BINARY_MIME_TYPES = [ + "image/*", + "audio/*", + "video/*", + "application/*", + "text/event-stream", +] + + class ProxyHandler(WebSocketHandlerMixin, JupyterHandler): """ A tornado request handler that proxies HTTP and websockets from @@ -117,10 +126,41 @@ def __init__(self, *args, **kwargs): "rewrite_response", tuple(), ) + self.progressive = kwargs.pop("progressive", None) self._requested_subprotocols = None self.update_last_activity = kwargs.pop("update_last_activity", True) super().__init__(*args, **kwargs) + @property + def progressive(self): + accept_header = self.request.headers.get("Accept") + + if self._progressive is not None: + if callable(self._progressive): + return self._progressive(accept_header) + else: + return self._progressive + + # Progressive and RewritableResponse are mutually exclusive + if self.rewrite_response: + return False + + if accept_header is None: + return False + + # If the client can accept multiple types, we will not make the request progressive + if "," in accept_header: + return False + + return any( + mime_types_match(pattern, accept_header) + for pattern in COMMON_BINARY_MIME_TYPES + ) + + @progressive.setter + def progressive(self, value): + self._progressive = value + # Support/use jupyter_server config arguments allow_origin and allow_origin_pat # to enable cross origin requests propagated by e.g. inverting proxies. @@ -376,9 +416,8 @@ async def proxy(self, host, port, proxied_path): ) else: client = httpclient.AsyncHTTPClient(force_instance=True) - # check if the request is stream request - accept_header = self.request.headers.get("Accept") - if accept_header == "text/event-stream": + + if self.progressive: return await self._proxy_progressive(host, port, proxied_path, body, client) else: return await self._proxy_buffered(host, port, proxied_path, body, client) @@ -386,6 +425,7 @@ async def proxy(self, host, port, proxied_path): async def _proxy_progressive(self, host, port, proxied_path, body, client): # Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila # Set up handlers so we can progressively flush result + self.log.debug(f"Request to '{proxied_path}' will be proxied progressive") headers_raw = [] @@ -466,9 +506,10 @@ def streaming_callback(chunk): self.write(response.body) async def _proxy_buffered(self, host, port, proxied_path, body, client): - req = self._build_proxy_request(host, port, proxied_path, body) + self.log.debug(f"Request to '{proxied_path}' will be proxied buffered") - self.log.debug(f"Proxying request to {req.url}") + req = self._build_proxy_request(host, port, proxied_path, body) + self.log.debug(f"Proxy request URL: {req.url}") try: # Here, "response" is a tornado.httpclient.HTTPResponse object. diff --git a/jupyter_server_proxy/utils.py b/jupyter_server_proxy/utils.py index ce4d6e28..4448d467 100644 --- a/jupyter_server_proxy/utils.py +++ b/jupyter_server_proxy/utils.py @@ -28,3 +28,20 @@ def call_with_asked_args(callback, args): ) ) return callback(*asked_arg_values) + + +def mime_types_match(pattern: str, value: str) -> bool: + """ + Compare a MIME type pattern, possibly with wildcards, and a value + """ + value = value.split(";")[0] # Remove optional details + if pattern == value: + return True + + type, subtype = value.split("/") + pattern = pattern.split("/") + + if pattern[0] == "*" or (pattern[0] == type and pattern[1] == "*"): + return True + + return False diff --git a/tests/test_utils.py b/tests/test_utils.py index 2ae1485b..ea75eef8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,3 +7,28 @@ def _test_func(a, b): return c assert utils.call_with_asked_args(_test_func, {"a": 5, "b": 4, "c": 8}) == 20 + + +def test_mime_types_match(): + # Exact match + assert utils.mime_types_match("text/plain", "text/plain") + assert not utils.mime_types_match("text/plain", "text/html") + + # With optional parameters + assert utils.mime_types_match("text/plain", "text/plain;charset=UTF-8") + assert not utils.mime_types_match("text/plain", "text/html;charset=UTF-8") + + # With a single widcard + assert utils.mime_types_match("*", "text/plain") + assert utils.mime_types_match("*", "text/plain;charset=UTF-8") + + # With both components wildcard + assert utils.mime_types_match("*/*", "text/plain") + assert utils.mime_types_match("*/*", "text/plain;charset=UTF-8") + + # With a subtype wildcard + assert utils.mime_types_match("text/*", "text/plain") + assert not utils.mime_types_match("image/*", "text/plain") + + assert utils.mime_types_match("text/*", "text/plain;charset=UTF-8") + assert not utils.mime_types_match("image/*", "text/plain;charset=UTF-8")