Skip to content

Commit

Permalink
Enable progressive proxy via flag
Browse files Browse the repository at this point in the history
  • Loading branch information
jwindgassen committed Dec 10, 2024
1 parent 76a98c9 commit b56a548
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 6 deletions.
21 changes: 21 additions & 0 deletions jupyter_server_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,27 @@ def cats_only(response, path):
""",
).tag(config=True)

progressive = Union(
[Bool(), Callable()],
default_value=None,
allow_none=True,
help="""
Makes the proxy progressive, meaning it won't buffer any requests from the server.
Useful for applications streaming their data, where the buffering of requests can lead
to a lagging, e.g. in video streams.
Must be either None (default), a bool, or a function. Setting it to a boolean will enable/disable
progressive requests for all requests. Setting to None, jupyter-server-proxy will only enable progressive
for somespecial types, like videos, images and binary data. A function must be taking the "Accept" header of
the request from the client as input and returning a bool, whether this request should be made progressive.
Note: `progressive` and `rewrite_response` are mutually exclusive on the same request. When rewrite_response
is given and progressive is None, the proxying will never be progressive. If progressive is a function,
rewrite_response will only be called on requests where it returns False. Progressive takes precedence over
rewrite_response when both are given!
""",
).tag(config=True)

update_last_activity = Bool(
True, help="Will cause the proxy to report activity back to jupyter server."
).tag(config=True)
Expand Down
53 changes: 47 additions & 6 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from traitlets.traitlets import HasTraits

from .unixsock import UnixResolver
from .utils import call_with_asked_args
from .utils import call_with_asked_args, mime_types_match
from .websocket import WebSocketHandlerMixin, pingable_ws_connect


Expand Down Expand Up @@ -95,6 +95,15 @@ def get(self, *args):
self.redirect(urlunparse(dest))


COMMON_BINARY_MIME_TYPES = [
"image/*",
"audio/*",
"video/*",
"application/*",
"text/event-stream",
]


class ProxyHandler(WebSocketHandlerMixin, JupyterHandler):
"""
A tornado request handler that proxies HTTP and websockets from
Expand All @@ -117,10 +126,41 @@ def __init__(self, *args, **kwargs):
"rewrite_response",
tuple(),
)
self.progressive = kwargs.pop("progressive", None)
self._requested_subprotocols = None
self.update_last_activity = kwargs.pop("update_last_activity", True)
super().__init__(*args, **kwargs)

@property
def progressive(self):
accept_header = self.request.headers.get("Accept")

if self._progressive is not None:
if callable(self._progressive):
return self._progressive(accept_header)
else:
return self._progressive

# Progressive and RewritableResponse are mutually exclusive
if self.rewrite_response:
return False

if accept_header is None:
return False

# If the client can accept multiple types, we will not make the request progressive
if "," in accept_header:
return False

return any(
mime_types_match(pattern, accept_header)
for pattern in COMMON_BINARY_MIME_TYPES
)

@progressive.setter
def progressive(self, value):
self._progressive = value

# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
# to enable cross origin requests propagated by e.g. inverting proxies.

Expand Down Expand Up @@ -376,16 +416,16 @@ async def proxy(self, host, port, proxied_path):
)
else:
client = httpclient.AsyncHTTPClient(force_instance=True)
# check if the request is stream request
accept_header = self.request.headers.get("Accept")
if accept_header == "text/event-stream":

if self.progressive:
return await self._proxy_progressive(host, port, proxied_path, body, client)
else:
return await self._proxy_buffered(host, port, proxied_path, body, client)

async def _proxy_progressive(self, host, port, proxied_path, body, client):
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
# Set up handlers so we can progressively flush result
self.log.debug(f"Request to '{proxied_path}' will be proxied progressive")

headers_raw = []

Expand Down Expand Up @@ -466,9 +506,10 @@ def streaming_callback(chunk):
self.write(response.body)

async def _proxy_buffered(self, host, port, proxied_path, body, client):
req = self._build_proxy_request(host, port, proxied_path, body)
self.log.debug(f"Request to '{proxied_path}' will be proxied buffered")

self.log.debug(f"Proxying request to {req.url}")
req = self._build_proxy_request(host, port, proxied_path, body)
self.log.debug(f"Proxy request URL: {req.url}")

try:
# Here, "response" is a tornado.httpclient.HTTPResponse object.
Expand Down
17 changes: 17 additions & 0 deletions jupyter_server_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,20 @@ def call_with_asked_args(callback, args):
)
)
return callback(*asked_arg_values)


def mime_types_match(pattern: str, value: str) -> bool:
"""
Compare a MIME type pattern, possibly with wildcards, and a value
"""
value = value.split(";")[0] # Remove optional details
if pattern == value:
return True

type, subtype = value.split("/")
pattern = pattern.split("/")

if pattern[0] == "*" or (pattern[0] == type and pattern[1] == "*"):
return True

return False
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,28 @@ def _test_func(a, b):
return c

assert utils.call_with_asked_args(_test_func, {"a": 5, "b": 4, "c": 8}) == 20


def test_mime_types_match():
# Exact match
assert utils.mime_types_match("text/plain", "text/plain")
assert not utils.mime_types_match("text/plain", "text/html")

# With optional parameters
assert utils.mime_types_match("text/plain", "text/plain;charset=UTF-8")
assert not utils.mime_types_match("text/plain", "text/html;charset=UTF-8")

# With a single widcard
assert utils.mime_types_match("*", "text/plain")
assert utils.mime_types_match("*", "text/plain;charset=UTF-8")

# With both components wildcard
assert utils.mime_types_match("*/*", "text/plain")
assert utils.mime_types_match("*/*", "text/plain;charset=UTF-8")

# With a subtype wildcard
assert utils.mime_types_match("text/*", "text/plain")
assert not utils.mime_types_match("image/*", "text/plain")

assert utils.mime_types_match("text/*", "text/plain;charset=UTF-8")
assert not utils.mime_types_match("image/*", "text/plain;charset=UTF-8")

0 comments on commit b56a548

Please sign in to comment.