diff --git a/frontend/.eslintrc.json b/.eslintrc.json similarity index 100% rename from frontend/.eslintrc.json rename to .eslintrc.json diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 64d0cac7f..ec476d315 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,6 +28,6 @@ jobs: with: node-version: 14 - name: install dependencies - run: cd frontend && npm install + run: npm install - name: build - run: cd frontend && npm run make + run: npm run make diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c3542e637..b796db1a2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,9 +10,9 @@ on: - 'v*' # Triggers on creation of a release - release: - types: - - created + # release: + # types: + # - created # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -32,8 +32,8 @@ jobs: with: node-version: 14 - name: install dependencies - run: cd frontend && npm install + run: npm install - name: publish env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: cd frontend && npm run publish + run: npm run publish diff --git a/.gitignore b/.gitignore index 8e3a10669..dc7b30e10 100644 --- a/.gitignore +++ b/.gitignore @@ -86,4 +86,4 @@ typings/ .webpack/ # Electron-Forge -out/ +out/ \ No newline at end of file diff --git a/README.md b/README.md index 05a57c2ee..882407f96 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,18 @@ # chaiNNer +![GitHub Latest Release](https://img.shields.io/github/v/release/joeyballentine/chaiNNer) ![GitHub Total Downloads](https://img.shields.io/github/downloads/joeyballentine/chaiNNer/total) ![License](https://img.shields.io/github/license/joeyballentine/chaiNNer) ![Discord](https://img.shields.io/discord/930865462852591648?label=Discord&logo=Discord&logoColor=white) + +

+ +

+ A flowchart/node-based image processing GUI aimed at making chaining image processing tasks (especially those done by neural networks) easy, intuitive, and customizable. No existing GUI gives you the level of customization of your image processing workflow that chaiNNer does. Not only do you have full control over your processing pipeline, you can do incredibly complex tasks just by connecting a few nodes together. -ChaiNNer is also cross-platform, meaning you can run it on Windows, MacOS, and Linux. +chaiNNer is also cross-platform, meaning you can run it on Windows, MacOS, and Linux. + +For help, suggestions, or just to hang out, you can join the [chaiNNer Discord server](https://discord.gg/pzvAKPKyHM) ## Installation @@ -15,3 +23,59 @@ The only dependency you need to have installed already is Python 3.7-3.9. All ot ## GPU Support Currently, chaiNNer's neural network support (via PyTorch) only supports Nvidia GPUs. There is currently no plan to support pre-compiled `.exe`s for NCNN processing. PyTorch also does not support GPU processing on MacOS. + +## Planned Features + +**Embedded Python** + +> I'm currently figuring out the best way to add this in. There are standalone python binaries for every platform that I plan on supporting. I am still just trying to figure out whether it should be downloaded and installed to on first run, or if all that should be done in the build action and bundled with the installer. + +**NCNN** + +> Once the python api for NCNN supports GPU, I will be adding the ability to convert from PyTorch to TorchScript to ONNX to NCNN. It'll be a bit convoluted but it'll allow AMD support I think + +**PIL & Wand** + +> I do plan on adding support for PIL and Wand for image processing. + +**Batch Processing** + +> I am waiting to add this until the node-graph library I use supports nested flows (which is coming relatively soon). The way I will be doing this will be similar to how for loops work, in that you will have iterator panels that will iterate over some sort of loaded array of items (i.e. folder input or frames of a video) + +**Undo History, Copy & Paste** + +> For now I am having difficulty adding these in. I plan on revisiting this later after I am forced to refactor my implementation due to the node-graph library I use releasing breaking changes soon. + +**Drag and Drop Images** + +> This is planned, ideally for both dragging into the file selection box and onto the window to make a new image read node + +**Presets** + +> Some things that are common tasks should have presets you can drag in, that are basically just multiple nodes packaged together + +**More SR Networks, More Image Processing Libraries** + +> What the title says + +**Live Updating** + +> This is something that will be a bit complex to do, but basically I'd like to have a mode where it constantly is running and refreshing on any node change, and displays previews of each node + +## FAQ + +**What does the name mean?** + +> chaiNNer is a play on the fact that you can "chain" different tasks together, with the NN in the name being a common abbreviation for Neural Networks. This is following the brilliant naming scheme of victorca25's machine learning tools (traiNNer, iNNfer, augmeNNt) which he granted me permission to use for this as well. + +**Why not just use Cupscale/IEU/CLI?** + +> All of these tools are viable options, but as anyone who has used them before knows, they can be limited in what it can do, as many features like chaining or interpolating models are hardcoded in and provide little flexibility. Certain features that would be useful, like being able to use a separate model on the alpha layer of an image, just do not exist in Cupscale, for example. Inversely, you can pretty much do whatever you want with chaiNNer provided there are nodes implemented. Whatever weird feature you want implemented, you can implement yourself by connecting nodes however you want. Cupscale also does not have other image processing abilities like chaiNNer does, such as adjusting contrast. + +**Wouldn't this make it more difficult to do things?** + +> In a way, yes. Similarly to how programming your own script to do this stuff is more difficult, chaiNNer will also be a bit more difficult than simply dragging and dropping and image and messing with some sliders and pressing an upscale button. However, this gives you a lot more flexibility in what you can do. The added complexity is really just connecting some dots together to do what you want. That doesn't sound that bad, right? + +**What platforms are supported?** + +> Windows, Linux, and MacOS are all supported by chaiNNer. However, MacOS currently lacks GPU support for pytorch, so I highly recommend using another OS if you need that functionality. diff --git a/backend/build.sh b/backend/build.sh deleted file mode 100644 index 20a745208..000000000 --- a/backend/build.sh +++ /dev/null @@ -1,12 +0,0 @@ -python -m nuitka --mingw64 --standalone --lto no run.py \ ---plugin-enable=torch --plugin-enable=pylint-warnings --plugin-enable=numpy --enable-plugin=anti-bloat \ ---noinclude-pytest-mode=nofollow --noinclude-setuptools-mode=nofollow \ ---noinclude-matplotlib --noinclude-scipy \ ---nofollow-import-to=PyQt5 \ ---nofollow-import-to=matplotlib \ ---nofollow-import-to=scipy \ ---nofollow-import-to=tkinter \ ---nofollow-import-to=torchvision \ ---nofollow-import-to=torchaudio \ ---nofollow-import-to=IPython \ ---nofollow-import-to=jedi \ \ No newline at end of file diff --git a/backend/nodes/node_factory.py b/backend/nodes/node_factory.py index 9ef0ce7dc..18a34cd47 100644 --- a/backend/nodes/node_factory.py +++ b/backend/nodes/node_factory.py @@ -1,38 +1,41 @@ -from typing import Callable, Dict - -from .node_base import NodeBase - -from sanic.log import logger - - -# Implementation based on https://medium.com/@geoffreykoh/implementing-the-factory-pattern-via-dynamic-registry-and-python-decorators-479fc1537bbe -class NodeFactory: - """The factory class for creating nodes""" - - registry = {} - """ Internal registry for available nodes """ - - @classmethod - def create_node(cls, category: str, name: str) -> NodeBase: - """Factory command to create the node""" - - node_class = cls.registry[category][name] - node = node_class() - logger.info(f"Created {category}, {name} node") - return node - - @classmethod - def register(cls, category: str, name: str) -> Callable: - def inner_wrapper(wrapped_class: NodeBase) -> Callable: - if category not in cls.registry: - cls.registry[category] = {} - if name in cls.registry[category]: - logger.warning(f"Node {name} already exists. Will replace it") - cls.registry[category][name] = wrapped_class - return wrapped_class - - return inner_wrapper - - @classmethod - def get_registry(cls) -> Dict: - return cls.registry +import sys +from typing import Callable, Dict + +sys.path.append("..") + +from sanic_server.sanic.log import logger + +from .node_base import NodeBase + + +# Implementation based on https://medium.com/@geoffreykoh/implementing-the-factory-pattern-via-dynamic-registry-and-python-decorators-479fc1537bbe +class NodeFactory: + """The factory class for creating nodes""" + + registry = {} + """ Internal registry for available nodes """ + + @classmethod + def create_node(cls, category: str, name: str) -> NodeBase: + """Factory command to create the node""" + + node_class = cls.registry[category][name] + node = node_class() + logger.info(f"Created {category}, {name} node") + return node + + @classmethod + def register(cls, category: str, name: str) -> Callable: + def inner_wrapper(wrapped_class: NodeBase) -> Callable: + if category not in cls.registry: + cls.registry[category] = {} + if name in cls.registry[category]: + logger.warning(f"Node {name} already exists. Will replace it") + cls.registry[category][name] = wrapped_class + return wrapped_class + + return inner_wrapper + + @classmethod + def get_registry(cls) -> Dict: + return cls.registry diff --git a/backend/nodes/numpy_nodes.py b/backend/nodes/numpy_nodes.py index 451639db4..6cefe50be 100644 --- a/backend/nodes/numpy_nodes.py +++ b/backend/nodes/numpy_nodes.py @@ -2,11 +2,15 @@ Nodes that provide functionality for numpy array manipulation """ +import sys from typing import List import cv2 import numpy as np -from sanic.log import logger + +sys.path.append("..") + +from sanic_server.sanic.log import logger from .node_base import NodeBase from .node_factory import NodeFactory diff --git a/backend/nodes/opencv_nodes.py b/backend/nodes/opencv_nodes.py index ca226526a..8152a6aab 100644 --- a/backend/nodes/opencv_nodes.py +++ b/backend/nodes/opencv_nodes.py @@ -3,10 +3,14 @@ """ import os +import sys import cv2 import numpy as np -from sanic.log import logger + +sys.path.append("..") + +from sanic_server.sanic.log import logger from .node_base import NodeBase from .node_factory import NodeFactory diff --git a/backend/nodes/properties/inputs/file_inputs.py b/backend/nodes/properties/inputs/file_inputs.py index 2674a010e..383ada38b 100644 --- a/backend/nodes/properties/inputs/file_inputs.py +++ b/backend/nodes/properties/inputs/file_inputs.py @@ -1,62 +1,67 @@ -from typing import Dict, List - -from .generic_inputs import DropDownInput - - -def FileInput( - input_type: str, label: str, accepts: List[str], filetypes: List[str] -) -> Dict: - """ Input for submitting a local file """ - return { - "type": f"file::{input_type}", - "label": label, - "accepts": None, - "filetypes": filetypes, - } - - -def ImageFileInput() -> Dict: - """ Input for submitting a local image file """ - return FileInput( - "image", "Image File", None, ["png", "jpg", "jpeg", "gif", "tiff", "webp"] - ) - - -def PthFileInput() -> Dict: - """ Input for submitting a local .pth file """ - return FileInput("pth", "Pretrained Model", None, ["pth"]) - - -def DirectoryInput() -> Dict: - """ Input for submitting a local directory """ - return FileInput("directory", "Directory", None, ["directory"]) - - -def ImageExtensionDropdown() -> Dict: - """ Input for selecting file type from dropdown """ - return DropDownInput( - "image-extensions", - "Image Extension", - [ - { - "option": "PNG", - "value": "png", - }, - { - "option": "JPG", - "value": "jpg", - }, - { - "option": "GIF", - "value": "gif", - }, - { - "option": "TIFF", - "value": "tiff", - }, - { - "option": "WEBP", - "value": "webp", - }, - ], - ) +from typing import Dict, List + +from .generic_inputs import DropDownInput + + +def FileInput( + input_type: str, label: str, accepts: List[str], filetypes: List[str] +) -> Dict: + """ Input for submitting a local file """ + return { + "type": f"file::{input_type}", + "label": label, + "accepts": None, + "filetypes": filetypes, + } + + +def ImageFileInput() -> Dict: + """ Input for submitting a local image file """ + return FileInput( + "image", "Image File", None, ["png", "jpg", "jpeg", "gif", "tiff", "webp"] + ) + + +def PthFileInput() -> Dict: + """ Input for submitting a local .pth file """ + return FileInput("pth", "Pretrained Model", None, ["pth"]) + + +def TorchFileInput() -> Dict: + """ Input for submitting a local .pth or .pt file """ + return FileInput("pth", "Pretrained Model", None, ["pth", "pt"]) + + +def DirectoryInput() -> Dict: + """ Input for submitting a local directory """ + return FileInput("directory", "Directory", None, ["directory"]) + + +def ImageExtensionDropdown() -> Dict: + """ Input for selecting file type from dropdown """ + return DropDownInput( + "image-extensions", + "Image Extension", + [ + { + "option": "PNG", + "value": "png", + }, + { + "option": "JPG", + "value": "jpg", + }, + { + "option": "GIF", + "value": "gif", + }, + { + "option": "TIFF", + "value": "tiff", + }, + { + "option": "WEBP", + "value": "webp", + }, + ], + ) diff --git a/backend/nodes/properties/inputs/pytorch_inputs.py b/backend/nodes/properties/inputs/pytorch_inputs.py index 92b10b9fe..bf324582f 100644 --- a/backend/nodes/properties/inputs/pytorch_inputs.py +++ b/backend/nodes/properties/inputs/pytorch_inputs.py @@ -15,3 +15,11 @@ def ModelInput() -> Any: "type": "pytorch::model", "label": "Loaded Model", } + + +def TorchScriptInput() -> Any: + """ Input a JIT traced model """ + return { + "type": "pytorch::torchscript", + "label": "Traced Model", + } diff --git a/backend/nodes/properties/outputs/pytorch_outputs.py b/backend/nodes/properties/outputs/pytorch_outputs.py index 37e8c5295..a0e357f20 100644 --- a/backend/nodes/properties/outputs/pytorch_outputs.py +++ b/backend/nodes/properties/outputs/pytorch_outputs.py @@ -1,4 +1,4 @@ -from typing import OrderedDict, Any +from typing import Any, OrderedDict def StateDictOutput() -> OrderedDict: @@ -15,3 +15,11 @@ def ModelOutput() -> Any: "type": "pytorch::model", "label": "Loaded Model", } + + +def TorchScriptOutput() -> Any: + """ Output a JIT traced model """ + return { + "type": "pytorch::torchscript", + "label": "Traced Model", + } diff --git a/backend/nodes/pytorch_nodes.py b/backend/nodes/pytorch_nodes.py index 7cf4c6438..1bdb179fe 100644 --- a/backend/nodes/pytorch_nodes.py +++ b/backend/nodes/pytorch_nodes.py @@ -1,275 +1,322 @@ -""" -Nodes that provide functionality for pytorch inference -""" - - -import os -from typing import Any, OrderedDict - -import numpy as np -import torch -from sanic.log import logger - -from .node_base import NodeBase -from .node_factory import NodeFactory -from .properties.inputs.file_inputs import DirectoryInput, PthFileInput -from .properties.inputs.generic_inputs import SliderInput, TextInput -from .properties.inputs.numpy_inputs import ImageInput -from .properties.inputs.pytorch_inputs import ModelInput, StateDictInput -from .properties.outputs.numpy_outputs import ImageOutput -from .properties.outputs.pytorch_outputs import ModelOutput, StateDictOutput -from .utils.architectures.RRDB import RRDBNet -from .utils.utils import auto_split_process, np2tensor, tensor2np - - -def check_env(): - os.environ["device"] = ( - "cuda" if torch.cuda.is_available() and os.environ["device"] != "cpu" else "cpu" - ) - - if bool(os.environ["isFp16"]): - if os.environ["device"] == "cpu": - torch.set_default_tensor_type(torch.HalfTensor) - elif os.environ["device"] == "cuda": - torch.set_default_tensor_type(torch.cuda.HalfTensor) - else: - logger.warn("Something isn't set right with the device env var") - - -@NodeFactory.register("PyTorch", "Model::Read") -class LoadStateDictNode(NodeBase): - """Load Model node""" - - def __init__(self): - """Constructor""" - self.description = "Load PyTorch state dict file (.pth) from path" - self.inputs = [PthFileInput()] - self.outputs = [StateDictOutput()] - - def run(self, path: str) -> OrderedDict: - """Read a pth file from the specified path and return it as a state dict""" - - logger.info(f"Reading state dict from path: {path}") - state_dict = torch.load(path) - - return state_dict - - -@NodeFactory.register("PyTorch", "ESRGAN::Load") -class LoadEsrganModelNode(NodeBase): - """Load ESRGAN Model node""" - - def __init__(self): - """Constructor""" - self.description = "Load PyTorch state dict into the ESRGAN model architecture" - self.inputs = [StateDictInput()] - self.outputs = [ModelOutput()] - - def run(self, state_dict: OrderedDict) -> Any: - """Loads the state dict to an ESRGAN model after finding arch config""" - - logger.info(f"Loading state dict into ESRGAN model") - - # Convert a 'new-arch' model to 'old-arch' - if "conv_first.weight" in state_dict: - state_dict = self.convert_new_to_old(state_dict) - - # extract model information - scale2 = 0 - max_part = 0 - in_nc = 0 - out_nc = 0 - plus = False - for part in list(state_dict): - parts = part.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if part_num > 6 and parts[0] == "model" and parts[2] == "weight": - scale2 += 1 - if part_num > max_part: - max_part = part_num - out_nc = state_dict[part].shape[0] - if "conv1x1" in part and not plus: - plus = True - - upscale = 2 ** scale2 - in_nc = state_dict["model.0.weight"].shape[1] - nf = state_dict["model.0.weight"].shape[0] - - model = RRDBNet( - in_nc=in_nc, - out_nc=out_nc, - nf=nf, - nb=nb, - gc=32, - upscale=upscale, - norm_type=None, - act_type="leakyrelu", - mode="CNA", - upsample_mode="upconv", - plus=plus, - ) - - model.load_state_dict(state_dict, strict=True) - for _, v in model.named_parameters(): - v.requires_grad = False - model.eval() - model.to(torch.device(os.environ["device"])) - - return model - - def convert_new_to_old(self, state_dict): - logger.warn("Attempting to convert and load a new-format model") - old_net = {} - items = [] - for k, _ in state_dict.items(): - items.append(k) - - old_net["model.0.weight"] = state_dict["conv_first.weight"] - old_net["model.0.bias"] = state_dict["conv_first.bias"] - - for k in items.copy(): - if "RDB" in k: - ori_k = k.replace("RRDB_trunk.", "model.1.sub.") - if ".weight" in k: - ori_k = ori_k.replace(".weight", ".0.weight") - elif ".bias" in k: - ori_k = ori_k.replace(".bias", ".0.bias") - old_net[ori_k] = state_dict[k] - items.remove(k) - - old_net["model.1.sub.23.weight"] = state_dict["trunk_conv.weight"] - old_net["model.1.sub.23.bias"] = state_dict["trunk_conv.bias"] - old_net["model.3.weight"] = state_dict["upconv1.weight"] - old_net["model.3.bias"] = state_dict["upconv1.bias"] - old_net["model.6.weight"] = state_dict["upconv2.weight"] - old_net["model.6.bias"] = state_dict["upconv2.bias"] - old_net["model.8.weight"] = state_dict["HRconv.weight"] - old_net["model.8.bias"] = state_dict["HRconv.bias"] - old_net["model.10.weight"] = state_dict["conv_last.weight"] - old_net["model.10.bias"] = state_dict["conv_last.bias"] - return old_net - - -@NodeFactory.register("PyTorch", "ESRGAN::Run") -class EsrganNode(NodeBase): - """ESRGAN node""" - - def __init__(self): - """Constructor""" - self.description = "Upscales a BGR numpy array using an ESRGAN model" - self.inputs = [ModelInput(), ImageInput()] - self.outputs = [ImageOutput("Upscaled Image")] - - def run(self, model: RRDBNet, img: np.ndarray) -> np.ndarray: - """Upscales an image with an ESRGAN pretrained model""" - - check_env() - - logger.info(f"Upscaling image...") - - img = img / np.iinfo(img.dtype).max - - in_nc = model.in_nc - out_nc = model.out_nc - scale = model.scale - h, w = img.shape[:2] - c = img.shape[2] if len(img.shape) > 2 else 1 - logger.info( - f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc: {out_nc})" - ) - - # Ensure correct amount of image channels for the model. - # The frontend should type-validate this enough where it shouldn't be needed, - # But I want to be extra safe - - # # Add extra channels if not enough (i.e single channel img, three channel model) - gray = False - if img.ndim == 2: - gray = True - logger.warn("Expanding image channels") - img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) - # Remove extra channels if too many (i.e three channel image, single channel model) - elif img.shape[2] > in_nc: - logger.warn("Truncating image channels") - img = img[:, :, :in_nc] - # Pad with solid alpha channel if needed (i.e three channel image, four channel model) - elif img.shape[2] == 3 and in_nc == 4: - logger.warn("Expanding image channels") - img = np.dstack((img, np.full(img.shape[:-1], 1.0))) - - # Borrowed from iNNfer - logger.info("Converting image to tensor") - img_tensor = np2tensor(img) - t_img = np2tensor(img).to(torch.device(os.environ["device"])) - t_out = t_img.clone() - if bool(os.environ["isFp16"]): - model = model.half() - t_img = t_img.half() - logger.info("Upscaling image") - t_out, _ = auto_split_process( - t_img, - model, - scale, - ) - # t_out = model(t_out) - logger.info("Converting tensor to image") - img_out = tensor2np(t_out.detach()) - logger.info("Done upscaling") - - if gray: - img_out = np.average(img_out, axis=2).astype("uint8") - - return img_out - - -@NodeFactory.register("PyTorch", "Model::Interpolate") -class InterpolateNode(NodeBase): - """Interpolate node""" - - def __init__(self): - """Constructor""" - self.description = "Interpolate two models together" - self.inputs = [ - StateDictInput(), - StateDictInput(), - SliderInput("Amount", 0, 100, 50), - ] - self.outputs = [StateDictOutput()] - - def run(self, model_a: RRDBNet, model_b: RRDBNet, amount: int) -> np.ndarray: - """Upscales an image with an ESRGAN pretrained model""" - - logger.info(f"Interpolating models...") - - amount_a = amount / 100 - amount_b = 1 - amount_a - - state_dict = OrderedDict() - for k, v_1 in model_a.items(): - v_2 = model_b[k] - state_dict[k] = (amount_a * v_1) + (amount_b * v_2) - return state_dict - - -@NodeFactory.register("PyTorch", "Model::Save") -class PthSaveNode(NodeBase): - """Model Save node""" - - def __init__(self): - """Constructor""" - self.description = "Save a PyTorch model" - self.inputs = [StateDictInput(), DirectoryInput(), TextInput("Model Name")] - self.outputs = [] - - def run(self, model: OrderedDict(), directory: str, name: str) -> np.ndarray: - """Upscales an image with an ESRGAN pretrained model""" - fullFile = f"{name}.pth" - fullPath = os.path.join(directory, fullFile) - logger.info(f"Writing image to path: {fullPath}") - status = torch.save(model, fullPath) - - return status +""" +Nodes that provide functionality for pytorch inference +""" + + +import os +import sys +from typing import Any, OrderedDict + +import numpy as np +import torch + +sys.path.append("..") + +from typing import Union + +from sanic_server.sanic.log import logger + +from .node_base import NodeBase +from .node_factory import NodeFactory +from .properties.inputs.file_inputs import DirectoryInput, PthFileInput, TorchFileInput +from .properties.inputs.generic_inputs import SliderInput, TextInput +from .properties.inputs.numpy_inputs import ImageInput +from .properties.inputs.pytorch_inputs import ( + ModelInput, + StateDictInput, + TorchScriptInput, +) +from .properties.outputs.numpy_outputs import ImageOutput +from .properties.outputs.pytorch_outputs import ( + ModelOutput, + StateDictOutput, + TorchScriptOutput, +) +from .utils.architecture.RRDB import RRDBNet as ESRGAN +from .utils.architecture.SPSR import SPSRNet as SPSR +from .utils.architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 +from .utils.utils import auto_split_process, np2tensor, tensor2np + + +def check_env(): + os.environ["device"] = ( + "cuda" if torch.cuda.is_available() and os.environ["device"] != "cpu" else "cpu" + ) + + if bool(os.environ["isFp16"]): + if os.environ["device"] == "cpu": + torch.set_default_tensor_type(torch.HalfTensor) + elif os.environ["device"] == "cuda": + torch.set_default_tensor_type(torch.cuda.HalfTensor) + else: + logger.warn("Something isn't set right with the device env var") + + +@NodeFactory.register("PyTorch", "Model::Read") +class LoadStateDictNode(NodeBase): + """Load Model node""" + + def __init__(self): + """Constructor""" + self.description = "Load PyTorch state dict file (.pth) from path" + self.inputs = [PthFileInput()] + self.outputs = [StateDictOutput()] + + def run(self, path: str) -> OrderedDict: + """Read a pth file from the specified path and return it as a state dict""" + + logger.info(f"Reading state dict from path: {path}") + state_dict = torch.load(path) + + return state_dict + + +@NodeFactory.register("PyTorch", "Model::AutoLoad") +class AutoLoadModelNode(NodeBase): + """Load PyTorch Model node""" + + def __init__(self): + """Constructor""" + self.description = "Load PyTorch state dict into an auto-detected supported model architecture. Supports most variations of the RRDB architecture (ESRGAN, Real-ESRGAN, RealSR, BSRGAN, SPSR) and Real-ESRGAN's SRVGG architecture" + self.inputs = [StateDictInput()] + self.outputs = [ModelOutput()] + + def run(self, state_dict: OrderedDict) -> Any: + """Loads the state dict to an ESRGAN model after finding arch config""" + + logger.info(f"Loading state dict into ESRGAN model") + + # SRVGGNet Real-ESRGAN (v2) + if ( + "params" in state_dict.keys() + and "body.0.weight" in state_dict["params"].keys() + ): + model = RealESRGANv2(state_dict) + # SPSR (ESRGAN with lots of extra layers) + elif "f_HR_conv1.0.weight" in state_dict: + model = SPSR(state_dict) + # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 + else: + model = ESRGAN(state_dict) + + for _, v in model.named_parameters(): + v.requires_grad = False + model.eval() + model.to(torch.device(os.environ["device"])) + + return model + + +@NodeFactory.register("PyTorch", "Image::Upscale") +class ImageUpscaleNode(NodeBase): + """Image Upscale node""" + + def __init__(self): + """Constructor""" + self.description = "Upscales a BGR numpy array using a Super-Resolution model" + self.inputs = [ModelInput(), ImageInput()] + self.outputs = [ImageOutput("Upscaled Image")] + + def run(self, model: torch.nn.Module, img: np.ndarray) -> np.ndarray: + """Upscales an image with a pretrained model""" + + check_env() + + logger.info(f"Upscaling image...") + + img = img / np.iinfo(img.dtype).max + + # TODO: Have all super resolution models inherit from something that forces them to use in_nc and out_nc + in_nc = model.in_nc + out_nc = model.out_nc + scale = model.scale + h, w = img.shape[:2] + c = img.shape[2] if len(img.shape) > 2 else 1 + logger.info( + f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc: {out_nc})" + ) + + # Ensure correct amount of image channels for the model. + # The frontend should type-validate this enough where it shouldn't be needed, + # But I want to be extra safe + + # # Add extra channels if not enough (i.e single channel img, three channel model) + gray = False + if img.ndim == 2: + gray = True + logger.warn("Expanding image channels") + img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) + # Remove extra channels if too many (i.e three channel image, single channel model) + elif img.shape[2] > in_nc: + logger.warn("Truncating image channels") + img = img[:, :, :in_nc] + # Pad with solid alpha channel if needed (i.e three channel image, four channel model) + elif img.shape[2] == 3 and in_nc == 4: + logger.warn("Expanding image channels") + img = np.dstack((img, np.full(img.shape[:-1], 1.0))) + + # Borrowed from iNNfer + logger.info("Converting image to tensor") + img_tensor = np2tensor(img) + t_img = np2tensor(img).to(torch.device(os.environ["device"])) + t_out = t_img.clone() + if bool(os.environ["isFp16"]): + model = model.half() + t_img = t_img.half() + logger.info("Upscaling image") + t_out, _ = auto_split_process( + t_img, + model, + scale, + ) + # t_out = model(t_out) + logger.info("Converting tensor to image") + img_out = tensor2np(t_out.detach()) + logger.info("Done upscaling") + + if gray: + img_out = np.average(img_out, axis=2).astype("uint8") + + return img_out + + +@NodeFactory.register("PyTorch", "Model::Interpolate") +class InterpolateNode(NodeBase): + """Interpolate node""" + + def __init__(self): + """Constructor""" + self.description = "Interpolate two of the same kind of model together" + self.inputs = [ + StateDictInput(), + StateDictInput(), + SliderInput("Amount", 0, 100, 50), + ] + self.outputs = [StateDictOutput()] + + def run( + self, model_a: torch.nn.Module, model_b: torch.nn.Module, amount: int + ) -> np.ndarray: + + logger.info(f"Interpolating models...") + + amount_a = amount / 100 + amount_b = 1 - amount_a + + state_dict = OrderedDict() + for k, v_1 in model_a.items(): + v_2 = model_b[k] + state_dict[k] = (amount_a * v_1) + (amount_b * v_2) + return state_dict + + +@NodeFactory.register("PyTorch", "Model::Save") +class PthSaveNode(NodeBase): + """Model Save node""" + + def __init__(self): + """Constructor""" + self.description = "Save a PyTorch model" + self.inputs = [StateDictInput(), DirectoryInput(), TextInput("Model Name")] + self.outputs = [] + + def run(self, model: OrderedDict(), directory: str, name: str) -> bool: + fullFile = f"{name}.pth" + fullPath = os.path.join(directory, fullFile) + logger.info(f"Writing model to path: {fullPath}") + status = torch.save(model, fullPath) + + return status + + +# @NodeFactory.register("PyTorch", "JIT::Trace") +# class JitTraceNode(NodeBase): +# """JIT trace node""" + +# def __init__(self): +# """Constructor""" +# self.description = "JIT trace a pytorch model" +# self.inputs = [ModelInput(), ImageInput("Example Input")] +# self.outputs = [TorchScriptOutput()] + +# def run(self, model: any, image: np.ndarray) -> torch.ScriptModule: +# tensor = np2tensor(image) +# traced = torch.jit.trace(model.cpu(), tensor.cpu()) + +# return traced + + +# @NodeFactory.register("PyTorch", "JIT::Optimize") +# class JitOptimizeNode(NodeBase): +# """JIT optimize node""" + +# def __init__(self): +# """Constructor""" +# self.description = "Optimize a JIT traced pytorch model for inference" +# self.inputs = [TorchScriptInput()] +# self.outputs = [TorchScriptOutput()] + +# def run(self, model: torch.ScriptModule) -> torch.ScriptModule: +# optimized = torch.jit.optimize_for_inference(model) + +# return optimized + + +# @NodeFactory.register("PyTorch", "JIT::Save") +# class JitSaveNode(NodeBase): +# """JIT save node""" + +# def __init__(self): +# """Constructor""" +# self.description = "Save a JIT traced pytorch model to a file" +# self.inputs = [TorchScriptInput(), DirectoryInput(), TextInput("Model Name")] +# self.outputs = [] + +# def run(self, model: torch.ScriptModule, directory: str, name: str): +# fullFile = f"{name}.pt" +# fullPath = os.path.join(directory, fullFile) +# logger.info(f"Writing model to path: {fullPath}") +# torch.jit.save(model, fullPath) + + +# @NodeFactory.register("PyTorch", "JIT::Load") +# class JitLoadNode(NodeBase): +# """JIT load node""" + +# def __init__(self): +# """Constructor""" +# self.description = "Load a JIT traced pytorch model from a file" +# self.inputs = [TorchFileInput()] +# self.outputs = [TorchScriptOutput()] + +# def run(self, path: str) -> torch.ScriptModule: +# # device = ( +# # f"cuda:0" +# # if torch.cuda.is_available() and os.environ["device"] != "cpu" +# # else "cpu" +# # ) +# model = torch.jit.load( +# path, map_location=torch.device("cpu") +# ) # , map_location=device) + +# return model + + +# @NodeFactory.register("PyTorch", "JIT::Run") +# class JitRunNode(NodeBase): +# """JIT run node""" + +# def __init__(self): +# """Constructor""" +# self.description = "Run a JIT traced pytorch model" +# self.inputs = [TorchScriptInput(), ImageInput()] +# self.outputs = [ImageOutput()] + +# def run(self, model: torch.ScriptModule, image: np.ndarray) -> np.ndarray: +# tensor = np2tensor(image).cpu() +# # if os.environ["device"] == "cuda": +# # model = model.cuda() +# # tensor = tensor.cuda() +# out = model.cpu()(tensor) + +# return out diff --git a/backend/nodes/utils/architecture/RRDB.py b/backend/nodes/utils/architecture/RRDB.py new file mode 100644 index 000000000..c142e1b1f --- /dev/null +++ b/backend/nodes/utils/architecture/RRDB.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import math +import re +from collections import OrderedDict + +import torch +import torch.nn as nn + +from . import block as B + + +# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py +# Which enhanced stuff that was already here +class RRDBNet(nn.Module): + def __init__( + self, + state_dict, + norm=None, + act: str = "leakyrelu", + upsampler: str = "upconv", + mode: str = "CNA", + ) -> None: + """ + ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. + By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, + and Chen Change Loy. + This is old-arch Residual in Residual Dense Block Network and is not + the newest revision that's available at github.com/xinntao/ESRGAN. + This is on purpose, the newest Network has severely limited the + potential use of the Network with no benefits. + This network supports model files from both new and old-arch. + Args: + norm: Normalization layer + act: Activation layer + upsampler: Upsample layer. upconv, pixel_shuffle + mode: Convolution mode + """ + super(RRDBNet, self).__init__() + + self.state = state_dict + self.norm = norm + self.act = act + self.upsampler = upsampler + self.mode = mode + + self.state_map = { + # currently supports old, new, and newer RRDBNet arch models + # ESRGAN, BSRGAN/RealSR, Real-ESRGAN + "model.0.weight": ("conv_first.weight",), + "model.0.bias": ("conv_first.bias",), + "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), + "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), + "model.3.weight": ("upconv1.weight", "conv_up1.weight"), + "model.3.bias": ("upconv1.bias", "conv_up1.bias"), + "model.6.weight": ("upconv2.weight", "conv_up2.weight"), + "model.6.bias": ("upconv2.bias", "conv_up2.bias"), + "model.8.weight": ("HRconv.weight", "conv_hr.weight"), + "model.8.bias": ("HRconv.bias", "conv_hr.bias"), + "model.10.weight": ("conv_last.weight",), + "model.10.bias": ("conv_last.bias",), + r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( + r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", + r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", + ), + } + if "params_ema" in self.state: + self.state = self.state["params_ema"] + self.num_blocks = self.get_num_blocks() + self.plus = any("conv1x1" in k for k in self.state.keys()) + + self.state = self.new_to_old_arch(self.state) + + self.key_arr = list(self.state.keys()) + + self.in_nc = self.state[self.key_arr[0]].shape[1] + self.out_nc = self.state[self.key_arr[-1]].shape[0] + + self.scale = self.get_scale() + self.num_filters = self.state[self.key_arr[0]].shape[0] + + # Detect if pixelunshuffle was used (Real-ESRGAN) + if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( + self.in_nc / 4, + self.in_nc / 16, + ): + self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) + else: + self.shuffle_factor = None + + upsample_block = { + "upconv": B.upconv_block, + "pixel_shuffle": B.pixelshuffle_block, + }.get(self.upsampler) + if upsample_block is None: + raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") + + if self.scale == 3: + upsample_blocks = upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + upscale_factor=3, + act_type=self.act, + ) + else: + upsample_blocks = [ + upsample_block( + in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + ) + for _ in range(int(math.log(self.scale, 2))) + ] + + self.model = B.sequential( + # fea conv + B.conv_block( + in_nc=self.in_nc, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ), + B.ShortcutBlock( + B.sequential( + # rrdb blocks + *[ + B.RRDB( + nf=self.num_filters, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=self.norm, + act_type=self.act, + mode="CNA", + plus=self.plus, + ) + for _ in range(self.num_blocks) + ], + # lr conv + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=self.norm, + act_type=None, + mode=self.mode, + ), + ) + ), + *upsample_blocks, + # hr_conv0 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=self.act, + ), + # hr_conv1 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.out_nc, + kernel_size=3, + norm_type=None, + act_type=None, + ), + ) + + self.load_state_dict(self.state, strict=False) + + def new_to_old_arch(self, state): + """Convert a new-arch model state dictionary to an old-arch dictionary.""" + if "params_ema" in state: + state = state["params_ema"] + + if "conv_first.weight" not in state: + # model is already old arch, this is a loose check, but should be sufficient + return state + + # add nb to state keys + for kind in ("weight", "bias"): + self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ + f"model.1.sub./NB/.{kind}" + ] + del self.state_map[f"model.1.sub./NB/.{kind}"] + + old_state = OrderedDict() + for old_key, new_keys in self.state_map.items(): + for new_key in new_keys: + if r"\1" in old_key: + for k, v in state.items(): + sub = re.sub(new_key, old_key, k) + if sub != k: + old_state[sub] = v + else: + if new_key in state: + old_state[old_key] = state[new_key] + + # Sort by first numeric value of each layer + def compare(item1, item2): + parts1 = item1.split(".") + parts2 = item2.split(".") + int1 = int(parts1[1]) + int2 = int(parts2[1]) + return int1 - int2 + + sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) + + # Rebuild the output dict in the right order + out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) + + return out_dict + + def get_scale(self, min_part: int = 6) -> int: + n = 0 + for part in list(self.state): + parts = part.split(".")[1:] + if len(parts) == 2: + part_num = int(parts[0]) + if part_num > min_part and parts[1] == "weight": + n += 1 + return 2 ** n + + def get_num_blocks(self) -> int: + nbs = [] + state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( + r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", + ) + for state_key in state_keys: + for k in self.state: + m = re.search(state_key, k) + if m: + nbs.append(int(m.group(1))) + if nbs: + break + return max(*nbs) + 1 + + def forward(self, x): + if self.shuffle_factor: + x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) + return self.model(x) diff --git a/backend/nodes/utils/architecture/SPSR.py b/backend/nodes/utils/architecture/SPSR.py new file mode 100644 index 000000000..2db873482 --- /dev/null +++ b/backend/nodes/utils/architecture/SPSR.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import block as B + + +class Get_gradient_nopadding(nn.Module): + def __init__(self): + super(Get_gradient_nopadding, self).__init__() + kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]] + kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) + + self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) + + def forward(self, x): + x_list = [] + for i in range(x.shape[1]): + x_i = x[:, i] + x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) + x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) + x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) + x_list.append(x_i) + + x = torch.cat(x_list, dim=1) + + return x + + +class SPSRNet(nn.Module): + def __init__( + self, + state_dict, + norm=None, + act: str = "leakyrelu", + upsampler: str = "upconv", + mode: str = "CNA", + ): + super(SPSRNet, self).__init__() + + self.state = state_dict + self.norm = norm + self.act = act + self.upsampler = upsampler + self.mode = mode + + self.num_blocks = self.get_num_blocks() + + self.in_nc = self.state["model.0.weight"].shape[1] + self.out_nc = self.state["f_HR_conv1.0.bias"].shape[0] + + self.scale = self.get_scale(4) + print(self.scale) + self.num_filters = self.state["model.0.weight"].shape[0] + + n_upscale = int(math.log(self.scale, 2)) + if self.scale == 3: + n_upscale = 1 + + fea_conv = B.conv_block( + self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None + ) + rb_blocks = [ + B.RRDB( + self.num_filters, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + for _ in range(self.num_blocks) + ] + LR_conv = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=norm, + act_type=None, + mode=mode, + ) + + if upsampler == "upconv": + upsample_block = B.upconv_block + elif upsampler == "pixelshuffle": + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError(f"upsample mode [{upsampler}] is not found") + if self.scale == 3: + a_upsampler = upsample_block( + self.num_filters, self.num_filters, 3, act_type=act + ) + else: + a_upsampler = [ + upsample_block(self.num_filters, self.num_filters, act_type=act) + for _ in range(n_upscale) + ] + self.HR_conv0_new = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=act, + ) + self.HR_conv1_new = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + + self.model = B.sequential( + fea_conv, + B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)), + *a_upsampler, + self.HR_conv0_new, + ) + + self.get_g_nopadding = Get_gradient_nopadding() + + self.b_fea_conv = B.conv_block( + self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None + ) + + self.b_concat_1 = B.conv_block( + 2 * self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + self.b_block_1 = B.RRDB( + self.num_filters * 2, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + + self.b_concat_2 = B.conv_block( + 2 * self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + self.b_block_2 = B.RRDB( + self.num_filters * 2, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + + self.b_concat_3 = B.conv_block( + 2 * self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + self.b_block_3 = B.RRDB( + self.num_filters * 2, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + + self.b_concat_4 = B.conv_block( + 2 * self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + self.b_block_4 = B.RRDB( + self.num_filters * 2, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + + self.b_LR_conv = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=norm, + act_type=None, + mode=mode, + ) + + if upsampler == "upconv": + upsample_block = B.upconv_block + elif upsampler == "pixelshuffle": + upsample_block = B.pixelshuffle_block + else: + raise NotImplementedError(f"upsample mode [{upsampler}] is not found") + if self.scale == 3: + b_upsampler = upsample_block( + self.num_filters, self.num_filters, 3, act_type=act + ) + else: + b_upsampler = [ + upsample_block(self.num_filters, self.num_filters, act_type=act) + for _ in range(n_upscale) + ] + + b_HR_conv0 = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=act, + ) + b_HR_conv1 = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + + self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) + + self.conv_w = B.conv_block( + self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None + ) + + self.f_concat = B.conv_block( + self.num_filters * 2, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + ) + + self.f_block = B.RRDB( + self.num_filters * 2, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=norm, + act_type=act, + mode="CNA", + ) + + self.f_HR_conv0 = B.conv_block( + self.num_filters, + self.num_filters, + kernel_size=3, + norm_type=None, + act_type=act, + ) + self.f_HR_conv1 = B.conv_block( + self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None + ) + + self.load_state_dict(self.state, strict=False) + + def get_scale(self, min_part: int = 4) -> int: + n = 0 + for part in list(self.state): + parts = part.split(".") + if len(parts) == 3: + part_num = int(parts[1]) + if part_num > min_part and parts[0] == "model" and parts[2] == "weight": + n += 1 + return 2 ** n + + def get_num_blocks(self) -> int: + nb = 0 + for part in list(self.state): + parts = part.split(".") + n_parts = len(parts) + if n_parts == 5 and parts[2] == "sub": + nb = int(parts[3]) + return nb + + def forward(self, x): + x_grad = self.get_g_nopadding(x) + x = self.model[0](x) + + x, block_list = self.model[1](x) + + x_ori = x + for i in range(5): + x = block_list[i](x) + x_fea1 = x + + for i in range(5): + x = block_list[i + 5](x) + x_fea2 = x + + for i in range(5): + x = block_list[i + 10](x) + x_fea3 = x + + for i in range(5): + x = block_list[i + 15](x) + x_fea4 = x + + x = block_list[20:](x) + # short cut + x = x_ori + x + x = self.model[2:](x) + x = self.HR_conv1_new(x) + + x_b_fea = self.b_fea_conv(x_grad) + x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) + + x_cat_1 = self.b_block_1(x_cat_1) + x_cat_1 = self.b_concat_1(x_cat_1) + + x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) + + x_cat_2 = self.b_block_2(x_cat_2) + x_cat_2 = self.b_concat_2(x_cat_2) + + x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) + + x_cat_3 = self.b_block_3(x_cat_3) + x_cat_3 = self.b_concat_3(x_cat_3) + + x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) + + x_cat_4 = self.b_block_4(x_cat_4) + x_cat_4 = self.b_concat_4(x_cat_4) + + x_cat_4 = self.b_LR_conv(x_cat_4) + + # short cut + x_cat_4 = x_cat_4 + x_b_fea + x_branch = self.b_module(x_cat_4) + + # x_out_branch = self.conv_w(x_branch) + ######## + x_branch_d = x_branch + x_f_cat = torch.cat([x_branch_d, x], dim=1) + x_f_cat = self.f_block(x_f_cat) + x_out = self.f_concat(x_f_cat) + x_out = self.f_HR_conv0(x_out) + x_out = self.f_HR_conv1(x_out) + + ######### + # return x_out_branch, x_out, x_grad + return x_out diff --git a/backend/nodes/utils/architecture/SRVGG.py b/backend/nodes/utils/architecture/SRVGG.py new file mode 100644 index 000000000..8b76b38d4 --- /dev/null +++ b/backend/nodes/utils/architecture/SRVGG.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Union + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__( + self, + state_dict, + act_type: str = "prelu", + ): + super(SRVGGNetCompact, self).__init__() + self.act_type = act_type + + self.state = state_dict + + if "params" in self.state: + self.state = self.state["params"] + + self.key_arr = list(self.state.keys()) + + self.in_nc = self.get_in_nc() + self.num_feat = self.get_num_feats() + self.num_conv = self.get_num_conv() + self.out_nc = self.in_nc # :( + self.scale = self.get_scale() + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1)) + # the first activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=self.num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(self.num_conv): + self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1)) + # activation + if act_type == "relu": + activation = nn.ReLU(inplace=True) + elif act_type == "prelu": + activation = nn.PReLU(num_parameters=self.num_feat) + elif act_type == "leakyrelu": + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(self.scale) + + self.load_state_dict(self.state, strict=False) + + def get_num_conv(self) -> int: + return (int(self.key_arr[-1].split(".")[1]) - 2) // 2 + + def get_num_feats(self) -> int: + return self.state[self.key_arr[0]].shape[0] + + def get_in_nc(self) -> int: + return self.state[self.key_arr[0]].shape[1] + + def get_scale(self) -> int: + self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0] + # Assume out_nc is the same as in_nc + # I cant think of a better way to do that + self.out_nc = self.in_nc + scale = math.sqrt(self.pixelshuffle_shape / self.out_nc) + if scale - int(scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + scale = int(scale) + return scale + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.scale, mode="nearest") + out += base + return out diff --git a/backend/nodes/utils/architecture/__index__.py b/backend/nodes/utils/architecture/__index__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/nodes/utils/architectures/block.py b/backend/nodes/utils/architecture/block.py similarity index 92% rename from backend/nodes/utils/architectures/block.py rename to backend/nodes/utils/architecture/block.py index d7b17ecb8..7fa1ca6dc 100644 --- a/backend/nodes/utils/architectures/block.py +++ b/backend/nodes/utils/architecture/block.py @@ -1,9 +1,10 @@ -# pylint: skip-file +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- from collections import OrderedDict +import torch import torch.nn as nn -from torch import cat as torch_cat #################### # Basic blocks @@ -72,7 +73,7 @@ def __init__(self, submodule): self.sub = submodule def forward(self, x): - output = torch_cat((x, self.sub(x)), dim=1) + output = torch.cat((x, self.sub(x)), dim=1) return output def __repr__(self): @@ -99,6 +100,22 @@ def __repr__(self): return tmpstr +class ShortcutBlockSPSR(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlockSPSR, self).__init__() + self.sub = submodule + + def forward(self, x): + return x, self.sub + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + def sequential(*args): # Flatten Sequential. It unwraps nn.Sequential. if len(args) == 1: @@ -309,9 +326,9 @@ class ResidualDenseBlock_5C(nn.Module): style: 5 convs The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) Modified options that can be used: - - 'Partial Convolution based Padding' arXiv:1811.11718 - - 'Spectral normalization' arXiv:1802.05957 - - 'ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN' N. C. + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} Args: @@ -340,9 +357,9 @@ def __init__( ): super(ResidualDenseBlock_5C, self).__init__() - # + + ## + self.conv1x1 = conv1x1(nf, gc) if plus else None - # + + ## + self.conv1 = conv_block( nf, @@ -406,14 +423,14 @@ def __init__( def forward(self, x): x1 = self.conv1(x) - x2 = self.conv2(torch_cat((x, x1), 1)) + x2 = self.conv2(torch.cat((x, x1), 1)) if self.conv1x1: x2 = x2 + self.conv1x1(x) # + - x3 = self.conv3(torch_cat((x, x1, x2), 1)) - x4 = self.conv4(torch_cat((x, x1, x2, x3), 1)) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) if self.conv1x1: x4 = x4 + x2 # + - x5 = self.conv5(torch_cat((x, x1, x2, x3, x4), 1)) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x diff --git a/backend/nodes/utils/architectures/RRDB.py b/backend/nodes/utils/architectures/RRDB.py deleted file mode 100644 index cb73fbaf9..000000000 --- a/backend/nodes/utils/architectures/RRDB.py +++ /dev/null @@ -1,108 +0,0 @@ -# pylint: skip-file - -import math - -import torch.nn as nn -from torch import clamp, sigmoid, tanh - -from . import block as B - - -class RRDBNet(nn.Module): - def __init__( - self, - in_nc, - out_nc, - nf, - nb, - gc=32, - upscale=4, - norm_type=None, - act_type="leakyrelu", - mode="CNA", - upsample_mode="upconv", - convtype="Conv2D", - finalact=None, - plus=False, - ): - super(RRDBNet, self).__init__() - - # Extra class-level values for checking later on - self.in_nc = in_nc - self.out_nc = out_nc - - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.scale = n_upscale ** 2 - - fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) - rb_blocks = [ - B.RRDB( - nf, - kernel_size=3, - gc=32, - stride=1, - bias=1, - pad_type="zero", - norm_type=norm_type, - act_type=act_type, - mode="CNA", - convtype=convtype, - plus=plus, - ) - for _ in range(nb) - ] - LR_conv = B.conv_block( - nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode - ) - - if upsample_mode == "upconv": - upsample_block = B.upconv_block - elif upsample_mode == "pixelshuffle": - upsample_block = B.pixelshuffle_block - else: - raise NotImplementedError( - "upsample mode [{:s}] is not found".format(upsample_mode) - ) - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type) - else: - upsampler = [ - upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale) - ] - HR_conv0 = B.conv_block( - nf, nf, kernel_size=3, norm_type=None, act_type=act_type - ) - HR_conv1 = B.conv_block( - nf, out_nc, kernel_size=3, norm_type=None, act_type=None - ) - - # Note: this option adds new parameters to the architecture, another option is to use 'outm' in the forward - outact = B.act(finalact) if finalact else None - - self.model = B.sequential( - fea_conv, - B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), - *upsampler, - HR_conv0, - HR_conv1, - outact - ) - - def forward(self, x, outm=None): - x = self.model(x) - - if ( - outm == "scaltanh" - ): # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py - return (tanh(x) + 1.0) / 2.0 - elif outm == "tanh": # limit output to [-1,1] range - return tanh(x) - elif outm == "sigmoid": # limit output to [0,1] range - return sigmoid(x) - elif outm == "clamp": - return clamp(x, min=0.0, max=1.0) - else: # Default, no cap for the output - return x diff --git a/backend/nodes/utils/utils.py b/backend/nodes/utils/utils.py index 41c3ab6c7..b6f19263e 100644 --- a/backend/nodes/utils/utils.py +++ b/backend/nodes/utils/utils.py @@ -1,280 +1,283 @@ -# pylint: skip-file -# From https://github.com/victorca25/iNNfer/blob/main/utils/utils.py - -import gc -from typing import Tuple - -import numpy as np -from sanic.log import logger -from torch import Tensor, cuda, empty, from_numpy - -MAX_VALUES_BY_DTYPE = { - np.dtype("int8"): 127, - np.dtype("uint8"): 255, - np.dtype("int16"): 32767, - np.dtype("uint16"): 65535, - np.dtype("int32"): 2147483647, - np.dtype("uint32"): 4294967295, - np.dtype("int64"): 9223372036854775807, - np.dtype("uint64"): 18446744073709551615, - np.dtype("float32"): 1.0, - np.dtype("float64"): 1.0, -} - - -def bgr_to_rgb(image: Tensor) -> Tensor: - # flip image channels - # https://github.com/pytorch/pytorch/issues/229 - out: Tensor = image.flip(-3) - # RGB to BGR #may be faster: - # out: Tensor = image[[2, 1, 0], :, :] - return out - - -def rgb_to_bgr(image: Tensor) -> Tensor: - # same operation as bgr_to_rgb(), flip image channels - return bgr_to_rgb(image) - - -def bgra_to_rgba(image: Tensor) -> Tensor: - out: Tensor = image[[2, 1, 0, 3], :, :] - return out - - -def rgba_to_bgra(image: Tensor) -> Tensor: - # same operation as bgra_to_rgba(), flip image channels - return bgra_to_rgba(image) - - -def denorm(x, min_max=(-1.0, 1.0)): - """Denormalize from [-1,1] range to [0,1] - formula: xi' = (xi - mu)/sigma - Example: "out = (x + 1.0) / 2.0" for denorm - range (-1,1) to (0,1) - for use with proper act in Generator output (ie. tanh) - """ - out = (x - min_max[0]) / (min_max[1] - min_max[0]) - if isinstance(x, Tensor): - return out.clamp(0, 1) - elif isinstance(x, np.ndarray): - return np.clip(out, 0, 1) - else: - raise TypeError("Got unexpected object type, expected Tensor or np.ndarray") - - -def norm(x): - """ Normalize (z-norm) from [0,1] range to [-1,1] """ - out = (x - 0.5) * 2.0 - if isinstance(x, Tensor): - return out.clamp(-1, 1) - elif isinstance(x, np.ndarray): - return np.clip(out, -1, 1) - else: - raise TypeError("Got unexpected object type, expected Tensor or np.ndarray") - - -def np2tensor( - img: np.ndarray, - bgr2rgb=True, - data_range=1.0, - normalize=False, - change_range=True, - add_batch=True, -) -> Tensor: - """Converts a numpy image array into a Tensor array. - Parameters: - img (numpy array): the input image numpy array - add_batch (bool): choose if new tensor needs batch dimension added - """ - if not isinstance(img, np.ndarray): # images expected to be uint8 -> 255 - raise TypeError("Got unexpected object type, expected np.ndarray") - # check how many channels the image has, then condition. ie. RGB, RGBA, Gray - # if bgr2rgb: - # img = img[ - # :, :, [2, 1, 0] - # ] # BGR to RGB -> in numpy, if using OpenCV, else not needed. Only if image has colors. - if change_range: - dtype = img.dtype - maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0) - t_dtype = np.dtype("float32") - img = img.astype(t_dtype) / maxval # ie: uint8 = /255 - img = from_numpy( - np.ascontiguousarray(np.transpose(img, (2, 0, 1))) - ).float() # "HWC to CHW" and "numpy to tensor" - if bgr2rgb: - # BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors.) - if ( - img.shape[0] % 3 == 0 - ): # RGB or MultixRGB (3xRGB, 5xRGB, etc. For video tensors.) - img = bgr_to_rgb(img) - elif img.shape[0] == 4: # RGBA - img = bgra_to_rgba(img) - if add_batch: - img.unsqueeze_( - 0 - ) # Add fake batch dimension = 1 . squeeze() will remove the dimensions of size 1 - if normalize: - img = norm(img) - return img - - -def tensor2np( - img: Tensor, - rgb2bgr=True, - remove_batch=True, - data_range=255, - denormalize=False, - change_range=True, - imtype=np.uint8, -) -> np.ndarray: - """Converts a Tensor array into a numpy image array. - Parameters: - img (tensor): the input image tensor array - 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order - remove_batch (bool): choose if tensor of shape BCHW needs to be squeezed - denormalize (bool): Used to denormalize from [-1,1] range back to [0,1] - imtype (type): the desired type of the converted numpy array (np.uint8 - default) - Output: - img (np array): 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - """ - if not isinstance(img, Tensor): - raise TypeError("Got unexpected object type, expected Tensor") - n_dim = img.dim() - - # TODO: Check: could denormalize here in tensor form instead, but end result is the same - - img = img.float().cpu() - - if n_dim in (4, 3): - # if n_dim == 4, has to convert to 3 dimensions - if n_dim == 4 and remove_batch: - # remove a fake batch dimension - img = img.squeeze(dim=0) - - if img.shape[0] == 3 and rgb2bgr: # RGB - # RGB to BGR -> in tensor, if using OpenCV, else not needed. Only if image has colors. - img_np = rgb_to_bgr(img).numpy() - elif img.shape[0] == 4 and rgb2bgr: # RGBA - # RGBA to BGRA -> in tensor, if using OpenCV, else not needed. Only if image has colors. - img_np = rgba_to_bgra(img).numpy() - else: - img_np = img.numpy() - img_np = np.transpose(img_np, (1, 2, 0)) # CHW to HWC - elif n_dim == 2: - img_np = img.numpy() - else: - raise TypeError( - f"Only support 4D, 3D and 2D tensor. But received with dimension: {n_dim:d}" - ) - - # if rgb2bgr: - # img_np = img_np[[2, 1, 0], :, :] #RGB to BGR -> in numpy, if using OpenCV, else not needed. Only if image has colors. - # TODO: Check: could denormalize in the begining in tensor form instead - if denormalize: - img_np = denorm(img_np) # denormalize if needed - if change_range: - img_np = np.clip( - data_range * img_np, 0, data_range - ).round() # np.clip to the data_range - - # has to be in range (0,255) before changing to np.uint8, else np.float32 - return img_np.astype(imtype) - - -def auto_split_process( - lr_img: Tensor, - model, - scale: int = 4, - overlap: int = 32, - max_depth: int = None, - current_depth: int = 1, -) -> Tuple[Tensor, int]: - # Original code: https://github.com/JoeyBallentine/ESRGAN/blob/master/utils/dataops.py - - # Prevent splitting from causing an infinite out-of-vram loop - if current_depth > 15: - cuda.empty_cache() - gc.collect() - raise RuntimeError("Splitting stopped to prevent infinite loop") - - # Attempt to upscale if unknown depth or if reached known max depth - if max_depth is None or max_depth == current_depth: - try: - result = model(lr_img) - return result, current_depth - except RuntimeError as e: - print(e) - # Check to see if its actually the CUDA out of memory error - if "allocate" in str(e): - # Collect garbage (clear VRAM) - cuda.empty_cache() - gc.collect() - # Re-raise the exception if not an OOM error - else: - raise RuntimeError(e) - - b, c, h, w = lr_img.shape - - # Split image into 4ths - top_left = lr_img[..., : h // 2 + overlap, : w // 2 + overlap] - top_right = lr_img[..., : h // 2 + overlap, w // 2 - overlap :] - bottom_left = lr_img[..., h // 2 - overlap :, : w // 2 + overlap] - bottom_right = lr_img[..., h // 2 - overlap :, w // 2 - overlap :] - - # Recursively upscale the quadrants - # After we go through the top left quadrant, we know the maximum depth and no longer need to test for out-of-memory - top_left_rlt, depth = auto_split_process( - top_left, - model, - scale=scale, - overlap=overlap, - current_depth=current_depth + 1, - ) - top_right_rlt, _ = auto_split_process( - top_right, - model, - scale=scale, - overlap=overlap, - max_depth=depth, - current_depth=current_depth + 1, - ) - bottom_left_rlt, _ = auto_split_process( - bottom_left, - model, - scale=scale, - overlap=overlap, - max_depth=depth, - current_depth=current_depth + 1, - ) - bottom_right_rlt, _ = auto_split_process( - bottom_right, - model, - scale=scale, - overlap=overlap, - max_depth=depth, - current_depth=current_depth + 1, - ) - - # Define output shape - out_h = h * scale - out_w = w * scale - - # Create blank output image - output_img = empty((b, c, out_h, out_w), dtype=lr_img.dtype, device=lr_img.device) - - # Fill output image with tiles, cropping out the overlaps - output_img[..., : out_h // 2, : out_w // 2] = top_left_rlt[ - ..., : out_h // 2, : out_w // 2 - ] - output_img[..., : out_h // 2, -out_w // 2 :] = top_right_rlt[ - ..., : out_h // 2, -out_w // 2 : - ] - output_img[..., -out_h // 2 :, : out_w // 2] = bottom_left_rlt[ - ..., -out_h // 2 :, : out_w // 2 - ] - output_img[..., -out_h // 2 :, -out_w // 2 :] = bottom_right_rlt[ - ..., -out_h // 2 :, -out_w // 2 : - ] - - return output_img, depth +# pylint: skip-file +# From https://github.com/victorca25/iNNfer/blob/main/utils/utils.py +import sys + +sys.path.append("...") + +import gc +from typing import Tuple + +import numpy as np +from sanic_server.sanic.log import logger +from torch import Tensor, cuda, empty, from_numpy + +MAX_VALUES_BY_DTYPE = { + np.dtype("int8"): 127, + np.dtype("uint8"): 255, + np.dtype("int16"): 32767, + np.dtype("uint16"): 65535, + np.dtype("int32"): 2147483647, + np.dtype("uint32"): 4294967295, + np.dtype("int64"): 9223372036854775807, + np.dtype("uint64"): 18446744073709551615, + np.dtype("float32"): 1.0, + np.dtype("float64"): 1.0, +} + + +def bgr_to_rgb(image: Tensor) -> Tensor: + # flip image channels + # https://github.com/pytorch/pytorch/issues/229 + out: Tensor = image.flip(-3) + # RGB to BGR #may be faster: + # out: Tensor = image[[2, 1, 0], :, :] + return out + + +def rgb_to_bgr(image: Tensor) -> Tensor: + # same operation as bgr_to_rgb(), flip image channels + return bgr_to_rgb(image) + + +def bgra_to_rgba(image: Tensor) -> Tensor: + out: Tensor = image[[2, 1, 0, 3], :, :] + return out + + +def rgba_to_bgra(image: Tensor) -> Tensor: + # same operation as bgra_to_rgba(), flip image channels + return bgra_to_rgba(image) + + +def denorm(x, min_max=(-1.0, 1.0)): + """Denormalize from [-1,1] range to [0,1] + formula: xi' = (xi - mu)/sigma + Example: "out = (x + 1.0) / 2.0" for denorm + range (-1,1) to (0,1) + for use with proper act in Generator output (ie. tanh) + """ + out = (x - min_max[0]) / (min_max[1] - min_max[0]) + if isinstance(x, Tensor): + return out.clamp(0, 1) + elif isinstance(x, np.ndarray): + return np.clip(out, 0, 1) + else: + raise TypeError("Got unexpected object type, expected Tensor or np.ndarray") + + +def norm(x): + """Normalize (z-norm) from [0,1] range to [-1,1]""" + out = (x - 0.5) * 2.0 + if isinstance(x, Tensor): + return out.clamp(-1, 1) + elif isinstance(x, np.ndarray): + return np.clip(out, -1, 1) + else: + raise TypeError("Got unexpected object type, expected Tensor or np.ndarray") + + +def np2tensor( + img: np.ndarray, + bgr2rgb=True, + data_range=1.0, + normalize=False, + change_range=True, + add_batch=True, +) -> Tensor: + """Converts a numpy image array into a Tensor array. + Parameters: + img (numpy array): the input image numpy array + add_batch (bool): choose if new tensor needs batch dimension added + """ + if not isinstance(img, np.ndarray): # images expected to be uint8 -> 255 + raise TypeError("Got unexpected object type, expected np.ndarray") + # check how many channels the image has, then condition. ie. RGB, RGBA, Gray + # if bgr2rgb: + # img = img[ + # :, :, [2, 1, 0] + # ] # BGR to RGB -> in numpy, if using OpenCV, else not needed. Only if image has colors. + if change_range: + dtype = img.dtype + maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0) + t_dtype = np.dtype("float32") + img = img.astype(t_dtype) / maxval # ie: uint8 = /255 + img = from_numpy( + np.ascontiguousarray(np.transpose(img, (2, 0, 1))) + ).float() # "HWC to CHW" and "numpy to tensor" + if bgr2rgb: + # BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors.) + if ( + img.shape[0] % 3 == 0 + ): # RGB or MultixRGB (3xRGB, 5xRGB, etc. For video tensors.) + img = bgr_to_rgb(img) + elif img.shape[0] == 4: # RGBA + img = bgra_to_rgba(img) + if add_batch: + img.unsqueeze_( + 0 + ) # Add fake batch dimension = 1 . squeeze() will remove the dimensions of size 1 + if normalize: + img = norm(img) + return img + + +def tensor2np( + img: Tensor, + rgb2bgr=True, + remove_batch=True, + data_range=255, + denormalize=False, + change_range=True, + imtype=np.uint8, +) -> np.ndarray: + """Converts a Tensor array into a numpy image array. + Parameters: + img (tensor): the input image tensor array + 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + remove_batch (bool): choose if tensor of shape BCHW needs to be squeezed + denormalize (bool): Used to denormalize from [-1,1] range back to [0,1] + imtype (type): the desired type of the converted numpy array (np.uint8 + default) + Output: + img (np array): 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + """ + if not isinstance(img, Tensor): + raise TypeError("Got unexpected object type, expected Tensor") + n_dim = img.dim() + + # TODO: Check: could denormalize here in tensor form instead, but end result is the same + + img = img.float().cpu() + + if n_dim in (4, 3): + # if n_dim == 4, has to convert to 3 dimensions + if n_dim == 4 and remove_batch: + # remove a fake batch dimension + img = img.squeeze(dim=0) + + if img.shape[0] == 3 and rgb2bgr: # RGB + # RGB to BGR -> in tensor, if using OpenCV, else not needed. Only if image has colors. + img_np = rgb_to_bgr(img).numpy() + elif img.shape[0] == 4 and rgb2bgr: # RGBA + # RGBA to BGRA -> in tensor, if using OpenCV, else not needed. Only if image has colors. + img_np = rgba_to_bgra(img).numpy() + else: + img_np = img.numpy() + img_np = np.transpose(img_np, (1, 2, 0)) # CHW to HWC + elif n_dim == 2: + img_np = img.numpy() + else: + raise TypeError( + f"Only support 4D, 3D and 2D tensor. But received with dimension: {n_dim:d}" + ) + + # if rgb2bgr: + # img_np = img_np[[2, 1, 0], :, :] #RGB to BGR -> in numpy, if using OpenCV, else not needed. Only if image has colors. + # TODO: Check: could denormalize in the begining in tensor form instead + if denormalize: + img_np = denorm(img_np) # denormalize if needed + if change_range: + img_np = np.clip( + data_range * img_np, 0, data_range + ).round() # np.clip to the data_range + + # has to be in range (0,255) before changing to np.uint8, else np.float32 + return img_np.astype(imtype) + + +def auto_split_process( + lr_img: Tensor, + model, + scale: int = 4, + overlap: int = 32, + max_depth: int = None, + current_depth: int = 1, +) -> Tuple[Tensor, int]: + # Original code: https://github.com/JoeyBallentine/ESRGAN/blob/master/utils/dataops.py + + # Prevent splitting from causing an infinite out-of-vram loop + if current_depth > 15: + cuda.empty_cache() + gc.collect() + raise RuntimeError("Splitting stopped to prevent infinite loop") + + # Attempt to upscale if unknown depth or if reached known max depth + if max_depth is None or max_depth == current_depth: + try: + result = model(lr_img) + return result, current_depth + except RuntimeError as e: + print(e) + # Check to see if its actually the CUDA out of memory error + if "allocate" in str(e): + # Collect garbage (clear VRAM) + cuda.empty_cache() + gc.collect() + # Re-raise the exception if not an OOM error + else: + raise RuntimeError(e) + + b, c, h, w = lr_img.shape + + # Split image into 4ths + top_left = lr_img[..., : h // 2 + overlap, : w // 2 + overlap] + top_right = lr_img[..., : h // 2 + overlap, w // 2 - overlap :] + bottom_left = lr_img[..., h // 2 - overlap :, : w // 2 + overlap] + bottom_right = lr_img[..., h // 2 - overlap :, w // 2 - overlap :] + + # Recursively upscale the quadrants + # After we go through the top left quadrant, we know the maximum depth and no longer need to test for out-of-memory + top_left_rlt, depth = auto_split_process( + top_left, + model, + scale=scale, + overlap=overlap, + current_depth=current_depth + 1, + ) + top_right_rlt, _ = auto_split_process( + top_right, + model, + scale=scale, + overlap=overlap, + max_depth=depth, + current_depth=current_depth + 1, + ) + bottom_left_rlt, _ = auto_split_process( + bottom_left, + model, + scale=scale, + overlap=overlap, + max_depth=depth, + current_depth=current_depth + 1, + ) + bottom_right_rlt, _ = auto_split_process( + bottom_right, + model, + scale=scale, + overlap=overlap, + max_depth=depth, + current_depth=current_depth + 1, + ) + + # Define output shape + out_h = h * scale + out_w = w * scale + + # Create blank output image + output_img = empty((b, c, out_h, out_w), dtype=lr_img.dtype, device=lr_img.device) + + # Fill output image with tiles, cropping out the overlaps + output_img[..., : out_h // 2, : out_w // 2] = top_left_rlt[ + ..., : out_h // 2, : out_w // 2 + ] + output_img[..., : out_h // 2, -out_w // 2 :] = top_right_rlt[ + ..., : out_h // 2, -out_w // 2 : + ] + output_img[..., -out_h // 2 :, : out_w // 2] = bottom_left_rlt[ + ..., -out_h // 2 :, : out_w // 2 + ] + output_img[..., -out_h // 2 :, -out_w // 2 :] = bottom_right_rlt[ + ..., -out_h // 2 :, -out_w // 2 : + ] + + return output_img, depth diff --git a/backend/package.json b/backend/package.json deleted file mode 100644 index 5becb77f5..000000000 --- a/backend/package.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "name": "chainner-backend", - "scripts": { - "start": "python run.py", - "build": "bash build.sh" - } -} \ No newline at end of file diff --git a/backend/process.py b/backend/process.py index ce6fd1b57..74cfb29cb 100644 --- a/backend/process.py +++ b/backend/process.py @@ -3,10 +3,9 @@ import uuid from typing import Dict, List -from sanic import app -from sanic.log import logger - from nodes.node_factory import NodeFactory +from sanic_server.sanic import app +from sanic_server.sanic.log import logger class Executor: diff --git a/backend/run.py b/backend/run.py index b1645e1f4..38d74c0cb 100644 --- a/backend/run.py +++ b/backend/run.py @@ -2,38 +2,41 @@ import os import sys -from sanic import Sanic -from sanic.log import logger -from sanic.response import json -from sanic_cors import CORS +from sanic_server.sanic import Sanic +from sanic_server.sanic.log import logger +from sanic_server.sanic.response import json +from sanic_server.sanic_cors import CORS try: import cv2 from nodes import opencv_nodes -except: - logger.info("OpenCV not installed") +except Exception as e: + logger.warning(e) + logger.info("OpenCV most likely not installed") try: import numpy from nodes import numpy_nodes -except: - logger.info("NumPy not installed") +except Exception as e: + logger.warning(e) + logger.info("NumPy most likely not installed") try: import torch from nodes import pytorch_nodes -except: - logger.info("PyTorch not installed") +except Exception as e: + logger.warning(e) + logger.info("PyTorch most likely not installed") from nodes.node_factory import NodeFactory from process import Executor app = Sanic("chaiNNer") CORS(app) -app.executor = None +app.ctx.executor = None @app.route("/nodes") @@ -59,9 +62,9 @@ async def nodes(_): async def run(request): """Runs the provided nodes""" try: - if request.app.executor: + if request.app.ctx.executor: logger.info("Resuming existing executor...") - executor = request.app.executor + executor = request.app.ctx.executor await executor.run() else: logger.info("Running new executor...") @@ -73,14 +76,14 @@ async def run(request): os.environ["resolutionX"] = str(full_data["resolutionX"]) os.environ["resolutionY"] = str(full_data["resolutionY"]) executor = Executor(nodes_list, app.loop) - request.app.executor = executor + request.app.ctx.executor = executor await executor.run() if not executor.paused: - request.app.executor = None + request.app.ctx.executor = None return json({"message": "Successfully ran nodes!"}, status=200) except Exception as exception: logger.log(2, exception, exc_info=1) - request.app.executor = None + request.app.ctx.executor = None return json( {"message": "Error running nodes!", "exception": str(exception)}, status=500 ) @@ -90,7 +93,7 @@ async def run(request): async def check(request): """Check the execution status""" try: - executor = request.app.executor + executor = request.app.ctx.executor if executor: response = await executor.check() return json(response, status=200) @@ -98,7 +101,7 @@ async def check(request): return json({"message": "No executor to check!"}, status=400) except Exception as exception: logger.log(2, exception, exc_info=1) - request.app.executor = None + request.app.ctx.executor = None return json( {"message": "Error checking nodes!", "exception": str(exception)}, status=500, @@ -109,9 +112,9 @@ async def check(request): async def kill(request): """Pauses the current execution""" try: - if request.app.executor: + if request.app.ctx.executor: logger.info("Executor found. Attempting to pause...") - await request.app.executor.pause() + await request.app.ctx.executor.pause() return json({"message": "Successfully paused execution!"}, status=200) logger.info("No executor to pause") return json({"message": "No executor to pause!"}, status=200) @@ -127,10 +130,10 @@ async def kill(request): async def kill(request): """Kills the current execution""" try: - if request.app.executor: + if request.app.ctx.executor: logger.info("Executor found. Attempting to kill...") - await request.app.executor.kill() - request.app.executor = None + await request.app.ctx.executor.kill() + request.app.ctx.executor = None return json({"message": "Successfully killed execution!"}, status=200) logger.info("No executor to kill") return json({"message": "No executor to kill!"}, status=200) @@ -143,5 +146,8 @@ async def kill(request): if __name__ == "__main__": - port = sys.argv[1] or 8000 + try: + port = sys.argv[1] or 8000 + except: + port = 8000 app.run(port=port) diff --git a/backend/sanic_server/sanic/LICENSE b/backend/sanic_server/sanic/LICENSE new file mode 100644 index 000000000..35740e3da --- /dev/null +++ b/backend/sanic_server/sanic/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016-present Sanic Community + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/backend/sanic_server/sanic/__init__.py b/backend/sanic_server/sanic/__init__.py new file mode 100644 index 000000000..8948b64a1 --- /dev/null +++ b/backend/sanic_server/sanic/__init__.py @@ -0,0 +1,18 @@ +from ..sanic.__version__ import __version__ +from ..sanic.app import Sanic +from ..sanic.blueprints import Blueprint +from ..sanic.constants import HTTPMethod +from ..sanic.request import Request +from ..sanic.response import HTTPResponse, html, json, text + +__all__ = ( + "__version__", + "Sanic", + "Blueprint", + "HTTPMethod", + "HTTPResponse", + "Request", + "html", + "json", + "text", +) diff --git a/backend/sanic_server/sanic/__main__.py b/backend/sanic_server/sanic/__main__.py new file mode 100644 index 000000000..a3874a9da --- /dev/null +++ b/backend/sanic_server/sanic/__main__.py @@ -0,0 +1,195 @@ +import os +import sys +from argparse import ArgumentParser, RawTextHelpFormatter +from importlib import import_module +from pathlib import Path +from typing import Any, Dict, Optional + +from ..sanic import __version__ +from ..sanic.app import Sanic +from ..sanic.config import BASE_LOGO +from ..sanic.log import error_logger +from ..sanic.simple import create_simple_server +from ..sanic_routing import __version__ as __routing_version__ # type: ignore + + +class SanicArgumentParser(ArgumentParser): + def add_bool_arguments(self, *args, **kwargs): + group = self.add_mutually_exclusive_group() + group.add_argument(*args, action="store_true", **kwargs) + kwargs["help"] = f"no {kwargs['help']}\n " + group.add_argument( + "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs + ) + + +def main(): + parser = SanicArgumentParser( + prog="sanic", + description=BASE_LOGO, + formatter_class=lambda prog: RawTextHelpFormatter( + prog, max_help_position=33 + ), + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=f"Sanic {__version__}; Routing {__routing_version__}", + ) + parser.add_argument( + "--factory", + action="store_true", + help=( + "Treat app as an application factory, " + "i.e. a () -> callable" + ), + ) + parser.add_argument( + "-s", + "--simple", + dest="simple", + action="store_true", + help="Run Sanic as a Simple Server (module arg should be a path)\n ", + ) + parser.add_argument( + "-H", + "--host", + dest="host", + type=str, + default="127.0.0.1", + help="Host address [default 127.0.0.1]", + ) + parser.add_argument( + "-p", + "--port", + dest="port", + type=int, + default=8000, + help="Port to serve on [default 8000]", + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + type=str, + default="", + help="location of unix socket\n ", + ) + parser.add_argument( + "--cert", dest="cert", type=str, help="Location of certificate for SSL" + ) + parser.add_argument( + "--key", dest="key", type=str, help="location of keyfile for SSL\n " + ) + parser.add_bool_arguments( + "--access-logs", dest="access_log", help="display access logs" + ) + parser.add_argument( + "-w", + "--workers", + dest="workers", + type=int, + default=1, + help="number of worker processes [default 1]\n ", + ) + parser.add_argument("-d", "--debug", dest="debug", action="store_true") + parser.add_argument( + "-r", + "--reload", + "--auto-reload", + dest="auto_reload", + action="store_true", + help="Watch source directory for file changes and reload on changes", + ) + parser.add_argument( + "-R", + "--reload-dir", + dest="path", + action="append", + help="Extra directories to watch and reload on changes\n ", + ) + parser.add_argument( + "module", + help=( + "Path to your Sanic app. Example: path.to.server:app\n" + "If running a Simple Server, path to directory to serve. " + "Example: ./\n" + ), + ) + args = parser.parse_args() + + try: + module_path = os.path.abspath(os.getcwd()) + if module_path not in sys.path: + sys.path.append(module_path) + + if args.simple: + path = Path(args.module) + app = create_simple_server(path) + else: + delimiter = ":" if ":" in args.module else "." + module_name, app_name = args.module.rsplit(delimiter, 1) + + if app_name.endswith("()"): + args.factory = True + app_name = app_name[:-2] + + module = import_module(module_name) + app = getattr(module, app_name, None) + if args.factory: + app = app() + + app_type_name = type(app).__name__ + + if not isinstance(app, Sanic): + raise ValueError( + f"Module is not a Sanic app, it is a {app_type_name}. " + f"Perhaps you meant {args.module}.app?" + ) + if args.cert is not None or args.key is not None: + ssl: Optional[Dict[str, Any]] = { + "cert": args.cert, + "key": args.key, + } + else: + ssl = None + + kwargs = { + "host": args.host, + "port": args.port, + "unix": args.unix, + "workers": args.workers, + "debug": args.debug, + "access_log": args.access_log, + "ssl": ssl, + } + if args.auto_reload: + kwargs["auto_reload"] = True + + if args.path: + if args.auto_reload or args.debug: + kwargs["reload_dir"] = args.path + else: + error_logger.warning( + "Ignoring '--reload-dir' since auto reloading was not " + "enabled. If you would like to watch directories for " + "changes, consider using --debug or --auto-reload." + ) + + app.run(**kwargs) + except ImportError as e: + if module_name.startswith(e.name): + error_logger.error( + f"No module named {e.name} found.\n" + " Example File: project/sanic_server.py -> app\n" + " Example Module: project.sanic_server.app" + ) + else: + raise e + except ValueError: + error_logger.exception("Failed to run app") + + +if __name__ == "__main__": + main() diff --git a/backend/sanic_server/sanic/__version__.py b/backend/sanic_server/sanic/__version__.py new file mode 100644 index 000000000..ec8701ae2 --- /dev/null +++ b/backend/sanic_server/sanic/__version__.py @@ -0,0 +1 @@ +__version__ = "21.9.3" diff --git a/backend/sanic_server/sanic/app.py b/backend/sanic_server/sanic/app.py new file mode 100644 index 000000000..54ef09661 --- /dev/null +++ b/backend/sanic_server/sanic/app.py @@ -0,0 +1,1476 @@ +from __future__ import annotations + +import logging +import logging.config +import os +import re +from asyncio import (AbstractEventLoop, CancelledError, Protocol, + ensure_future, get_event_loop, wait_for) +from asyncio.futures import Future +from collections import defaultdict, deque +from functools import partial +from inspect import isawaitable +from pathlib import Path +from socket import socket +from ssl import Purpose, SSLContext, create_default_context +from traceback import format_exc +from types import SimpleNamespace +from typing import (Any, AnyStr, Awaitable, Callable, Coroutine, Deque, Dict, + Iterable, List, Optional, Set, Tuple, Type, Union) +from urllib.parse import urlencode, urlunparse + +from ..sanic import reloader_helpers +from ..sanic.asgi import ASGIApp +from ..sanic.base import BaseSanic +from ..sanic.blueprint_group import BlueprintGroup +from ..sanic.blueprints import Blueprint +from ..sanic.config import BASE_LOGO, SANIC_PREFIX, Config +from ..sanic.exceptions import (InvalidUsage, SanicException, ServerError, + URLBuildError) +from ..sanic.handlers import ErrorHandler +from ..sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger +from ..sanic.mixins.listeners import ListenerEvent +from ..sanic.models.futures import (FutureException, FutureListener, + FutureMiddleware, FutureRoute, + FutureSignal, FutureStatic) +from ..sanic.models.handler_types import ListenerType, MiddlewareType +from ..sanic.request import Request +from ..sanic.response import BaseHTTPResponse, HTTPResponse +from ..sanic.router import Router +from ..sanic.server import AsyncioServer, HttpProtocol +from ..sanic.server import Signal as ServerSignal +from ..sanic.server import serve, serve_multiple, serve_single +from ..sanic.server.protocols.websocket_protocol import WebSocketProtocol +from ..sanic.server.websockets.impl import ConnectionClosed +from ..sanic.signals import Signal, SignalRouter +from ..sanic.touchup import TouchUp, TouchUpMeta +from ..sanic_routing.exceptions import FinalizationError # type: ignore +from ..sanic_routing.exceptions import NotFound # type: ignore +from ..sanic_routing.route import Route # type: ignore + + +class Sanic(BaseSanic, metaclass=TouchUpMeta): + """ + The main application instance + """ + + __touchup__ = ( + "handle_request", + "handle_exception", + "_run_response_middleware", + "_run_request_middleware", + ) + __fake_slots__ = ( + "_asgi_app", + "_app_registry", + "_asgi_client", + "_blueprint_order", + "_delayed_tasks", + "_future_routes", + "_future_statics", + "_future_middleware", + "_future_listeners", + "_future_exceptions", + "_future_signals", + "_test_client", + "_test_manager", + "auto_reload", + "asgi", + "blueprints", + "config", + "configure_logging", + "ctx", + "debug", + "error_handler", + "go_fast", + "is_running", + "is_stopping", + "listeners", + "name", + "named_request_middleware", + "named_response_middleware", + "reload_dirs", + "request_class", + "request_middleware", + "response_middleware", + "router", + "signal_router", + "sock", + "strict_slashes", + "test_mode", + "websocket_enabled", + "websocket_tasks", + ) + + _app_registry: Dict[str, "Sanic"] = {} + test_mode = False + + def __init__( + self, + name: str = None, + config: Optional[Config] = None, + ctx: Optional[Any] = None, + router: Optional[Router] = None, + signal_router: Optional[SignalRouter] = None, + error_handler: Optional[ErrorHandler] = None, + load_env: Union[bool, str] = True, + env_prefix: Optional[str] = SANIC_PREFIX, + request_class: Optional[Type[Request]] = None, + strict_slashes: bool = False, + log_config: Optional[Dict[str, Any]] = None, + configure_logging: bool = True, + register: Optional[bool] = None, + dumps: Optional[Callable[..., AnyStr]] = None, + ) -> None: + super().__init__(name=name) + + # logging + if configure_logging: + logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) + + if config and (load_env is not True or env_prefix != SANIC_PREFIX): + raise SanicException( + "When instantiating Sanic with config, you cannot also pass " + "load_env or env_prefix" + ) + + self._asgi_client = None + self._blueprint_order: List[Blueprint] = [] + self._delayed_tasks: List[str] = [] + self._test_client = None + self._test_manager = None + self.asgi = False + self.auto_reload = False + self.blueprints: Dict[str, Blueprint] = {} + self.config: Config = config or Config( + load_env=load_env, + env_prefix=env_prefix, + app=self, + ) + self.configure_logging: bool = configure_logging + self.ctx: Any = ctx or SimpleNamespace() + self.debug = None + self.error_handler: ErrorHandler = error_handler or ErrorHandler() + self.is_running = False + self.is_stopping = False + self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) + self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} + self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} + self.reload_dirs: Set[Path] = set() + self.request_class = request_class + self.request_middleware: Deque[MiddlewareType] = deque() + self.response_middleware: Deque[MiddlewareType] = deque() + self.router = router or Router() + self.signal_router = signal_router or SignalRouter() + self.sock = None + self.strict_slashes = strict_slashes + self.websocket_enabled = False + self.websocket_tasks: Set[Future] = set() + + # Register alternative method names + self.go_fast = self.run + + if register is not None: + self.config.REGISTER = register + if self.config.REGISTER: + self.__class__.register_app(self) + + self.router.ctx.app = self + self.signal_router.ctx.app = self + + if dumps: + BaseHTTPResponse._dumps = dumps # type: ignore + + @property + def loop(self): + """ + Synonymous with asyncio.get_event_loop(). + + .. note:: + + Only supported when using the `app.run` method. + """ + if not self.is_running and self.asgi is False: + raise SanicException( + "Loop can only be retrieved after the app has started " + "running. Not supported with `create_server` function" + ) + return get_event_loop() + + # -------------------------------------------------------------------- # + # Registration + # -------------------------------------------------------------------- # + + def add_task(self, task) -> None: + """ + Schedule a task to run later, after the loop has started. + Different from asyncio.ensure_future in that it does not + also return a future, and the actual ensure_future call + is delayed until before server start. + + `See user guide re: background tasks + `__ + + :param task: future, couroutine or awaitable + """ + try: + loop = self.loop # Will raise SanicError if loop is not started + self._loop_add_task(task, self, loop) + except SanicException: + task_name = f"sanic.delayed_task.{hash(task)}" + if not self._delayed_tasks: + self.after_server_start(partial(self.dispatch_delayed_tasks)) + + self.signal(task_name)(partial(self.run_delayed_task, task=task)) + self._delayed_tasks.append(task_name) + + def register_listener(self, listener: Callable, event: str) -> Any: + """ + Register the listener for a given event. + + :param listener: callable i.e. setup_db(app, loop) + :param event: when to register listener i.e. 'before_server_start' + :return: listener + """ + + try: + _event = ListenerEvent[event.upper()] + except (ValueError, AttributeError): + valid = ", ".join( + map(lambda x: x.lower(), ListenerEvent.__members__.keys()) + ) + raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") + + if "." in _event: + self.signal(_event.value)( + partial(self._listener, listener=listener) + ) + else: + self.listeners[_event.value].append(listener) + + return listener + + def register_middleware(self, middleware, attach_to: str = "request"): + """ + Register an application level middleware that will be attached + to all the API URLs registered under this application. + + This method is internally invoked by the :func:`middleware` + decorator provided at the app level. + + :param middleware: Callback method to be attached to the + middleware + :param attach_to: The state at which the middleware needs to be + invoked in the lifecycle of an *HTTP Request*. + **request** - Invoke before the request is processed + **response** - Invoke before the response is returned back + :return: decorated method + """ + if attach_to == "request": + if middleware not in self.request_middleware: + self.request_middleware.append(middleware) + if attach_to == "response": + if middleware not in self.response_middleware: + self.response_middleware.appendleft(middleware) + return middleware + + def register_named_middleware( + self, + middleware, + route_names: Iterable[str], + attach_to: str = "request", + ): + """ + Method for attaching middleware to specific routes. This is mainly an + internal tool for use by Blueprints to attach middleware to only its + specfic routes. But, it could be used in a more generalized fashion. + + :param middleware: the middleware to execute + :param route_names: a list of the names of the endpoints + :type route_names: Iterable[str] + :param attach_to: whether to attach to request or response, + defaults to "request" + :type attach_to: str, optional + """ + if attach_to == "request": + for _rn in route_names: + if _rn not in self.named_request_middleware: + self.named_request_middleware[_rn] = deque() + if middleware not in self.named_request_middleware[_rn]: + self.named_request_middleware[_rn].append(middleware) + if attach_to == "response": + for _rn in route_names: + if _rn not in self.named_response_middleware: + self.named_response_middleware[_rn] = deque() + if middleware not in self.named_response_middleware[_rn]: + self.named_response_middleware[_rn].appendleft(middleware) + return middleware + + def _apply_exception_handler( + self, + handler: FutureException, + route_names: Optional[List[str]] = None, + ): + """Decorate a function to be registered as a handler for exceptions + + :param exceptions: exceptions + :return: decorated function + """ + + for exception in handler.exceptions: + if isinstance(exception, (tuple, list)): + for e in exception: + self.error_handler.add(e, handler.handler, route_names) + else: + self.error_handler.add(exception, handler.handler, route_names) + return handler.handler + + def _apply_listener(self, listener: FutureListener): + return self.register_listener(listener.listener, listener.event) + + def _apply_route(self, route: FutureRoute) -> List[Route]: + params = route._asdict() + websocket = params.pop("websocket", False) + subprotocols = params.pop("subprotocols", None) + + if websocket: + self.enable_websocket() + websocket_handler = partial( + self._websocket_handler, + route.handler, + subprotocols=subprotocols, + ) + websocket_handler.__name__ = route.handler.__name__ # type: ignore + websocket_handler.is_websocket = True # type: ignore + params["handler"] = websocket_handler + + routes = self.router.add(**params) + if isinstance(routes, Route): + routes = [routes] + for r in routes: + r.ctx.websocket = websocket + r.ctx.static = params.get("static", False) + + return routes + + def _apply_static(self, static: FutureStatic) -> Route: + return self._register_static(static) + + def _apply_middleware( + self, + middleware: FutureMiddleware, + route_names: Optional[List[str]] = None, + ): + if route_names: + return self.register_named_middleware( + middleware.middleware, route_names, middleware.attach_to + ) + else: + return self.register_middleware( + middleware.middleware, middleware.attach_to + ) + + def _apply_signal(self, signal: FutureSignal) -> Signal: + return self.signal_router.add(*signal) + + def dispatch( + self, + event: str, + *, + condition: Optional[Dict[str, str]] = None, + context: Optional[Dict[str, Any]] = None, + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, + ) -> Coroutine[Any, Any, Awaitable[Any]]: + return self.signal_router.dispatch( + event, + context=context, + condition=condition, + inline=inline, + reverse=reverse, + fail_not_found=fail_not_found, + ) + + async def event( + self, event: str, timeout: Optional[Union[int, float]] = None + ): + signal = self.signal_router.name_index.get(event) + if not signal: + if self.config.EVENT_AUTOREGISTER: + self.signal_router.reset() + self.add_signal(None, event) + signal = self.signal_router.name_index[event] + self.signal_router.finalize() + else: + raise NotFound("Could not find signal %s" % event) + return await wait_for(signal.ctx.event.wait(), timeout=timeout) + + def enable_websocket(self, enable=True): + """Enable or disable the support for websocket. + + Websocket is enabled automatically if websocket routes are + added to the application. + """ + if not self.websocket_enabled: + # if the server is stopped, we want to cancel any ongoing + # websocket tasks, to allow the server to exit promptly + self.listener("before_server_stop")(self._cancel_websocket_tasks) + + self.websocket_enabled = enable + + def blueprint( + self, + blueprint: Union[ + Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup + ], + **options: Any, + ): + """Register a blueprint on the application. + + :param blueprint: Blueprint object or (list, tuple) thereof + :param options: option dictionary with blueprint defaults + :return: Nothing + """ + if isinstance(blueprint, (list, tuple, BlueprintGroup)): + for item in blueprint: + params = {**options} + if isinstance(blueprint, BlueprintGroup): + if blueprint.url_prefix: + merge_from = [ + options.get("url_prefix", ""), + blueprint.url_prefix, + ] + if not isinstance(item, BlueprintGroup): + merge_from.append(item.url_prefix or "") + merged_prefix = "/".join( + u.strip("/") for u in merge_from + ).rstrip("/") + params["url_prefix"] = f"/{merged_prefix}" + + for _attr in ["version", "strict_slashes"]: + if getattr(item, _attr) is None: + params[_attr] = getattr( + blueprint, _attr + ) or options.get(_attr) + if item.version_prefix == "/v": + if blueprint.version_prefix == "/v": + params["version_prefix"] = options.get( + "version_prefix" + ) + else: + params["version_prefix"] = blueprint.version_prefix + self.blueprint(item, **params) + return + if blueprint.name in self.blueprints: + assert self.blueprints[blueprint.name] is blueprint, ( + 'A blueprint with the name "%s" is already registered. ' + "Blueprint names must be unique." % (blueprint.name,) + ) + else: + self.blueprints[blueprint.name] = blueprint + self._blueprint_order.append(blueprint) + + if ( + self.strict_slashes is not None + and blueprint.strict_slashes is None + ): + blueprint.strict_slashes = self.strict_slashes + blueprint.register(self, options) + + def url_for(self, view_name: str, **kwargs): + """Build a URL based on a view name and the values provided. + + In order to build a URL, all request parameters must be supplied as + keyword arguments, and each parameter must pass the test for the + specified parameter type. If these conditions are not met, a + `URLBuildError` will be thrown. + + Keyword arguments that are not request parameters will be included in + the output URL's query string. + + There are several _special_ keyword arguments that will alter how the + URL will be returned: + + 1. **_anchor**: ``str`` - Adds an ``#anchor`` to the end + 2. **_scheme**: ``str`` - Should be either ``"http"`` or ``"https"``, + default is ``"http"`` + 3. **_external**: ``bool`` - Whether to return the path or a full URL + with scheme and host + 4. **_host**: ``str`` - Used when one or more hosts are defined for a + route to tell Sanic which to use + (only applies with ``_external=True``) + 5. **_server**: ``str`` - If not using ``_host``, this will be used + for defining the hostname of the URL + (only applies with ``_external=True``), + defaults to ``app.config.SERVER_NAME`` + + If you want the PORT to appear in your URL, you should set it in: + + .. code-block:: + + app.config.SERVER_NAME = "myserver:7777" + + `See user guide re: routing + `__ + + :param view_name: string referencing the view name + :param kwargs: keys and values that are used to build request + parameters and query string arguments. + + :return: the built URL + + Raises: + URLBuildError + """ + # find the route by the supplied view name + kw: Dict[str, str] = {} + # special static files url_for + + if "." not in view_name: + view_name = f"{self.name}.{view_name}" + + if view_name.endswith(".static"): + name = kwargs.pop("name", None) + if name: + view_name = view_name.replace("static", name) + kw.update(name=view_name) + + route = self.router.find_route_by_view_name(view_name, **kw) + if not route: + raise URLBuildError( + f"Endpoint with name `{view_name}` was not found" + ) + + uri = route.path + + if getattr(route.ctx, "static", None): + filename = kwargs.pop("filename", "") + # it's static folder + if "__file_uri__" in uri: + folder_ = uri.split("<__file_uri__:", 1)[0] + if folder_.endswith("/"): + folder_ = folder_[:-1] + + if filename.startswith("/"): + filename = filename[1:] + + kwargs["__file_uri__"] = filename + + if ( + uri != "/" + and uri.endswith("/") + and not route.strict + and not route.raw_path[:-1] + ): + uri = uri[:-1] + + if not uri.startswith("/"): + uri = f"/{uri}" + + out = uri + + # _method is only a placeholder now, don't know how to support it + kwargs.pop("_method", None) + anchor = kwargs.pop("_anchor", "") + # _external need SERVER_NAME in config or pass _server arg + host = kwargs.pop("_host", None) + external = kwargs.pop("_external", False) or bool(host) + scheme = kwargs.pop("_scheme", "") + if route.ctx.hosts and external: + if not host and len(route.ctx.hosts) > 1: + raise ValueError( + f"Host is ambiguous: {', '.join(route.ctx.hosts)}" + ) + elif host and host not in route.ctx.hosts: + raise ValueError( + f"Requested host ({host}) is not available for this " + f"route: {route.ctx.hosts}" + ) + elif not host: + host = list(route.ctx.hosts)[0] + + if scheme and not external: + raise ValueError("When specifying _scheme, _external must be True") + + netloc = kwargs.pop("_server", None) + if netloc is None and external: + netloc = host or self.config.get("SERVER_NAME", "") + + if external: + if not scheme: + if ":" in netloc[:8]: + scheme = netloc[:8].split(":", 1)[0] + else: + scheme = "http" + + if "://" in netloc[:8]: + netloc = netloc.split("://", 1)[-1] + + # find all the parameters we will need to build in the URL + # matched_params = re.findall(self.router.parameter_pattern, uri) + route.finalize() + for param_info in route.params.values(): + # name, _type, pattern = self.router.parse_parameter_string(match) + # we only want to match against each individual parameter + + try: + supplied_param = str(kwargs.pop(param_info.name)) + except KeyError: + raise URLBuildError( + f"Required parameter `{param_info.name}` was not " + "passed to url_for" + ) + + # determine if the parameter supplied by the caller + # passes the test in the URL + if param_info.pattern: + pattern = ( + param_info.pattern[1] + if isinstance(param_info.pattern, tuple) + else param_info.pattern + ) + passes_pattern = pattern.match(supplied_param) + if not passes_pattern: + if param_info.cast != str: + msg = ( + f'Value "{supplied_param}" ' + f"for parameter `{param_info.name}` does " + "not match pattern for type " + f"`{param_info.cast.__name__}`: " + f"{pattern.pattern}" + ) + else: + msg = ( + f'Value "{supplied_param}" for parameter ' + f"`{param_info.name}` does not satisfy " + f"pattern {pattern.pattern}" + ) + raise URLBuildError(msg) + + # replace the parameter in the URL with the supplied value + replacement_regex = f"(<{param_info.name}.*?>)" + out = re.sub(replacement_regex, supplied_param, out) + + # parse the remainder of the keyword arguments into a querystring + query_string = urlencode(kwargs, doseq=True) if kwargs else "" + # scheme://netloc/path;parameters?query#fragment + out = urlunparse((scheme, netloc, out, "", query_string, anchor)) + + return out + + # -------------------------------------------------------------------- # + # Request Handling + # -------------------------------------------------------------------- # + + async def handle_exception( + self, request: Request, exception: BaseException + ): # no cov + """ + A handler that catches specific exceptions and outputs a response. + + :param request: The current request object + :type request: :class:`SanicASGITestClient` + :param exception: The exception that was raised + :type exception: BaseException + :raises ServerError: response 500 + """ + await self.dispatch( + "http.lifecycle.exception", + inline=True, + context={"request": request, "exception": exception}, + ) + + # -------------------------------------------- # + # Request Middleware + # -------------------------------------------- # + response = await self._run_request_middleware( + request, request_name=None + ) + # No middleware results + if not response: + try: + response = self.error_handler.response(request, exception) + if isawaitable(response): + response = await response + except Exception as e: + if isinstance(e, SanicException): + response = self.error_handler.default(request, e) + elif self.debug: + response = HTTPResponse( + ( + f"Error while handling error: {e}\n" + f"Stack: {format_exc()}" + ), + status=500, + ) + else: + response = HTTPResponse( + "An error occurred while handling an error", status=500 + ) + if response is not None: + try: + response = await request.respond(response) + except BaseException: + # Skip response middleware + if request.stream: + request.stream.respond(response) + await response.send(end_stream=True) + raise + else: + if request.stream: + response = request.stream.response + if isinstance(response, BaseHTTPResponse): + await response.send(end_stream=True) + else: + raise ServerError( + f"Invalid response type {response!r} (need HTTPResponse)" + ) + + async def handle_request(self, request: Request): # no cov + """Take a request from the HTTP Server and return a response object + to be sent back The HTTP Server only expects a response object, so + exception handling must be done here + + :param request: HTTP Request object + :return: Nothing + """ + await self.dispatch( + "http.lifecycle.handle", + inline=True, + context={"request": request}, + ) + + # Define `response` var here to remove warnings about + # allocation before assignment below. + response = None + try: + + await self.dispatch( + "http.routing.before", + inline=True, + context={"request": request}, + ) + # Fetch handler from router + route, handler, kwargs = self.router.get( + request.path, + request.method, + request.headers.getone("host", None), + ) + + request._match_info = {**kwargs} + request.route = route + + await self.dispatch( + "http.routing.after", + inline=True, + context={ + "request": request, + "route": route, + "kwargs": kwargs, + "handler": handler, + }, + ) + + if ( + request.stream + and request.stream.request_body + and not route.ctx.ignore_body + ): + + if hasattr(handler, "is_stream"): + # Streaming handler: lift the size limit + request.stream.request_max_size = float("inf") + else: + # Non-streaming handler: preload body + await request.receive_body() + + # -------------------------------------------- # + # Request Middleware + # -------------------------------------------- # + response = await self._run_request_middleware( + request, request_name=route.name + ) + + # No middleware results + if not response: + # -------------------------------------------- # + # Execute Handler + # -------------------------------------------- # + + if handler is None: + raise ServerError( + ( + "'None' was returned while requesting a " + "handler from the router" + ) + ) + + # Run response handler + response = handler(request, **request.match_info) + if isawaitable(response): + response = await response + + if response is not None: + response = await request.respond(response) + elif not hasattr(handler, "is_websocket"): + response = request.stream.response # type: ignore + + # Make sure that response is finished / run StreamingHTTP callback + if isinstance(response, BaseHTTPResponse): + await self.dispatch( + "http.lifecycle.response", + inline=True, + context={ + "request": request, + "response": response, + }, + ) + await response.send(end_stream=True) + else: + if not hasattr(handler, "is_websocket"): + raise ServerError( + f"Invalid response type {response!r} " + "(need HTTPResponse)" + ) + + except CancelledError: + raise + except Exception as e: + # Response Generation Failed + await self.handle_exception(request, e) + + async def _websocket_handler( + self, handler, request, *args, subprotocols=None, **kwargs + ): + if self.asgi: + ws = request.transport.get_websocket_connection() + await ws.accept(subprotocols) + else: + protocol = request.transport.get_protocol() + ws = await protocol.websocket_handshake(request, subprotocols) + + # schedule the application handler + # its future is kept in self.websocket_tasks in case it + # needs to be cancelled due to the server being stopped + fut = ensure_future(handler(request, ws, *args, **kwargs)) + self.websocket_tasks.add(fut) + cancelled = False + try: + await fut + except Exception as e: + self.error_handler.log(request, e) + except (CancelledError, ConnectionClosed): + cancelled = True + finally: + self.websocket_tasks.remove(fut) + if cancelled: + ws.end_connection(1000) + else: + await ws.close() + + # -------------------------------------------------------------------- # + # Testing + # -------------------------------------------------------------------- # + + @property + def test_client(self): # noqa + if self._test_client: + return self._test_client + elif self._test_manager: + return self._test_manager.test_client + from ..sanic_testing.testing import SanicTestClient # type: ignore + + self._test_client = SanicTestClient(self) + return self._test_client + + @property + def asgi_client(self): # noqa + """ + A testing client that uses ASGI to reach into the application to + execute hanlers. + + :return: testing client + :rtype: :class:`SanicASGITestClient` + """ + if self._asgi_client: + return self._asgi_client + elif self._test_manager: + return self._test_manager.asgi_client + from ..sanic_testing.testing import SanicASGITestClient # type: ignore + + self._asgi_client = SanicASGITestClient(self) + return self._asgi_client + + # -------------------------------------------------------------------- # + # Execution + # -------------------------------------------------------------------- # + + def run( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + debug: bool = False, + auto_reload: Optional[bool] = None, + ssl: Union[Dict[str, str], SSLContext, None] = None, + sock: Optional[socket] = None, + workers: int = 1, + protocol: Optional[Type[Protocol]] = None, + backlog: int = 100, + register_sys_signals: bool = True, + access_log: Optional[bool] = None, + unix: Optional[str] = None, + loop: None = None, + reload_dir: Optional[Union[List[str], str]] = None, + ) -> None: + """ + Run the HTTP Server and listen until keyboard interrupt or term + signal. On termination, drain connections before closing. + + :param host: Address to host on + :type host: str + :param port: Port to host on + :type port: int + :param debug: Enables debug output (slows server) + :type debug: bool + :param auto_reload: Reload app whenever its source code is changed. + Enabled by default in debug mode. + :type auto_relaod: bool + :param ssl: SSLContext, or location of certificate and key + for SSL encryption of worker(s) + :type ssl: SSLContext or dict + :param sock: Socket for the server to accept connections from + :type sock: socket + :param workers: Number of processes received before it is respected + :type workers: int + :param protocol: Subclass of asyncio Protocol class + :type protocol: type[Protocol] + :param backlog: a number of unaccepted connections that the system + will allow before refusing new connections + :type backlog: int + :param register_sys_signals: Register SIG* events + :type register_sys_signals: bool + :param access_log: Enables writing access logs (slows server) + :type access_log: bool + :param unix: Unix socket to listen on instead of TCP port + :type unix: str + :return: Nothing + """ + if reload_dir: + if isinstance(reload_dir, str): + reload_dir = [reload_dir] + + for directory in reload_dir: + direc = Path(directory) + if not direc.is_dir(): + logger.warning( + f"Directory {directory} could not be located" + ) + self.reload_dirs.add(Path(directory)) + + if loop is not None: + raise TypeError( + "loop is not a valid argument. To use an existing loop, " + "change to create_server().\nSee more: " + "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" + "#asynchronous-support" + ) + + if auto_reload or auto_reload is None and debug: + self.auto_reload = True + if os.environ.get("SANIC_SERVER_RUNNING") != "true": + return reloader_helpers.watchdog(1.0, self) + + if sock is None: + host, port = host or "127.0.0.1", port or 8000 + + if protocol is None: + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) + # if access_log is passed explicitly change config.ACCESS_LOG + if access_log is not None: + self.config.ACCESS_LOG = access_log + + server_settings = self._helper( + host=host, + port=port, + debug=debug, + ssl=ssl, + sock=sock, + unix=unix, + workers=workers, + protocol=protocol, + backlog=backlog, + register_sys_signals=register_sys_signals, + auto_reload=auto_reload, + ) + + try: + self.is_running = True + self.is_stopping = False + if workers > 1 and os.name != "posix": + logger.warn( + f"Multiprocessing is currently not supported on {os.name}," + " using workers=1 instead" + ) + workers = 1 + if workers == 1: + serve_single(server_settings) + else: + serve_multiple(server_settings, workers) + except BaseException: + error_logger.exception( + "Experienced exception while trying to serve" + ) + raise + finally: + self.is_running = False + logger.info("Server Stopped") + + def stop(self): + """ + This kills the Sanic + """ + if not self.is_stopping: + self.is_stopping = True + get_event_loop().stop() + + async def create_server( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + debug: bool = False, + ssl: Union[Dict[str, str], SSLContext, None] = None, + sock: Optional[socket] = None, + protocol: Type[Protocol] = None, + backlog: int = 100, + access_log: Optional[bool] = None, + unix: Optional[str] = None, + return_asyncio_server: bool = False, + asyncio_server_kwargs: Dict[str, Any] = None, + ) -> Optional[AsyncioServer]: + """ + Asynchronous version of :func:`run`. + + This method will take care of the operations necessary to invoke + the *before_start* events via :func:`trigger_events` method invocation + before starting the *sanic* app in Async mode. + + .. note:: + This does not support multiprocessing and is not the preferred + way to run a :class:`Sanic` application. + + :param host: Address to host on + :type host: str + :param port: Port to host on + :type port: int + :param debug: Enables debug output (slows server) + :type debug: bool + :param ssl: SSLContext, or location of certificate and key + for SSL encryption of worker(s) + :type ssl: SSLContext or dict + :param sock: Socket for the server to accept connections from + :type sock: socket + :param protocol: Subclass of asyncio Protocol class + :type protocol: type[Protocol] + :param backlog: a number of unaccepted connections that the system + will allow before refusing new connections + :type backlog: int + :param access_log: Enables writing access logs (slows server) + :type access_log: bool + :param return_asyncio_server: flag that defines whether there's a need + to return asyncio.Server or + start it serving right away + :type return_asyncio_server: bool + :param asyncio_server_kwargs: key-value arguments for + asyncio/uvloop create_server method + :type asyncio_server_kwargs: dict + :return: AsyncioServer if return_asyncio_server is true, else Nothing + """ + + if sock is None: + host, port = host or "127.0.0.1", port or 8000 + + if protocol is None: + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) + # if access_log is passed explicitly change config.ACCESS_LOG + if access_log is not None: + self.config.ACCESS_LOG = access_log + + server_settings = self._helper( + host=host, + port=port, + debug=debug, + ssl=ssl, + sock=sock, + unix=unix, + loop=get_event_loop(), + protocol=protocol, + backlog=backlog, + run_async=return_asyncio_server, + ) + + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + if main_start or main_stop: + logger.warning( + "Listener events for the main process are not available " + "with create_server()" + ) + + return await serve( + asyncio_server_kwargs=asyncio_server_kwargs, **server_settings + ) + + async def _run_request_middleware( + self, request, request_name=None + ): # no cov + # The if improves speed. I don't know why + named_middleware = self.named_request_middleware.get( + request_name, deque() + ) + applicable_middleware = self.request_middleware + named_middleware + + # request.request_middleware_started is meant as a stop-gap solution + # until RFC 1630 is adopted + if applicable_middleware and not request.request_middleware_started: + request.request_middleware_started = True + + for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + + response = middleware(request) + if isawaitable(response): + response = await response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + + if response: + return response + return None + + async def _run_response_middleware( + self, request, response, request_name=None + ): # no cov + named_middleware = self.named_response_middleware.get( + request_name, deque() + ) + applicable_middleware = self.response_middleware + named_middleware + if applicable_middleware: + for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) + + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) + + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break + return response + + def _helper( + self, + host=None, + port=None, + debug=False, + ssl=None, + sock=None, + unix=None, + workers=1, + loop=None, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_async=False, + auto_reload=False, + ): + """Helper function used by `run` and `create_server`.""" + + if isinstance(ssl, dict): + # try common aliaseses + cert = ssl.get("cert") or ssl.get("certificate") + key = ssl.get("key") or ssl.get("keyfile") + if cert is None or key is None: + raise ValueError("SSLContext or certificate and key required.") + context = create_default_context(purpose=Purpose.CLIENT_AUTH) + context.load_cert_chain(cert, keyfile=key) + ssl = context + if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: + raise ValueError( + "PROXIES_COUNT cannot be negative. " + "https://sanic.readthedocs.io/en/latest/sanic/config.html" + "#proxy-configuration" + ) + + self.error_handler.debug = debug + self.debug = debug + + server_settings = { + "protocol": protocol, + "host": host, + "port": port, + "sock": sock, + "unix": unix, + "ssl": ssl, + "app": self, + "signal": ServerSignal(), + "loop": loop, + "register_sys_signals": register_sys_signals, + "backlog": backlog, + } + + # Register start/stop events + + for event_name, settings_name, reverse in ( + ("main_process_start", "main_start", False), + ("main_process_stop", "main_stop", True), + ): + listeners = self.listeners[event_name].copy() + if reverse: + listeners.reverse() + # Prepend sanic to the arguments when listeners are triggered + listeners = [partial(listener, self) for listener in listeners] + server_settings[settings_name] = listeners + + if self.configure_logging and debug: + logger.setLevel(logging.DEBUG) + + if ( + self.config.LOGO + and os.environ.get("SANIC_SERVER_RUNNING") != "true" + ): + logger.debug( + self.config.LOGO + if isinstance(self.config.LOGO, str) + else BASE_LOGO + ) + + if run_async: + server_settings["run_async"] = True + + # Serve + if host and port: + proto = "http" + if ssl is not None: + proto = "https" + if unix: + logger.info(f"Goin' Fast @ {unix} {proto}://...") + else: + logger.info(f"Goin' Fast @ {proto}://{host}:{port}") + + debug_mode = "enabled" if self.debug else "disabled" + reload_mode = "enabled" if auto_reload else "disabled" + logger.debug(f"Sanic auto-reload: {reload_mode}") + logger.debug(f"Sanic debug mode: {debug_mode}") + + return server_settings + + def _build_endpoint_name(self, *parts): + parts = [self.name, *parts] + return ".".join(parts) + + @classmethod + def _prep_task(cls, task, app, loop): + if callable(task): + try: + task = task(app) + except TypeError: + task = task() + + return task + + @classmethod + def _loop_add_task(cls, task, app, loop): + prepped = cls._prep_task(task, app, loop) + loop.create_task(prepped) + + @classmethod + def _cancel_websocket_tasks(cls, app, loop): + for task in app.websocket_tasks: + task.cancel() + + @staticmethod + async def dispatch_delayed_tasks(app, loop): + for name in app._delayed_tasks: + await app.dispatch(name, context={"app": app, "loop": loop}) + app._delayed_tasks.clear() + + @staticmethod + async def run_delayed_task(app, loop, task): + prepped = app._prep_task(task, app, loop) + await prepped + + @staticmethod + async def _listener( + app: Sanic, loop: AbstractEventLoop, listener: ListenerType + ): + maybe_coro = listener(app, loop) + if maybe_coro and isawaitable(maybe_coro): + await maybe_coro + + # -------------------------------------------------------------------- # + # ASGI + # -------------------------------------------------------------------- # + + async def __call__(self, scope, receive, send): + """ + To be ASGI compliant, our instance must be a callable that accepts + three arguments: scope, receive, send. See the ASGI reference for more + details: https://asgi.readthedocs.io/en/latest + """ + self.asgi = True + self._asgi_app = await ASGIApp.create(self, scope, receive, send) + asgi_app = self._asgi_app + await asgi_app() + + _asgi_single_callable = True # We conform to ASGI 3.0 single-callable + + # -------------------------------------------------------------------- # + # Configuration + # -------------------------------------------------------------------- # + + def update_config(self, config: Union[bytes, str, dict, Any]): + """ + Update app.config. Full implementation can be found in the user guide. + + `See user guide re: configuration + `__ + """ + + self.config.update_config(config) + + # -------------------------------------------------------------------- # + # Class methods + # -------------------------------------------------------------------- # + + @classmethod + def register_app(cls, app: "Sanic") -> None: + """ + Register a Sanic instance + """ + if not isinstance(app, cls): + raise SanicException("Registered app must be an instance of Sanic") + + name = app.name + if name in cls._app_registry and not cls.test_mode: + raise SanicException(f'Sanic app name "{name}" already in use.') + + cls._app_registry[name] = app + + @classmethod + def get_app( + cls, name: Optional[str] = None, *, force_create: bool = False + ) -> "Sanic": + """ + Retrieve an instantiated Sanic instance + """ + if name is None: + if len(cls._app_registry) > 1: + raise SanicException( + 'Multiple Sanic apps found, use Sanic.get_app("app_name")' + ) + elif len(cls._app_registry) == 0: + raise SanicException("No Sanic apps have been registered.") + else: + return list(cls._app_registry.values())[0] + try: + return cls._app_registry[name] + except KeyError: + if force_create: + return cls(name) + raise SanicException(f'Sanic app name "{name}" not found.') + + # -------------------------------------------------------------------- # + # Lifecycle + # -------------------------------------------------------------------- # + + def finalize(self): + try: + self.router.finalize() + except FinalizationError as e: + if not Sanic.test_mode: + raise e + + def signalize(self): + try: + self.signal_router.finalize() + except FinalizationError as e: + if not Sanic.test_mode: + raise e + + async def _startup(self): + self.signalize() + self.finalize() + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) + TouchUp.run(self) + + async def _server_event( + self, + concern: str, + action: str, + loop: Optional[AbstractEventLoop] = None, + ) -> None: + event = f"server.{concern}.{action}" + if action not in ("before", "after") or concern not in ( + "init", + "shutdown", + ): + raise SanicException(f"Invalid server event: {event}") + logger.debug(f"Triggering server events: {event}") + reverse = concern == "shutdown" + if loop is None: + loop = self.loop + await self.dispatch( + event, + fail_not_found=False, + reverse=reverse, + inline=True, + context={ + "app": self, + "loop": loop, + }, + ) diff --git a/backend/sanic_server/sanic/asgi.py b/backend/sanic_server/sanic/asgi.py new file mode 100644 index 000000000..aa8f003db --- /dev/null +++ b/backend/sanic_server/sanic/asgi.py @@ -0,0 +1,199 @@ +import warnings +from typing import Optional +from urllib.parse import quote + +import sanic.app # noqa + +from ..sanic.compat import Header +from ..sanic.exceptions import ServerError +from ..sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport +from ..sanic.request import Request +from ..sanic.server import ConnInfo +from ..sanic.server.websockets.connection import WebSocketConnection + + +class Lifespan: + def __init__(self, asgi_app: "ASGIApp") -> None: + self.asgi_app = asgi_app + + if ( + "server.init.before" + in self.asgi_app.sanic_app.signal_router.name_index + ): + warnings.warn( + 'You have set a listener for "before_server_start" ' + "in ASGI mode. " + "It will be executed as early as possible, but not before " + "the ASGI server is started." + ) + if ( + "server.shutdown.after" + in self.asgi_app.sanic_app.signal_router.name_index + ): + warnings.warn( + 'You have set a listener for "after_server_stop" ' + "in ASGI mode. " + "It will be executed as late as possible, but not after " + "the ASGI server is stopped." + ) + + async def startup(self) -> None: + """ + Gather the listeners to fire on server start. + Because we are using a third-party server and not Sanic server, we do + not have access to fire anything BEFORE the server starts. + Therefore, we fire before_server_start and after_server_start + in sequence since the ASGI lifespan protocol only supports a single + startup event. + """ + await self.asgi_app.sanic_app._startup() + await self.asgi_app.sanic_app._server_event("init", "before") + await self.asgi_app.sanic_app._server_event("init", "after") + + async def shutdown(self) -> None: + """ + Gather the listeners to fire on server stop. + Because we are using a third-party server and not Sanic server, we do + not have access to fire anything AFTER the server stops. + Therefore, we fire before_server_stop and after_server_stop + in sequence since the ASGI lifespan protocol only supports a single + shutdown event. + """ + await self.asgi_app.sanic_app._server_event("shutdown", "before") + await self.asgi_app.sanic_app._server_event("shutdown", "after") + + async def __call__( + self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: + message = await receive() + if message["type"] == "lifespan.startup": + await self.startup() + await send({"type": "lifespan.startup.complete"}) + + message = await receive() + if message["type"] == "lifespan.shutdown": + await self.shutdown() + await send({"type": "lifespan.shutdown.complete"}) + + +class ASGIApp: + sanic_app: "sanic.app.Sanic" + request: Request + transport: MockTransport + lifespan: Lifespan + ws: Optional[WebSocketConnection] + + def __init__(self) -> None: + self.ws = None + + @classmethod + async def create( + cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> "ASGIApp": + instance = cls() + instance.sanic_app = sanic_app + instance.transport = MockTransport(scope, receive, send) + instance.transport.loop = sanic_app.loop + setattr(instance.transport, "add_task", sanic_app.loop.create_task) + + headers = Header( + [ + (key.decode("latin-1"), value.decode("latin-1")) + for key, value in scope.get("headers", []) + ] + ) + instance.lifespan = Lifespan(instance) + + if scope["type"] == "lifespan": + await instance.lifespan(scope, receive, send) + else: + path = ( + scope["path"][1:] + if scope["path"].startswith("/") + else scope["path"] + ) + url = "/".join([scope.get("root_path", ""), quote(path)]) + url_bytes = url.encode("latin-1") + url_bytes += b"?" + scope["query_string"] + + if scope["type"] == "http": + version = scope["http_version"] + method = scope["method"] + elif scope["type"] == "websocket": + version = "1.1" + method = "GET" + + instance.ws = instance.transport.create_websocket_connection( + send, receive + ) + else: + raise ServerError("Received unknown ASGI scope") + + request_class = sanic_app.request_class or Request + instance.request = request_class( + url_bytes, + headers, + version, + method, + instance.transport, + sanic_app, + ) + instance.request.stream = instance + instance.request_body = True + instance.request.conn_info = ConnInfo(instance.transport) + + return instance + + async def read(self) -> Optional[bytes]: + """ + Read and stream the body in chunks from an incoming ASGI message. + """ + message = await self.transport.receive() + body = message.get("body", b"") + if not message.get("more_body", False): + self.request_body = False + if not body: + return None + return body + + async def __aiter__(self): + while self.request_body: + data = await self.read() + if data: + yield data + + def respond(self, response): + response.stream, self.response = self, response + return response + + async def send(self, data, end_stream): + if self.response: + response, self.response = self.response, None + await self.transport.send( + { + "type": "http.response.start", + "status": response.status, + "headers": response.processed_headers, + } + ) + response_body = getattr(response, "body", None) + if response_body: + data = response_body + data if data else response_body + await self.transport.send( + { + "type": "http.response.body", + "body": data.encode() if hasattr(data, "encode") else data, + "more_body": not end_stream, + } + ) + + _asgi_single_callable = True # We conform to ASGI 3.0 single-callable + + async def __call__(self) -> None: + """ + Handle the incoming request. + """ + try: + await self.sanic_app.handle_request(self.request) + except Exception as e: + await self.sanic_app.handle_exception(self.request, e) diff --git a/backend/sanic_server/sanic/base.py b/backend/sanic_server/sanic/base.py new file mode 100644 index 000000000..54dc2e32b --- /dev/null +++ b/backend/sanic_server/sanic/base.py @@ -0,0 +1,64 @@ +import re +from typing import Any, Tuple +from warnings import warn + +from ..sanic.exceptions import SanicException +from ..sanic.mixins.exceptions import ExceptionMixin +from ..sanic.mixins.listeners import ListenerMixin +from ..sanic.mixins.middleware import MiddlewareMixin +from ..sanic.mixins.routes import RouteMixin +from ..sanic.mixins.signals import SignalMixin + +VALID_NAME = re.compile(r"^[a-zA-Z][a-zA-Z0-9_\-]*$") + + +class BaseSanic( + RouteMixin, + MiddlewareMixin, + ListenerMixin, + ExceptionMixin, + SignalMixin, +): + __fake_slots__: Tuple[str, ...] + + def __init__(self, name: str = None, *args, **kwargs) -> None: + class_name = self.__class__.__name__ + + if name is None: + raise SanicException( + f"{class_name} instance cannot be unnamed. " + "Please use Sanic(name='your_application_name') instead.", + ) + + if not VALID_NAME.match(name): + warn( + f"{class_name} instance named '{name}' uses a format that is" + f"deprecated. Starting in version 21.12, {class_name} objects " + "must be named only using alphanumeric characters, _, or -.", + DeprecationWarning, + ) + + self.name = name + + for base in BaseSanic.__bases__: + base.__init__(self, *args, **kwargs) # type: ignore + + def __str__(self) -> str: + return f"<{self.__class__.__name__} {self.name}>" + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(name="{self.name}")' + + def __setattr__(self, name: str, value: Any) -> None: + # This is a temporary compat layer so we can raise a warning until + # setting attributes on the app instance can be removed and deprecated + # with a proper implementation of __slots__ + if name not in self.__fake_slots__: + warn( + f"Setting variables on {self.__class__.__name__} instances is " + "deprecated and will be removed in version 21.12. You should " + f"change your {self.__class__.__name__} instance to use " + f"instance.ctx.{name} instead.", + DeprecationWarning, + ) + super().__setattr__(name, value) diff --git a/backend/sanic_server/sanic/blueprint_group.py b/backend/sanic_server/sanic/blueprint_group.py new file mode 100644 index 000000000..8328980ab --- /dev/null +++ b/backend/sanic_server/sanic/blueprint_group.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from collections.abc import MutableSequence +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Union + +if TYPE_CHECKING: + from ..sanic.blueprints import Blueprint + + +class BlueprintGroup(MutableSequence): + """ + This class provides a mechanism to implement a Blueprint Group + using the :meth:`~sanic.blueprints.Blueprint.group` method in + :class:`~sanic.blueprints.Blueprint`. To avoid having to re-write + some of the existing implementation, this class provides a custom + iterator implementation that will let you use the object of this + class as a list/tuple inside the existing implementation. + + .. code-block:: python + + bp1 = Blueprint('bp1', url_prefix='/bp1') + bp2 = Blueprint('bp2', url_prefix='/bp2') + + bp3 = Blueprint('bp3', url_prefix='/bp4') + bp3 = Blueprint('bp3', url_prefix='/bp4') + + bpg = BlueprintGroup(bp3, bp4, url_prefix="/api", version="v1") + + @bp1.middleware('request') + async def bp1_only_middleware(request): + print('applied on Blueprint : bp1 Only') + + @bp1.route('/') + async def bp1_route(request): + return text('bp1') + + @bp2.route('/') + async def bp2_route(request, param): + return text(param) + + @bp3.route('/') + async def bp1_route(request): + return text('bp1') + + @bp4.route('/') + async def bp2_route(request, param): + return text(param) + + group = Blueprint.group(bp1, bp2) + + @group.middleware('request') + async def group_middleware(request): + print('common middleware applied for both bp1 and bp2') + + # Register Blueprint group under the app + app.blueprint(group) + app.blueprint(bpg) + """ + + __slots__ = ( + "_blueprints", + "_url_prefix", + "_version", + "_strict_slashes", + "_version_prefix", + ) + + def __init__( + self, + url_prefix: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, + version_prefix: str = "/v", + ): + """ + Create a new Blueprint Group + + :param url_prefix: URL: to be prefixed before all the Blueprint Prefix + :param version: API Version for the blueprint group. This will be + inherited by each of the Blueprint + :param strict_slashes: URL Strict slash behavior indicator + """ + self._blueprints: List[Blueprint] = [] + self._url_prefix = url_prefix + self._version = version + self._version_prefix = version_prefix + self._strict_slashes = strict_slashes + + @property + def url_prefix(self) -> Optional[Union[int, str, float]]: + """ + Retrieve the URL prefix being used for the Current Blueprint Group + + :return: string with url prefix + """ + return self._url_prefix + + @property + def blueprints(self) -> List[Blueprint]: + """ + Retrieve a list of all the available blueprints under this group. + + :return: List of Blueprint instance + """ + return self._blueprints + + @property + def version(self) -> Optional[Union[str, int, float]]: + """ + API Version for the Blueprint Group. This will be applied only in case + if the Blueprint doesn't already have a version specified + + :return: Version information + """ + return self._version + + @property + def strict_slashes(self) -> Optional[bool]: + """ + URL Slash termination behavior configuration + + :return: bool + """ + return self._strict_slashes + + @property + def version_prefix(self) -> str: + """ + Version prefix; defaults to ``/v`` + + :return: str + """ + return self._version_prefix + + def __iter__(self): + """ + Tun the class Blueprint Group into an Iterable item + """ + return iter(self._blueprints) + + def __getitem__(self, item): + """ + This method returns a blueprint inside the group specified by + an index value. This will enable indexing, splice and slicing + of the blueprint group like we can do with regular list/tuple. + + This method is provided to ensure backward compatibility with + any of the pre-existing usage that might break. + + :param item: Index of the Blueprint item in the group + :return: Blueprint object + """ + return self._blueprints[item] + + def __setitem__(self, index, item) -> None: + """ + Abstract method implemented to turn the `BlueprintGroup` class + into a list like object to support all the existing behavior. + + This method is used to perform the list's indexed setter operation. + + :param index: Index to use for inserting a new Blueprint item + :param item: New `Blueprint` object. + :return: None + """ + self._blueprints[index] = item + + def __delitem__(self, index) -> None: + """ + Abstract method implemented to turn the `BlueprintGroup` class + into a list like object to support all the existing behavior. + + This method is used to delete an item from the list of blueprint + groups like it can be done on a regular list with index. + + :param index: Index to use for removing a new Blueprint item + :return: None + """ + del self._blueprints[index] + + def __len__(self) -> int: + """ + Get the Length of the blueprint group object. + + :return: Length of Blueprint group object + """ + return len(self._blueprints) + + def append(self, value: Blueprint) -> None: + """ + The Abstract class `MutableSequence` leverages this append method to + perform the `BlueprintGroup.append` operation. + :param value: New `Blueprint` object. + :return: None + """ + self._blueprints.append(value) + + def exception(self, *exceptions, **kwargs): + """ + A decorator that can be used to implement a global exception handler + for all the Blueprints that belong to this Blueprint Group. + + In case of nested Blueprint Groups, the same handler is applied + across each of the Blueprints recursively. + + :param args: List of Python exceptions to be caught by the handler + :param kwargs: Additional optional arguments to be passed to the + exception handler + :return a decorated method to handle global exceptions for any + blueprint registered under this group. + """ + + def register_exception_handler_for_blueprints(fn): + for blueprint in self.blueprints: + blueprint.exception(*exceptions, **kwargs)(fn) + + return register_exception_handler_for_blueprints + + def insert(self, index: int, item: Blueprint) -> None: + """ + The Abstract class `MutableSequence` leverages this insert method to + perform the `BlueprintGroup.append` operation. + + :param index: Index to use for removing a new Blueprint item + :param item: New `Blueprint` object. + :return: None + """ + self._blueprints.insert(index, item) + + def middleware(self, *args, **kwargs): + """ + A decorator that can be used to implement a Middleware plugin to + all of the Blueprints that belongs to this specific Blueprint Group. + + In case of nested Blueprint Groups, the same middleware is applied + across each of the Blueprints recursively. + + :param args: Optional positional Parameters to be use middleware + :param kwargs: Optional Keyword arg to use with Middleware + :return: Partial function to apply the middleware + """ + + def register_middleware_for_blueprints(fn): + for blueprint in self.blueprints: + blueprint.middleware(fn, *args, **kwargs) + + if args and callable(args[0]): + fn = args[0] + args = list(args)[1:] + return register_middleware_for_blueprints(fn) + return register_middleware_for_blueprints + + def on_request(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "request") + else: + return partial(self.middleware, attach_to="request") + + def on_response(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "response") + else: + return partial(self.middleware, attach_to="response") diff --git a/backend/sanic_server/sanic/blueprints.py b/backend/sanic_server/sanic/blueprints.py new file mode 100644 index 000000000..d12e37cea --- /dev/null +++ b/backend/sanic_server/sanic/blueprints.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import asyncio +from collections import defaultdict +from copy import deepcopy +from types import SimpleNamespace +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union + +from ..sanic.base import BaseSanic +from ..sanic.blueprint_group import BlueprintGroup +from ..sanic.exceptions import SanicException +from ..sanic.helpers import Default, _default +from ..sanic.models.futures import FutureRoute, FutureStatic +from ..sanic.models.handler_types import (ListenerType, MiddlewareType, + RouteHandler) +from ..sanic_routing.exceptions import NotFound # type: ignore +from ..sanic_routing.route import Route # type: ignore + +if TYPE_CHECKING: + from ..sanic import Sanic # noqa + + +class Blueprint(BaseSanic): + """ + In *Sanic* terminology, a **Blueprint** is a logical collection of + URLs that perform a specific set of tasks which can be identified by + a unique name. + + It is the main tool for grouping functionality and similar endpoints. + + `See user guide re: blueprints + `__ + + :param name: unique name of the blueprint + :param url_prefix: URL to be prefixed before all route URLs + :param host: IP Address of FQDN for the sanic server to use. + :param version: Blueprint Version + :param strict_slashes: Enforce the API urls are requested with a + trailing */* + """ + + __fake_slots__ = ( + "_apps", + "_future_routes", + "_future_statics", + "_future_middleware", + "_future_listeners", + "_future_exceptions", + "_future_signals", + "ctx", + "exceptions", + "host", + "listeners", + "middlewares", + "name", + "routes", + "statics", + "strict_slashes", + "url_prefix", + "version", + "version_prefix", + "websocket_routes", + ) + + def __init__( + self, + name: str = None, + url_prefix: Optional[str] = None, + host: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, + version_prefix: str = "/v", + ): + super().__init__(name=name) + self.reset() + self.ctx = SimpleNamespace() + self.host = host + self.strict_slashes = strict_slashes + self.url_prefix = ( + url_prefix[:-1] + if url_prefix and url_prefix.endswith("/") + else url_prefix + ) + self.version = version + self.version_prefix = version_prefix + + def __repr__(self) -> str: + args = ", ".join( + [ + f'{attr}="{getattr(self, attr)}"' + if isinstance(getattr(self, attr), str) + else f"{attr}={getattr(self, attr)}" + for attr in ( + "name", + "url_prefix", + "host", + "version", + "strict_slashes", + ) + ] + ) + return f"Blueprint({args})" + + @property + def apps(self): + if not self._apps: + raise SanicException( + f"{self} has not yet been registered to an app" + ) + return self._apps + + def route(self, *args, **kwargs): + kwargs["apply"] = False + return super().route(*args, **kwargs) + + def static(self, *args, **kwargs): + kwargs["apply"] = False + return super().static(*args, **kwargs) + + def middleware(self, *args, **kwargs): + kwargs["apply"] = False + return super().middleware(*args, **kwargs) + + def listener(self, *args, **kwargs): + kwargs["apply"] = False + return super().listener(*args, **kwargs) + + def exception(self, *args, **kwargs): + kwargs["apply"] = False + return super().exception(*args, **kwargs) + + def signal(self, event: str, *args, **kwargs): + kwargs["apply"] = False + return super().signal(event, *args, **kwargs) + + def reset(self): + self._apps: Set[Sanic] = set() + self.exceptions: List[RouteHandler] = [] + self.listeners: Dict[str, List[ListenerType]] = {} + self.middlewares: List[MiddlewareType] = [] + self.routes: List[Route] = [] + self.statics: List[RouteHandler] = [] + self.websocket_routes: List[Route] = [] + + def copy( + self, + name: str, + url_prefix: Optional[Union[str, Default]] = _default, + version: Optional[Union[int, str, float, Default]] = _default, + version_prefix: Union[str, Default] = _default, + strict_slashes: Optional[Union[bool, Default]] = _default, + with_registration: bool = True, + with_ctx: bool = False, + ): + """ + Copy a blueprint instance with some optional parameters to + override the values of attributes in the old instance. + + :param name: unique name of the blueprint + :param url_prefix: URL to be prefixed before all route URLs + :param version: Blueprint Version + :param version_prefix: the prefix of the version number shown in the + URL. + :param strict_slashes: Enforce the API urls are requested with a + trailing */* + :param with_registration: whether register new blueprint instance with + sanic apps that were registered with the old instance or not. + :param with_ctx: whether ``ctx`` will be copied or not. + """ + + attrs_backup = { + "_apps": self._apps, + "routes": self.routes, + "websocket_routes": self.websocket_routes, + "middlewares": self.middlewares, + "exceptions": self.exceptions, + "listeners": self.listeners, + "statics": self.statics, + } + + self.reset() + new_bp = deepcopy(self) + new_bp.name = name + + if not isinstance(url_prefix, Default): + new_bp.url_prefix = url_prefix + if not isinstance(version, Default): + new_bp.version = version + if not isinstance(strict_slashes, Default): + new_bp.strict_slashes = strict_slashes + if not isinstance(version_prefix, Default): + new_bp.version_prefix = version_prefix + + for key, value in attrs_backup.items(): + setattr(self, key, value) + + if with_registration and self._apps: + if new_bp._future_statics: + raise SanicException( + "Static routes registered with the old blueprint instance," + " cannot be registered again." + ) + for app in self._apps: + app.blueprint(new_bp) + + if not with_ctx: + new_bp.ctx = SimpleNamespace() + + return new_bp + + @staticmethod + def group( + *blueprints: Union[Blueprint, BlueprintGroup], + url_prefix: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, + version_prefix: str = "/v", + ): + """ + Create a list of blueprints, optionally grouping them under a + general URL prefix. + + :param blueprints: blueprints to be registered as a group + :param url_prefix: URL route to be prepended to all sub-prefixes + :param version: API Version to be used for Blueprint group + :param strict_slashes: Indicate strict slash termination behavior + for URL + """ + + def chain(nested) -> Iterable[Blueprint]: + """itertools.chain() but leaves strings untouched""" + for i in nested: + if isinstance(i, (list, tuple)): + yield from chain(i) + else: + yield i + + bps = BlueprintGroup( + url_prefix=url_prefix, + version=version, + strict_slashes=strict_slashes, + version_prefix=version_prefix, + ) + for bp in chain(blueprints): + bps.append(bp) + return bps + + def register(self, app, options): + """ + Register the blueprint to the sanic app. + + :param app: Instance of :class:`sanic.app.Sanic` class + :param options: Options to be used while registering the + blueprint into the app. + *url_prefix* - URL Prefix to override the blueprint prefix + """ + + self._apps.add(app) + url_prefix = options.get("url_prefix", self.url_prefix) + opt_version = options.get("version", None) + opt_strict_slashes = options.get("strict_slashes", None) + opt_version_prefix = options.get("version_prefix", self.version_prefix) + error_format = options.get( + "error_format", app.config.FALLBACK_ERROR_FORMAT + ) + + routes = [] + middleware = [] + exception_handlers = [] + listeners = defaultdict(list) + + # Routes + for future in self._future_routes: + # attach the blueprint name to the handler so that it can be + # prefixed properly in the router + future.handler.__blueprintname__ = self.name + # Prepend the blueprint URI prefix if available + uri = url_prefix + future.uri if url_prefix else future.uri + + version_prefix = self.version_prefix + for prefix in ( + future.version_prefix, + opt_version_prefix, + ): + if prefix and prefix != "/v": + version_prefix = prefix + break + + version = self._extract_value( + future.version, opt_version, self.version + ) + strict_slashes = self._extract_value( + future.strict_slashes, opt_strict_slashes, self.strict_slashes + ) + + name = app._generate_name(future.name) + + apply_route = FutureRoute( + future.handler, + uri[1:] if uri.startswith("//") else uri, + future.methods, + future.host or self.host, + strict_slashes, + future.stream, + version, + name, + future.ignore_body, + future.websocket, + future.subprotocols, + future.unquote, + future.static, + version_prefix, + error_format, + ) + + route = app._apply_route(apply_route) + operation = ( + routes.extend if isinstance(route, list) else routes.append + ) + operation(route) + + # Static Files + for future in self._future_statics: + # Prepend the blueprint URI prefix if available + uri = url_prefix + future.uri if url_prefix else future.uri + apply_route = FutureStatic(uri, *future[1:]) + route = app._apply_static(apply_route) + routes.append(route) + + route_names = [route.name for route in routes if route] + + if route_names: + # Middleware + for future in self._future_middleware: + middleware.append(app._apply_middleware(future, route_names)) + + # Exceptions + for future in self._future_exceptions: + exception_handlers.append( + app._apply_exception_handler(future, route_names) + ) + + # Event listeners + for listener in self._future_listeners: + listeners[listener.event].append(app._apply_listener(listener)) + + # Signals + for signal in self._future_signals: + signal.condition.update({"blueprint": self.name}) + app._apply_signal(signal) + + self.routes = [route for route in routes if isinstance(route, Route)] + self.websocket_routes = [ + route for route in self.routes if route.ctx.websocket + ] + self.middlewares = middleware + self.exceptions = exception_handlers + self.listeners = dict(listeners) + + async def dispatch(self, *args, **kwargs): + condition = kwargs.pop("condition", {}) + condition.update({"blueprint": self.name}) + kwargs["condition"] = condition + await asyncio.gather( + *[app.dispatch(*args, **kwargs) for app in self.apps] + ) + + def event(self, event: str, timeout: Optional[Union[int, float]] = None): + events = set() + for app in self.apps: + signal = app.signal_router.name_index.get(event) + if not signal: + raise NotFound("Could not find signal %s" % event) + events.add(signal.ctx.event) + + return asyncio.wait( + [event.wait() for event in events], + return_when=asyncio.FIRST_COMPLETED, + timeout=timeout, + ) + + @staticmethod + def _extract_value(*values): + value = values[-1] + for v in values: + if v is not None: + value = v + break + return value diff --git a/backend/sanic_server/sanic/compat.py b/backend/sanic_server/sanic/compat.py new file mode 100644 index 000000000..f8b3a74ae --- /dev/null +++ b/backend/sanic_server/sanic/compat.py @@ -0,0 +1,76 @@ +import asyncio +import os +import signal + +from sys import argv + +from multidict import CIMultiDict # type: ignore + + +OS_IS_WINDOWS = os.name == "nt" + + +class Header(CIMultiDict): + """ + Container used for both request and response headers. It is a subclass of + `CIMultiDict + `_. + + It allows for multiple values for a single key in keeping with the HTTP + spec. Also, all keys are *case in-sensitive*. + + Please checkout `the MultiDict documentation + `_ + for more details about how to use the object. In general, it should work + very similar to a regular dictionary. + """ + + def get_all(self, key: str): + """ + Convenience method mapped to ``getall()``. + """ + return self.getall(key, default=[]) + + +use_trio = argv[0].endswith("hypercorn") and "trio" in argv + +if use_trio: # pragma: no cover + import trio # type: ignore + + def stat_async(path): + return trio.Path(path).stat() + + open_async = trio.open_file + CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled]) +else: + from aiofiles import open as aio_open # type: ignore + from aiofiles.os import stat as stat_async # type: ignore # noqa: F401 + + async def open_async(file, mode="r", **kwargs): + return aio_open(file, mode, **kwargs) + + CancelledErrors = tuple([asyncio.CancelledError]) + + +def ctrlc_workaround_for_windows(app): + async def stay_active(app): + """Asyncio wakeups to allow receiving SIGINT in Python""" + while not die: + # If someone else stopped the app, just exit + if app.is_stopping: + return + # Windows Python blocks signal handlers while the event loop is + # waiting for I/O. Frequent wakeups keep interrupts flowing. + await asyncio.sleep(0.1) + # Can't be called from signal handler, so call it from here + app.stop() + + def ctrlc_handler(sig, frame): + nonlocal die + if die: + raise KeyboardInterrupt("Non-graceful Ctrl+C") + die = True + + die = False + signal.signal(signal.SIGINT, ctrlc_handler) + app.add_task(stay_active) diff --git a/backend/sanic_server/sanic/config.py b/backend/sanic_server/sanic/config.py new file mode 100644 index 000000000..8dd8cff9e --- /dev/null +++ b/backend/sanic_server/sanic/config.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from inspect import isclass +from os import environ +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from warnings import warn + +from ..sanic.errorpages import check_error_format +from ..sanic.http import Http +from .utils import load_module_from_file_location, str_to_bool + +if TYPE_CHECKING: # no cov + from ..sanic import Sanic + + +SANIC_PREFIX = "SANIC_" +BASE_LOGO = """ + + Sanic + Build Fast. Run Fast. + +""" + +DEFAULT_CONFIG = { + "ACCESS_LOG": True, + "EVENT_AUTOREGISTER": False, + "FALLBACK_ERROR_FORMAT": "auto", + "FORWARDED_FOR_HEADER": "X-Forwarded-For", + "FORWARDED_SECRET": None, + "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec + "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds + "KEEP_ALIVE": True, + "PROXIES_COUNT": None, + "REAL_IP_HEADER": None, + "REGISTER": True, + "REQUEST_BUFFER_SIZE": 65536, # 64 KiB + "REQUEST_MAX_HEADER_SIZE": 8192, # 8 KiB, but cannot exceed 16384 + "REQUEST_ID_HEADER": "X-Request-ID", + "REQUEST_MAX_SIZE": 100000000, # 100 megabytes + "REQUEST_TIMEOUT": 60, # 60 seconds + "RESPONSE_TIMEOUT": 60, # 60 seconds + "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte + "WEBSOCKET_PING_INTERVAL": 20, + "WEBSOCKET_PING_TIMEOUT": 20, +} + + +class Config(dict): + ACCESS_LOG: bool + EVENT_AUTOREGISTER: bool + FALLBACK_ERROR_FORMAT: str + FORWARDED_FOR_HEADER: str + FORWARDED_SECRET: Optional[str] + GRACEFUL_SHUTDOWN_TIMEOUT: float + KEEP_ALIVE_TIMEOUT: int + KEEP_ALIVE: bool + PROXIES_COUNT: Optional[int] + REAL_IP_HEADER: Optional[str] + REGISTER: bool + REQUEST_BUFFER_SIZE: int + REQUEST_MAX_HEADER_SIZE: int + REQUEST_ID_HEADER: str + REQUEST_MAX_SIZE: int + REQUEST_TIMEOUT: int + RESPONSE_TIMEOUT: int + SERVER_NAME: str + WEBSOCKET_MAX_SIZE: int + WEBSOCKET_PING_INTERVAL: int + WEBSOCKET_PING_TIMEOUT: int + + def __init__( + self, + defaults: Dict[str, Union[str, bool, int, float, None]] = None, + load_env: Optional[Union[bool, str]] = True, + env_prefix: Optional[str] = SANIC_PREFIX, + keep_alive: Optional[bool] = None, + *, + app: Optional[Sanic] = None, + ): + defaults = defaults or {} + super().__init__({**DEFAULT_CONFIG, **defaults}) + + self._app = app + self._LOGO = BASE_LOGO + + if keep_alive is not None: + self.KEEP_ALIVE = keep_alive + + if env_prefix != SANIC_PREFIX: + if env_prefix: + self.load_environment_vars(env_prefix) + elif load_env is not True: + if load_env: + self.load_environment_vars(prefix=load_env) + warn( + "Use of load_env is deprecated and will be removed in " + "21.12. Modify the configuration prefix by passing " + "env_prefix instead.", + DeprecationWarning, + ) + else: + self.load_environment_vars(SANIC_PREFIX) + + self._configure_header_size() + self._check_error_format() + self._init = True + + def __getattr__(self, attr): + try: + return self[attr] + except KeyError as ke: + raise AttributeError(f"Config has no '{ke.args[0]}'") + + def __setattr__(self, attr, value) -> None: + self.update({attr: value}) + + def __setitem__(self, attr, value) -> None: + self.update({attr: value}) + + def update(self, *other, **kwargs) -> None: + other_mapping = {k: v for item in other for k, v in dict(item).items()} + super().update(*other, **kwargs) + for attr, value in {**other_mapping, **kwargs}.items(): + self._post_set(attr, value) + + def _post_set(self, attr, value) -> None: + if self.get("_init"): + if attr in ( + "REQUEST_MAX_HEADER_SIZE", + "REQUEST_BUFFER_SIZE", + "REQUEST_MAX_SIZE", + ): + self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() + if self.app and value != self.app.error_handler.fallback: + if self.app.error_handler.fallback != "auto": + warn( + "Overriding non-default ErrorHandler fallback " + "value. Changing from " + f"{self.app.error_handler.fallback} to {value}." + ) + self.app.error_handler.fallback = value + elif attr == "LOGO": + self._LOGO = value + warn( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + DeprecationWarning, + ) + + @property + def app(self): + return self._app + + @property + def LOGO(self): + return self._LOGO + + def _configure_header_size(self): + Http.set_header_max_size( + self.REQUEST_MAX_HEADER_SIZE, + self.REQUEST_BUFFER_SIZE - 4096, + self.REQUEST_MAX_SIZE, + ) + + def _check_error_format(self): + check_error_format(self.FALLBACK_ERROR_FORMAT) + + def load_environment_vars(self, prefix=SANIC_PREFIX): + """ + Looks for prefixed environment variables and applies + them to the configuration if present. This is called automatically when + Sanic starts up to load environment variables into config. + + It will automatically hyrdate the following types: + + - ``int`` + - ``float`` + - ``bool`` + + Anything else will be imported as a ``str``. + """ + for k, v in environ.items(): + if k.startswith(prefix): + _, config_key = k.split(prefix, 1) + try: + self[config_key] = int(v) + except ValueError: + try: + self[config_key] = float(v) + except ValueError: + try: + self[config_key] = str_to_bool(v) + except ValueError: + self[config_key] = v + + def update_config(self, config: Union[bytes, str, dict, Any]): + """ + Update app.config. + + .. note:: + + Only upper case settings are considered + + You can upload app config by providing path to py file + holding settings. + + .. code-block:: python + + # /some/py/file + A = 1 + B = 2 + + .. code-block:: python + + config.update_config("${some}/py/file") + + Yes you can put environment variable here, but they must be provided + in format: ``${some_env_var}``, and mark that ``$some_env_var`` is + treated as plain string. + + You can upload app config by providing dict holding settings. + + .. code-block:: python + + d = {"A": 1, "B": 2} + config.update_config(d) + + You can upload app config by providing any object holding settings, + but in such case config.__dict__ will be used as dict holding settings. + + .. code-block:: python + + class C: + A = 1 + B = 2 + + config.update_config(C) + + `See user guide re: config + `__ + """ + + if isinstance(config, (bytes, str, Path)): + config = load_module_from_file_location(location=config) + + if not isinstance(config, dict): + cfg = {} + if not isclass(config): + cfg.update( + { + key: getattr(config, key) + for key in config.__class__.__dict__.keys() + } + ) + + config = dict(config.__dict__) + config.update(cfg) + + config = dict(filter(lambda i: i[0].isupper(), config.items())) + + self.update(config) + + load = update_config diff --git a/backend/sanic_server/sanic/constants.py b/backend/sanic_server/sanic/constants.py new file mode 100644 index 000000000..80f1d2a9b --- /dev/null +++ b/backend/sanic_server/sanic/constants.py @@ -0,0 +1,28 @@ +from enum import Enum, auto + + +class HTTPMethod(str, Enum): + def _generate_next_value_(name, start, count, last_values): + return name.upper() + + def __eq__(self, value: object) -> bool: + value = str(value).upper() + return super().__eq__(value) + + def __hash__(self) -> int: + return hash(self.value) + + def __str__(self) -> str: + return self.value + + GET = auto() + POST = auto() + PUT = auto() + HEAD = auto() + OPTIONS = auto() + PATCH = auto() + DELETE = auto() + + +HTTP_METHODS = tuple(HTTPMethod.__members__.values()) +DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" diff --git a/backend/sanic_server/sanic/cookies.py b/backend/sanic_server/sanic/cookies.py new file mode 100644 index 000000000..993ce3522 --- /dev/null +++ b/backend/sanic_server/sanic/cookies.py @@ -0,0 +1,156 @@ +import re +import string + +from datetime import datetime +from typing import Dict + + +DEFAULT_MAX_AGE = 0 + +# ------------------------------------------------------------ # +# SimpleCookie +# ------------------------------------------------------------ # + +# Straight up copied this section of dark magic from SimpleCookie + +_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:" +_UnescapedChars = _LegalChars + " ()/<=>?@[]{}" + +_Translator = { + n: "\\%03o" % n for n in set(range(256)) - set(map(ord, _UnescapedChars)) +} +_Translator.update({ord('"'): '\\"', ord("\\"): "\\\\"}) + + +def _quote(str): + r"""Quote a string for use in a cookie header. + If the string does not need to be double-quoted, then just return the + string. Otherwise, surround the string in doublequotes and quote + (with a \) special characters. + """ + if str is None or _is_legal_key(str): + return str + else: + return '"' + str.translate(_Translator) + '"' + + +_is_legal_key = re.compile("[%s]+" % re.escape(_LegalChars)).fullmatch + +# ------------------------------------------------------------ # +# Custom SimpleCookie +# ------------------------------------------------------------ # + + +class CookieJar(dict): + """ + CookieJar dynamically writes headers as cookies are added and removed + It gets around the limitation of one header per name by using the + MultiHeader class to provide a unique key that encodes to Set-Cookie. + """ + + def __init__(self, headers): + super().__init__() + self.headers: Dict[str, str] = headers + self.cookie_headers: Dict[str, str] = {} + self.header_key: str = "Set-Cookie" + + def __setitem__(self, key, value): + # If this cookie doesn't exist, add it to the header keys + if not self.cookie_headers.get(key): + cookie = Cookie(key, value) + cookie["path"] = "/" + self.cookie_headers[key] = self.header_key + self.headers.add(self.header_key, cookie) + return super().__setitem__(key, cookie) + else: + self[key].value = value + + def __delitem__(self, key): + if key not in self.cookie_headers: + self[key] = "" + self[key]["max-age"] = 0 + else: + cookie_header = self.cookie_headers[key] + # remove it from header + cookies = self.headers.popall(cookie_header) + for cookie in cookies: + if cookie.key != key: + self.headers.add(cookie_header, cookie) + del self.cookie_headers[key] + return super().__delitem__(key) + + +class Cookie(dict): + """A stripped down version of Morsel from SimpleCookie #gottagofast""" + + _keys = { + "expires": "expires", + "path": "Path", + "comment": "Comment", + "domain": "Domain", + "max-age": "Max-Age", + "secure": "Secure", + "httponly": "HttpOnly", + "version": "Version", + "samesite": "SameSite", + } + _flags = {"secure", "httponly"} + + def __init__(self, key, value): + if key in self._keys: + raise KeyError("Cookie name is a reserved word") + if not _is_legal_key(key): + raise KeyError("Cookie key contains illegal characters") + self.key = key + self.value = value + super().__init__() + + def __setitem__(self, key, value): + if key not in self._keys: + raise KeyError("Unknown cookie property") + if value is not False: + if key.lower() == "max-age": + if not str(value).isdigit(): + raise ValueError("Cookie max-age must be an integer") + elif key.lower() == "expires": + if not isinstance(value, datetime): + raise TypeError( + "Cookie 'expires' property must be a datetime" + ) + return super().__setitem__(key, value) + + def encode(self, encoding): + """ + Encode the cookie content in a specific type of encoding instructed + by the developer. Leverages the :func:`str.encode` method provided + by python. + + This method can be used to encode and embed ``utf-8`` content into + the cookies. + + :param encoding: Encoding to be used with the cookie + :return: Cookie encoded in a codec of choosing. + :except: UnicodeEncodeError + """ + return str(self).encode(encoding) + + def __str__(self): + """Format as a Set-Cookie header value.""" + output = ["%s=%s" % (self.key, _quote(self.value))] + for key, value in self.items(): + if key == "max-age": + try: + output.append("%s=%d" % (self._keys[key], value)) + except TypeError: + output.append("%s=%s" % (self._keys[key], value)) + elif key == "expires": + output.append( + "%s=%s" + % (self._keys[key], value.strftime("%a, %d-%b-%Y %T GMT")) + ) + elif key in self._flags and self[key]: + output.append(self._keys[key]) + else: + output.append("%s=%s" % (self._keys[key], value)) + + return "; ".join(output) diff --git a/backend/sanic_server/sanic/errorpages.py b/backend/sanic_server/sanic/errorpages.py new file mode 100644 index 000000000..af6765732 --- /dev/null +++ b/backend/sanic_server/sanic/errorpages.py @@ -0,0 +1,469 @@ +""" +Sanic `provides a pattern +`_ +for providing a response when an exception occurs. However, if you do no handle +an exception, it will provide a fallback. There are three fallback types: + +- HTML - *default* +- Text +- JSON + +Setting ``app.config.FALLBACK_ERROR_FORMAT = "auto"`` will enable a switch that +will attempt to provide an appropriate response format based upon the +request type. +""" + +import sys +import typing as t +from functools import partial +from traceback import extract_tb + +from ..sanic.exceptions import InvalidUsage, SanicException +from ..sanic.helpers import STATUS_CODES +from ..sanic.request import Request +from ..sanic.response import HTTPResponse, html, json, text + +try: + from ujson import dumps + + dumps = partial(dumps, escape_forward_slashes=False) +except ImportError: # noqa + from json import dumps # type: ignore + + +FALLBACK_TEXT = ( + "The server encountered an internal error and " "cannot complete your request." +) +FALLBACK_STATUS = 500 + + +class BaseRenderer: + """ + Base class that all renderers must inherit from. + """ + + def __init__(self, request, exception, debug): + self.request = request + self.exception = exception + self.debug = debug + + @property + def headers(self): + if isinstance(self.exception, SanicException): + return getattr(self.exception, "headers", {}) + return {} + + @property + def status(self): + if isinstance(self.exception, SanicException): + return getattr(self.exception, "status_code", FALLBACK_STATUS) + return FALLBACK_STATUS + + @property + def text(self): + if self.debug or isinstance(self.exception, SanicException): + return str(self.exception) + return FALLBACK_TEXT + + @property + def title(self): + status_text = STATUS_CODES.get(self.status, b"Error Occurred").decode() + return f"{self.status} — {status_text}" + + def render(self) -> HTTPResponse: + """ + Outputs the exception as a :class:`HTTPResponse`. + + :return: The formatted exception + :rtype: str + """ + output = ( + self.full + if self.debug and not getattr(self.exception, "quiet", False) + else self.minimal + ) + return output() + + def minimal(self) -> HTTPResponse: # noqa + """ + Provide a formatted message that is meant to not show any sensitive + data or details. + """ + raise NotImplementedError + + def full(self) -> HTTPResponse: # noqa + """ + Provide a formatted message that has all details and is mean to be used + primarily for debugging and non-production environments. + """ + raise NotImplementedError + + +class HTMLRenderer(BaseRenderer): + """ + Render an exception as HTML. + + The default fallback type. + """ + + TRACEBACK_STYLE = """ + html { font-family: sans-serif } + h2 { color: #888; } + .tb-wrapper p { margin: 0 } + .frame-border { margin: 1rem } + .frame-line > * { padding: 0.3rem 0.6rem } + .frame-line { margin-bottom: 0.3rem } + .frame-code { font-size: 16px; padding-left: 4ch } + .tb-wrapper { border: 1px solid #eee } + .tb-header { background: #eee; padding: 0.3rem; font-weight: bold } + .frame-descriptor { background: #e2eafb; font-size: 14px } + """ + TRACEBACK_WRAPPER_HTML = ( + "
{exc_name}: {exc_value}
" + "
{frame_html}
" + ) + TRACEBACK_BORDER = ( + "
" + "The above exception was the direct cause of the following exception:" + "
" + ) + TRACEBACK_LINE_HTML = ( + "
" + "

" + "File {0.filename}, line {0.lineno}, " + "in {0.name}" + "

{0.line}" + "

" + ) + OUTPUT_HTML = ( + "" + "{title}\n" + "\n" + "

{title}

{text}\n" + "{body}" + ) + + def full(self) -> HTTPResponse: + return html( + self.OUTPUT_HTML.format( + title=self.title, + text=self.text, + style=self.TRACEBACK_STYLE, + body=self._generate_body(), + ), + status=self.status, + ) + + def minimal(self) -> HTTPResponse: + return html( + self.OUTPUT_HTML.format( + title=self.title, + text=self.text, + style=self.TRACEBACK_STYLE, + body="", + ), + status=self.status, + headers=self.headers, + ) + + @property + def text(self): + return escape(super().text) + + @property + def title(self): + return escape(f"⚠️ {super().title}") + + def _generate_body(self): + _, exc_value, __ = sys.exc_info() + exceptions = [] + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ + + traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions)) + appname = escape(self.request.app.name) + name = escape(self.exception.__class__.__name__) + value = escape(self.exception) + path = escape(self.request.path) + lines = [ + f"

Traceback of {appname} (most recent call last):

", + f"{traceback_html}", + "

", + f"{name}: {value} while handling path {path}", + "

", + ] + return "\n".join(lines) + + def _format_exc(self, exc): + frames = extract_tb(exc.__traceback__) + frame_html = "".join(self.TRACEBACK_LINE_HTML.format(frame) for frame in frames) + return self.TRACEBACK_WRAPPER_HTML.format( + exc_name=escape(exc.__class__.__name__), + exc_value=escape(exc), + frame_html=frame_html, + ) + + +class TextRenderer(BaseRenderer): + """ + Render an exception as plain text. + """ + + OUTPUT_TEXT = "{title}\n{bar}\n{text}\n\n{body}" + SPACER = " " + + def full(self) -> HTTPResponse: + return text( + self.OUTPUT_TEXT.format( + title=self.title, + text=self.text, + bar=("=" * len(self.title)), + body=self._generate_body(), + ), + status=self.status, + ) + + def minimal(self) -> HTTPResponse: + return text( + self.OUTPUT_TEXT.format( + title=self.title, + text=self.text, + bar=("=" * len(self.title)), + body="", + ), + status=self.status, + headers=self.headers, + ) + + @property + def title(self): + return f"⚠️ {super().title}" + + def _generate_body(self): + _, exc_value, __ = sys.exc_info() + exceptions = [] + + lines = [ + f"{self.exception.__class__.__name__}: {self.exception} while " + f"handling path {self.request.path}", + f"Traceback of {self.request.app.name} (most recent call last):\n", + ] + + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ + + return "\n".join(lines + exceptions[::-1]) + + def _format_exc(self, exc): + frames = "\n\n".join( + [ + f"{self.SPACER * 2}File {frame.filename}, " + f"line {frame.lineno}, in " + f"{frame.name}\n{self.SPACER * 2}{frame.line}" + for frame in extract_tb(exc.__traceback__) + ] + ) + return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}" + + +class JSONRenderer(BaseRenderer): + """ + Render an exception as JSON. + """ + + def full(self) -> HTTPResponse: + output = self._generate_output(full=True) + return json(output, status=self.status, dumps=dumps) + + def minimal(self) -> HTTPResponse: + output = self._generate_output(full=False) + return json(output, status=self.status, dumps=dumps) + + def _generate_output(self, *, full): + output = { + "description": self.title, + "status": self.status, + "message": self.text, + } + + if full: + _, exc_value, __ = sys.exc_info() + exceptions = [] + + while exc_value: + exceptions.append( + { + "type": exc_value.__class__.__name__, + "exception": str(exc_value), + "frames": [ + { + "file": frame.filename, + "line": frame.lineno, + "name": frame.name, + "src": frame.line, + } + for frame in extract_tb(exc_value.__traceback__) + ], + } + ) + exc_value = exc_value.__cause__ + + output["path"] = self.request.path + output["args"] = self.request.args + output["exceptions"] = exceptions[::-1] + + return output + + @property + def title(self): + return STATUS_CODES.get(self.status, b"Error Occurred").decode() + + +def escape(text): + """ + Minimal HTML escaping, not for attribute values (unlike html.escape). + """ + return f"{text}".replace("&", "&").replace("<", "<") + + +RENDERERS_BY_CONFIG = { + "html": HTMLRenderer, + "json": JSONRenderer, + "text": TextRenderer, +} + +RENDERERS_BY_CONTENT_TYPE = { + "text/plain": TextRenderer, + "application/json": JSONRenderer, + "multipart/form-data": HTMLRenderer, + "text/html": HTMLRenderer, +} +CONTENT_TYPE_BY_RENDERERS = {v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()} + +RESPONSE_MAPPING = { + "empty": "html", + "json": "json", + "text": "text", + "raw": "text", + "html": "html", + "file": "html", + "file_stream": "text", + "stream": "text", + "redirect": "html", + "text/plain": "text", + "text/html": "html", + "application/json": "json", +} + + +def check_error_format(format): + if format not in RENDERERS_BY_CONFIG and format != "auto": + raise SanicException(f"Unknown format: {format}") + + +def exception_response( + request: Request, + exception: Exception, + debug: bool, + fallback: str, + base: t.Type[BaseRenderer], + renderer: t.Type[t.Optional[BaseRenderer]] = None, +) -> HTTPResponse: + """ + Render a response for the default FALLBACK exception handler. + """ + content_type = None + + if not renderer: + # Make sure we have something set + renderer = base + render_format = fallback + + if request: + # If there is a request, try and get the format + # from the route + if request.route: + try: + if request.route.ctx.error_format: + render_format = request.route.ctx.error_format + except AttributeError: + ... + + content_type = request.headers.getone("content-type", "").split(";")[0] + + acceptable = request.accept + + # If the format is auto still, make a guess + if render_format == "auto": + # First, if there is an Accept header, check if text/html + # is the first option + # According to MDN Web Docs, all major browsers use text/html + # as the primary value in Accept (with the exception of IE 8, + # and, well, if you are supporting IE 8, then you have bigger + # problems to concern yourself with than what default exception + # renderer is used) + # Source: + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values + + if acceptable and acceptable[0].match( + "text/html", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ): + renderer = HTMLRenderer + + # Second, if there is an Accept header, check if + # application/json is an option, or if the content-type + # is application/json + elif ( + acceptable + and acceptable.match( + "application/json", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ) + or content_type == "application/json" + ): + renderer = JSONRenderer + + # Third, if there is no Accept header, assume we want text. + # The likely use case here is a raw socket. + elif not acceptable: + renderer = TextRenderer + else: + # Fourth, look to see if there was a JSON body + # When in this situation, the request is probably coming + # from curl, an API client like Postman or Insomnia, or a + # package like requests or httpx + try: + # Give them the benefit of the doubt if they did: + # $ curl localhost:8000 -d '{"foo": "bar"}' + # And provide them with JSONRenderer + renderer = JSONRenderer if request.json else base + except InvalidUsage: + renderer = base + else: + renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) + + # Lastly, if there is an Accept header, make sure + # our choice is okay + if acceptable: + type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore + if type_ and type_ not in acceptable: + # If the renderer selected is not in the Accept header + # look through what is in the Accept header, and select + # the first option that matches. Otherwise, just drop back + # to the original default + for accept in acceptable: + mtype = f"{accept.type_}/{accept.subtype}" + maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) + if maybe: + renderer = maybe + break + else: + renderer = base + + renderer = t.cast(t.Type[BaseRenderer], renderer) + return renderer(request, exception, debug).render() diff --git a/backend/sanic_server/sanic/exceptions.py b/backend/sanic_server/sanic/exceptions.py new file mode 100644 index 000000000..17af1f957 --- /dev/null +++ b/backend/sanic_server/sanic/exceptions.py @@ -0,0 +1,262 @@ +from typing import Optional, Union + +from ..sanic.helpers import STATUS_CODES + + +class SanicException(Exception): + message: str = "" + + def __init__( + self, + message: Optional[Union[str, bytes]] = None, + status_code: Optional[int] = None, + quiet: Optional[bool] = None, + ) -> None: + if message is None: + if self.message: + message = self.message + elif status_code is not None: + msg: bytes = STATUS_CODES.get(status_code, b"") + message = msg.decode("utf8") + + super().__init__(message) + + if status_code is not None: + self.status_code = status_code + + # quiet=None/False/True with None meaning choose by status + if quiet or quiet is None and status_code not in (None, 500): + self.quiet = True + + +class NotFound(SanicException): + """ + **Status**: 404 Not Found + """ + + status_code = 404 + quiet = True + + +class InvalidUsage(SanicException): + """ + **Status**: 400 Bad Request + """ + + status_code = 400 + quiet = True + + +class MethodNotSupported(SanicException): + """ + **Status**: 405 Method Not Allowed + """ + + status_code = 405 + quiet = True + + def __init__(self, message, method, allowed_methods): + super().__init__(message) + self.headers = {"Allow": ", ".join(allowed_methods)} + + +class ServerError(SanicException): + """ + **Status**: 500 Internal Server Error + """ + + status_code = 500 + + +class ServiceUnavailable(SanicException): + """ + **Status**: 503 Service Unavailable + + The server is currently unavailable (because it is overloaded or + down for maintenance). Generally, this is a temporary state. + """ + + status_code = 503 + quiet = True + + +class URLBuildError(ServerError): + """ + **Status**: 500 Internal Server Error + """ + + status_code = 500 + + +class FileNotFound(NotFound): + """ + **Status**: 404 Not Found + """ + + def __init__(self, message, path, relative_url): + super().__init__(message) + self.path = path + self.relative_url = relative_url + + +class RequestTimeout(SanicException): + """The Web server (running the Web site) thinks that there has been too + long an interval of time between 1) the establishment of an IP + connection (socket) between the client and the server and + 2) the receipt of any data on that socket, so the server has dropped + the connection. The socket connection has actually been lost - the Web + server has 'timed out' on that particular socket connection. + """ + + status_code = 408 + quiet = True + + +class PayloadTooLarge(SanicException): + """ + **Status**: 413 Payload Too Large + """ + + status_code = 413 + quiet = True + + +class HeaderNotFound(InvalidUsage): + """ + **Status**: 400 Bad Request + """ + + +class InvalidHeader(InvalidUsage): + """ + **Status**: 400 Bad Request + """ + + +class ContentRangeError(SanicException): + """ + **Status**: 416 Range Not Satisfiable + """ + + status_code = 416 + quiet = True + + def __init__(self, message, content_range): + super().__init__(message) + self.headers = {"Content-Range": f"bytes */{content_range.total}"} + + +class HeaderExpectationFailed(SanicException): + """ + **Status**: 417 Expectation Failed + """ + + status_code = 417 + quiet = True + + +class Forbidden(SanicException): + """ + **Status**: 403 Forbidden + """ + + status_code = 403 + quiet = True + + +class InvalidRangeType(ContentRangeError): + """ + **Status**: 416 Range Not Satisfiable + """ + + status_code = 416 + quiet = True + + +class PyFileError(Exception): + def __init__(self, file): + super().__init__("could not execute config file %s", file) + + +class Unauthorized(SanicException): + """ + **Status**: 401 Unauthorized + + :param message: Message describing the exception. + :param status_code: HTTP Status code. + :param scheme: Name of the authentication scheme to be used. + + When present, kwargs is used to complete the WWW-Authentication header. + + Examples:: + + # With a Basic auth-scheme, realm MUST be present: + raise Unauthorized("Auth required.", + scheme="Basic", + realm="Restricted Area") + + # With a Digest auth-scheme, things are a bit more complicated: + raise Unauthorized("Auth required.", + scheme="Digest", + realm="Restricted Area", + qop="auth, auth-int", + algorithm="MD5", + nonce="abcdef", + opaque="zyxwvu") + + # With a Bearer auth-scheme, realm is optional so you can write: + raise Unauthorized("Auth required.", scheme="Bearer") + + # or, if you want to specify the realm: + raise Unauthorized("Auth required.", + scheme="Bearer", + realm="Restricted Area") + """ + + status_code = 401 + quiet = True + + def __init__(self, message, status_code=None, scheme=None, **kwargs): + super().__init__(message, status_code) + + # if auth-scheme is specified, set "WWW-Authenticate" header + if scheme is not None: + values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()] + challenge = ", ".join(values) + + self.headers = {"WWW-Authenticate": f"{scheme} {challenge}".rstrip()} + + +class LoadFileException(SanicException): + pass + + +class InvalidSignal(SanicException): + pass + + +class WebsocketClosed(SanicException): + quiet = True + message = "Client has closed the websocket connection" + + +def abort(status_code: int, message: Optional[Union[str, bytes]] = None): + """ + Raise an exception based on SanicException. Returns the HTTP response + message appropriate for the given status code, unless provided. + + STATUS_CODES from ..sanic.helpers for the given status code. + + :param status_code: The HTTP status code to return. + :param message: The HTTP response body. Defaults to the messages in + """ + import warnings + + warnings.warn( + "sanic.exceptions.abort has been marked as deprecated, and will be " + "removed in release 21.12.\n To migrate your code, simply replace " + "abort(status_code, msg) with raise SanicException(msg, status_code), " + "or even better, raise an appropriate SanicException subclass." + ) + + raise SanicException(message=message, status_code=status_code) diff --git a/backend/sanic_server/sanic/handlers.py b/backend/sanic_server/sanic/handlers.py new file mode 100644 index 000000000..7a8b1ee11 --- /dev/null +++ b/backend/sanic_server/sanic/handlers.py @@ -0,0 +1,278 @@ +from inspect import signature +from typing import Dict, List, Optional, Tuple, Type + +from ..sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response +from ..sanic.exceptions import (ContentRangeError, HeaderNotFound, + InvalidRangeType) +from ..sanic.log import error_logger +from ..sanic.models.handler_types import RouteHandler +from ..sanic.response import text + + +class ErrorHandler: + """ + Provide :class:`sanic.app.Sanic` application with a mechanism to handle + and process any and all uncaught exceptions in a way the application + developer will set fit. + + This error handling framework is built into the core that can be extended + by the developers to perform a wide range of tasks from recording the error + stats to reporting them to an external service that can be used for + realtime alerting system. + + """ + + # Beginning in v22.3, the base renderer will be TextRenderer + def __init__( + self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer + ): + self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] + self.cached_handlers: Dict[ + Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] + ] = {} + self.debug = False + self.fallback = fallback + self.base = base + + @classmethod + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + + if not isinstance(error_handler, cls): + error_logger.warning( + f"Error handler is non-conforming: {type(error_handler)}" + ) + + sig = signature(error_handler.lookup) + if len(sig.parameters) == 1: + error_logger.warning( + DeprecationWarning( + "You are using a deprecated error handler. The lookup " + "method should accept two positional parameters: " + "(exception, route_name: Optional[str]). " + "Until you upgrade your ErrorHandler.lookup, Blueprint " + "specific exceptions will not work properly. Beginning " + "in v22.3, the legacy style lookup method will not " + "work at all." + ), + ) + error_handler._lookup = error_handler._legacy_lookup + + def _full_lookup(self, exception, route_name: Optional[str] = None): + return self.lookup(exception, route_name) + + def _legacy_lookup(self, exception, route_name: Optional[str] = None): + return self.lookup(exception) + + def add(self, exception, handler, route_names: Optional[List[str]] = None): + """ + Add a new exception handler to an already existing handler object. + + :param exception: Type of exception that need to be handled + :param handler: Reference to the method that will handle the exception + + :type exception: :class:`sanic.exceptions.SanicException` or + :class:`Exception` + :type handler: ``function`` + + :return: None + """ + # self.handlers is deprecated and will be removed in version 22.3 + self.handlers.append((exception, handler)) + + if route_names: + for route in route_names: + self.cached_handlers[(exception, route)] = handler + else: + self.cached_handlers[(exception, None)] = handler + + def lookup(self, exception, route_name: Optional[str] = None): + """ + Lookup the existing instance of :class:`ErrorHandler` and fetch the + registered handler for a specific type of exception. + + This method leverages a dict lookup to speedup the retrieval process. + + :param exception: Type of exception + + :type exception: :class:`sanic.exceptions.SanicException` or + :class:`Exception` + + :return: Registered function if found ``None`` otherwise + """ + exception_class = type(exception) + + for name in (route_name, None): + exception_key = (exception_class, name) + handler = self.cached_handlers.get(exception_key) + if handler: + return handler + + for name in (route_name, None): + for ancestor in type.mro(exception_class): + exception_key = (ancestor, name) + if exception_key in self.cached_handlers: + handler = self.cached_handlers[exception_key] + self.cached_handlers[ + (exception_class, route_name) + ] = handler + return handler + + if ancestor is BaseException: + break + self.cached_handlers[(exception_class, route_name)] = None + handler = None + return handler + + _lookup = _full_lookup + + def response(self, request, exception): + """Fetches and executes an exception handler and returns a response + object + + :param request: Instance of :class:`sanic.request.Request` + :param exception: Exception to handle + + :type request: :class:`sanic.request.Request` + :type exception: :class:`sanic.exceptions.SanicException` or + :class:`Exception` + + :return: Wrap the return value obtained from :func:`default` + or registered handler for that type of exception. + """ + route_name = request.name if request else None + handler = self._lookup(exception, route_name) + response = None + try: + if handler: + response = handler(request, exception) + if response is None: + response = self.default(request, exception) + except Exception: + try: + url = repr(request.url) + except AttributeError: + url = "unknown" + response_message = ( + "Exception raised in exception handler " '"%s" for uri: %s' + ) + error_logger.exception(response_message, handler.__name__, url) + + if self.debug: + return text(response_message % (handler.__name__, url), 500) + else: + return text("An error occurred while handling an error", 500) + return response + + def default(self, request, exception): + """ + Provide a default behavior for the objects of :class:`ErrorHandler`. + If a developer chooses to extent the :class:`ErrorHandler` they can + provide a custom implementation for this method to behave in a way + they see fit. + + :param request: Incoming request + :param exception: Exception object + + :type request: :class:`sanic.request.Request` + :type exception: :class:`sanic.exceptions.SanicException` or + :class:`Exception` + :return: + """ + self.log(request, exception) + return exception_response( + request, + exception, + debug=self.debug, + base=self.base, + fallback=self.fallback, + ) + + @staticmethod + def log(request, exception): + quiet = getattr(exception, "quiet", False) + if quiet is False: + try: + url = repr(request.url) + except AttributeError: + url = "unknown" + + error_logger.exception( + "Exception occurred while handling uri: %s", url + ) + + +class ContentRangeHandler: + """ + A mechanism to parse and process the incoming request headers to + extract the content range information. + + :param request: Incoming api request + :param stats: Stats related to the content + + :type request: :class:`sanic.request.Request` + :type stats: :class:`posix.stat_result` + + :ivar start: Content Range start + :ivar end: Content Range end + :ivar size: Length of the content + :ivar total: Total size identified by the :class:`posix.stat_result` + instance + :ivar ContentRangeHandler.headers: Content range header ``dict`` + """ + + __slots__ = ("start", "end", "size", "total", "headers") + + def __init__(self, request, stats): + self.total = stats.st_size + _range = request.headers.getone("range", None) + if _range is None: + raise HeaderNotFound("Range Header Not Found") + unit, _, value = tuple(map(str.strip, _range.partition("="))) + if unit != "bytes": + raise InvalidRangeType( + "%s is not a valid Range Type" % (unit,), self + ) + start_b, _, end_b = tuple(map(str.strip, value.partition("-"))) + try: + self.start = int(start_b) if start_b else None + except ValueError: + raise ContentRangeError( + "'%s' is invalid for Content Range" % (start_b,), self + ) + try: + self.end = int(end_b) if end_b else None + except ValueError: + raise ContentRangeError( + "'%s' is invalid for Content Range" % (end_b,), self + ) + if self.end is None: + if self.start is None: + raise ContentRangeError( + "Invalid for Content Range parameters", self + ) + else: + # this case represents `Content-Range: bytes 5-` + self.end = self.total - 1 + else: + if self.start is None: + # this case represents `Content-Range: bytes -5` + self.start = self.total - self.end + self.end = self.total - 1 + if self.start >= self.end: + raise ContentRangeError( + "Invalid for Content Range parameters", self + ) + self.size = self.end - self.start + 1 + self.headers = { + "Content-Range": "bytes %s-%s/%s" + % (self.start, self.end, self.total) + } + + def __bool__(self): + return self.size > 0 diff --git a/backend/sanic_server/sanic/headers.py b/backend/sanic_server/sanic/headers.py new file mode 100644 index 000000000..23e9364c3 --- /dev/null +++ b/backend/sanic_server/sanic/headers.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import re +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from urllib.parse import unquote + +from ..sanic.exceptions import InvalidHeader +from ..sanic.helpers import STATUS_CODES + +# TODO: +# - the Options object should be a typed object to allow for less casting +# across the application (in request.py for example) +HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str +HeaderBytesIterable = Iterable[Tuple[bytes, bytes]] +Options = Dict[str, Union[int, str]] # key=value fields in various headers +OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys + +_token, _quoted = r"([\w!#$%&'*+\-.^_`|~]+)", r'"([^"]*)"' +_param = re.compile(fr";\s*{_token}=(?:{_token}|{_quoted})", re.ASCII) +_firefox_quote_escape = re.compile(r'\\"(?!; |\s*$)') +_ipv6 = "(?:[0-9A-Fa-f]{0,4}:){2,7}[0-9A-Fa-f]{0,4}" +_ipv6_re = re.compile(_ipv6) +_host_re = re.compile(r"((?:\[" + _ipv6 + r"\])|[a-zA-Z0-9.\-]{1,253})(?::(\d{1,5}))?") + +# RFC's quoted-pair escapes are mostly ignored by browsers. Chrome, Firefox and +# curl all have different escaping, that we try to handle as well as possible, +# even though no client espaces in a way that would allow perfect handling. + +# For more information, consult ../tests/test_requests.py + + +def parse_arg_as_accept(f): + def func(self, other, *args, **kwargs): + if not isinstance(other, Accept) and other: + other = Accept.parse(other) + return f(self, other, *args, **kwargs) + + return func + + +class MediaType(str): + def __new__(cls, value: str): + return str.__new__(cls, value) + + def __init__(self, value: str) -> None: + self.value = value + self.is_wildcard = self.check_if_wildcard(value) + + def __eq__(self, other): + if self.is_wildcard: + return True + + if self.match(other): + return True + + other_is_wildcard = ( + other.is_wildcard + if isinstance(other, MediaType) + else self.check_if_wildcard(other) + ) + + return other_is_wildcard + + def match(self, other): + other_value = other.value if isinstance(other, MediaType) else other + return self.value == other_value + + @staticmethod + def check_if_wildcard(value): + return value == "*" + + +class Accept(str): + def __new__(cls, value: str, *args, **kwargs): + return str.__new__(cls, value) + + def __init__( + self, + value: str, + type_: MediaType, + subtype: MediaType, + *, + q: str = "1.0", + **kwargs: str, + ): + qvalue = float(q) + if qvalue > 1 or qvalue < 0: + raise InvalidHeader( + f"Accept header qvalue must be between 0 and 1, not: {qvalue}" + ) + self.value = value + self.type_ = type_ + self.subtype = subtype + self.qvalue = qvalue + self.params = kwargs + + def _compare(self, other, method): + try: + return method(self.qvalue, other.qvalue) + except (AttributeError, TypeError): + return NotImplemented + + @parse_arg_as_accept + def __lt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s < o) + + @parse_arg_as_accept + def __le__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s <= o) + + @parse_arg_as_accept + def __eq__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s == o) + + @parse_arg_as_accept + def __ge__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s >= o) + + @parse_arg_as_accept + def __gt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s > o) + + @parse_arg_as_accept + def __ne__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s != o) + + @parse_arg_as_accept + def match( + self, + other, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + type_match = ( + self.type_ == other.type_ + if allow_type_wildcard + else ( + self.type_.match(other.type_) + and not self.type_.is_wildcard + and not other.type_.is_wildcard + ) + ) + subtype_match = ( + self.subtype == other.subtype + if allow_subtype_wildcard + else ( + self.subtype.match(other.subtype) + and not self.subtype.is_wildcard + and not other.subtype.is_wildcard + ) + ) + + return type_match and subtype_match + + @classmethod + def parse(cls, raw: str) -> Accept: + invalid = False + mtype = raw.strip() + + try: + media, *raw_params = mtype.split(";") + type_, subtype = media.split("/") + except ValueError: + invalid = True + + if invalid or not type_ or not subtype: + raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + + params = dict( + [ + (key.strip(), value.strip()) + for key, value in (param.split("=", 1) for param in raw_params) + ] + ) + + return cls(mtype, MediaType(type_), MediaType(subtype), **params) + + +class AcceptContainer(list): + def __contains__(self, o: object) -> bool: + return any(item.match(o) for item in self) + + def match( + self, + o: object, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + return any( + item.match( + o, + allow_type_wildcard=allow_type_wildcard, + allow_subtype_wildcard=allow_subtype_wildcard, + ) + for item in self + ) + + +def parse_content_header(value: str) -> Tuple[str, Options]: + """Parse content-type and content-disposition header values. + + E.g. 'form-data; name=upload; filename=\"file.txt\"' to + ('form-data', {'name': 'upload', 'filename': 'file.txt'}) + + Mostly identical to cgi.parse_header and werkzeug.parse_options_header + but runs faster and handles special characters better. Unescapes quotes. + """ + value = _firefox_quote_escape.sub("%22", value) + pos = value.find(";") + if pos == -1: + options: Dict[str, Union[int, str]] = {} + else: + options = { + m.group(1).lower(): m.group(2) or m.group(3).replace("%22", '"') + for m in _param.finditer(value[pos:]) + } + value = value[:pos] + return value.strip().lower(), options + + +# https://tools.ietf.org/html/rfc7230#section-3.2.6 and +# https://tools.ietf.org/html/rfc7239#section-4 +# This regex is for *reversed* strings because that works much faster for +# right-to-left matching than the other way around. Be wary that all things are +# a bit backwards! _rparam matches forwarded pairs alike ";key=value" +_rparam = re.compile(f"(?:{_token}|{_quoted})={_token}\\s*($|[;,])", re.ASCII) + + +def parse_forwarded(headers, config) -> Optional[Options]: + """Parse RFC 7239 Forwarded headers. + The value of `by` or `secret` must match `config.FORWARDED_SECRET` + :return: dict with keys and values, or None if nothing matched + """ + header = headers.getall("forwarded", None) + secret = config.FORWARDED_SECRET + if header is None or not secret: + return None + header = ",".join(header) # Join multiple header lines + if secret not in header: + return None + # Loop over = elements from right to left + sep = pos = None + options: List[Tuple[str, str]] = [] + found = False + for m in _rparam.finditer(header[::-1]): + # Start of new element? (on parser skips and non-semicolon right sep) + if m.start() != pos or sep != ";": + # Was the previous element (from right) what we wanted? + if found: + break + # Clear values and parse as new element + del options[:] + pos = m.end() + val_token, val_quoted, key, sep = m.groups() + key = key.lower()[::-1] + val = (val_token or val_quoted.replace('"\\', '"'))[::-1] + options.append((key, val)) + if key in ("secret", "by") and val == secret: + found = True + # Check if we would return on next round, to avoid useless parse + if found and sep != ";": + break + # If secret was found, return the matching options in left-to-right order + return fwd_normalize(reversed(options)) if found else None + + +def parse_xforwarded(headers, config) -> Optional[Options]: + """Parse traditional proxy headers.""" + real_ip_header = config.REAL_IP_HEADER + proxies_count = config.PROXIES_COUNT + addr = real_ip_header and headers.getone(real_ip_header, None) + if not addr and proxies_count: + assert proxies_count > 0 + try: + # Combine, split and filter multiple headers' entries + forwarded_for = headers.getall(config.FORWARDED_FOR_HEADER) + proxies = [ + p for p in (p.strip() for h in forwarded_for for p in h.split(",")) if p + ] + addr = proxies[-proxies_count] + except (KeyError, IndexError): + pass + # No processing of other headers if no address is found + if not addr: + return None + + def options(): + yield "for", addr + for key, header in ( + ("proto", "x-scheme"), + ("proto", "x-forwarded-proto"), # Overrides X-Scheme if present + ("host", "x-forwarded-host"), + ("port", "x-forwarded-port"), + ("path", "x-forwarded-path"), + ): + yield key, headers.getone(header, None) + + return fwd_normalize(options()) + + +def fwd_normalize(fwd: OptionsIterable) -> Options: + """Normalize and convert values extracted from forwarded headers.""" + ret: Dict[str, Union[int, str]] = {} + for key, val in fwd: + if val is not None: + try: + if key in ("by", "for"): + ret[key] = fwd_normalize_address(val) + elif key in ("host", "proto"): + ret[key] = val.lower() + elif key == "port": + ret[key] = int(val) + elif key == "path": + ret[key] = unquote(val) + else: + ret[key] = val + except ValueError: + pass + return ret + + +def fwd_normalize_address(addr: str) -> str: + """Normalize address fields of proxy headers.""" + if addr == "unknown": + raise ValueError() # omit unknown value identifiers + if addr.startswith("_"): + return addr # do not lower-case obfuscated strings + if _ipv6_re.fullmatch(addr): + addr = f"[{addr}]" # bracket IPv6 + return addr.lower() + + +def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]: + """Split host:port into hostname and port. + :return: None in place of missing elements + """ + m = _host_re.fullmatch(host) + if not m: + return None, None + host, port = m.groups() + return host.lower(), int(port) if port is not None else None + + +_HTTP1_STATUSLINES = [ + b"HTTP/1.1 %d %b\r\n" % (status, STATUS_CODES.get(status, b"UNKNOWN")) + for status in range(1000) +] + + +def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: + """Format a HTTP/1.1 response header.""" + # Note: benchmarks show that here bytes concat is faster than bytearray, + # b"".join() or %-formatting. %timeit any changes you make. + ret = _HTTP1_STATUSLINES[status] + for h in headers: + ret += b"%b: %b\r\n" % h + ret += b"\r\n" + return ret + + +def _sort_accept_value(accept: Accept): + return ( + accept.qvalue, + len(accept.params), + accept.subtype != "*", + accept.type_ != "*", + ) + + +def parse_accept(accept: str) -> AcceptContainer: + """Parse an Accept header and order the acceptable media types in + accorsing to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + media_types = accept.split(",") + accept_list: List[Accept] = [] + + for mtype in media_types: + if not mtype: + continue + + accept_list.append(Accept.parse(mtype)) + + return AcceptContainer(sorted(accept_list, key=_sort_accept_value, reverse=True)) diff --git a/backend/sanic_server/sanic/helpers.py b/backend/sanic_server/sanic/helpers.py new file mode 100644 index 000000000..87d51b53a --- /dev/null +++ b/backend/sanic_server/sanic/helpers.py @@ -0,0 +1,171 @@ +"""Defines basics of HTTP standard.""" + +from importlib import import_module +from inspect import ismodule +from typing import Dict + + +STATUS_CODES: Dict[int, bytes] = { + 100: b"Continue", + 101: b"Switching Protocols", + 102: b"Processing", + 103: b"Early Hints", + 200: b"OK", + 201: b"Created", + 202: b"Accepted", + 203: b"Non-Authoritative Information", + 204: b"No Content", + 205: b"Reset Content", + 206: b"Partial Content", + 207: b"Multi-Status", + 208: b"Already Reported", + 226: b"IM Used", + 300: b"Multiple Choices", + 301: b"Moved Permanently", + 302: b"Found", + 303: b"See Other", + 304: b"Not Modified", + 305: b"Use Proxy", + 307: b"Temporary Redirect", + 308: b"Permanent Redirect", + 400: b"Bad Request", + 401: b"Unauthorized", + 402: b"Payment Required", + 403: b"Forbidden", + 404: b"Not Found", + 405: b"Method Not Allowed", + 406: b"Not Acceptable", + 407: b"Proxy Authentication Required", + 408: b"Request Timeout", + 409: b"Conflict", + 410: b"Gone", + 411: b"Length Required", + 412: b"Precondition Failed", + 413: b"Request Entity Too Large", + 414: b"Request-URI Too Long", + 415: b"Unsupported Media Type", + 416: b"Requested Range Not Satisfiable", + 417: b"Expectation Failed", + 418: b"I'm a teapot", + 422: b"Unprocessable Entity", + 423: b"Locked", + 424: b"Failed Dependency", + 426: b"Upgrade Required", + 428: b"Precondition Required", + 429: b"Too Many Requests", + 431: b"Request Header Fields Too Large", + 451: b"Unavailable For Legal Reasons", + 500: b"Internal Server Error", + 501: b"Not Implemented", + 502: b"Bad Gateway", + 503: b"Service Unavailable", + 504: b"Gateway Timeout", + 505: b"HTTP Version Not Supported", + 506: b"Variant Also Negotiates", + 507: b"Insufficient Storage", + 508: b"Loop Detected", + 510: b"Not Extended", + 511: b"Network Authentication Required", +} + +# According to https://tools.ietf.org/html/rfc2616#section-7.1 +_ENTITY_HEADERS = frozenset( + [ + "allow", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-type", + "expires", + "last-modified", + "extension-header", + ] +) + +# According to https://tools.ietf.org/html/rfc2616#section-13.5.1 +_HOP_BY_HOP_HEADERS = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ] +) + + +def has_message_body(status): + """ + According to the following RFC message body and length SHOULD NOT + be included in responses status 1XX, 204 and 304. + https://tools.ietf.org/html/rfc2616#section-4.4 + https://tools.ietf.org/html/rfc2616#section-4.3 + """ + return status not in (204, 304) and not (100 <= status < 200) + + +def is_entity_header(header): + """Checks if the given header is an Entity Header""" + return header.lower() in _ENTITY_HEADERS + + +def is_hop_by_hop_header(header): + """Checks if the given header is a Hop By Hop header""" + return header.lower() in _HOP_BY_HOP_HEADERS + + +def remove_entity_headers(headers, allowed=("content-location", "expires")): + """ + Removes all the entity headers present in the headers given. + According to RFC 2616 Section 10.3.5, + Content-Location and Expires are allowed as for the + "strong cache validator". + https://tools.ietf.org/html/rfc2616#section-10.3.5 + + returns the headers without the entity headers + """ + allowed = set([h.lower() for h in allowed]) + headers = { + header: value + for header, value in headers.items() + if not is_entity_header(header) or header.lower() in allowed + } + return headers + + +def import_string(module_name, package=None): + """ + import a module or class by string path. + + :module_name: str with path of module or path to import and + instanciate a class + :returns: a module object or one instance from class if + module_name is a valid path to class + + """ + module, klass = module_name.rsplit(".", 1) + module = import_module(module, package=package) + obj = getattr(module, klass) + if ismodule(obj): + return obj + return obj() + + +class Default: + """ + It is used to replace `None` or `object()` as a sentinel + that represents a default value. Sometimes we want to set + a value to `None` so we cannot use `None` to represent the + default value, and `object()` is hard to be typed. + """ + + pass + + +_default = Default() diff --git a/backend/sanic_server/sanic/http.py b/backend/sanic_server/sanic/http.py new file mode 100644 index 000000000..5b1c11310 --- /dev/null +++ b/backend/sanic_server/sanic/http.py @@ -0,0 +1,594 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ..sanic.request import Request + from ..sanic.response import BaseHTTPResponse + +from asyncio import CancelledError, sleep +from enum import Enum + +from ..sanic.compat import Header +from ..sanic.exceptions import (HeaderExpectationFailed, InvalidUsage, + PayloadTooLarge, ServerError, + ServiceUnavailable) +from ..sanic.headers import format_http1_response +from ..sanic.helpers import has_message_body +from ..sanic.log import access_logger, error_logger, logger +from ..sanic.touchup import TouchUpMeta + + +class Stage(Enum): + """ + Enum for representing the stage of the request/response cycle + + | ``IDLE`` Waiting for request + | ``REQUEST`` Request headers being received + | ``HANDLER`` Headers done, handler running + | ``RESPONSE`` Response headers sent, body in progress + | ``FAILED`` Unrecoverable state (error while sending response) + | + """ + + IDLE = 0 # Waiting for request + REQUEST = 1 # Request headers being received + HANDLER = 3 # Headers done, handler running + RESPONSE = 4 # Response headers sent, body in progress + FAILED = 100 # Unrecoverable state (error while sending response) + + +HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" + + +class Http(metaclass=TouchUpMeta): + """ + Internal helper for managing the HTTP request/response cycle + + :raises ServerError: + :raises PayloadTooLarge: + :raises Exception: + :raises InvalidUsage: + :raises HeaderExpectationFailed: + :raises RuntimeError: + :raises ServerError: + :raises ServerError: + :raises InvalidUsage: + :raises InvalidUsage: + :raises InvalidUsage: + :raises PayloadTooLarge: + :raises RuntimeError: + """ + + HEADER_CEILING = 16_384 + HEADER_MAX_SIZE = 0 + + __touchup__ = ( + "http1_request_header", + "http1_response_header", + "read", + ) + __slots__ = [ + "_send", + "_receive_more", + "dispatch", + "recv_buffer", + "protocol", + "expecting_continue", + "stage", + "keep_alive", + "head_only", + "request", + "exception", + "url", + "request_body", + "request_bytes", + "request_bytes_left", + "request_max_size", + "response", + "response_func", + "response_size", + "response_bytes_left", + "upgrade_websocket", + ] + + def __init__(self, protocol): + self._send = protocol.send + self._receive_more = protocol.receive_more + self.recv_buffer = protocol.recv_buffer + self.protocol = protocol + self.keep_alive = True + self.stage: Stage = Stage.IDLE + self.dispatch = self.protocol.app.dispatch + + def init_for_request(self): + """Init/reset all per-request variables.""" + self.exception = None + self.expecting_continue: bool = False + self.head_only = None + self.request_body = None + self.request_bytes = None + self.request_bytes_left = None + self.request_max_size = self.protocol.request_max_size + self.request: Request = None + self.response: BaseHTTPResponse = None + self.upgrade_websocket = False + self.url = None + + def __bool__(self): + """Test if request handling is in progress""" + return self.stage in (Stage.HANDLER, Stage.RESPONSE) + + async def http1(self): + """ + HTTP 1.1 connection handler + """ + # Handle requests while the connection stays reusable + while self.keep_alive and self.stage is Stage.IDLE: + self.init_for_request() + # Wait for incoming bytes (in IDLE stage) + if not self.recv_buffer: + await self._receive_more() + self.stage = Stage.REQUEST + try: + # Receive and handle a request + self.response_func = self.http1_response_header + + await self.http1_request_header() + + self.stage = Stage.HANDLER + self.request.conn_info = self.protocol.conn_info + await self.protocol.request_handler(self.request) + + # Handler finished, response should've been sent + if self.stage is Stage.HANDLER and not self.upgrade_websocket: + raise ServerError("Handler produced no response") + + if self.stage is Stage.RESPONSE: + await self.response.send(end_stream=True) + except CancelledError: + # Write an appropriate response before exiting + if not self.protocol.transport: + logger.info( + f"Request: {self.request.method} {self.request.url} " + "stopped. Transport is closed." + ) + return + e = self.exception or ServiceUnavailable("Cancelled") + self.exception = None + self.keep_alive = False + await self.error_response(e) + except Exception as e: + # Write an error response + await self.error_response(e) + + # Try to consume any remaining request body + if self.request_body: + if self.response and 200 <= self.response.status < 300: + error_logger.error(f"{self.request} body not consumed.") + # Limit the size because the handler may have set it infinite + self.request_max_size = min( + self.request_max_size, self.protocol.request_max_size + ) + try: + async for _ in self: + pass + except PayloadTooLarge: + # We won't read the body and that may cause httpx and + # tests to fail. This little delay allows clients to push + # a small request into network buffers before we close the + # socket, so that they are then able to read the response. + await sleep(0.001) + self.keep_alive = False + + # Clean up to free memory and for the next request + if self.request: + self.request.stream = None + if self.response: + self.response.stream = None + + async def http1_request_header(self): # no cov + """ + Receive and parse request header into self.request. + """ + # Receive until full header is in buffer + buf = self.recv_buffer + pos = 0 + + while True: + pos = buf.find(b"\r\n\r\n", pos) + if pos != -1: + break + + pos = max(0, len(buf) - 3) + if pos >= self.HEADER_MAX_SIZE: + break + + await self._receive_more() + + if pos >= self.HEADER_MAX_SIZE: + raise PayloadTooLarge("Request header exceeds the size limit") + + # Parse header content + try: + head = buf[:pos] + raw_headers = head.decode(errors="surrogateescape") + reqline, *split_headers = raw_headers.split("\r\n") + method, self.url, protocol = reqline.split(" ") + + await self.dispatch( + "http.lifecycle.read_head", + inline=True, + context={"head": bytes(head)}, + ) + + if protocol == "HTTP/1.1": + self.keep_alive = True + elif protocol == "HTTP/1.0": + self.keep_alive = False + else: + raise Exception # Raise a Bad Request on try-except + + self.head_only = method.upper() == "HEAD" + request_body = False + headers = [] + + for name, value in (h.split(":", 1) for h in split_headers): + name, value = h = name.lower(), value.lstrip() + + if name in ("content-length", "transfer-encoding"): + request_body = True + elif name == "connection": + self.keep_alive = value.lower() == "keep-alive" + + headers.append(h) + except Exception: + raise InvalidUsage("Bad Request") + + headers_instance = Header(headers) + self.upgrade_websocket = ( + headers_instance.getone("upgrade", "").lower() == "websocket" + ) + + # Prepare a Request object + request = self.protocol.request_class( + url_bytes=self.url.encode(), + headers=headers_instance, + head=bytes(head), + version=protocol[5:], + method=method, + transport=self.protocol.transport, + app=self.protocol.app, + ) + await self.dispatch( + "http.lifecycle.request", + inline=True, + context={"request": request}, + ) + + # Prepare for request body + self.request_bytes_left = self.request_bytes = 0 + if request_body: + headers = request.headers + expect = headers.getone("expect", None) + + if expect is not None: + if expect.lower() == "100-continue": + self.expecting_continue = True + else: + raise HeaderExpectationFailed(f"Unknown Expect: {expect}") + + if headers.getone("transfer-encoding", None) == "chunked": + self.request_body = "chunked" + pos -= 2 # One CRLF stays in buffer + else: + self.request_body = True + self.request_bytes_left = self.request_bytes = int( + headers["content-length"] + ) + + # Remove header and its trailing CRLF + del buf[: pos + 4] + self.request, request.stream = request, self + self.protocol.state["requests_count"] += 1 + + async def http1_response_header( + self, data: bytes, end_stream: bool + ) -> None: # no cov + res = self.response + + # Compatibility with simple response body + if not data and getattr(res, "body", None): + data, end_stream = res.body, True # type: ignore + + size = len(data) + headers = res.headers + status = res.status + self.response_size = size + + if not isinstance(status, int) or status < 200: + raise RuntimeError(f"Invalid response status {status!r}") + + if not has_message_body(status): + # Header-only response status + self.response_func = None + if ( + data + or not end_stream + or "content-length" in headers + or "transfer-encoding" in headers + ): + data, size, end_stream = b"", 0, True + headers.pop("content-length", None) + headers.pop("transfer-encoding", None) + logger.warning( + f"Message body set in response on {self.request.path}. " + f"A {status} response may only have headers, no body." + ) + elif self.head_only and "content-length" in headers: + self.response_func = None + elif end_stream: + # Non-streaming response (all in one block) + headers["content-length"] = size + self.response_func = None + elif "content-length" in headers: + # Streaming response with size known in advance + self.response_bytes_left = int(headers["content-length"]) - size + self.response_func = self.http1_response_normal + else: + # Length not known, use chunked encoding + headers["transfer-encoding"] = "chunked" + data = b"%x\r\n%b\r\n" % (size, data) if size else b"" + self.response_func = self.http1_response_chunked + + if self.head_only: + # Head request: don't send body + data = b"" + self.response_func = self.head_response_ignored + + headers["connection"] = "keep-alive" if self.keep_alive else "close" + ret = format_http1_response(status, res.processed_headers) + if data: + ret += data + + # Send a 100-continue if expected and not Expectation Failed + if self.expecting_continue: + self.expecting_continue = False + if status != 417: + ret = HTTP_CONTINUE + ret + + # Send response + if self.protocol.access_log: + self.log_response() + + await self._send(ret) + self.stage = Stage.IDLE if end_stream else Stage.RESPONSE + + def head_response_ignored(self, data: bytes, end_stream: bool) -> None: + """ + HEAD response: body data silently ignored. + """ + if end_stream: + self.response_func = None + self.stage = Stage.IDLE + + async def http1_response_chunked( + self, data: bytes, end_stream: bool + ) -> None: + """ + Format a part of response body in chunked encoding. + """ + # Chunked encoding + size = len(data) + if end_stream: + await self._send( + b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) + if size + else b"0\r\n\r\n" + ) + self.response_func = None + self.stage = Stage.IDLE + elif size: + await self._send(b"%x\r\n%b\r\n" % (size, data)) + + async def http1_response_normal( + self, data: bytes, end_stream: bool + ) -> None: + """ + Format / keep track of non-chunked response. + """ + bytes_left = self.response_bytes_left - len(data) + if bytes_left <= 0: + if bytes_left < 0: + raise ServerError("Response was bigger than content-length") + + await self._send(data) + self.response_func = None + self.stage = Stage.IDLE + else: + if end_stream: + raise ServerError("Response was smaller than content-length") + + await self._send(data) + self.response_bytes_left = bytes_left + + async def error_response(self, exception: Exception) -> None: + """ + Handle response when exception encountered + """ + # Disconnect after an error if in any other state than handler + if self.stage is not Stage.HANDLER: + self.keep_alive = False + + # Request failure? Respond but then disconnect + if self.stage is Stage.REQUEST: + self.stage = Stage.HANDLER + + # From request and handler states we can respond, otherwise be silent + if self.stage is Stage.HANDLER: + app = self.protocol.app + + if self.request is None: + self.create_empty_request() + + await app.handle_exception(self.request, exception) + + def create_empty_request(self) -> None: + """ + Current error handling code needs a request object that won't exist + if an error occurred during before a request was received. Create a + bogus response for error handling use. + """ + + # FIXME: Avoid this by refactoring error handling and response code + self.request = self.protocol.request_class( + url_bytes=self.url.encode() if self.url else b"*", + headers=Header({}), + version="1.1", + method="NONE", + transport=self.protocol.transport, + app=self.protocol.app, + ) + self.request.stream = self + + def log_response(self) -> None: + """ + Helper method provided to enable the logging of responses in case if + the :attr:`HttpProtocol.access_log` is enabled. + """ + req, res = self.request, self.response + extra = { + "status": getattr(res, "status", 0), + "byte": getattr( + self, "response_bytes_left", getattr(self, "response_size", -1) + ), + "host": "UNKNOWN", + "request": "nil", + } + if req is not None: + if req.remote_addr or req.ip: + extra["host"] = f"{req.remote_addr or req.ip}:{req.port}" + extra["request"] = f"{req.method} {req.url}" + access_logger.info("", extra=extra) + + # Request methods + + async def __aiter__(self): + """ + Async iterate over request body. + """ + while self.request_body: + data = await self.read() + + if data: + yield data + + async def read(self) -> Optional[bytes]: # no cov + """ + Read some bytes of request body. + """ + + # Send a 100-continue if needed + if self.expecting_continue: + self.expecting_continue = False + await self._send(HTTP_CONTINUE) + + # Receive request body chunk + buf = self.recv_buffer + if self.request_bytes_left == 0 and self.request_body == "chunked": + # Process a chunk header: \r\n[;]\r\n + while True: + pos = buf.find(b"\r\n", 3) + + if pos != -1: + break + + if len(buf) > 64: + self.keep_alive = False + raise InvalidUsage("Bad chunked encoding") + + await self._receive_more() + + try: + size = int(buf[2:pos].split(b";", 1)[0].decode(), 16) + except Exception: + self.keep_alive = False + raise InvalidUsage("Bad chunked encoding") + + if size <= 0: + self.request_body = None + + if size < 0: + self.keep_alive = False + raise InvalidUsage("Bad chunked encoding") + + # Consume CRLF, chunk size 0 and the two CRLF that follow + pos += 4 + # Might need to wait for the final CRLF + while len(buf) < pos: + await self._receive_more() + del buf[:pos] + return None + + # Remove CRLF, chunk size and the CRLF that follows + del buf[: pos + 2] + + self.request_bytes_left = size + self.request_bytes += size + + # Request size limit + if self.request_bytes > self.request_max_size: + self.keep_alive = False + raise PayloadTooLarge("Request body exceeds the size limit") + + # End of request body? + if not self.request_bytes_left: + self.request_body = None + return None + + # At this point we are good to read/return up to request_bytes_left + if not buf: + await self._receive_more() + + data = bytes(buf[: self.request_bytes_left]) + size = len(data) + + del buf[:size] + + self.request_bytes_left -= size + + await self.dispatch( + "http.lifecycle.read_body", + inline=True, + context={"body": data}, + ) + + return data + + # Response methods + + def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse: + """ + Initiate new streaming response. + + Nothing is sent until the first send() call on the returned object, and + calling this function multiple times will just alter the response to be + given. + """ + if self.stage is not Stage.HANDLER: + self.stage = Stage.FAILED + raise RuntimeError("Response already started") + + self.response, response.stream = response, self + return response + + @property + def send(self): + return self.response_func + + @classmethod + def set_header_max_size(cls, *sizes: int): + cls.HEADER_MAX_SIZE = min( + *sizes, + cls.HEADER_CEILING, + ) diff --git a/backend/sanic_server/sanic/log.py b/backend/sanic_server/sanic/log.py new file mode 100644 index 000000000..2e3608359 --- /dev/null +++ b/backend/sanic_server/sanic/log.py @@ -0,0 +1,69 @@ +import logging +import sys + + +LOGGING_CONFIG_DEFAULTS = dict( + version=1, + disable_existing_loggers=False, + loggers={ + "sanic.root": {"level": "INFO", "handlers": ["console"]}, + "sanic.error": { + "level": "INFO", + "handlers": ["error_console"], + "propagate": True, + "qualname": "sanic.error", + }, + "sanic.access": { + "level": "INFO", + "handlers": ["access_console"], + "propagate": True, + "qualname": "sanic.access", + }, + }, + handlers={ + "console": { + "class": "logging.StreamHandler", + "formatter": "generic", + "stream": sys.stdout, + }, + "error_console": { + "class": "logging.StreamHandler", + "formatter": "generic", + "stream": sys.stderr, + }, + "access_console": { + "class": "logging.StreamHandler", + "formatter": "access", + "stream": sys.stdout, + }, + }, + formatters={ + "generic": { + "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", + "datefmt": "[%Y-%m-%d %H:%M:%S %z]", + "class": "logging.Formatter", + }, + "access": { + "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + + "%(request)s %(message)s %(status)d %(byte)d", + "datefmt": "[%Y-%m-%d %H:%M:%S %z]", + "class": "logging.Formatter", + }, + }, +) + + +logger = logging.getLogger("sanic.root") +""" +General Sanic logger +""" + +error_logger = logging.getLogger("sanic.error") +""" +Logger used by Sanic for error logging +""" + +access_logger = logging.getLogger("sanic.access") +""" +Logger used by Sanic for access logging +""" diff --git a/backend/sanic_server/sanic/mixins/__init__.py b/backend/sanic_server/sanic/mixins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic/mixins/exceptions.py b/backend/sanic_server/sanic/mixins/exceptions.py new file mode 100644 index 000000000..c99d95cc7 --- /dev/null +++ b/backend/sanic_server/sanic/mixins/exceptions.py @@ -0,0 +1,39 @@ +from typing import Set + +from ...sanic.models.futures import FutureException + + +class ExceptionMixin: + def __init__(self, *args, **kwargs) -> None: + self._future_exceptions: Set[FutureException] = set() + + def _apply_exception_handler(self, handler: FutureException): + raise NotImplementedError # noqa + + def exception(self, *exceptions, apply=True): + """ + This method enables the process of creating a global exception + handler for the current blueprint under question. + + :param args: List of Python exceptions to be caught by the handler + :param kwargs: Additional optional arguments to be passed to the + exception handler + + :return a decorated method to handle global exceptions for any + route registered under this blueprint. + """ + + def decorator(handler): + nonlocal apply + nonlocal exceptions + + if isinstance(exceptions[0], list): + exceptions = tuple(*exceptions) + + future_exception = FutureException(handler, exceptions) + self._future_exceptions.add(future_exception) + if apply: + self._apply_exception_handler(future_exception) + return handler + + return decorator diff --git a/backend/sanic_server/sanic/mixins/listeners.py b/backend/sanic_server/sanic/mixins/listeners.py new file mode 100644 index 000000000..bd18c4735 --- /dev/null +++ b/backend/sanic_server/sanic/mixins/listeners.py @@ -0,0 +1,81 @@ +from enum import Enum, auto +from functools import partial +from typing import List, Optional, Union + +from ...sanic.models.futures import FutureListener +from ...sanic.models.handler_types import ListenerType + + +class ListenerEvent(str, Enum): + def _generate_next_value_(name: str, *args) -> str: # type: ignore + return name.lower() + + BEFORE_SERVER_START = "server.init.before" + AFTER_SERVER_START = "server.init.after" + BEFORE_SERVER_STOP = "server.shutdown.before" + AFTER_SERVER_STOP = "server.shutdown.after" + MAIN_PROCESS_START = auto() + MAIN_PROCESS_STOP = auto() + + +class ListenerMixin: + def __init__(self, *args, **kwargs) -> None: + self._future_listeners: List[FutureListener] = [] + + def _apply_listener(self, listener: FutureListener): + raise NotImplementedError # noqa + + def listener( + self, + listener_or_event: Union[ListenerType, str], + event_or_none: Optional[str] = None, + apply: bool = True, + ): + """ + Create a listener from a decorated function. + + To be used as a decorator: + + .. code-block:: python + + @bp.listener("before_server_start") + async def before_server_start(app, loop): + ... + + `See user guide re: listeners + `__ + + :param event: event to listen to + """ + + def register_listener(listener, event): + nonlocal apply + + future_listener = FutureListener(listener, event) + self._future_listeners.append(future_listener) + if apply: + self._apply_listener(future_listener) + return listener + + if callable(listener_or_event): + return register_listener(listener_or_event, event_or_none) + else: + return partial(register_listener, event=listener_or_event) + + def main_process_start(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "main_process_start") + + def main_process_stop(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "main_process_stop") + + def before_server_start(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "before_server_start") + + def after_server_start(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "after_server_start") + + def before_server_stop(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "before_server_stop") + + def after_server_stop(self, listener: ListenerType) -> ListenerType: + return self.listener(listener, "after_server_stop") diff --git a/backend/sanic_server/sanic/mixins/middleware.py b/backend/sanic_server/sanic/mixins/middleware.py new file mode 100644 index 000000000..22d3dfd72 --- /dev/null +++ b/backend/sanic_server/sanic/mixins/middleware.py @@ -0,0 +1,52 @@ +from functools import partial +from typing import List + +from ...sanic.models.futures import FutureMiddleware + + +class MiddlewareMixin: + def __init__(self, *args, **kwargs) -> None: + self._future_middleware: List[FutureMiddleware] = [] + + def _apply_middleware(self, middleware: FutureMiddleware): + raise NotImplementedError # noqa + + def middleware(self, middleware_or_request, attach_to="request", apply=True): + """ + Decorate and register middleware to be called before a request. + Can either be called as *@app.middleware* or + *@app.middleware('request')* + + `See user guide re: middleware + `__ + + :param: middleware_or_request: Optional parameter to use for + identifying which type of middleware is being registered. + """ + + def register_middleware(middleware, attach_to="request"): + nonlocal apply + + future_middleware = FutureMiddleware(middleware, attach_to) + self._future_middleware.append(future_middleware) + if apply: + self._apply_middleware(future_middleware) + return middleware + + # Detect which way this was called, @middleware or @middleware('AT') + if callable(middleware_or_request): + return register_middleware(middleware_or_request, attach_to=attach_to) + else: + return partial(register_middleware, attach_to=middleware_or_request) + + def on_request(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "request") + else: + return partial(self.middleware, attach_to="request") + + def on_response(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "response") + else: + return partial(self.middleware, attach_to="response") diff --git a/backend/sanic_server/sanic/mixins/routes.py b/backend/sanic_server/sanic/mixins/routes.py new file mode 100644 index 000000000..cdfea00a3 --- /dev/null +++ b/backend/sanic_server/sanic/mixins/routes.py @@ -0,0 +1,947 @@ +from ast import NodeVisitor, Return, parse +from functools import partial, wraps +from inspect import getsource, signature +from mimetypes import guess_type +from os import path +from pathlib import PurePath +from re import sub +from textwrap import dedent +from time import gmtime, strftime +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union +from urllib.parse import unquote + +from ...sanic_routing.route import Route # type: ignore + +from ...sanic.compat import stat_async +from ...sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS +from ...sanic.errorpages import RESPONSE_MAPPING +from ...sanic.exceptions import ( + ContentRangeError, + FileNotFound, + HeaderNotFound, + InvalidUsage, +) +from ...sanic.handlers import ContentRangeHandler +from ...sanic.log import error_logger +from ...sanic.models.futures import FutureRoute, FutureStatic +from ...sanic.models.handler_types import RouteHandler +from ...sanic.response import HTTPResponse, file, file_stream +from ...sanic.views import CompositionView + + +RouteWrapper = Callable[[RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]]] + + +class RouteMixin: + name: str + + def __init__(self, *args, **kwargs) -> None: + self._future_routes: Set[FutureRoute] = set() + self._future_statics: Set[FutureStatic] = set() + self.strict_slashes: Optional[bool] = False + + def _apply_route(self, route: FutureRoute) -> List[Route]: + raise NotImplementedError # noqa + + def _apply_static(self, static: FutureStatic) -> Route: + raise NotImplementedError # noqa + + def route( + self, + uri: str, + methods: Optional[Iterable[str]] = None, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + stream: bool = False, + version: Optional[Union[int, str, float]] = None, + name: Optional[str] = None, + ignore_body: bool = False, + apply: bool = True, + subprotocols: Optional[List[str]] = None, + websocket: bool = False, + unquote: bool = False, + static: bool = False, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Decorate a function to be registered as a route + + :param uri: path of the URL + :param methods: list or tuple of methods allowed + :param host: the host, if required + :param strict_slashes: whether to apply strict slashes to the route + :param stream: whether to allow the request to stream its body + :param version: route specific versioning + :param name: user defined route name for url_for + :param ignore_body: whether the handler should ignore request + body (eg. GET requests) + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: tuple of routes, decorated function + """ + + # Fix case where the user did not prefix the URL with a / + # and will probably get confused as to why it's not working + if not uri.startswith("/") and (uri or hasattr(self, "router")): + uri = "/" + uri + + if strict_slashes is None: + strict_slashes = self.strict_slashes + + if not methods and not websocket: + methods = frozenset({"GET"}) + + def decorator(handler): + nonlocal uri + nonlocal methods + nonlocal host + nonlocal strict_slashes + nonlocal stream + nonlocal version + nonlocal name + nonlocal ignore_body + nonlocal subprotocols + nonlocal websocket + nonlocal static + nonlocal version_prefix + nonlocal error_format + + if isinstance(handler, tuple): + # if a handler fn is already wrapped in a route, the handler + # variable will be a tuple of (existing routes, handler fn) + _, handler = handler + + name = self._generate_name(name, handler) + + if isinstance(host, str): + host = frozenset([host]) + elif host and not isinstance(host, frozenset): + try: + host = frozenset(host) + except TypeError: + raise ValueError( + "Expected either string or Iterable of host strings, " + "not %s" % host + ) + if isinstance(subprotocols, list): + # Ordered subprotocols, maintain order + subprotocols = tuple(subprotocols) + elif isinstance(subprotocols, set): + # subprotocol is unordered, keep it unordered + subprotocols = frozenset(subprotocols) + + if not error_format or error_format == "auto": + error_format = self._determine_error_format(handler) + + route = FutureRoute( + handler, + uri, + None if websocket else frozenset([x.upper() for x in methods]), + host, + strict_slashes, + stream, + version, + name, + ignore_body, + websocket, + subprotocols, + unquote, + static, + version_prefix, + error_format, + ) + + self._future_routes.add(route) + + args = list(signature(handler).parameters.keys()) + if websocket and len(args) < 2: + handler_name = handler.__name__ + + raise ValueError( + f"Required parameter `request` and/or `ws` missing " + f"in the {handler_name}() route?" + ) + elif not args: + handler_name = handler.__name__ + + raise ValueError( + f"Required parameter `request` missing " + f"in the {handler_name}() route?" + ) + + if not websocket and stream: + handler.is_stream = stream + + if apply: + self._apply_route(route) + + if static: + return route, handler + return handler + + return decorator + + def add_route( + self, + handler: RouteHandler, + uri: str, + methods: Iterable[str] = frozenset({"GET"}), + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + stream: bool = False, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteHandler: + """A helper method to register class instance or + functions as a handler to the application url + routes. + + :param handler: function or class instance + :param uri: path of the URL + :param methods: list or tuple of methods allowed, these are overridden + if using a HTTPMethodView + :param host: + :param strict_slashes: + :param version: + :param name: user defined route name for url_for + :param stream: boolean specifying if the handler is a stream handler + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: function or class instance + """ + # Handle HTTPMethodView differently + if hasattr(handler, "view_class"): + methods = set() + + for method in HTTP_METHODS: + view_class = getattr(handler, "view_class") + _handler = getattr(view_class, method.lower(), None) + if _handler: + methods.add(method) + if hasattr(_handler, "is_stream"): + stream = True + + # handle composition view differently + if isinstance(handler, CompositionView): + methods = handler.handlers.keys() + for _handler in handler.handlers.values(): + if hasattr(_handler, "is_stream"): + stream = True + break + + if strict_slashes is None: + strict_slashes = self.strict_slashes + + self.route( + uri=uri, + methods=methods, + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + version_prefix=version_prefix, + error_format=error_format, + )(handler) + return handler + + # Shorthand method decorators + def get( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + ignore_body: bool = True, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **GET** *HTTP* method + + :param uri: URL to be tagged to **GET** method of *HTTP* + :param host: Host IP or FQDN for the service to use + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :param version: API Version + :param name: Unique name that can be used to identify the Route + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"GET"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ignore_body=ignore_body, + version_prefix=version_prefix, + error_format=error_format, + ) + + def post( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + stream: bool = False, + version: Optional[int] = None, + name: Optional[str] = None, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **POST** *HTTP* method + + :param uri: URL to be tagged to **POST** method of *HTTP* + :param host: Host IP or FQDN for the service to use + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :param version: API Version + :param name: Unique name that can be used to identify the Route + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"POST"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + version_prefix=version_prefix, + error_format=error_format, + ) + + def put( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + stream: bool = False, + version: Optional[int] = None, + name: Optional[str] = None, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **PUT** *HTTP* method + + :param uri: URL to be tagged to **PUT** method of *HTTP* + :param host: Host IP or FQDN for the service to use + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :param version: API Version + :param name: Unique name that can be used to identify the Route + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"PUT"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + version_prefix=version_prefix, + error_format=error_format, + ) + + def head( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + ignore_body: bool = True, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **HEAD** *HTTP* method + + :param uri: URL to be tagged to **HEAD** method of *HTTP* + :type uri: str + :param host: Host IP or FQDN for the service to use + :type host: Optional[str], optional + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :type strict_slashes: Optional[bool], optional + :param version: API Version + :type version: Optional[str], optional + :param name: Unique name that can be used to identify the Route + :type name: Optional[str], optional + :param ignore_body: whether the handler should ignore request + body (eg. GET requests), defaults to True + :type ignore_body: bool, optional + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"HEAD"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ignore_body=ignore_body, + version_prefix=version_prefix, + error_format=error_format, + ) + + def options( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + ignore_body: bool = True, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **OPTIONS** *HTTP* method + + :param uri: URL to be tagged to **OPTIONS** method of *HTTP* + :type uri: str + :param host: Host IP or FQDN for the service to use + :type host: Optional[str], optional + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :type strict_slashes: Optional[bool], optional + :param version: API Version + :type version: Optional[str], optional + :param name: Unique name that can be used to identify the Route + :type name: Optional[str], optional + :param ignore_body: whether the handler should ignore request + body (eg. GET requests), defaults to True + :type ignore_body: bool, optional + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"OPTIONS"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ignore_body=ignore_body, + version_prefix=version_prefix, + error_format=error_format, + ) + + def patch( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + stream=False, + version: Optional[int] = None, + name: Optional[str] = None, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **PATCH** *HTTP* method + + :param uri: URL to be tagged to **PATCH** method of *HTTP* + :type uri: str + :param host: Host IP or FQDN for the service to use + :type host: Optional[str], optional + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :type strict_slashes: Optional[bool], optional + :param stream: whether to allow the request to stream its body + :type stream: Optional[bool], optional + :param version: API Version + :type version: Optional[str], optional + :param name: Unique name that can be used to identify the Route + :type name: Optional[str], optional + :param ignore_body: whether the handler should ignore request + body (eg. GET requests), defaults to True + :type ignore_body: bool, optional + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"PATCH"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + version_prefix=version_prefix, + error_format=error_format, + ) + + def delete( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + ignore_body: bool = True, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> RouteWrapper: + """ + Add an API URL under the **DELETE** *HTTP* method + + :param uri: URL to be tagged to **DELETE** method of *HTTP* + :param host: Host IP or FQDN for the service to use + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :param version: API Version + :param name: Unique name that can be used to identify the Route + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Object decorated with :func:`route` method + """ + return self.route( + uri, + methods=frozenset({"DELETE"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ignore_body=ignore_body, + version_prefix=version_prefix, + error_format=error_format, + ) + + def websocket( + self, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + subprotocols: Optional[List[str]] = None, + version: Optional[int] = None, + name: Optional[str] = None, + apply: bool = True, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ): + """ + Decorate a function to be registered as a websocket route + + :param uri: path of the URL + :param host: Host IP or FQDN details + :param strict_slashes: If the API endpoint needs to terminate + with a "/" or not + :param subprotocols: optional list of str with supported subprotocols + :param name: A unique name assigned to the URL so that it can + be used with :func:`url_for` + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: tuple of routes, decorated function + """ + return self.route( + uri=uri, + host=host, + methods=None, + strict_slashes=strict_slashes, + version=version, + name=name, + apply=apply, + subprotocols=subprotocols, + websocket=True, + version_prefix=version_prefix, + error_format=error_format, + ) + + def add_websocket_route( + self, + handler, + uri: str, + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + subprotocols=None, + version: Optional[int] = None, + name: Optional[str] = None, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ): + """ + A helper method to register a function as a websocket route. + + :param handler: a callable function or instance of a class + that can handle the websocket request + :param host: Host IP or FQDN details + :param uri: URL path that will be mapped to the websocket + handler + handler + :param strict_slashes: If the API endpoint needs to terminate + with a "/" or not + :param subprotocols: Subprotocols to be used with websocket + handshake + :param name: A unique name assigned to the URL so that it can + be used with :func:`url_for` + :param version_prefix: URL path that should be before the version + value; default: ``/v`` + :return: Objected decorated by :func:`websocket` + """ + return self.websocket( + uri=uri, + host=host, + strict_slashes=strict_slashes, + subprotocols=subprotocols, + version=version, + name=name, + version_prefix=version_prefix, + error_format=error_format, + )(handler) + + def static( + self, + uri, + file_or_directory: Union[str, bytes, PurePath], + pattern=r"/?.+", + use_modified_since=True, + use_content_range=False, + stream_large_files=False, + name="static", + host=None, + strict_slashes=None, + content_type=None, + apply=True, + resource_type=None, + ): + """ + Register a root to serve files from. The input can either be a + file or a directory. This method will enable an easy and simple way + to setup the :class:`Route` necessary to serve the static files. + + :param uri: URL path to be used for serving static content + :param file_or_directory: Path for the Static file/directory with + static files + :param pattern: Regex Pattern identifying the valid static files + :param use_modified_since: If true, send file modified time, and return + not modified if the browser's matches the server's + :param use_content_range: If true, process header for range requests + and sends the file part that is requested + :param stream_large_files: If true, use the + :func:`StreamingHTTPResponse.file_stream` handler rather + than the :func:`HTTPResponse.file` handler to send the file. + If this is an integer, this represents the threshold size to + switch to :func:`StreamingHTTPResponse.file_stream` + :param name: user defined name used for url_for + :param host: Host IP or FQDN for the service to use + :param strict_slashes: Instruct :class:`Sanic` to check if the request + URLs need to terminate with a */* + :param content_type: user defined content type for header + :return: routes registered on the router + :rtype: List[sanic.router.Route] + """ + + name = self._generate_name(name) + + if strict_slashes is None and self.strict_slashes is not None: + strict_slashes = self.strict_slashes + + if not isinstance(file_or_directory, (str, bytes, PurePath)): + raise ValueError( + f"Static route must be a valid path, not {file_or_directory}" + ) + + static = FutureStatic( + uri, + file_or_directory, + pattern, + use_modified_since, + use_content_range, + stream_large_files, + name, + host, + strict_slashes, + content_type, + resource_type, + ) + self._future_statics.add(static) + + if apply: + self._apply_static(static) + + def _generate_name(self, *objects) -> str: + name = None + + for obj in objects: + if obj: + if isinstance(obj, str): + name = obj + break + + try: + name = obj.name + except AttributeError: + try: + name = obj.__name__ + except AttributeError: + continue + else: + break + + if not name: # noqa + raise ValueError("Could not generate a name for handler") + + if not name.startswith(f"{self.name}."): + name = f"{self.name}.{name}" + + return name + + async def _static_request_handler( + self, + file_or_directory, + use_modified_since, + use_content_range, + stream_large_files, + request, + content_type=None, + __file_uri__=None, + ): + # Using this to determine if the URL is trying to break out of the path + # served. os.path.realpath seems to be very slow + if __file_uri__ and "../" in __file_uri__: + raise InvalidUsage("Invalid URL") + # Merge served directory and requested file if provided + # Strip all / that in the beginning of the URL to help prevent python + # from herping a derp and treating the uri as an absolute path + root_path = file_path = file_or_directory + if __file_uri__: + file_path = path.join(file_or_directory, sub("^[/]*", "", __file_uri__)) + + # URL decode the path sent by the browser otherwise we won't be able to + # match filenames which got encoded (filenames with spaces etc) + file_path = path.abspath(unquote(file_path)) + if not file_path.startswith(path.abspath(unquote(root_path))): + error_logger.exception( + f"File not found: path={file_or_directory}, " + f"relative_url={__file_uri__}" + ) + raise FileNotFound( + "File not found", + path=file_or_directory, + relative_url=__file_uri__, + ) + try: + headers = {} + # Check if the client has been sent this file before + # and it has not been modified since + stats = None + if use_modified_since: + stats = await stat_async(file_path) + modified_since = strftime( + "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) + ) + if request.headers.getone("if-modified-since", None) == modified_since: + return HTTPResponse(status=304) + headers["Last-Modified"] = modified_since + _range = None + if use_content_range: + _range = None + if not stats: + stats = await stat_async(file_path) + headers["Accept-Ranges"] = "bytes" + headers["Content-Length"] = str(stats.st_size) + if request.method != "HEAD": + try: + _range = ContentRangeHandler(request, stats) + except HeaderNotFound: + pass + else: + del headers["Content-Length"] + for key, value in _range.headers.items(): + headers[key] = value + + if "content-type" not in headers: + content_type = ( + content_type + or guess_type(file_path)[0] + or DEFAULT_HTTP_CONTENT_TYPE + ) + + if "charset=" not in content_type and ( + content_type.startswith("text/") + or content_type == "application/javascript" + ): + content_type += "; charset=utf-8" + + headers["Content-Type"] = content_type + + if request.method == "HEAD": + return HTTPResponse(headers=headers) + else: + if stream_large_files: + if type(stream_large_files) == int: + threshold = stream_large_files + else: + threshold = 1024 * 1024 + + if not stats: + stats = await stat_async(file_path) + if stats.st_size >= threshold: + return await file_stream( + file_path, headers=headers, _range=_range + ) + return await file(file_path, headers=headers, _range=_range) + except ContentRangeError: + raise + except FileNotFoundError: + raise FileNotFound( + "File not found", + path=file_or_directory, + relative_url=__file_uri__, + ) + except Exception: + error_logger.exception( + f"Exception in static request handler: " + f"path={file_or_directory}, " + f"relative_url={__file_uri__}" + ) + raise + + def _register_static( + self, + static: FutureStatic, + ): + # TODO: Though sanic is not a file server, I feel like we should + # at least make a good effort here. Modified-since is nice, but + # we could also look into etags, expires, and caching + """ + Register a static directory handler with Sanic by adding a route to the + router and registering a handler. + + :param app: Sanic + :param file_or_directory: File or directory path to serve from + :type file_or_directory: Union[str,bytes,Path] + :param uri: URL to serve from + :type uri: str + :param pattern: regular expression used to match files in the URL + :param use_modified_since: If true, send file modified time, and return + not modified if the browser's matches the + server's + :param use_content_range: If true, process header for range requests + and sends the file part that is requested + :param stream_large_files: If true, use the file_stream() handler + rather than the file() handler to send the file + If this is an integer, this represents the + threshold size to switch to file_stream() + :param name: user defined name used for url_for + :type name: str + :param content_type: user defined content type for header + :return: registered static routes + :rtype: List[sanic.router.Route] + """ + + if isinstance(static.file_or_directory, bytes): + file_or_directory = static.file_or_directory.decode("utf-8") + elif isinstance(static.file_or_directory, PurePath): + file_or_directory = str(static.file_or_directory) + elif not isinstance(static.file_or_directory, str): + raise ValueError("Invalid file path string.") + else: + file_or_directory = static.file_or_directory + + uri = static.uri + name = static.name + # If we're not trying to match a file directly, + # serve from the folder + if not static.resource_type: + if not path.isfile(file_or_directory): + uri += "/<__file_uri__:path>" + elif static.resource_type == "dir": + if path.isfile(file_or_directory): + raise TypeError( + "Resource type improperly identified as directory. " + f"'{file_or_directory}'" + ) + uri += "/<__file_uri__:path>" + elif static.resource_type == "file" and not path.isfile(file_or_directory): + raise TypeError( + "Resource type improperly identified as file. " f"'{file_or_directory}'" + ) + elif static.resource_type != "file": + raise ValueError("The resource_type should be set to 'file' or 'dir'") + + # special prefix for static files + # if not static.name.startswith("_static_"): + # name = f"_static_{static.name}" + + _handler = wraps(self._static_request_handler)( + partial( + self._static_request_handler, + file_or_directory, + static.use_modified_since, + static.use_content_range, + static.stream_large_files, + content_type=static.content_type, + ) + ) + + route, _ = self.route( # type: ignore + uri=uri, + methods=["GET", "HEAD"], + name=name, + host=static.host, + strict_slashes=static.strict_slashes, + static=True, + )(_handler) + + return route + + def _determine_error_format(self, handler) -> Optional[str]: + if not isinstance(handler, CompositionView): + try: + src = dedent(getsource(handler)) + tree = parse(src) + http_response_types = self._get_response_types(tree) + + if len(http_response_types) == 1: + return next(iter(http_response_types)) + except (OSError, TypeError): + ... + + return None + + def _get_response_types(self, node): + types = set() + + class HttpResponseVisitor(NodeVisitor): + def visit_Return(self, node: Return) -> Any: + nonlocal types + + try: + checks = [node.value.func.id] # type: ignore + if node.value.keywords: # type: ignore + checks += [ + k.value + for k in node.value.keywords # type: ignore + if k.arg == "content_type" + ] + + for check in checks: + if check in RESPONSE_MAPPING: + types.add(RESPONSE_MAPPING[check]) + except AttributeError: + ... + + HttpResponseVisitor().visit(node) + + return types diff --git a/backend/sanic_server/sanic/mixins/signals.py b/backend/sanic_server/sanic/mixins/signals.py new file mode 100644 index 000000000..b154c5e94 --- /dev/null +++ b/backend/sanic_server/sanic/mixins/signals.py @@ -0,0 +1,75 @@ +from typing import Any, Callable, Dict, Optional, Set + +from ...sanic.models.futures import FutureSignal +from ...sanic.models.handler_types import SignalHandler +from ...sanic.signals import Signal + + +class HashableDict(dict): + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +class SignalMixin: + def __init__(self, *args, **kwargs) -> None: + self._future_signals: Set[FutureSignal] = set() + + def _apply_signal(self, signal: FutureSignal) -> Signal: + raise NotImplementedError # noqa + + def signal( + self, + event: str, + *, + apply: bool = True, + condition: Dict[str, Any] = None, + ) -> Callable[[SignalHandler], SignalHandler]: + """ + For creating a signal handler, used similar to a route handler: + + .. code-block:: python + + @app.signal("foo.bar.") + async def signal_handler(thing, **kwargs): + print(f"[signal_handler] {thing=}", kwargs) + + :param event: Representation of the event in ``one.two.three`` form + :type event: str + :param apply: For lazy evaluation, defaults to True + :type apply: bool, optional + :param condition: For use with the ``condition`` argument in dispatch + filtering, defaults to None + :type condition: Dict[str, Any], optional + """ + + def decorator(handler: SignalHandler): + nonlocal event + nonlocal apply + + future_signal = FutureSignal(handler, event, HashableDict(condition or {})) + self._future_signals.add(future_signal) + + if apply: + self._apply_signal(future_signal) + + return handler + + return decorator + + def add_signal( + self, + handler: Optional[Callable[..., Any]], + event: str, + condition: Dict[str, Any] = None, + ): + if not handler: + + async def noop(): + ... + + handler = noop + self.signal(event=event, condition=condition)(handler) + return handler + + def event(self, event: str): + raise NotImplementedError diff --git a/backend/sanic_server/sanic/models/__init__.py b/backend/sanic_server/sanic/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic/models/asgi.py b/backend/sanic_server/sanic/models/asgi.py new file mode 100644 index 000000000..7abdcd85c --- /dev/null +++ b/backend/sanic_server/sanic/models/asgi.py @@ -0,0 +1,93 @@ +import asyncio + +from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union + +from ...sanic.exceptions import InvalidUsage +from ...sanic.server.websockets.connection import WebSocketConnection + + +ASGIScope = MutableMapping[str, Any] +ASGIMessage = MutableMapping[str, Any] +ASGISend = Callable[[ASGIMessage], Awaitable[None]] +ASGIReceive = Callable[[], Awaitable[ASGIMessage]] + + +class MockProtocol: + def __init__(self, transport: "MockTransport", loop): + self.transport = transport + self._not_paused = asyncio.Event(loop=loop) + self._not_paused.set() + self._complete = asyncio.Event(loop=loop) + + def pause_writing(self) -> None: + self._not_paused.clear() + + def resume_writing(self) -> None: + self._not_paused.set() + + async def complete(self) -> None: + self._not_paused.set() + await self.transport.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @property + def is_complete(self) -> bool: + return self._complete.is_set() + + async def push_data(self, data: bytes) -> None: + if not self.is_complete: + await self.transport.send( + {"type": "http.response.body", "body": data, "more_body": True} + ) + + async def drain(self) -> None: + await self._not_paused.wait() + + +class MockTransport: + _protocol: Optional[MockProtocol] + + def __init__(self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> None: + self.scope = scope + self._receive = receive + self._send = send + self._protocol = None + self.loop = None + + def get_protocol(self) -> MockProtocol: + if not self._protocol: + self._protocol = MockProtocol(self, self.loop) + return self._protocol + + def get_extra_info(self, info: str) -> Union[str, bool, None]: + if info == "peername": + return self.scope.get("client") + elif info == "sslcontext": + return self.scope.get("scheme") in ["https", "wss"] + return None + + def get_websocket_connection(self) -> WebSocketConnection: + try: + return self._websocket_connection + except AttributeError: + raise InvalidUsage("Improper websocket connection.") + + def create_websocket_connection( + self, send: ASGISend, receive: ASGIReceive + ) -> WebSocketConnection: + self._websocket_connection = WebSocketConnection( + send, receive, self.scope.get("subprotocols", []) + ) + return self._websocket_connection + + def add_task(self) -> None: + raise NotImplementedError + + async def send(self, data) -> None: + # TODO: + # - Validation on data and that it is formatted properly and is valid + await self._send(data) + + async def receive(self) -> ASGIMessage: + return await self._receive() diff --git a/backend/sanic_server/sanic/models/futures.py b/backend/sanic_server/sanic/models/futures.py new file mode 100644 index 000000000..58ca030ec --- /dev/null +++ b/backend/sanic_server/sanic/models/futures.py @@ -0,0 +1,62 @@ +from pathlib import PurePath +from typing import Dict, Iterable, List, NamedTuple, Optional, Union + +from ...sanic.models.handler_types import ( + ErrorMiddlewareType, + ListenerType, + MiddlewareType, + SignalHandler, +) + + +class FutureRoute(NamedTuple): + handler: str + uri: str + methods: Optional[Iterable[str]] + host: str + strict_slashes: bool + stream: bool + version: Optional[int] + name: str + ignore_body: bool + websocket: bool + subprotocols: Optional[List[str]] + unquote: bool + static: bool + version_prefix: str + error_format: Optional[str] + + +class FutureListener(NamedTuple): + listener: ListenerType + event: str + + +class FutureMiddleware(NamedTuple): + middleware: MiddlewareType + attach_to: str + + +class FutureException(NamedTuple): + handler: ErrorMiddlewareType + exceptions: List[BaseException] + + +class FutureStatic(NamedTuple): + uri: str + file_or_directory: Union[str, bytes, PurePath] + pattern: str + use_modified_since: bool + use_content_range: bool + stream_large_files: bool + name: str + host: Optional[str] + strict_slashes: Optional[bool] + content_type: Optional[bool] + resource_type: Optional[str] + + +class FutureSignal(NamedTuple): + handler: SignalHandler + event: str + condition: Optional[Dict[str, str]] diff --git a/backend/sanic_server/sanic/models/handler_types.py b/backend/sanic_server/sanic/models/handler_types.py new file mode 100644 index 000000000..390a54f69 --- /dev/null +++ b/backend/sanic_server/sanic/models/handler_types.py @@ -0,0 +1,21 @@ +from asyncio.events import AbstractEventLoop +from typing import Any, Callable, Coroutine, Optional, TypeVar, Union + +from ...sanic.request import Request +from ...sanic.response import BaseHTTPResponse, HTTPResponse + + +Sanic = TypeVar("Sanic") + +MiddlewareResponse = Union[ + Optional[HTTPResponse], Coroutine[Any, Any, Optional[HTTPResponse]] +] +RequestMiddlewareType = Callable[[Request], MiddlewareResponse] +ResponseMiddlewareType = Callable[[Request, BaseHTTPResponse], MiddlewareResponse] +ErrorMiddlewareType = Callable[ + [Request, BaseException], Optional[Coroutine[Any, Any, None]] +] +MiddlewareType = Union[RequestMiddlewareType, ResponseMiddlewareType] +ListenerType = Callable[[Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]]] +RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]] +SignalHandler = Callable[..., Coroutine[Any, Any, None]] diff --git a/backend/sanic_server/sanic/models/protocol_types.py b/backend/sanic_server/sanic/models/protocol_types.py new file mode 100644 index 000000000..4e139e4f5 --- /dev/null +++ b/backend/sanic_server/sanic/models/protocol_types.py @@ -0,0 +1,44 @@ +import sys + +from typing import Any, AnyStr, TypeVar, Union + + +if sys.version_info < (3, 8): + from asyncio import BaseTransport + + # from ...sanic.models.asgi import MockTransport + MockTransport = TypeVar("MockTransport") + + TransportProtocol = Union[MockTransport, BaseTransport] + Range = Any + HTMLProtocol = Any +else: + # Protocol is a 3.8+ feature + from typing import Protocol + + class TransportProtocol(Protocol): + def get_protocol(self): + ... + + def get_extra_info(self, info: str) -> Union[str, bool, None]: + ... + + class HTMLProtocol(Protocol): + def __html__(self) -> AnyStr: + ... + + def _repr_html_(self) -> AnyStr: + ... + + class Range(Protocol): + def start(self) -> int: + ... + + def end(self) -> int: + ... + + def size(self) -> int: + ... + + def total(self) -> int: + ... diff --git a/backend/sanic_server/sanic/models/server_types.py b/backend/sanic_server/sanic/models/server_types.py new file mode 100644 index 000000000..a80db91c8 --- /dev/null +++ b/backend/sanic_server/sanic/models/server_types.py @@ -0,0 +1,52 @@ +from types import SimpleNamespace + +from ...sanic.models.protocol_types import TransportProtocol + + +class Signal: + stopped = False + + +class ConnInfo: + """ + Local and remote addresses and SSL status info. + """ + + __slots__ = ( + "client_port", + "client", + "client_ip", + "ctx", + "peername", + "server_port", + "server", + "sockname", + "ssl", + ) + + def __init__(self, transport: TransportProtocol, unix=None): + self.ctx = SimpleNamespace() + self.peername = None + self.server = self.client = "" + self.server_port = self.client_port = 0 + self.client_ip = "" + self.sockname = addr = transport.get_extra_info("sockname") + self.ssl: bool = bool(transport.get_extra_info("sslcontext")) + + if isinstance(addr, str): # UNIX socket + self.server = unix or addr + return + + # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) + if isinstance(addr, tuple): + self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.server_port = addr[1] + # self.server gets non-standard port appended + if addr[1] != (443 if self.ssl else 80): + self.server = f"{self.server}:{addr[1]}" + self.peername = addr = transport.get_extra_info("peername") + + if isinstance(addr, tuple): + self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.client_ip = addr[0] + self.client_port = addr[1] diff --git a/backend/sanic_server/sanic/py.typed b/backend/sanic_server/sanic/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic/reloader_helpers.py b/backend/sanic_server/sanic/reloader_helpers.py new file mode 100644 index 000000000..c61ead264 --- /dev/null +++ b/backend/sanic_server/sanic/reloader_helpers.py @@ -0,0 +1,120 @@ +import itertools +import os +import signal +import subprocess +import sys +from time import sleep + +from ..sanic.config import BASE_LOGO +from ..sanic.log import logger + + +def _iter_module_files(): + """This iterates over all relevant Python files. + + It goes through all + loaded files from modules, all files in folders of already loaded modules + as well as all files reachable through a package. + """ + # The list call is necessary on Python 3 in case the module + # dictionary modifies during iteration. + for module in list(sys.modules.values()): + if module is None: + continue + filename = getattr(module, "__file__", None) + if filename: + old = None + while not os.path.isfile(filename): + old = filename + filename = os.path.dirname(filename) + if filename == old: + break + else: + if filename[-4:] in (".pyc", ".pyo"): + filename = filename[:-1] + yield filename + + +def _get_args_for_reloading(): + """Returns the executable.""" + main_module = sys.modules["__main__"] + mod_spec = getattr(main_module, "__spec__", None) + if sys.argv[0] in ("", "-c"): + raise RuntimeError(f"Autoreloader cannot work with argv[0]={sys.argv[0]!r}") + if mod_spec: + # Parent exe was launched as a module rather than a script + return [sys.executable, "-m", mod_spec.name] + sys.argv[1:] + return [sys.executable] + sys.argv + + +def restart_with_reloader(): + """Create a new process and a subprocess in it with the same arguments as + this one. + """ + return subprocess.Popen( + _get_args_for_reloading(), + env={**os.environ, "SANIC_SERVER_RUNNING": "true"}, + ) + + +def _check_file(filename, mtimes): + need_reload = False + + mtime = os.stat(filename).st_mtime + old_time = mtimes.get(filename) + if old_time is None: + mtimes[filename] = mtime + elif mtime > old_time: + mtimes[filename] = mtime + need_reload = True + + return need_reload + + +def watchdog(sleep_interval, app): + """Watch project files, restart worker process if a change happened. + + :param sleep_interval: interval in second. + :return: Nothing + """ + + def interrupt_self(*args): + raise KeyboardInterrupt + + mtimes = {} + signal.signal(signal.SIGTERM, interrupt_self) + if os.name == "nt": + signal.signal(signal.SIGBREAK, interrupt_self) + + worker_process = restart_with_reloader() + + if app.config.LOGO: + logger.debug(app.config.LOGO if isinstance(app.config.LOGO, str) else BASE_LOGO) + + try: + while True: + need_reload = False + + for filename in itertools.chain( + _iter_module_files(), + *(d.glob("**/*") for d in app.reload_dirs), + ): + try: + check = _check_file(filename, mtimes) + except OSError: + continue + + if check: + need_reload = True + + if need_reload: + worker_process.terminate() + worker_process.wait() + worker_process = restart_with_reloader() + + sleep(sleep_interval) + except KeyboardInterrupt: + pass + finally: + worker_process.terminate() + worker_process.wait() diff --git a/backend/sanic_server/sanic/request.py b/backend/sanic_server/sanic/request.py new file mode 100644 index 000000000..f0408e63d --- /dev/null +++ b/backend/sanic_server/sanic/request.py @@ -0,0 +1,789 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + DefaultDict, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +from ..sanic_routing.route import Route # type: ignore + +if TYPE_CHECKING: + from ..sanic.server import ConnInfo + from ..sanic.app import Sanic + from ..sanic.http import Http + +import email.utils +import uuid +from collections import defaultdict +from http.cookies import SimpleCookie +from types import SimpleNamespace +from urllib.parse import parse_qs, parse_qsl, unquote, urlunparse + +from httptools import parse_url # type: ignore + +from ..sanic.compat import CancelledErrors, Header +from ..sanic.constants import DEFAULT_HTTP_CONTENT_TYPE +from ..sanic.exceptions import InvalidUsage +from ..sanic.headers import ( + AcceptContainer, + Options, + parse_accept, + parse_content_header, + parse_forwarded, + parse_host, + parse_xforwarded, +) +from ..sanic.log import error_logger, logger +from ..sanic.models.protocol_types import TransportProtocol +from ..sanic.response import BaseHTTPResponse, HTTPResponse + +try: + from ujson import loads as json_loads # type: ignore +except ImportError: + from json import loads as json_loads # type: ignore + + +class RequestParameters(dict): + """ + Hosts a dict with lists as values where get returns the first + value of the list and getlist returns the whole shebang + """ + + def get(self, name: str, default: Optional[Any] = None) -> Optional[Any]: + """Return the first value, either the default or actual""" + return super().get(name, [default])[0] + + def getlist(self, name: str, default: Optional[Any] = None) -> Optional[Any]: + """ + Return the entire list + """ + return super().get(name, default) + + +class Request: + """ + Properties of an HTTP request such as URL, headers, etc. + """ + + __slots__ = ( + "__weakref__", + "_cookies", + "_id", + "_ip", + "_parsed_url", + "_port", + "_protocol", + "_remote_addr", + "_socket", + "_match_info", + "_name", + "app", + "body", + "conn_info", + "ctx", + "head", + "headers", + "method", + "parsed_accept", + "parsed_args", + "parsed_not_grouped_args", + "parsed_files", + "parsed_form", + "parsed_json", + "parsed_forwarded", + "raw_url", + "request_middleware_started", + "route", + "stream", + "transport", + "version", + ) + + def __init__( + self, + url_bytes: bytes, + headers: Header, + version: str, + method: str, + transport: TransportProtocol, + app: Sanic, + head: bytes = b"", + ): + self.raw_url = url_bytes + # TODO: Content-Encoding detection + self._parsed_url = parse_url(url_bytes) + self._id: Optional[Union[uuid.UUID, str, int]] = None + self._name: Optional[str] = None + self.app = app + + self.headers = Header(headers) + self.version = version + self.method = method + self.transport = transport + self.head = head + + # Init but do not inhale + self.body = b"" + self.conn_info: Optional[ConnInfo] = None + self.ctx = SimpleNamespace() + self.parsed_forwarded: Optional[Options] = None + self.parsed_accept: Optional[AcceptContainer] = None + self.parsed_json = None + self.parsed_form = None + self.parsed_files = None + self.parsed_args: DefaultDict[ + Tuple[bool, bool, str, str], RequestParameters + ] = defaultdict(RequestParameters) + self.parsed_not_grouped_args: DefaultDict[ + Tuple[bool, bool, str, str], List[Tuple[str, str]] + ] = defaultdict(list) + self.request_middleware_started = False + self._cookies: Optional[Dict[str, str]] = None + self._match_info: Dict[str, Any] = {} + self.stream: Optional[Http] = None + self.route: Optional[Route] = None + self._protocol = None + + def __repr__(self): + class_name = self.__class__.__name__ + return f"<{class_name}: {self.method} {self.path}>" + + @classmethod + def generate_id(*_): + return uuid.uuid4() + + async def respond( + self, + response: Optional[BaseHTTPResponse] = None, + *, + status: int = 200, + headers: Optional[Union[Header, Dict[str, str]]] = None, + content_type: Optional[str] = None, + ): + # This logic of determining which response to use is subject to change + if response is None: + response = (self.stream and self.stream.response) or HTTPResponse( + status=status, + headers=headers, + content_type=content_type, + ) + # Connect the response + if isinstance(response, BaseHTTPResponse) and self.stream: + response = self.stream.respond(response) + # Run response middleware + try: + response = await self.app._run_response_middleware( + self, response, request_name=self.name + ) + except CancelledErrors: + raise + except Exception: + error_logger.exception( + "Exception occurred in one of response middleware handlers" + ) + return response + + async def receive_body(self): + """Receive request.body, if not already received. + + Streaming handlers may call this to receive the full body. Sanic calls + this function before running any handlers of non-streaming routes. + + Custom request classes can override this for custom handling of both + streaming and non-streaming routes. + """ + if not self.body: + self.body = b"".join([data async for data in self.stream]) + + @property + def name(self): + if self._name: + return self._name + elif self.route: + return self.route.name + return None + + @property + def endpoint(self): + return self.name + + @property + def uri_template(self): + return f"/{self.route.path}" + + @property + def protocol(self): + if not self._protocol: + self._protocol = self.transport.get_protocol() + return self._protocol + + @property + def raw_headers(self): + _, headers = self.head.split(b"\r\n", 1) + return bytes(headers) + + @property + def request_line(self): + reqline, _ = self.head.split(b"\r\n", 1) + return bytes(reqline) + + @property + def id(self) -> Optional[Union[uuid.UUID, str, int]]: + """ + A request ID passed from the client, or generated from the backend. + + By default, this will look in a request header defined at: + ``self.app.config.REQUEST_ID_HEADER``. It defaults to + ``X-Request-ID``. Sanic will try to cast the ID into a ``UUID`` or an + ``int``. If there is not a UUID from the client, then Sanic will try + to generate an ID by calling ``Request.generate_id()``. The default + behavior is to generate a ``UUID``. You can customize this behavior + by subclassing ``Request``. + + .. code-block:: python + + from ..sanic import Request, Sanic + from itertools import count + + class IntRequest(Request): + counter = count() + + def generate_id(self): + return next(self.counter) + + app = Sanic("MyApp", request_class=IntRequest) + """ + if not self._id: + self._id = self.headers.getone( + self.app.config.REQUEST_ID_HEADER, + self.__class__.generate_id(self), # type: ignore + ) + + # Try casting to a UUID or an integer + if isinstance(self._id, str): + try: + self._id = uuid.UUID(self._id) + except ValueError: + try: + self._id = int(self._id) # type: ignore + except ValueError: + ... + + return self._id # type: ignore + + @property + def json(self): + if self.parsed_json is None: + self.load_json() + + return self.parsed_json + + def load_json(self, loads=json_loads): + try: + self.parsed_json = loads(self.body) + except Exception: + if not self.body: + return None + raise InvalidUsage("Failed when parsing body as json") + + return self.parsed_json + + @property + def accept(self) -> AcceptContainer: + if self.parsed_accept is None: + accept_header = self.headers.getone("accept", "") + self.parsed_accept = parse_accept(accept_header) + return self.parsed_accept + + @property + def token(self): + """Attempt to return the auth header token. + + :return: token related to request + """ + prefixes = ("Bearer", "Token") + auth_header = self.headers.getone("authorization", None) + + if auth_header is not None: + for prefix in prefixes: + if prefix in auth_header: + return auth_header.partition(prefix)[-1].strip() + + return auth_header + + @property + def form(self): + if self.parsed_form is None: + self.parsed_form = RequestParameters() + self.parsed_files = RequestParameters() + content_type = self.headers.getone( + "content-type", DEFAULT_HTTP_CONTENT_TYPE + ) + content_type, parameters = parse_content_header(content_type) + try: + if content_type == "application/x-www-form-urlencoded": + self.parsed_form = RequestParameters( + parse_qs(self.body.decode("utf-8")) + ) + elif content_type == "multipart/form-data": + # TODO: Stream this instead of reading to/from memory + boundary = parameters["boundary"].encode("utf-8") + self.parsed_form, self.parsed_files = parse_multipart_form( + self.body, boundary + ) + except Exception: + error_logger.exception("Failed when parsing form") + + return self.parsed_form + + @property + def files(self): + if self.parsed_files is None: + self.form # compute form to get files + + return self.parsed_files + + def get_args( + self, + keep_blank_values: bool = False, + strict_parsing: bool = False, + encoding: str = "utf-8", + errors: str = "replace", + ) -> RequestParameters: + """ + Method to parse `query_string` using `urllib.parse.parse_qs`. + This methods is used by `args` property. + Can be used directly if you need to change default parameters. + + :param keep_blank_values: + flag indicating whether blank values in + percent-encoded queries should be treated as blank strings. + A true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + :type keep_blank_values: bool + :param strict_parsing: + flag indicating what to do with parsing errors. + If false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + :type strict_parsing: bool + :param encoding: + specify how to decode percent-encoded sequences + into Unicode characters, as accepted by the bytes.decode() method. + :type encoding: str + :param errors: + specify how to decode percent-encoded sequences + into Unicode characters, as accepted by the bytes.decode() method. + :type errors: str + :return: RequestParameters + """ + if ( + keep_blank_values, + strict_parsing, + encoding, + errors, + ) not in self.parsed_args: + if self.query_string: + self.parsed_args[ + (keep_blank_values, strict_parsing, encoding, errors) + ] = RequestParameters( + parse_qs( + qs=self.query_string, + keep_blank_values=keep_blank_values, + strict_parsing=strict_parsing, + encoding=encoding, + errors=errors, + ) + ) + + return self.parsed_args[(keep_blank_values, strict_parsing, encoding, errors)] + + args = property(get_args) + + def get_query_args( + self, + keep_blank_values: bool = False, + strict_parsing: bool = False, + encoding: str = "utf-8", + errors: str = "replace", + ) -> list: + """ + Method to parse `query_string` using `urllib.parse.parse_qsl`. + This methods is used by `query_args` property. + Can be used directly if you need to change default parameters. + + :param keep_blank_values: + flag indicating whether blank values in + percent-encoded queries should be treated as blank strings. + A true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + :type keep_blank_values: bool + :param strict_parsing: + flag indicating what to do with parsing errors. + If false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + :type strict_parsing: bool + :param encoding: + specify how to decode percent-encoded sequences + into Unicode characters, as accepted by the bytes.decode() method. + :type encoding: str + :param errors: + specify how to decode percent-encoded sequences + into Unicode characters, as accepted by the bytes.decode() method. + :type errors: str + :return: list + """ + if ( + keep_blank_values, + strict_parsing, + encoding, + errors, + ) not in self.parsed_not_grouped_args: + if self.query_string: + self.parsed_not_grouped_args[ + (keep_blank_values, strict_parsing, encoding, errors) + ] = parse_qsl( + qs=self.query_string, + keep_blank_values=keep_blank_values, + strict_parsing=strict_parsing, + encoding=encoding, + errors=errors, + ) + return self.parsed_not_grouped_args[ + (keep_blank_values, strict_parsing, encoding, errors) + ] + + query_args = property(get_query_args) + """ + Convenience property to access :meth:`Request.get_query_args` with + default values. + """ + + @property + def cookies(self) -> Dict[str, str]: + """ + :return: Incoming cookies on the request + :rtype: Dict[str, str] + """ + + if self._cookies is None: + cookie = self.headers.getone("cookie", None) + if cookie is not None: + cookies: SimpleCookie = SimpleCookie() + cookies.load(cookie) + self._cookies = {name: cookie.value for name, cookie in cookies.items()} + else: + self._cookies = {} + return self._cookies + + @property + def content_type(self) -> str: + """ + :return: Content-Type header form the request + :rtype: str + """ + return self.headers.getone("content-type", DEFAULT_HTTP_CONTENT_TYPE) + + @property + def match_info(self): + """ + :return: matched info after resolving route + """ + return self._match_info + + @match_info.setter + def match_info(self, value): + self._match_info = value + + # Transport properties (obtained from local interface only) + + @property + def ip(self) -> str: + """ + :return: peer ip of the socket + :rtype: str + """ + return self.conn_info.client_ip if self.conn_info else "" + + @property + def port(self) -> int: + """ + :return: peer port of the socket + :rtype: int + """ + return self.conn_info.client_port if self.conn_info else 0 + + @property + def socket(self): + return self.conn_info.peername if self.conn_info else (None, None) + + @property + def path(self) -> str: + """ + :return: path of the local HTTP request + :rtype: str + """ + return self._parsed_url.path.decode("utf-8") + + # Proxy properties (using SERVER_NAME/forwarded/request/transport info) + + @property + def forwarded(self) -> Options: + """ + Active proxy information obtained from request headers, as specified in + Sanic configuration. + + Field names by, for, proto, host, port and path are normalized. + - for and by IPv6 addresses are bracketed + - port (int) is only set by port headers, not from host. + - path is url-unencoded + + Additional values may be available from new style Forwarded headers. + + :return: forwarded address info + :rtype: Dict[str, str] + """ + if self.parsed_forwarded is None: + self.parsed_forwarded = ( + parse_forwarded(self.headers, self.app.config) + or parse_xforwarded(self.headers, self.app.config) + or {} + ) + return self.parsed_forwarded + + @property + def remote_addr(self) -> str: + """ + Client IP address, if available. + 1. proxied remote address `self.forwarded['for']` + 2. local remote address `self.ip` + + :return: IPv4, bracketed IPv6, UNIX socket name or arbitrary string + :rtype: str + """ + if not hasattr(self, "_remote_addr"): + self._remote_addr = str(self.forwarded.get("for", "")) # or self.ip + return self._remote_addr + + @property + def scheme(self) -> str: + """ + Determine request scheme. + 1. `config.SERVER_NAME` if in full URL format + 2. proxied proto/scheme + 3. local connection protocol + + :return: http|https|ws|wss or arbitrary value given by the headers. + :rtype: str + """ + if "//" in self.app.config.get("SERVER_NAME", ""): + return self.app.config.SERVER_NAME.split("//")[0] + if "proto" in self.forwarded: + return str(self.forwarded["proto"]) + + if ( + self.app.websocket_enabled + and self.headers.getone("upgrade", "").lower() == "websocket" + ): + scheme = "ws" + else: + scheme = "http" + + if self.transport.get_extra_info("sslcontext"): + scheme += "s" + + return scheme + + @property + def host(self) -> str: + """ + The currently effective server 'host' (hostname or hostname:port). + 1. `config.SERVER_NAME` overrides any client headers + 2. proxied host of original request + 3. request host header + hostname and port may be separated by + `sanic.headers.parse_host(request.host)`. + + :return: the first matching host found, or empty string + :rtype: str + """ + server_name = self.app.config.get("SERVER_NAME") + if server_name: + return server_name.split("//", 1)[-1].split("/", 1)[0] + return str(self.forwarded.get("host") or self.headers.getone("host", "")) + + @property + def server_name(self) -> str: + """ + :return: hostname the client connected to, by ``request.host`` + :rtype: str + """ + return parse_host(self.host)[0] or "" + + @property + def server_port(self) -> int: + """ + The port the client connected to, by forwarded ``port`` or + ``request.host``. + + Default port is returned as 80 and 443 based on ``request.scheme``. + + :return: port number + :rtype: int + """ + port = self.forwarded.get("port") or parse_host(self.host)[1] + return int(port or (80 if self.scheme in ("http", "ws") else 443)) + + @property + def server_path(self) -> str: + """ + :return: full path of current URL; uses proxied or local path + :rtype: str + """ + return str(self.forwarded.get("path") or self.path) + + @property + def query_string(self) -> str: + """ + :return: representation of the requested query + :rtype: str + """ + if self._parsed_url.query: + return self._parsed_url.query.decode("utf-8") + else: + return "" + + @property + def url(self) -> str: + """ + :return: the URL + :rtype: str + """ + return urlunparse( + (self.scheme, self.host, self.path, None, self.query_string, None) + ) + + def url_for(self, view_name: str, **kwargs) -> str: + """ + Same as :func:`sanic.Sanic.url_for`, but automatically determine + `scheme` and `netloc` base on the request. Since this method is aiming + to generate correct schema & netloc, `_external` is implied. + + :param kwargs: takes same parameters as in :func:`sanic.Sanic.url_for` + :return: an absolute url to the given view + :rtype: str + """ + # Full URL SERVER_NAME can only be handled in app.url_for + try: + if "//" in self.app.config.SERVER_NAME: + return self.app.url_for(view_name, _external=True, **kwargs) + except AttributeError: + pass + + scheme = self.scheme + host = self.server_name + port = self.server_port + + if (scheme.lower() in ("http", "ws") and port == 80) or ( + scheme.lower() in ("https", "wss") and port == 443 + ): + netloc = host + else: + netloc = f"{host}:{port}" + + return self.app.url_for( + view_name, _external=True, _scheme=scheme, _server=netloc, **kwargs + ) + + +class File(NamedTuple): + """ + Model for defining a file. It is a ``namedtuple``, therefore you can + iterate over the object, or access the parameters by name. + + :param type: The mimetype, defaults to text/plain + :param body: Bytes of the file + :param name: The filename + """ + + type: str + body: bytes + name: str + + +def parse_multipart_form(body, boundary): + """ + Parse a request body and returns fields and files + + :param body: bytes request body + :param boundary: bytes multipart boundary + :return: fields (RequestParameters), files (RequestParameters) + """ + files = RequestParameters() + fields = RequestParameters() + + form_parts = body.split(boundary) + for form_part in form_parts[1:-1]: + file_name = None + content_type = "text/plain" + content_charset = "utf-8" + field_name = None + line_index = 2 + line_end_index = 0 + while not line_end_index == -1: + line_end_index = form_part.find(b"\r\n", line_index) + form_line = form_part[line_index:line_end_index].decode("utf-8") + line_index = line_end_index + 2 + + if not form_line: + break + + colon_index = form_line.index(":") + form_header_field = form_line[0:colon_index].lower() + form_header_value, form_parameters = parse_content_header( + form_line[colon_index + 2 :] + ) + + if form_header_field == "content-disposition": + field_name = form_parameters.get("name") + file_name = form_parameters.get("filename") + + # non-ASCII filenames in RFC2231, "filename*" format + if file_name is None and form_parameters.get("filename*"): + encoding, _, value = email.utils.decode_rfc2231( + form_parameters["filename*"] + ) + file_name = unquote(value, encoding=encoding) + elif form_header_field == "content-type": + content_type = form_header_value + content_charset = form_parameters.get("charset", "utf-8") + + if field_name: + post_data = form_part[line_index:-4] + if file_name is None: + value = post_data.decode(content_charset) + if field_name in fields: + fields[field_name].append(value) + else: + fields[field_name] = [value] + else: + form_file = File(type=content_type, name=file_name, body=post_data) + if field_name in files: + files[field_name].append(form_file) + else: + files[field_name] = [form_file] + else: + logger.debug( + "Form-data field does not have a 'name' parameter " + "in the Content-Disposition header" + ) + + return fields, files diff --git a/backend/sanic_server/sanic/response.py b/backend/sanic_server/sanic/response.py new file mode 100644 index 000000000..357b20f6b --- /dev/null +++ b/backend/sanic_server/sanic/response.py @@ -0,0 +1,510 @@ +from functools import partial +from mimetypes import guess_type +from os import path +from pathlib import PurePath +from typing import ( + Any, + AnyStr, + Callable, + Coroutine, + Dict, + Iterator, + Optional, + Tuple, + Union, +) +from urllib.parse import quote_plus +from warnings import warn + +from ..sanic.compat import Header, open_async +from ..sanic.constants import DEFAULT_HTTP_CONTENT_TYPE +from ..sanic.cookies import CookieJar +from ..sanic.helpers import has_message_body, remove_entity_headers +from ..sanic.http import Http +from ..sanic.models.protocol_types import HTMLProtocol, Range + + +try: + from ujson import dumps as json_dumps +except ImportError: + # This is done in order to ensure that the JSON response is + # kept consistent across both ujson and inbuilt json usage. + from json import dumps + + json_dumps = partial(dumps, separators=(",", ":")) + + +class BaseHTTPResponse: + """ + The base class for all HTTP Responses + """ + + _dumps = json_dumps + + def __init__(self): + self.asgi: bool = False + self.body: Optional[bytes] = None + self.content_type: Optional[str] = None + self.stream: Http = None + self.status: int = None + self.headers = Header({}) + self._cookies: Optional[CookieJar] = None + + def _encode_body(self, data: Optional[AnyStr]): + if data is None: + return b"" + return data.encode() if hasattr(data, "encode") else data # type: ignore + + @property + def cookies(self) -> CookieJar: + """ + The response cookies. Cookies should be set and written as follows: + + .. code-block:: python + + response.cookies["test"] = "It worked!" + response.cookies["test"]["domain"] = ".yummy-yummy-cookie.com" + response.cookies["test"]["httponly"] = True + + `See user guide re: cookies + `__ + + :return: the cookie jar + :rtype: CookieJar + """ + if self._cookies is None: + self._cookies = CookieJar(self.headers) + return self._cookies + + @property + def processed_headers(self) -> Iterator[Tuple[bytes, bytes]]: + """ + Obtain a list of header tuples encoded in bytes for sending. + + Add and remove headers based on status and content_type. + + :return: response headers + :rtype: Tuple[Tuple[bytes, bytes], ...] + """ + # TODO: Make a blacklist set of header names and then filter with that + if self.status in (304, 412): # Not Modified, Precondition Failed + self.headers = remove_entity_headers(self.headers) + if has_message_body(self.status): + self.headers.setdefault("content-type", self.content_type) + # Encode headers into bytes + return ( + (name.encode("ascii"), f"{value}".encode(errors="surrogateescape")) + for name, value in self.headers.items() + ) + + async def send( + self, + data: Optional[Union[AnyStr]] = None, + end_stream: Optional[bool] = None, + ) -> None: + """ + Send any pending response headers and the given data as body. + + :param data: str or bytes to be written + :param end_stream: whether to close the stream after this block + """ + if data is None and end_stream is None: + end_stream = True + if end_stream and not data and self.stream.send is None: + return + data = data.encode() if hasattr(data, "encode") else data or b"" # type: ignore + await self.stream.send(data, end_stream=end_stream) + + +StreamingFunction = Callable[[BaseHTTPResponse], Coroutine[Any, Any, None]] + + +class StreamingHTTPResponse(BaseHTTPResponse): + """ + Old style streaming response where you pass a streaming function: + + .. code-block:: python + + async def sample_streaming_fn(response): + await response.write("foo") + await asyncio.sleep(1) + await response.write("bar") + await asyncio.sleep(1) + + @app.post("/") + async def test(request): + return stream(sample_streaming_fn) + + .. warning:: + + **Deprecated** and set for removal in v21.12. You can now achieve the + same functionality without a callback. + + .. code-block:: python + + @app.post("/") + async def test(request): + response = await request.respond() + await response.send("foo", False) + await asyncio.sleep(1) + await response.send("bar", False) + await asyncio.sleep(1) + await response.send("", True) + return response + + """ + + __slots__ = ( + "streaming_fn", + "status", + "content_type", + "headers", + "_cookies", + ) + + def __init__( + self, + streaming_fn: StreamingFunction, + status: int = 200, + headers: Optional[Union[Header, Dict[str, str]]] = None, + content_type: str = "text/plain; charset=utf-8", + ignore_deprecation_notice: bool = False, + ): + if not ignore_deprecation_notice: + warn( + "Use of the StreamingHTTPResponse is deprecated in v21.6, and " + "will be removed in v21.12. Please upgrade your streaming " + "response implementation. You can learn more here: " + "https://sanicframework.org/en/guide/advanced/streaming.html" + "#response-streaming. If you use the builtin stream() or " + "file_stream() methods, this upgrade will be be done for you." + ) + + super().__init__() + + self.content_type = content_type + self.streaming_fn = streaming_fn + self.status = status + self.headers = Header(headers or {}) + self._cookies = None + + async def write(self, data): + """Writes a chunk of data to the streaming response. + + :param data: str or bytes-ish data to be written. + """ + await super().send(self._encode_body(data)) + + async def send(self, *args, **kwargs): + if self.streaming_fn is not None: + await self.streaming_fn(self) + self.streaming_fn = None + await super().send(*args, **kwargs) + + async def eof(self): + raise NotImplementedError + + +class HTTPResponse(BaseHTTPResponse): + """ + HTTP response to be sent back to the client. + + :param body: the body content to be returned + :type body: Optional[bytes] + :param status: HTTP response number. **Default=200** + :type status: int + :param headers: headers to be returned + :type headers: Optional; + :param content_type: content type to be returned (as a header) + :type content_type: Optional[str] + """ + + __slots__ = ("body", "status", "content_type", "headers", "_cookies") + + def __init__( + self, + body: Optional[AnyStr] = None, + status: int = 200, + headers: Optional[Union[Header, Dict[str, str]]] = None, + content_type: Optional[str] = None, + ): + super().__init__() + + self.content_type: Optional[str] = content_type + self.body = self._encode_body(body) + self.status = status + self.headers = Header(headers or {}) + self._cookies = None + + async def eof(self): + await self.send("", True) + + async def __aenter__(self): + return self.send + + async def __aexit__(self, *_): + await self.eof() + + +def empty(status=204, headers: Optional[Dict[str, str]] = None) -> HTTPResponse: + """ + Returns an empty response to the client. + + :param status Response code. + :param headers Custom Headers. + """ + return HTTPResponse(body=b"", status=status, headers=headers) + + +def json( + body: Any, + status: int = 200, + headers: Optional[Dict[str, str]] = None, + content_type: str = "application/json", + dumps: Optional[Callable[..., str]] = None, + **kwargs, +) -> HTTPResponse: + """ + Returns response object with body in json format. + + :param body: Response data to be serialized. + :param status: Response code. + :param headers: Custom Headers. + :param kwargs: Remaining arguments that are passed to the json encoder. + """ + if not dumps: + dumps = BaseHTTPResponse._dumps + return HTTPResponse( + dumps(body, **kwargs), + headers=headers, + status=status, + content_type=content_type, + ) + + +def text( + body: str, + status: int = 200, + headers: Optional[Dict[str, str]] = None, + content_type: str = "text/plain; charset=utf-8", +) -> HTTPResponse: + """ + Returns response object with body in text format. + + :param body: Response data to be encoded. + :param status: Response code. + :param headers: Custom Headers. + :param content_type: the content type (string) of the response + """ + if not isinstance(body, str): + raise TypeError(f"Bad body type. Expected str, got {type(body).__name__})") + + return HTTPResponse(body, status=status, headers=headers, content_type=content_type) + + +def raw( + body: Optional[AnyStr], + status: int = 200, + headers: Optional[Dict[str, str]] = None, + content_type: str = DEFAULT_HTTP_CONTENT_TYPE, +) -> HTTPResponse: + """ + Returns response object without encoding the body. + + :param body: Response data. + :param status: Response code. + :param headers: Custom Headers. + :param content_type: the content type (string) of the response. + """ + return HTTPResponse( + body=body, + status=status, + headers=headers, + content_type=content_type, + ) + + +def html( + body: Union[str, bytes, HTMLProtocol], + status: int = 200, + headers: Optional[Dict[str, str]] = None, +) -> HTTPResponse: + """ + Returns response object with body in html format. + + :param body: str or bytes-ish, or an object with __html__ or _repr_html_. + :param status: Response code. + :param headers: Custom Headers. + """ + if not isinstance(body, (str, bytes)): + if hasattr(body, "__html__"): + body = body.__html__() + elif hasattr(body, "_repr_html_"): + body = body._repr_html_() + + return HTTPResponse( # type: ignore + body, + status=status, + headers=headers, + content_type="text/html; charset=utf-8", + ) + + +async def file( + location: Union[str, PurePath], + status: int = 200, + mime_type: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + filename: Optional[str] = None, + _range: Optional[Range] = None, +) -> HTTPResponse: + """Return a response object with file data. + + :param location: Location of file on system. + :param mime_type: Specific mime_type. + :param headers: Custom Headers. + :param filename: Override filename. + :param _range: + """ + headers = headers or {} + if filename: + headers.setdefault("Content-Disposition", f'attachment; filename="{filename}"') + filename = filename or path.split(location)[-1] + + async with await open_async(location, mode="rb") as f: + if _range: + await f.seek(_range.start) + out_stream = await f.read(_range.size) + headers[ + "Content-Range" + ] = f"bytes {_range.start}-{_range.end}/{_range.total}" + status = 206 + else: + out_stream = await f.read() + + mime_type = mime_type or guess_type(filename)[0] or "text/plain" + return HTTPResponse( + body=out_stream, + status=status, + headers=headers, + content_type=mime_type, + ) + + +async def file_stream( + location: Union[str, PurePath], + status: int = 200, + chunk_size: int = 4096, + mime_type: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + filename: Optional[str] = None, + _range: Optional[Range] = None, +) -> StreamingHTTPResponse: + """Return a streaming response object with file data. + + :param location: Location of file on system. + :param chunk_size: The size of each chunk in the stream (in bytes) + :param mime_type: Specific mime_type. + :param headers: Custom Headers. + :param filename: Override filename. + :param chunked: Deprecated + :param _range: + """ + headers = headers or {} + if filename: + headers.setdefault("Content-Disposition", f'attachment; filename="{filename}"') + filename = filename or path.split(location)[-1] + mime_type = mime_type or guess_type(filename)[0] or "text/plain" + if _range: + start = _range.start + end = _range.end + total = _range.total + + headers["Content-Range"] = f"bytes {start}-{end}/{total}" + status = 206 + + async def _streaming_fn(response): + async with await open_async(location, mode="rb") as f: + if _range: + await f.seek(_range.start) + to_send = _range.size + while to_send > 0: + content = await f.read(min((_range.size, chunk_size))) + if len(content) < 1: + break + to_send -= len(content) + await response.write(content) + else: + while True: + content = await f.read(chunk_size) + if len(content) < 1: + break + await response.write(content) + + return StreamingHTTPResponse( + streaming_fn=_streaming_fn, + status=status, + headers=headers, + content_type=mime_type, + ignore_deprecation_notice=True, + ) + + +def stream( + streaming_fn: StreamingFunction, + status: int = 200, + headers: Optional[Dict[str, str]] = None, + content_type: str = "text/plain; charset=utf-8", +): + """Accepts an coroutine `streaming_fn` which can be used to + write chunks to a streaming response. Returns a `StreamingHTTPResponse`. + + Example usage:: + + @app.route("/") + async def index(request): + async def streaming_fn(response): + await response.write('foo') + await response.write('bar') + + return stream(streaming_fn, content_type='text/plain') + + :param streaming_fn: A coroutine accepts a response and + writes content to that response. + :param mime_type: Specific mime_type. + :param headers: Custom Headers. + :param chunked: Deprecated + """ + return StreamingHTTPResponse( + streaming_fn, + headers=headers, + content_type=content_type, + status=status, + ignore_deprecation_notice=True, + ) + + +def redirect( + to: str, + headers: Optional[Dict[str, str]] = None, + status: int = 302, + content_type: str = "text/html; charset=utf-8", +) -> HTTPResponse: + """ + Abort execution and cause a 302 redirect (by default) by setting a + Location header. + + :param to: path or fully qualified URL to redirect to + :param headers: optional dict of headers to include in the new request + :param status: status code (int) of the new request, defaults to 302 + :param content_type: the content type (string) of the response + """ + headers = headers or {} + + # URL Quote the URL before redirecting + safe_to = quote_plus(to, safe=":/%#?&=@[]!$&'()*+,;") + + # According to RFC 7231, a relative URI is now permitted. + headers["Location"] = safe_to + + return HTTPResponse(status=status, headers=headers, content_type=content_type) diff --git a/backend/sanic_server/sanic/router.py b/backend/sanic_server/sanic/router.py new file mode 100644 index 000000000..fc6cacaa6 --- /dev/null +++ b/backend/sanic_server/sanic/router.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from functools import lru_cache +from inspect import signature +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import UUID + +from ..sanic_routing import BaseRouter # type: ignore +from ..sanic_routing.exceptions import NoMethod # type: ignore +from ..sanic_routing.exceptions import ( + NotFound as RoutingNotFound, # type: ignore +) +from ..sanic_routing.route import Route # type: ignore + +from ..sanic.constants import HTTP_METHODS +from ..sanic.errorpages import check_error_format +from ..sanic.exceptions import MethodNotSupported, NotFound, SanicException +from ..sanic.models.handler_types import RouteHandler + + +ROUTER_CACHE_SIZE = 1024 +ALLOWED_LABELS = ("__file_uri__",) + + +class Router(BaseRouter): + """ + The router implementation responsible for routing a :class:`Request` object + to the appropriate handler. + """ + + DEFAULT_METHOD = "GET" + ALLOWED_METHODS = HTTP_METHODS + + def _get( + self, path: str, method: str, host: Optional[str] + ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: + try: + return self.resolve( + path=path, + method=method, + extra={"host": host} if host else None, + ) + except RoutingNotFound as e: + raise NotFound("Requested URL {} not found".format(e.path)) + except NoMethod as e: + raise MethodNotSupported( + "Method {} not allowed for URL {}".format(method, path), + method=method, + allowed_methods=e.allowed_methods, + ) + + @lru_cache(maxsize=ROUTER_CACHE_SIZE) + def get( # type: ignore + self, path: str, method: str, host: Optional[str] + ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: + """ + Retrieve a `Route` object containg the details about how to handle + a response for a given request + + :param request: the incoming request object + :type request: Request + :return: details needed for handling the request and returning the + correct response + :rtype: Tuple[ Route, RouteHandler, Dict[str, Any]] + """ + return self._get(path, method, host) + + def add( # type: ignore + self, + uri: str, + methods: Iterable[str], + handler: RouteHandler, + host: Optional[Union[str, Iterable[str]]] = None, + strict_slashes: bool = False, + stream: bool = False, + ignore_body: bool = False, + version: Union[str, float, int] = None, + name: Optional[str] = None, + unquote: bool = False, + static: bool = False, + version_prefix: str = "/v", + error_format: Optional[str] = None, + ) -> Union[Route, List[Route]]: + """ + Add a handler to the router + + :param uri: the path of the route + :type uri: str + :param methods: the types of HTTP methods that should be attached, + example: ``["GET", "POST", "OPTIONS"]`` + :type methods: Iterable[str] + :param handler: the sync or async function to be executed + :type handler: RouteHandler + :param host: host that the route should be on, defaults to None + :type host: Optional[str], optional + :param strict_slashes: whether to apply strict slashes, defaults + to False + :type strict_slashes: bool, optional + :param stream: whether to stream the response, defaults to False + :type stream: bool, optional + :param ignore_body: whether the incoming request body should be read, + defaults to False + :type ignore_body: bool, optional + :param version: a version modifier for the uri, defaults to None + :type version: Union[str, float, int], optional + :param name: an identifying name of the route, defaults to None + :type name: Optional[str], optional + :return: the route object + :rtype: Route + """ + if version is not None: + version = str(version).strip("/").lstrip("v") + uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) + + uri = self._normalize(uri, handler) + + params = dict( + path=uri, + handler=handler, + methods=frozenset(map(str, methods)) if methods else None, + name=name, + strict=strict_slashes, + unquote=unquote, + ) + + if isinstance(host, str): + hosts = [host] + else: + hosts = host or [None] # type: ignore + + routes = [] + + for host in hosts: + if host: + params.update({"requirements": {"host": host}}) + + route = super().add(**params) # type: ignore + route.ctx.ignore_body = ignore_body + route.ctx.stream = stream + route.ctx.hosts = hosts + route.ctx.static = static + route.ctx.error_format = error_format + + if error_format: + check_error_format(route.ctx.error_format) + + routes.append(route) + + if len(routes) == 1: + return routes[0] + return routes + + @lru_cache(maxsize=ROUTER_CACHE_SIZE) + def find_route_by_view_name(self, view_name, name=None): + """ + Find a route in the router based on the specified view name. + + :param view_name: string of view name to search by + :param kwargs: additional params, usually for static files + :return: tuple containing (uri, Route) + """ + if not view_name: + return None + + route = self.name_index.get(view_name) + if not route: + full_name = self.ctx.app._generate_name(view_name) + route = self.name_index.get(full_name) + + if not route: + return None + + return route + + @property + def routes_all(self): + return {route.parts: route for route in self.routes} + + @property + def routes_static(self): + return self.static_routes + + @property + def routes_dynamic(self): + return self.dynamic_routes + + @property + def routes_regex(self): + return self.regex_routes + + def finalize(self, *args, **kwargs): + super().finalize(*args, **kwargs) + + for route in self.dynamic_routes.values(): + if any( + label.startswith("__") and label not in ALLOWED_LABELS + for label in route.labels + ): + raise SanicException( + f"Invalid route: {route}. Parameter names cannot use '__'." + ) + + def _normalize(self, uri: str, handler: RouteHandler) -> str: + if "<" not in uri: + return uri + + sig = signature(handler) + mapping = { + param.name: param.annotation.__name__.lower() + for param in sig.parameters.values() + if param.annotation in (str, int, float, UUID) + } + + reconstruction = [] + for part in uri.split("/"): + if part.startswith("<") and ":" not in part: + name = part[1:-1] + annotation = mapping.get(name) + if annotation: + part = f"<{name}:{annotation}>" + reconstruction.append(part) + return "/".join(reconstruction) diff --git a/backend/sanic_server/sanic/server/__init__.py b/backend/sanic_server/sanic/server/__init__.py new file mode 100644 index 000000000..9f160beb0 --- /dev/null +++ b/backend/sanic_server/sanic/server/__init__.py @@ -0,0 +1,26 @@ +import asyncio + +from ...sanic.models.server_types import ConnInfo, Signal +from ...sanic.server.async_server import AsyncioServer +from ...sanic.server.protocols.http_protocol import HttpProtocol +from ...sanic.server.runners import serve, serve_multiple, serve_single + + +try: + import uvloop # type: ignore + + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +__all__ = ( + "AsyncioServer", + "ConnInfo", + "HttpProtocol", + "Signal", + "serve", + "serve_multiple", + "serve_single", +) diff --git a/backend/sanic_server/sanic/server/async_server.py b/backend/sanic_server/sanic/server/async_server.py new file mode 100644 index 000000000..0241f0c1d --- /dev/null +++ b/backend/sanic_server/sanic/server/async_server.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio + +from ...sanic.exceptions import SanicException + + +class AsyncioServer: + """ + Wraps an asyncio server with functionality that might be useful to + a user who needs to manage the server lifecycle manually. + """ + + __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") + + def __init__( + self, + app, + loop, + serve_coro, + connections, + ): + # Note, Sanic already called "before_server_start" events + # before this helper was even created. So we don't need it here. + self.app = app + self.connections = connections + self.loop = loop + self.serve_coro = serve_coro + self.server = None + self.init = False + + def startup(self): + """ + Trigger "before_server_start" events + """ + self.init = True + return self.app._startup() + + def before_start(self): + """ + Trigger "before_server_start" events + """ + return self._server_event("init", "before") + + def after_start(self): + """ + Trigger "after_server_start" events + """ + return self._server_event("init", "after") + + def before_stop(self): + """ + Trigger "before_server_stop" events + """ + return self._server_event("shutdown", "before") + + def after_stop(self): + """ + Trigger "after_server_stop" events + """ + return self._server_event("shutdown", "after") + + def is_serving(self) -> bool: + if self.server: + return self.server.is_serving() + return False + + def wait_closed(self): + if self.server: + return self.server.wait_closed() + + def close(self): + if self.server: + self.server.close() + coro = self.wait_closed() + task = asyncio.ensure_future(coro, loop=self.loop) + return task + + def start_serving(self): + if self.server: + try: + return self.server.start_serving() + except AttributeError: + raise NotImplementedError( + "server.start_serving not available in this version " + "of asyncio or uvloop." + ) + + def serve_forever(self): + if self.server: + try: + return self.server.serve_forever() + except AttributeError: + raise NotImplementedError( + "server.serve_forever not available in this version " + "of asyncio or uvloop." + ) + + def _server_event(self, concern: str, action: str): + if not self.init: + raise SanicException( + "Cannot dispatch server event without " "first running server.startup()" + ) + return self.app._server_event(concern, action, loop=self.loop) + + def __await__(self): + """ + Starts the asyncio server, returns AsyncServerCoro + """ + task = asyncio.ensure_future(self.serve_coro) + while not task.done(): + yield + self.server = task.result() + return self diff --git a/backend/sanic_server/sanic/server/events.py b/backend/sanic_server/sanic/server/events.py new file mode 100644 index 000000000..3b71281d9 --- /dev/null +++ b/backend/sanic_server/sanic/server/events.py @@ -0,0 +1,16 @@ +from inspect import isawaitable +from typing import Any, Callable, Iterable, Optional + + +def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): + """ + Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + if events: + for event in events: + result = event(loop) + if isawaitable(result): + loop.run_until_complete(result) diff --git a/backend/sanic_server/sanic/server/protocols/__init__.py b/backend/sanic_server/sanic/server/protocols/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic/server/protocols/base_protocol.py b/backend/sanic_server/sanic/server/protocols/base_protocol.py new file mode 100644 index 000000000..cd877bae6 --- /dev/null +++ b/backend/sanic_server/sanic/server/protocols/base_protocol.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + + +if TYPE_CHECKING: + from ....sanic.app import Sanic + +import asyncio + +from asyncio import CancelledError +from asyncio.transports import Transport +from time import monotonic as current_time + +from ....sanic.log import error_logger +from ....sanic.models.server_types import ConnInfo, Signal + + +class SanicProtocol(asyncio.Protocol): + __slots__ = ( + "app", + # event loop, connection + "loop", + "transport", + "connections", + "conn_info", + "signal", + "_can_write", + "_time", + "_task", + "_unix", + "_data_received", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + unix=None, + **kwargs, + ): + asyncio.set_event_loop(loop) + self.loop = loop + self.app: Sanic = app + self.signal = signal or Signal() + self.transport: Optional[Transport] = None + self.connections = connections if connections is not None else set() + self.conn_info: Optional[ConnInfo] = None + self._can_write = asyncio.Event() + self._can_write.set() + self._unix = unix + self._time = 0.0 # type: float + self._task = None # type: Optional[asyncio.Task] + self._data_received = asyncio.Event() + + @property + def ctx(self): + if self.conn_info is not None: + return self.conn_info.ctx + else: + return None + + async def send(self, data): + """ + Generic data write implementation with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + self.transport.write(data) + self._time = current_time() + + async def receive_more(self): + """ + Wait until more data is received into the Server protocol's buffer + """ + self.transport.resume_reading() + self._data_received.clear() + await self._data_received.wait() + + def close(self, timeout: Optional[float] = None): + """ + Attempt close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + if timeout is None: + timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT + self.loop.call_later(timeout, self.abort) + + def abort(self): + """ + Force close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.abort() + self.transport = None + + # asyncio.Protocol API Callbacks # + # ------------------------------ # + def connection_made(self, transport): + """ + Generic connection-made, with no connection_task, and no recv_buffer. + Override this for protocol-specific connection implementations. + """ + try: + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except BaseException: + error_logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data: bytes): + try: + self._time = current_time() + if not data: + return self.close() + + if self._data_received: + self._data_received.set() + except BaseException: + error_logger.exception("protocol.data_received") diff --git a/backend/sanic_server/sanic/server/protocols/http_protocol.py b/backend/sanic_server/sanic/server/protocols/http_protocol.py new file mode 100644 index 000000000..5c0609276 --- /dev/null +++ b/backend/sanic_server/sanic/server/protocols/http_protocol.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ....sanic.touchup.meta import TouchUpMeta + + +if TYPE_CHECKING: + from ....sanic.app import Sanic + +from asyncio import CancelledError +from time import monotonic as current_time + +from ....sanic.exceptions import RequestTimeout, ServiceUnavailable +from ....sanic.http import Http, Stage +from ....sanic.log import error_logger, logger +from ....sanic.models.server_types import ConnInfo +from ....sanic.request import Request +from ....sanic.server.protocols.base_protocol import SanicProtocol + + +class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): + """ + This class provides implements the HTTP 1.1 protocol on top of our + Sanic Server transport + """ + + __touchup__ = ( + "send", + "connection_task", + ) + __slots__ = ( + # request params + "request", + # request config + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_class", + "error_handler", + # enable or disable access log purpose + "access_log", + # connection management + "state", + "url", + "_handler_task", + "_http", + "_exception", + "recv_buffer", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + state=None, + unix=None, + **kwargs, + ): + super().__init__( + loop=loop, + app=app, + signal=signal, + connections=connections, + unix=unix, + ) + self.url = None + self.request: Optional[Request] = None + self.access_log = self.app.config.ACCESS_LOG + self.request_handler = self.app.handle_request + self.error_handler = self.app.error_handler + self.request_timeout = self.app.config.REQUEST_TIMEOUT + self.response_timeout = self.app.config.RESPONSE_TIMEOUT + self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT + self.request_max_size = self.app.config.REQUEST_MAX_SIZE + self.request_class = self.app.request_class or Request + self.state = state if state else {} + if "requests_count" not in self.state: + self.state["requests_count"] = 0 + self._exception = None + + def _setup_connection(self): + self._http = Http(self) + self._time = current_time() + self.check_timeouts() + + async def connection_task(self): # no cov + """ + Run a HTTP connection. + + Timeouts and some additional error handling occur here, while most of + everything else happens in class Http or in code called from there. + """ + try: + self._setup_connection() + await self.app.dispatch( + "http.lifecycle.begin", + inline=True, + context={"conn_info": self.conn_info}, + ) + await self._http.http1() + except CancelledError: + pass + except Exception: + error_logger.exception("protocol.connection_task uncaught") + finally: + if ( + self.app.debug + and self._http + and self.transport + and not self._http.upgrade_websocket + ): + ip = self.transport.get_extra_info("peername") + error_logger.error( + "Connection lost before response written" + f" @ {ip} {self._http.request}" + ) + self._http = None + self._task = None + try: + self.close() + except BaseException: + error_logger.exception("Closing failed") + finally: + await self.app.dispatch( + "http.lifecycle.complete", + inline=True, + context={"conn_info": self.conn_info}, + ) + # Important to keep this Ellipsis here for the TouchUp module + ... + + def check_timeouts(self): + """ + Runs itself periodically to enforce any expired timeouts. + """ + try: + if not self._task: + return + duration = current_time() - self._time + stage = self._http.stage + if stage is Stage.IDLE and duration > self.keep_alive_timeout: + logger.debug("KeepAlive Timeout. Closing connection.") + elif stage is Stage.REQUEST and duration > self.request_timeout: + logger.debug("Request Timeout. Closing connection.") + self._http.exception = RequestTimeout("Request Timeout") + elif stage is Stage.HANDLER and self._http.upgrade_websocket: + logger.debug("Handling websocket. Timeouts disabled.") + return + elif ( + stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) + and duration > self.response_timeout + ): + logger.debug("Response Timeout. Closing connection.") + self._http.exception = ServiceUnavailable("Response Timeout") + else: + interval = ( + min( + self.keep_alive_timeout, + self.request_timeout, + self.response_timeout, + ) + / 2 + ) + self.loop.call_later(max(0.1, interval), self.check_timeouts) + return + self._task.cancel() + except Exception: + error_logger.exception("protocol.check_timeouts") + + async def send(self, data): # no cov + """ + Writes HTTP data with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + await self.app.dispatch( + "http.lifecycle.send", + inline=True, + context={"data": data}, + ) + self.transport.write(data) + self._time = current_time() + + def close_if_idle(self) -> bool: + """ + Close the connection if a request is not being sent or received + + :return: boolean - True if closed, false if staying open + """ + if self._http is None or self._http.stage is Stage.IDLE: + self.close() + return True + return False + + # -------------------------------------------- # + # Only asyncio.Protocol callbacks below this + # -------------------------------------------- # + + def connection_made(self, transport): + """ + HTTP-protocol-specific new connection handler + """ + try: + # TODO: Benchmark to find suitable write buffer limits + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self._task = self.loop.create_task(self.connection_task()) + self.recv_buffer = bytearray() + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def data_received(self, data: bytes): + + try: + self._time = current_time() + if not data: + return self.close() + self.recv_buffer += data + + if ( + len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE + and self.transport + ): + self.transport.pause_reading() + + if self._data_received: + self._data_received.set() + except Exception: + error_logger.exception("protocol.data_received") diff --git a/backend/sanic_server/sanic/server/protocols/websocket_protocol.py b/backend/sanic_server/sanic/server/protocols/websocket_protocol.py new file mode 100644 index 000000000..c3990823b --- /dev/null +++ b/backend/sanic_server/sanic/server/protocols/websocket_protocol.py @@ -0,0 +1,161 @@ +from typing import TYPE_CHECKING, Optional, Sequence + +from websockets.connection import CLOSED, CLOSING, OPEN +from websockets.server import ServerConnection + +from ....sanic.exceptions import ServerError +from ....sanic.log import error_logger +from ....sanic.server import HttpProtocol + +from ..websockets.impl import WebsocketImplProtocol + + +if TYPE_CHECKING: + from websockets import http11 + + +class WebSocketProtocol(HttpProtocol): + + websocket: Optional[WebsocketImplProtocol] + websocket_timeout: float + websocket_max_size = Optional[int] + websocket_ping_interval = Optional[float] + websocket_ping_timeout = Optional[float] + + def __init__( + self, + *args, + websocket_timeout: float = 10.0, + websocket_max_size: Optional[int] = None, + websocket_max_queue: Optional[int] = None, # max_queue is deprecated + websocket_read_limit: Optional[int] = None, # read_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit deprecated + websocket_ping_interval: Optional[float] = 20.0, + websocket_ping_timeout: Optional[float] = 20.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.websocket = None + self.websocket_timeout = websocket_timeout + self.websocket_max_size = websocket_max_size + if websocket_max_queue is not None and websocket_max_queue > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses queueing, so websocket_max_queue" + " is no longer required." + ) + ) + if websocket_read_limit is not None and websocket_read_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses read buffers, so " + "websocket_read_limit is not required." + ) + ) + if websocket_write_limit is not None and websocket_write_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses write buffers, so " + "websocket_write_limit is not required." + ) + ) + self.websocket_ping_interval = websocket_ping_interval + self.websocket_ping_timeout = websocket_ping_timeout + + def connection_lost(self, exc): + if self.websocket is not None: + self.websocket.connection_lost(exc) + super().connection_lost(exc) + + def data_received(self, data): + if self.websocket is not None: + self.websocket.data_received(data) + else: + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) + + def eof_received(self) -> Optional[bool]: + if self.websocket is not None: + return self.websocket.eof_received() + else: + return False + + def close(self, timeout: Optional[float] = None): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closing + if self.websocket is not None: + # Note, we don't want to use websocket.close() + # That is used for user's application code to send a + # websocket close packet. This is different. + self.websocket.end_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + elif self.websocket.loop is not None: + self.websocket.loop.create_task(self.websocket.close(1001)) + else: + self.websocket.end_connection(1001) + else: + return super().close_if_idle() + + async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str]]): + # let the websockets package do the handshake with the client + try: + if subprotocols is not None: + # subprotocols can be a set or frozenset, + # but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_conn = ServerConnection( + max_size=self.websocket_max_size, + subprotocols=subprotocols, + state=OPEN, + logger=error_logger, + ) + resp: "http11.Response" = ws_conn.accept(request) + except Exception: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise ServerError(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbody = "".join( + [ + "HTTP/1.1 ", + str(resp.status_code), + " ", + resp.reason_phrase, + "\r\n", + ] + ) + rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + if resp.body is not None: + rbody += f"\r\n{resp.body}\r\n\r\n" + else: + rbody += "\r\n" + await super().send(rbody.encode()) + else: + raise ServerError(resp.body, resp.status_code) + self.websocket = WebsocketImplProtocol( + ws_conn, + ping_interval=self.websocket_ping_interval, + ping_timeout=self.websocket_ping_timeout, + close_timeout=self.websocket_timeout, + ) + loop = ( + request.transport.loop + if hasattr(request, "transport") and hasattr(request.transport, "loop") + else None + ) + await self.websocket.connection_made(self, loop=loop) + return self.websocket diff --git a/backend/sanic_server/sanic/server/runners.py b/backend/sanic_server/sanic/server/runners.py new file mode 100644 index 000000000..74e74fd12 --- /dev/null +++ b/backend/sanic_server/sanic/server/runners.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from ssl import SSLContext +from typing import TYPE_CHECKING, Dict, Optional, Type, Union + +from ...sanic.config import Config +from ...sanic.server.events import trigger_events + + +if TYPE_CHECKING: + from ...sanic.app import Sanic + +import asyncio +import multiprocessing +import os +import socket + +from functools import partial +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func + +from ...sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows +from ...sanic.log import error_logger, logger +from ...sanic.models.server_types import Signal +from ...sanic.server.async_server import AsyncioServer +from ...sanic.server.protocols.http_protocol import HttpProtocol +from ...sanic.server.socket import ( + bind_socket, + bind_unix_socket, + remove_unix_socket, +) + + +def serve( + host, + port, + app: Sanic, + ssl: Optional[SSLContext] = None, + sock: Optional[socket.socket] = None, + unix: Optional[str] = None, + reuse_port: bool = False, + loop=None, + protocol: Type[asyncio.Protocol] = HttpProtocol, + backlog: int = 100, + register_sys_signals: bool = True, + run_multiple: bool = False, + run_async: bool = False, + connections=None, + signal=Signal(), + state=None, + asyncio_server_kwargs=None, +): + """Start asynchronous HTTP Server on an individual process. + + :param host: Address to host on + :param port: Port to host on + :param before_start: function to be executed before the server starts + listening. Takes arguments `app` instance and `loop` + :param after_start: function to be executed after the server starts + listening. Takes arguments `app` instance and `loop` + :param before_stop: function to be executed when a stop signal is + received before it is respected. Takes arguments + `app` instance and `loop` + :param after_stop: function to be executed when a stop signal is + received after it is respected. Takes arguments + `app` instance and `loop` + :param ssl: SSLContext + :param sock: Socket for the server to accept connections from + :param unix: Unix socket to listen on instead of TCP port + :param reuse_port: `True` for multiple workers + :param loop: asyncio compatible event loop + :param run_async: bool: Do not create a new event loop for the server, + and return an AsyncServer object rather than running it + :param asyncio_server_kwargs: key-value args for asyncio/uvloop + create_server method + :return: Nothing + """ + if not run_async and not loop: + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if app.debug: + loop.set_debug(app.debug) + + app.asgi = False + + connections = connections if connections is not None else set() + protocol_kwargs = _build_protocol_kwargs(protocol, app.config) + server = partial( + protocol, + loop=loop, + connections=connections, + signal=signal, + app=app, + state=state, + unix=unix, + **protocol_kwargs, + ) + asyncio_server_kwargs = asyncio_server_kwargs if asyncio_server_kwargs else {} + # UNIX sockets are always bound by us (to preserve semantics between modes) + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_coroutine = loop.create_server( + server, + None if sock else host, + None if sock else port, + ssl=ssl, + reuse_port=reuse_port, + sock=sock, + backlog=backlog, + **asyncio_server_kwargs, + ) + + if run_async: + return AsyncioServer( + app=app, + loop=loop, + serve_coro=server_coroutine, + connections=connections, + ) + + loop.run_until_complete(app._startup()) + loop.run_until_complete(app._server_event("init", "before")) + + try: + http_server = loop.run_until_complete(server_coroutine) + except BaseException: + error_logger.exception("Unable to start server") + return + + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + + # Register signals for graceful termination + if register_sys_signals: + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) + + loop.run_until_complete(app._server_event("init", "after")) + pid = os.getpid() + try: + logger.info("Starting worker [%s]", pid) + loop.run_forever() + finally: + logger.info("Stopping worker [%s]", pid) + + # Run the on_stop function if provided + loop.run_until_complete(app._server_event("shutdown", "before")) + + # Wait for event loop to finish and all connections to drain + http_server.close() + loop.run_until_complete(http_server.wait_closed()) + + # Complete all tasks on the loop + signal.stopped = True + for connection in connections: + connection.close_if_idle() + + # Gracefully shutdown timeout. + # We should provide graceful_shutdown_timeout, + # instead of letting connection hangs forever. + # Let's roughly calcucate time. + graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT + start_shutdown: float = 0 + while connections and (start_shutdown < graceful): + loop.run_until_complete(asyncio.sleep(0.1)) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + for conn in connections: + if hasattr(conn, "websocket") and conn.websocket: + conn.websocket.fail_connection(code=1001) + else: + conn.abort() + loop.run_until_complete(app._server_event("shutdown", "after")) + + remove_unix_socket(unix) + + +def serve_single(server_settings): + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + + if not server_settings.get("run_async"): + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + server_settings["loop"] = loop + + trigger_events(main_start, server_settings["loop"]) + serve(**server_settings) + trigger_events(main_stop, server_settings["loop"]) + + server_settings["loop"].close() + + +def serve_multiple(server_settings, workers): + """Start multiple server processes simultaneously. Stop on interrupt + and terminate signals, and drain connections when complete. + + :param server_settings: kw arguments to be passed to the serve function + :param workers: number of workers to launch + :param stop_event: if provided, is used as a stop signal + :return: + """ + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True + + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + trigger_events(main_start, loop) + + # Create a listening socket or use the one in settings + sock = server_settings.get("sock") + unix = server_settings["unix"] + backlog = server_settings["backlog"] + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_settings["unix"] = unix + if sock is None: + sock = bind_socket( + server_settings["host"], server_settings["port"], backlog=backlog + ) + sock.set_inheritable(True) + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None + + processes = [] + + def sig_handler(signal, frame): + logger.info("Received signal %s. Shutting down.", Signals(signal).name) + for process in processes: + os.kill(process.pid, SIGTERM) + + signal_func(SIGINT, lambda s, f: sig_handler(s, f)) + signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) + mp = multiprocessing.get_context("fork") + + for _ in range(workers): + process = mp.Process(target=serve, kwargs=server_settings) + process.daemon = True + process.start() + processes.append(process) + + for process in processes: + process.join() + + # the above processes will block this until they're stopped + for process in processes: + process.terminate() + + trigger_events(main_stop, loop) + + sock.close() + loop.close() + remove_unix_socket(unix) + + +def _build_protocol_kwargs( + protocol: Type[asyncio.Protocol], config: Config +) -> Dict[str, Union[int, float]]: + if hasattr(protocol, "websocket_handshake"): + return { + "websocket_max_size": config.WEBSOCKET_MAX_SIZE, + "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, + "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, + } + return {} diff --git a/backend/sanic_server/sanic/server/socket.py b/backend/sanic_server/sanic/server/socket.py new file mode 100644 index 000000000..3d908306c --- /dev/null +++ b/backend/sanic_server/sanic/server/socket.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +import secrets +import socket +import stat + +from ipaddress import ip_address +from typing import Optional + + +def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: + """Create TCP server socket. + :param host: IPv4, IPv6 or hostname may be specified + :param port: TCP port number + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + try: # IP address: family must be specified for IPv6 at least + ip = ip_address(host) + host = str(ip) + sock = socket.socket( + socket.AF_INET6 if ip.version == 6 else socket.AF_INET + ) + except ValueError: # Hostname, may become AF_INET or AF_INET6 + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(backlog) + return sock + + +def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: + """Create unix socket. + :param path: filesystem path + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + """Open or atomically replace existing socket with zero downtime.""" + # Sanitise and pre-verify socket path + path = os.path.abspath(path) + folder = os.path.dirname(path) + if not os.path.isdir(folder): + raise FileNotFoundError(f"Socket folder does not exist: {folder}") + try: + if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + raise FileExistsError(f"Existing file is not a socket: {path}") + except FileNotFoundError: + pass + # Create new socket with a random temporary name + tmp_path = f"{path}.{secrets.token_urlsafe()}" + sock = socket.socket(socket.AF_UNIX) + try: + # Critical section begins (filename races) + sock.bind(tmp_path) + try: + os.chmod(tmp_path, mode) + # Start listening before rename to avoid connection failures + sock.listen(backlog) + os.rename(tmp_path, path) + except: # noqa: E722 + try: + os.unlink(tmp_path) + finally: + raise + except: # noqa: E722 + try: + sock.close() + finally: + raise + return sock + + +def remove_unix_socket(path: Optional[str]) -> None: + """Remove dead unix socket during server exit.""" + if not path: + return + try: + if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + # Is it actually dead (doesn't belong to a new server instance)? + with socket.socket(socket.AF_UNIX) as testsock: + try: + testsock.connect(path) + except ConnectionRefusedError: + os.unlink(path) + except FileNotFoundError: + pass diff --git a/backend/sanic_server/sanic/server/websockets/__init__.py b/backend/sanic_server/sanic/server/websockets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic/server/websockets/connection.py b/backend/sanic_server/sanic/server/websockets/connection.py new file mode 100644 index 000000000..c53a65a58 --- /dev/null +++ b/backend/sanic_server/sanic/server/websockets/connection.py @@ -0,0 +1,82 @@ +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + MutableMapping, + Optional, + Union, +) + + +ASIMessage = MutableMapping[str, Any] + + +class WebSocketConnection: + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ + + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, + ) -> None: + self._send = send + self._receive = receive + self._subprotocols = subprotocols or [] + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} + + if isinstance(data, bytes): + message.update({"bytes": data}) + else: + message.update({"text": str(data)}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + + return None + + receive = recv + + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + + await self._send( + { + "type": "websocket.accept", + "subprotocol": subprotocol, + } + ) + + async def close(self, code: int = 1000, reason: str = "") -> None: + pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/backend/sanic_server/sanic/server/websockets/frame.py b/backend/sanic_server/sanic/server/websockets/frame.py new file mode 100644 index 000000000..8579d3bd2 --- /dev/null +++ b/backend/sanic_server/sanic/server/websockets/frame.py @@ -0,0 +1,291 @@ +import asyncio +import codecs + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from websockets.frames import Frame, Opcode +from websockets.typing import Data + +from ....sanic.exceptions import ServerError + + +if TYPE_CHECKING: + from .impl import WebsocketImplProtocol + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + + __slots__ = ( + "protocol", + "read_mutex", + "write_mutex", + "message_complete", + "message_fetched", + "get_in_progress", + "decoder", + "completed_queue", + "chunks", + "chunks_queue", + "paused", + "get_id", + "put_id", + ) + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data] + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder = None + + # Buffer data from frames belonging to the same message. + self.chunks = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue = None + + # Flag to indicate we've paused the protocol + self.paused = False + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one task can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + try: + await asyncio.wait_for( + self.message_complete.wait(), timeout=timeout + ) + except asyncio.TimeoutError: + ... + finally: + completed = self.message_complete.is_set() + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + if not self.message_complete.is_set(): + return None + + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is here + # as a failsafe + raise ServerError( + "Websocket get() found a message when " "state was already fetched." + ) + self.message_fetched.set() + self.chunks = [] + # this should already be None, but set it here for safety + self.chunks_queue = None + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one task can get here + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + if not self.message_complete.is_set(): + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "Websocket frame assembler chunks queue ended before " + "message was complete." + ) + self.message_complete.clear() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is + # here as a failsafe + raise ServerError( + "Websocket get_iter() found a message when state was " + "already fetched." + ) + + self.message_fetched.set() + # this should already be empty, but set it here for safety + self.chunks = [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + # nobody is waiting for this frame, so try to pause subsequent + # frames at the protocol level + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + if self.message_complete.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + self.message_complete.set() # Signal to get() it can serve the + if self.message_fetched.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when the previous " + "message was not yet fetched." + ) + + # Allow get() to run and eventually set the event. + await self.message_fetched.wait() + self.message_fetched.clear() + self.decoder = None diff --git a/backend/sanic_server/sanic/server/websockets/impl.py b/backend/sanic_server/sanic/server/websockets/impl.py new file mode 100644 index 000000000..c206f25c9 --- /dev/null +++ b/backend/sanic_server/sanic/server/websockets/impl.py @@ -0,0 +1,806 @@ +import asyncio +import random +import struct + +from typing import ( + AsyncIterator, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +from websockets.connection import CLOSED, CLOSING, OPEN, Event +from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.frames import Frame, Opcode +from websockets.server import ServerConnection +from websockets.typing import Data + +from ....sanic.log import error_logger, logger +from ....sanic.server.protocols.base_protocol import SanicProtocol + +from ...exceptions import ServerError, WebsocketClosed +from .frame import WebsocketFrameAssembler + + +class WebsocketImplProtocol: + connection: ServerConnection + io_proto: Optional[SanicProtocol] + loop: Optional[asyncio.AbstractEventLoop] + max_queue: int + close_timeout: float + ping_interval: Optional[float] + ping_timeout: Optional[float] + assembler: WebsocketFrameAssembler + # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] + conn_mutex: asyncio.Lock + recv_lock: asyncio.Lock + recv_cancel: Optional[asyncio.Future] + process_event_mutex: asyncio.Lock + can_pause: bool + # Optional[asyncio.Future[None]] + data_finished_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] + keepalive_ping_task: Optional[asyncio.Task] + auto_closer_task: Optional[asyncio.Task] + + def __init__( + self, + connection, + max_queue=None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: float = 10, + loop=None, + ): + self.connection = connection + self.io_proto = None + self.loop = None + self.max_queue = max_queue + self.close_timeout = close_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.assembler = WebsocketFrameAssembler(self) + self.pings = {} + self.conn_mutex = asyncio.Lock() + self.recv_lock = asyncio.Lock() + self.recv_cancel = None + self.process_event_mutex = asyncio.Lock() + self.data_finished_fut = None + self.can_pause = True + self.pause_frame_fut = None + self.keepalive_ping_task = None + self.auto_closer_task = None + self.connection_lost_waiter = None + + @property + def subprotocol(self): + return self.connection.subprotocol + + def pause_frames(self): + if not self.can_pause: + return False + if self.pause_frame_fut: + logger.debug("Websocket connection already paused.") + return False + if (not self.loop) or (not self.io_proto): + return False + if self.io_proto.transport: + self.io_proto.transport.pause_reading() + self.pause_frame_fut = self.loop.create_future() + logger.debug("Websocket connection paused.") + return True + + def resume_frames(self): + if not self.pause_frame_fut: + logger.debug("Websocket connection not paused.") + return False + if (not self.loop) or (not self.io_proto): + logger.debug( + "Websocket attempting to resume reading frames, " + "but connection is gone." + ) + return False + if self.io_proto.transport: + self.io_proto.transport.resume_reading() + self.pause_frame_fut.set_result(None) + self.pause_frame_fut = None + logger.debug("Websocket connection unpaused.") + return True + + async def connection_made( + self, + io_proto: SanicProtocol, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if not loop: + try: + loop = getattr(io_proto, "loop") + except AttributeError: + loop = asyncio.get_event_loop() + if not loop: + # This catch is for mypy type checker + # to assert loop is not None here. + raise ServerError("Connection received with no asyncio loop.") + if self.auto_closer_task: + raise ServerError( + "Cannot call connection_made more than once " + "on a websocket connection." + ) + self.loop = loop + self.io_proto = io_proto + self.connection_lost_waiter = self.loop.create_future() + self.data_finished_fut = asyncio.shield(self.loop.create_future()) + + if self.ping_interval: + self.keepalive_ping_task = asyncio.create_task(self.keepalive_ping()) + self.auto_closer_task = asyncio.create_task(self.auto_close_connection()) + + async def wait_for_connection_lost(self, timeout=None) -> bool: + """ + Wait until the TCP connection is closed or ``timeout`` elapses. + If timeout is None, wait forever. + Recommend you should pass in self.close_timeout as timeout + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter: + return False + if self.connection_lost_waiter.done(): + return True + else: + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), timeout + ) + return True + except asyncio.TimeoutError: + # Re-check self.connection_lost_waiter.done() synchronously + # because connection_lost() could run between the moment the + # timeout occurs and the moment this coroutine resumes running + return self.connection_lost_waiter.done() + + async def process_events(self, events: Sequence[Event]) -> None: + """ + Process a list of incoming events. + """ + # Wrapped in a mutex lock, to prevent other incoming events + # from processing at the same time + async with self.process_event_mutex: + for event in events: + if not isinstance(event, Frame): + # Event is not a frame. Ignore it. + continue + if event.opcode == Opcode.PONG: + await self.process_pong(event) + elif event.opcode == Opcode.CLOSE: + if self.recv_cancel: + self.recv_cancel.cancel() + else: + await self.assembler.put(event) + + async def process_pong(self, frame: Frame) -> None: + if frame.data in self.pings: + # Acknowledge all pings up to the one matching this pong. + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # noqa + raise ServerError("ping_id is not in self.pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when auto_close_connection() cancels keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + ping_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for(ping_waiter, self.ping_timeout) + except asyncio.TimeoutError: + error_logger.warning("Websocket timed out waiting for pong") + self.fail_connection(1011) + break + except asyncio.CancelledError: + # It is expected for this task to be cancelled during during + # normal operation, when the connection is closed. + logger.debug("Websocket keepalive ping task was cancelled.") + except (ConnectionClosed, WebsocketClosed): + logger.debug("Websocket closed. Keepalive ping task exiting.") + except Exception as e: + error_logger.warning( + "Unexpected exception in websocket keepalive ping task." + ) + logger.debug(str(e)) + + def _force_disconnect(self) -> bool: + """ + Internal methdod used by end_connection and fail_connection + only when the graceful auto-closer cannot be used + """ + if self.auto_closer_task and not self.auto_closer_task.done(): + self.auto_closer_task.cancel() + if self.data_finished_fut and not self.data_finished_fut.done(): + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.keepalive_ping_task and not self.keepalive_ping_task.done(): + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + if self.loop and self.io_proto and self.io_proto.transport: + self.io_proto.transport.close() + self.loop.call_later(self.close_timeout, self.io_proto.transport.abort) + # We were never open, or already closed + return True + + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: + """ + Fail the WebSocket Connection + This requires: + 1. Stopping all processing of incoming data, which means cancelling + pausing the underlying io protocol. The close code will be 1006 + unless a close frame was received earlier. + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + 3. Closing the connection. :meth:`auto_close_connection` takes care + of this. + (The specification describes these steps in the opposite order.) + """ + if self.io_proto and self.io_proto.transport: + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # ut can be called when the transport is already paused or closed + self.io_proto.transport.pause_reading() + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not draining the write buffer is acceptable in this context. + + # clear the send buffer + _ = self.connection.data_to_send() + # If we're not already CLOSED or CLOSING, then send the close. + if self.connection.state is OPEN: + if code in (1000, 1001): + self.connection.send_close(code, reason) + else: + self.connection.fail(code, reason) + try: + data_to_send = self.connection.data_to_send() + while ( + len(data_to_send) and self.io_proto and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... + if code == 1006: + # Special case: 1006 consider the transport already closed + self.connection.state = CLOSED + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + return self._force_disconnect() + return False + + def end_connection(self, code=1000, reason=""): + # This is like slightly more graceful form of fail_connection + # Use this instead of close() when you need an immediate + # close and cannot await websocket.close() handshake. + + if code == 1006 or not self.io_proto or not self.io_proto.transport: + return self.fail_connection(code, reason) + + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + if self.connection.state == OPEN: + data_to_send = self.connection.data_to_send() + self.connection.send_close(code, reason) + data_to_send.extend(self.connection.data_to_send()) + try: + while len(data_to_send) and self.io_proto and self.io_proto.transport: + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + # But that doesn't matter at this point + ... + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have the ability to signal the auto-closer + # try to trigger it to auto-close the connection + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + # Auto-closer is not running, do force disconnect + return self._force_disconnect() + return False + + async def auto_close_connection(self) -> None: + """ + Close the WebSocket Connection + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + """ + try: + # Wait for the data transfer phase to complete. + if self.data_finished_fut: + try: + await self.data_finished_fut + logger.debug("Websocket task finished. Closing the connection.") + except asyncio.CancelledError: + # Cancelled error is called when data phase is cancelled + # if an error occurred or the client closed the connection + logger.debug("Websocket handler cancelled. Closing the connection.") + + # Cancel the keepalive ping task. + if self.keepalive_ping_task: + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + + # Half-close the TCP connection if possible (when there's no TLS). + if ( + self.io_proto + and self.io_proto.transport + and self.io_proto.transport.can_write_eof() + ): + logger.debug("Websocket half-closing TCP connection") + self.io_proto.transport.write_eof() + if self.connection_lost_waiter: + if await self.wait_for_connection_lost(timeout=0): + return + except asyncio.CancelledError: + ... + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + if (not self.io_proto) or (not self.io_proto.transport): + # we were never open, or done. Can't do any finalization. + return + elif self.connection_lost_waiter and self.connection_lost_waiter.done(): + # connection confirmed closed already, proceed to abort waiter + ... + elif self.io_proto.transport.is_closing(): + # Connection is already closing (due to half-close above) + # proceed to abort waiter + ... + else: + self.io_proto.transport.close() + if not self.connection_lost_waiter: + # Our connection monitor task isn't running. + try: + await asyncio.sleep(self.close_timeout) + except asyncio.CancelledError: + ... + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + else: + if await self.wait_for_connection_lost(timeout=self.close_timeout): + # Connection aborted before the timeout expired. + return + error_logger.warning( + "Timeout waiting for TCP connection to close. Aborting" + ) + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + They'll never receive a pong once the connection is closed. + """ + if self.connection.state is not CLOSED: + raise ServerError( + "Webscoket about_pings should only be called " + "after connection state is changed to CLOSED" + ) + + for ping in self.pings.values(): + ping.set_exception(ConnectionClosedError(None, None)) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + This is a websocket-protocol level close. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ + if code == 1006: + self.fail_connection(code, reason) + return + async with self.conn_mutex: + if self.connection.state is OPEN: + self.connection.send_close(code, reason) + data_to_send = self.connection.data_to_send() + await self.send_data(data_to_send) + + async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Receive the next message. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + If ``timeout`` is ``None``, block until a message is received. Else, + if no message is received within ``timeout`` seconds, return ``None``. + Set ``timeout`` to ``0`` to check if a message was already received. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises asyncio.CancelledError: if the websocket closes while waiting + :raises ServerError: if two tasks call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv while another task is " + "already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + self.recv_cancel = asyncio.Future() + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + else: + self.recv_cancel.cancel() + return done_task.result() + finally: + self.recv_cancel = None + self.recv_lock.release() + + async def recv_burst(self, max_recv=256) -> Sequence[Data]: + """ + Receive the messages which have arrived since last checking. + Return a :class:`list` containing :class:`str` for a text frame + and :class:`bytes` for a binary frame. + When the end of the message stream is reached, :meth:`recv_burst` + raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a + normal connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ServerError: if two tasks call :meth:`recv_burst` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv_burst while another task is already waiting " + "for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + messages = [] + try: + # Prevent pausing the transport when we're + # receiving a burst of messages + self.can_pause = False + self.recv_cancel = asyncio.Future() + while True: + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout=0)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv_burst was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + m = done_task.result() + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) + self.recv_cancel.cancel() + finally: + self.recv_cancel = None + self.can_pause = True + self.recv_lock.release() + return messages + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + Return an iterator of :class:`str` for a text frame and :class:`bytes` + for a binary frame. The iterator should be exhausted, or else the + connection will become unusable. + With the exception of the return value, :meth:`recv_streaming` behaves + like :meth:`recv`. + """ + if self.recv_lock.locked(): + raise ServerError( + "Cannot call recv_streaming while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + cancelled = False + self.recv_cancel = asyncio.Future() + self.can_pause = False + async for m in self.assembler.get_iter(): + if self.recv_cancel.done(): + cancelled = True + break + yield m + if cancelled: + raise asyncio.CancelledError() + finally: + self.can_pause = True + self.recv_cancel = None + self.recv_lock.release() + + async def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + :raises TypeError: for unsupported inputs + """ + async with self.conn_mutex: + + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot write to websocket interface after it is closed." + ) + if (not self.data_finished_fut) or self.data_finished_fut.done(): + raise ServerError( + "Cannot write to websocket interface after it is finished." + ) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + self.connection.send_text(message.encode("utf-8")) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, (bytes, bytearray, memoryview)): + self.connection.send_binary(message) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, Mapping): + # Catch a common mistake -- passing a dict to send(). + raise TypeError("data is a dict-like object") + + elif isinstance(message, Iterable): + # Fragmented message -- regular iterator. + raise NotImplementedError( + "Fragmented websocket messages are not supported." + ) + else: + raise TypeError("Websocket data must be bytes, str.") + + async def ping(self, data: Optional[Data] = None) -> asyncio.Future: + """ + Send a ping. + Return an :class:`~asyncio.Future` that will be resolved when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + await pong_event = ws.ping() + await pong_event # only if you want to wait for the pong + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot send a ping when the websocket interface " "is closed." + ) + if (not self.io_proto) or (not self.io_proto.loop): + raise ServerError( + "Cannot send a ping when the websocket has no I/O " + "protocol attached." + ) + if data is not None: + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.io_proto.loop.create_future() + + self.connection.send_ping(data) + await self.send_data(self.connection.data_to_send()) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + An unsolicited pong may serve as a unidirectional heartbeat. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + # Cannot send pong after transport is shutting down + return + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + self.connection.send_pong(data) + await self.send_data(self.connection.data_to_send()) + + async def send_data(self, data_to_send): + for data in data_to_send: + if data: + await self.io_proto.send(data) + else: + # Send an EOF - We don't actually send it, + # just trigger to autoclose the connection + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + async def async_data_received(self, data_to_send, events_to_process): + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + # receiving data can generate data to send (eg, pong for a ping) + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + def data_received(self, data): + self.connection.receive_data(data) + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task( + self.async_data_received(data_to_send, events_to_process) + ) + + async def async_eof_received(self, data_to_send, events_to_process): + # receiving EOF can generate data to send + # send connection.data_to_send() + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + if self.recv_cancel: + self.recv_cancel.cancel() + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + # Cancel the running handler if its waiting + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + def eof_received(self) -> Optional[bool]: + self.connection.receive_eof() + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + asyncio.create_task(self.async_eof_received(data_to_send, events_to_process)) + return False + + def connection_lost(self, exc): + """ + The WebSocket Connection is Closed. + """ + if not self.connection.state == CLOSED: + # signal to the websocket connection handler + # we've lost the connection + self.connection.fail(code=1006) + self.connection.state = CLOSED + + self.abort_pings() + if self.connection_lost_waiter: + self.connection_lost_waiter.set_result(None) diff --git a/backend/sanic_server/sanic/signals.py b/backend/sanic_server/sanic/signals.py new file mode 100644 index 000000000..174bee398 --- /dev/null +++ b/backend/sanic_server/sanic/signals.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import asyncio + +from inspect import isawaitable +from typing import Any, Dict, List, Optional, Tuple, Union + +from ..sanic_routing import BaseRouter, Route, RouteGroup # type: ignore +from ..sanic_routing.exceptions import NotFound # type: ignore +from ..sanic_routing.utils import path_to_parts # type: ignore + +from ..sanic.exceptions import InvalidSignal +from ..sanic.log import error_logger, logger +from ..sanic.models.handler_types import SignalHandler + + +RESERVED_NAMESPACES = { + "server": ( + # "server.main.start", + # "server.main.stop", + "server.init.before", + "server.init.after", + "server.shutdown.before", + "server.shutdown.after", + ), + "http": ( + "http.lifecycle.begin", + "http.lifecycle.complete", + "http.lifecycle.exception", + "http.lifecycle.handle", + "http.lifecycle.read_body", + "http.lifecycle.read_head", + "http.lifecycle.request", + "http.lifecycle.response", + "http.routing.after", + "http.routing.before", + "http.lifecycle.send", + "http.middleware.after", + "http.middleware.before", + ), +} + + +def _blank(): + ... + + +class Signal(Route): + ... + + +class SignalGroup(RouteGroup): + ... + + +class SignalRouter(BaseRouter): + def __init__(self) -> None: + super().__init__( + delimiter=".", + route_class=Signal, + group_class=SignalGroup, + stacking=True, + ) + self.ctx.loop = None + + def get( # type: ignore + self, + event: str, + condition: Optional[Dict[str, str]] = None, + ): + extra = condition or {} + try: + group, param_basket = self.find_route( + f".{event}", + self.DEFAULT_METHOD, + self, + {"__params__": {}, "__matches__": {}}, + extra=extra, + ) + except NotFound: + message = "Could not find signal %s" + terms: List[Union[str, Optional[Dict[str, str]]]] = [event] + if extra: + message += " with %s" + terms.append(extra) + raise NotFound(message % tuple(terms)) + + # Regex routes evaluate and can extract params directly. They are set + # on param_basket["__params__"] + params = param_basket["__params__"] + if not params: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. + params = { + param.name: param_basket["__matches__"][idx] + for idx, param in group.params.items() + } + + return group, [route.handler for route in group], params + + async def _dispatch( + self, + event: str, + context: Optional[Dict[str, Any]] = None, + condition: Optional[Dict[str, str]] = None, + fail_not_found: bool = True, + reverse: bool = False, + ) -> Any: + try: + group, handlers, params = self.get(event, condition=condition) + except NotFound as e: + if fail_not_found: + raise e + else: + if self.ctx.app.debug: + error_logger.warning(str(e)) + return None + + events = [signal.ctx.event for signal in group] + for signal_event in events: + signal_event.set() + if context: + params.update(context) + + if not reverse: + handlers = handlers[::-1] + try: + for handler in handlers: + if condition is None or condition == handler.__requirements__: + maybe_coroutine = handler(**params) + if isawaitable(maybe_coroutine): + retval = await maybe_coroutine + if retval: + return retval + elif maybe_coroutine: + return maybe_coroutine + return None + finally: + for signal_event in events: + signal_event.clear() + + async def dispatch( + self, + event: str, + *, + context: Optional[Dict[str, Any]] = None, + condition: Optional[Dict[str, str]] = None, + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, + ) -> Union[asyncio.Task, Any]: + dispatch = self._dispatch( + event, + context=context, + condition=condition, + fail_not_found=fail_not_found and inline, + reverse=reverse, + ) + logger.debug(f"Dispatching signal: {event}") + + if inline: + return await dispatch + + task = asyncio.get_running_loop().create_task(dispatch) + await asyncio.sleep(0) + return task + + def add( # type: ignore + self, + handler: SignalHandler, + event: str, + condition: Optional[Dict[str, Any]] = None, + ) -> Signal: + parts = self._build_event_parts(event) + if parts[2].startswith("<"): + name = ".".join([*parts[:-1], "*"]) + else: + name = event + + handler.__requirements__ = condition # type: ignore + + return super().add( + event, + handler, + requirements=condition, + name=name, + append=True, + ) # type: ignore + + def finalize(self, do_compile: bool = True, do_optimize: bool = False): + self.add(_blank, "sanic.__signal__.__init__") + + try: + self.ctx.loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError("Cannot finalize signals outside of event loop") + + for signal in self.routes: + signal.ctx.event = asyncio.Event() + + return super().finalize(do_compile=do_compile, do_optimize=do_optimize) + + def _build_event_parts(self, event: str) -> Tuple[str, str, str]: + parts = path_to_parts(event, self.delimiter) + if len(parts) != 3 or parts[0].startswith("<") or parts[1].startswith("<"): + raise InvalidSignal("Invalid signal event: %s" % event) + + if ( + parts[0] in RESERVED_NAMESPACES + and event not in RESERVED_NAMESPACES[parts[0]] + and not (parts[2].startswith("<") and parts[2].endswith(">")) + ): + raise InvalidSignal("Cannot declare reserved signal event: %s" % event) + return parts diff --git a/backend/sanic_server/sanic/simple.py b/backend/sanic_server/sanic/simple.py new file mode 100644 index 000000000..e12361233 --- /dev/null +++ b/backend/sanic_server/sanic/simple.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from ..sanic import Sanic +from ..sanic.exceptions import SanicException +from ..sanic.response import redirect + + +def create_simple_server(directory: Path): + if not directory.is_dir(): + raise SanicException( + "Cannot setup Sanic Simple Server without a path to a directory" + ) + + app = Sanic("SimpleServer") + app.static("/", directory, name="main") + + @app.get("/") + def index(_): + return redirect(app.url_for("main", filename="index.html")) + + return app diff --git a/backend/sanic_server/sanic/touchup/__init__.py b/backend/sanic_server/sanic/touchup/__init__.py new file mode 100644 index 000000000..6fe208abb --- /dev/null +++ b/backend/sanic_server/sanic/touchup/__init__.py @@ -0,0 +1,8 @@ +from .meta import TouchUpMeta +from .service import TouchUp + + +__all__ = ( + "TouchUp", + "TouchUpMeta", +) diff --git a/backend/sanic_server/sanic/touchup/meta.py b/backend/sanic_server/sanic/touchup/meta.py new file mode 100644 index 000000000..af811a456 --- /dev/null +++ b/backend/sanic_server/sanic/touchup/meta.py @@ -0,0 +1,22 @@ +from ...sanic.exceptions import SanicException + +from .service import TouchUp + + +class TouchUpMeta(type): + def __new__(cls, name, bases, attrs, **kwargs): + gen_class = super().__new__(cls, name, bases, attrs, **kwargs) + + methods = attrs.get("__touchup__") + attrs["__touched__"] = False + if methods: + + for method in methods: + if method not in attrs: + raise SanicException( + "Cannot perform touchup on non-existent method: " + f"{name}.{method}" + ) + TouchUp.register(gen_class, method) + + return gen_class diff --git a/backend/sanic_server/sanic/touchup/schemes/__init__.py b/backend/sanic_server/sanic/touchup/schemes/__init__.py new file mode 100644 index 000000000..87057a5fc --- /dev/null +++ b/backend/sanic_server/sanic/touchup/schemes/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseScheme +from .ode import OptionalDispatchEvent # noqa + + +__all__ = ("BaseScheme",) diff --git a/backend/sanic_server/sanic/touchup/schemes/base.py b/backend/sanic_server/sanic/touchup/schemes/base.py new file mode 100644 index 000000000..d16619b2f --- /dev/null +++ b/backend/sanic_server/sanic/touchup/schemes/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Set, Type + + +class BaseScheme(ABC): + ident: str + _registry: Set[Type] = set() + + def __init__(self, app) -> None: + self.app = app + + @abstractmethod + def run(self, method, module_globals) -> None: + ... + + def __init_subclass__(cls): + BaseScheme._registry.add(cls) + + def __call__(self, method, module_globals): + return self.run(method, module_globals) diff --git a/backend/sanic_server/sanic/touchup/schemes/ode.py b/backend/sanic_server/sanic/touchup/schemes/ode.py new file mode 100644 index 000000000..c8a0e3b57 --- /dev/null +++ b/backend/sanic_server/sanic/touchup/schemes/ode.py @@ -0,0 +1,65 @@ +from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any + +from ....sanic.log import logger + +from .base import BaseScheme + + +class OptionalDispatchEvent(BaseScheme): + ident = "ODE" + + def __init__(self, app) -> None: + super().__init__(app) + + self._registered_events = [signal.path for signal in app.signal_router.routes] + + def run(self, method, module_globals): + raw_source = getsource(method) + src = dedent(raw_source) + tree = parse(src) + node = RemoveDispatch(self._registered_events).visit(tree) + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + + return exec_locals[method.__name__] + + +class RemoveDispatch(NodeTransformer): + def __init__(self, registered_events) -> None: + self._registered_events = registered_events + + def visit_Expr(self, node: Expr) -> Any: + call = node.value + if isinstance(call, Await): + call = call.value + + func = getattr(call, "func", None) + args = getattr(call, "args", None) + if not func or not args: + return node + + if isinstance(func, Attribute) and func.attr == "dispatch": + event = args[0] + if hasattr(event, "s"): + event_name = getattr(event, "value", event.s) + if self._not_registered(event_name): + logger.debug(f"Disabling event: {event_name}") + return None + return node + + def _not_registered(self, event_name): + dynamic = [] + for event in self._registered_events: + if event.endswith(">"): + namespace_concern, _ = event.rsplit(".", 1) + dynamic.append(namespace_concern) + + namespace_concern, _ = event_name.rsplit(".", 1) + return ( + event_name not in self._registered_events + and namespace_concern not in dynamic + ) diff --git a/backend/sanic_server/sanic/touchup/service.py b/backend/sanic_server/sanic/touchup/service.py new file mode 100644 index 000000000..95792dca1 --- /dev/null +++ b/backend/sanic_server/sanic/touchup/service.py @@ -0,0 +1,33 @@ +from inspect import getmembers, getmodule +from typing import Set, Tuple, Type + +from .schemes import BaseScheme + + +class TouchUp: + _registry: Set[Tuple[Type, str]] = set() + + @classmethod + def run(cls, app): + for target, method_name in cls._registry: + method = getattr(target, method_name) + + if app.test_mode: + placeholder = f"_{method_name}" + if hasattr(target, placeholder): + method = getattr(target, placeholder) + else: + setattr(target, placeholder, method) + + module = getmodule(target) + module_globals = dict(getmembers(module)) + + for scheme in BaseScheme._registry: + modified = scheme(app)(method, module_globals) + setattr(target, method_name, modified) + + target.__touched__ = True + + @classmethod + def register(cls, target, method_name): + cls._registry.add((target, method_name)) diff --git a/backend/sanic_server/sanic/utils.py b/backend/sanic_server/sanic/utils.py new file mode 100644 index 000000000..2260da246 --- /dev/null +++ b/backend/sanic_server/sanic/utils.py @@ -0,0 +1,126 @@ +import types + +from importlib.util import module_from_spec, spec_from_file_location +from os import environ as os_environ +from pathlib import Path +from re import findall as re_findall +from typing import Union + +from ..sanic.exceptions import LoadFileException, PyFileError +from ..sanic.helpers import import_string + + +def str_to_bool(val: str) -> bool: + """Takes string and tries to turn it into bool as human would do. + + If val is in case insensitive ( + "y", "yes", "yep", "yup", "t", + "true", "on", "enable", "enabled", "1" + ) returns True. + If val is in case insensitive ( + "n", "no", "f", "false", "off", "disable", "disabled", "0" + ) returns False. + Else Raise ValueError.""" + + val = val.lower() + if val in { + "y", + "yes", + "yep", + "yup", + "t", + "true", + "on", + "enable", + "enabled", + "1", + }: + return True + elif val in {"n", "no", "f", "false", "off", "disable", "disabled", "0"}: + return False + else: + raise ValueError(f"Invalid truth value {val}") + + +def load_module_from_file_location( + location: Union[bytes, str, Path], encoding: str = "utf8", *args, **kwargs +): # noqa + """Returns loaded module provided as a file path. + + :param args: + Coresponds to importlib.util.spec_from_file_location location + parameters,but with this differences: + - It has to be of a string or bytes type. + - You can also use here environment variables + in format ${some_env_var}. + Mark that $some_env_var will not be resolved as environment variable. + :encoding: + If location parameter is of a bytes type, then use this encoding + to decode it into string. + :param args: + Coresponds to the rest of importlib.util.spec_from_file_location + parameters. + :param kwargs: + Coresponds to the rest of importlib.util.spec_from_file_location + parameters. + + For example You can: + + some_module = load_module_from_file_location( + "some_module_name", + "/some/path/${some_env_var}" + ) + """ + if isinstance(location, bytes): + location = location.decode(encoding) + + if isinstance(location, Path) or "/" in location or "$" in location: + + if not isinstance(location, Path): + # A) Check if location contains any environment variables + # in format ${some_env_var}. + env_vars_in_location = set(re_findall(r"\${(.+?)}", location)) + + # B) Check these variables exists in environment. + not_defined_env_vars = env_vars_in_location.difference(os_environ.keys()) + if not_defined_env_vars: + raise LoadFileException( + "The following environment variables are not set: " + f"{', '.join(not_defined_env_vars)}" + ) + + # C) Substitute them in location. + for env_var in env_vars_in_location: + location = location.replace("${" + env_var + "}", os_environ[env_var]) + + location = str(location) + if ".py" in location: + name = location.split("/")[-1].split(".")[ + 0 + ] # get just the file name without path and .py extension + _mod_spec = spec_from_file_location(name, location, *args, **kwargs) + assert _mod_spec is not None # type assertion for mypy + module = module_from_spec(_mod_spec) + _mod_spec.loader.exec_module(module) # type: ignore + + else: + module = types.ModuleType("config") + module.__file__ = str(location) + try: + with open(location) as config_file: + exec( # nosec + compile(config_file.read(), location, "exec"), + module.__dict__, + ) + except IOError as e: + e.strerror = "Unable to load configuration file (e.strerror)" + raise + except Exception as e: + raise PyFileError(location) from e + + return module + else: + try: + return import_string(location) + except ValueError: + raise IOError("Unable to load configuration %s" % str(location)) diff --git a/backend/sanic_server/sanic/views.py b/backend/sanic_server/sanic/views.py new file mode 100644 index 000000000..e7162977c --- /dev/null +++ b/backend/sanic_server/sanic/views.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + List, + Optional, + Union, +) +from warnings import warn + +from ..sanic.constants import HTTP_METHODS +from ..sanic.exceptions import InvalidUsage +from ..sanic.models.handler_types import RouteHandler + + +if TYPE_CHECKING: + from ..sanic import Sanic + from ..sanic.blueprints import Blueprint + + +class HTTPMethodView: + """Simple class based implementation of view for the sanic. + You should implement methods (get, post, put, patch, delete) for the class + to every HTTP method you want to support. + + For example: + + .. code-block:: python + + class DummyView(HTTPMethodView): + def get(self, request, *args, **kwargs): + return text('I am get method') + def put(self, request, *args, **kwargs): + return text('I am put method') + + If someone tries to use a non-implemented method, there will be a + 405 response. + + If you need any url params just mention them in method definition: + + .. code-block:: python + + class DummyView(HTTPMethodView): + def get(self, request, my_param_here, *args, **kwargs): + return text('I am get method with %s' % my_param_here) + + To add the view into the routing you could use + + 1) ``app.add_route(DummyView.as_view(), '/')``, OR + 2) ``app.route('/')(DummyView.as_view())`` + + To add any decorator you could set it into decorators variable + """ + + decorators: List[Callable[[Callable[..., Any]], Callable[..., Any]]] = [] + + def __init_subclass__( + cls, + attach: Optional[Union[Sanic, Blueprint]] = None, + uri: str = "", + methods: Iterable[str] = frozenset({"GET"}), + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + stream: bool = False, + version_prefix: str = "/v", + ) -> None: + if attach: + cls.attach( + attach, + uri=uri, + methods=methods, + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + stream=stream, + version_prefix=version_prefix, + ) + + def dispatch_request(self, request, *args, **kwargs): + handler = getattr(self, request.method.lower(), None) + return handler(request, *args, **kwargs) + + @classmethod + def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler: + """Return view function for use with the routing system, that + dispatches request to appropriate handler method. + """ + + def view(*args, **kwargs): + self = view.view_class(*class_args, **class_kwargs) + return self.dispatch_request(*args, **kwargs) + + if cls.decorators: + view.__module__ = cls.__module__ + for decorator in cls.decorators: + view = decorator(view) + + view.view_class = cls # type: ignore + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.__name__ = cls.__name__ + return view + + @classmethod + def attach( + cls, + to: Union[Sanic, Blueprint], + uri: str, + methods: Iterable[str] = frozenset({"GET"}), + host: Optional[str] = None, + strict_slashes: Optional[bool] = None, + version: Optional[int] = None, + name: Optional[str] = None, + stream: bool = False, + version_prefix: str = "/v", + ) -> None: + to.add_route( + cls.as_view(), + uri=uri, + methods=methods, + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + stream=stream, + version_prefix=version_prefix, + ) + + +def stream(func): + func.is_stream = True + return func + + +class CompositionView: + """Simple method-function mapped view for the sanic. + You can add handler functions to methods (get, post, put, patch, delete) + for every HTTP method you want to support. + + For example: + + .. code-block:: python + + view = CompositionView() + view.add(['GET'], lambda request: text('I am get method')) + view.add(['POST', 'PUT'], lambda request: text('I am post/put method')) + + If someone tries to use a non-implemented method, there will be a + 405 response. + """ + + def __init__(self): + self.handlers = {} + self.name = self.__class__.__name__ + warn( + "CompositionView has been deprecated and will be removed in " + "v21.12. Please update your view to HTTPMethodView.", + DeprecationWarning, + ) + + def __name__(self): + return self.name + + def add(self, methods, handler, stream=False): + if stream: + handler.is_stream = stream + for method in methods: + if method not in HTTP_METHODS: + raise InvalidUsage(f"{method} is not a valid HTTP method.") + + if method in self.handlers: + raise InvalidUsage(f"Method {method} is already registered.") + self.handlers[method] = handler + + def __call__(self, request, *args, **kwargs): + handler = self.handlers[request.method.upper()] + return handler(request, *args, **kwargs) diff --git a/backend/sanic_server/sanic/worker.py b/backend/sanic_server/sanic/worker.py new file mode 100644 index 000000000..c89b31df0 --- /dev/null +++ b/backend/sanic_server/sanic/worker.py @@ -0,0 +1,245 @@ +import asyncio +import logging +import os +import signal +import sys +import traceback + +from gunicorn.workers import base # type: ignore + +from ..sanic.log import logger +from ..sanic.server import HttpProtocol, Signal, serve +from ..sanic.server.protocols.websocket_protocol import WebSocketProtocol + +try: + import ssl # type: ignore +except ImportError: + ssl = None # type: ignore + +try: + import uvloop # type: ignore + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +class GunicornWorker(base.Worker): + + http_protocol = HttpProtocol + websocket_protocol = WebSocketProtocol + + def __init__(self, *args, **kw): # pragma: no cover + super().__init__(*args, **kw) + cfg = self.cfg + if cfg.is_ssl: + self.ssl_context = self._create_ssl_context(cfg) + else: + self.ssl_context = None + self.servers = {} + self.connections = set() + self.exit_code = 0 + self.signal = Signal() + + def init_process(self): + # create new event_loop after fork + asyncio.get_event_loop().close() + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + super().init_process() + + def run(self): + is_debug = self.log.loglevel == logging.DEBUG + protocol = ( + self.websocket_protocol + if self.app.callable.websocket_enabled + else self.http_protocol + ) + + self._server_settings = self.app.callable._helper( + loop=self.loop, + debug=is_debug, + protocol=protocol, + ssl=self.ssl_context, + run_async=True, + ) + self._server_settings["signal"] = self.signal + self._server_settings.pop("sock") + self._await(self.app.callable._startup()) + self._await( + self.app.callable._server_event("init", "before", loop=self.loop) + ) + + main_start = self._server_settings.pop("main_start", None) + main_stop = self._server_settings.pop("main_stop", None) + + if main_start or main_stop: # noqa + logger.warning( + "Listener events for the main process are not available " + "with GunicornWorker" + ) + + try: + self._await(self._run()) + self.app.callable.is_running = True + self._await( + self.app.callable._server_event( + "init", "after", loop=self.loop + ) + ) + self.loop.run_until_complete(self._check_alive()) + self._await( + self.app.callable._server_event( + "shutdown", "before", loop=self.loop + ) + ) + self.loop.run_until_complete(self.close()) + except BaseException: + traceback.print_exc() + finally: + try: + self._await( + self.app.callable._server_event( + "shutdown", "after", loop=self.loop + ) + ) + except BaseException: + traceback.print_exc() + finally: + self.loop.close() + + sys.exit(self.exit_code) + + async def close(self): + if self.servers: + # stop accepting connections + self.log.info( + "Stopping server: %s, connections: %s", + self.pid, + len(self.connections), + ) + for server in self.servers: + server.close() + await server.wait_closed() + self.servers.clear() + + # prepare connections for closing + self.signal.stopped = True + for conn in self.connections: + conn.close_if_idle() + + # gracefully shutdown timeout + start_shutdown = 0 + graceful_shutdown_timeout = self.cfg.graceful_timeout + while self.connections and ( + start_shutdown < graceful_shutdown_timeout + ): + await asyncio.sleep(0.1) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + for conn in self.connections: + if hasattr(conn, "websocket") and conn.websocket: + conn.websocket.fail_connection(code=1001) + else: + conn.abort() + + async def _run(self): + for sock in self.sockets: + state = dict(requests_count=0) + self._server_settings["host"] = None + self._server_settings["port"] = None + server = await serve( + sock=sock, + connections=self.connections, + state=state, + **self._server_settings + ) + self.servers[server] = state + + async def _check_alive(self): + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: + self.notify() + + req_count = sum( + self.servers[srv]["requests_count"] for srv in self.servers + ) + if self.max_requests and req_count > self.max_requests: + self.alive = False + self.log.info( + "Max requests exceeded, shutting down: %s", self + ) + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await asyncio.sleep(1.0, loop=self.loop) + except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): + pass + + @staticmethod + def _create_ssl_context(cfg): + """Creates SSLContext instance for usage in asyncio.create_server. + See ssl.SSLSocket.__init__ for more details. + """ + ctx = ssl.SSLContext(cfg.ssl_version) + ctx.load_cert_chain(cfg.certfile, cfg.keyfile) + ctx.verify_mode = cfg.cert_reqs + if cfg.ca_certs: + ctx.load_verify_locations(cfg.ca_certs) + if cfg.ciphers: + ctx.set_ciphers(cfg.ciphers) + return ctx + + def init_signals(self): + # Set up signals through the event loop API. + + self.loop.add_signal_handler( + signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None + ) + + self.loop.add_signal_handler( + signal.SIGTERM, self.handle_exit, signal.SIGTERM, None + ) + + self.loop.add_signal_handler( + signal.SIGINT, self.handle_quit, signal.SIGINT, None + ) + + self.loop.add_signal_handler( + signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None + ) + + self.loop.add_signal_handler( + signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None + ) + + self.loop.add_signal_handler( + signal.SIGABRT, self.handle_abort, signal.SIGABRT, None + ) + + # Don't let SIGTERM and SIGUSR1 disturb active requests + # by interrupting system calls + signal.siginterrupt(signal.SIGTERM, False) + signal.siginterrupt(signal.SIGUSR1, False) + + def handle_quit(self, sig, frame): + self.alive = False + self.app.callable.is_running = False + self.cfg.worker_int(self) + + def handle_abort(self, sig, frame): + self.alive = False + self.exit_code = 1 + self.cfg.worker_abort(self) + sys.exit(1) + + def _await(self, coro): + fut = asyncio.ensure_future(coro, loop=self.loop) + self.loop.run_until_complete(fut) diff --git a/backend/sanic_server/sanic_cors/__init__.py b/backend/sanic_server/sanic_cors/__init__.py new file mode 100644 index 000000000..872fc8534 --- /dev/null +++ b/backend/sanic_server/sanic_cors/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" + sanic_cors + ~~~~ + Sanic-CORS is a simple extension to Sanic allowing you to support cross + origin resource sharing (CORS) using a simple decorator. + + :copyright: (c) 2021 by Ashley Sommer (based on flask-cors by Cory Dolphin). + :license: MIT, see LICENSE for more details. +""" +from .decorator import cross_origin +from .extension import CORS +from .version import __version__ + +__all__ = ['CORS', 'cross_origin'] + +# Set default logging handler to avoid "No handler found" warnings. +import logging +from logging import NullHandler + +# Set initial level to WARN. Users must manually enable logging for +# sanic_cors to see our logging. +rootlogger = logging.getLogger(__name__) +rootlogger.addHandler(NullHandler()) + +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) diff --git a/backend/sanic_server/sanic_cors/core.py b/backend/sanic_server/sanic_cors/core.py new file mode 100644 index 000000000..fe72c9015 --- /dev/null +++ b/backend/sanic_server/sanic_cors/core.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +""" + core + ~~~~ + Core functionality shared between the extension and the decorator. + + :copyright: (c) 2021 by Ashley Sommer (based on flask-cors by Cory Dolphin). + :license: MIT, see LICENSE for more details. +""" +import collections +import logging +import re +from datetime import timedelta + +try: + # Sanic compat Header from Sanic v19.9.0 and above + from ..sanic.compat import Header as CIMultiDict +except ImportError: + try: + # Sanic server CIMultiDict from Sanic v0.8.0 and above + from ..sanic.server import CIMultiDict + except ImportError: + raise RuntimeError("Your version of sanic does not support " "CIMultiDict") + +LOG = logging.getLogger(__name__) + +# Response Headers +ACL_ORIGIN = "Access-Control-Allow-Origin" +ACL_METHODS = "Access-Control-Allow-Methods" +ACL_ALLOW_HEADERS = "Access-Control-Allow-Headers" +ACL_EXPOSE_HEADERS = "Access-Control-Expose-Headers" +ACL_CREDENTIALS = "Access-Control-Allow-Credentials" +ACL_MAX_AGE = "Access-Control-Max-Age" + +# Request Header +ACL_REQUEST_METHOD = "Access-Control-Request-Method" +ACL_REQUEST_HEADERS = "Access-Control-Request-Headers" + +ALL_METHODS = ["GET", "HEAD", "POST", "OPTIONS", "PUT", "PATCH", "DELETE"] +CONFIG_OPTIONS = [ + "CORS_ORIGINS", + "CORS_METHODS", + "CORS_ALLOW_HEADERS", + "CORS_EXPOSE_HEADERS", + "CORS_SUPPORTS_CREDENTIALS", + "CORS_MAX_AGE", + "CORS_SEND_WILDCARD", + "CORS_AUTOMATIC_OPTIONS", + "CORS_VARY_HEADER", + "CORS_RESOURCES", + "CORS_INTERCEPT_EXCEPTIONS", + "CORS_ALWAYS_SEND", +] +# Attribute added to request object by decorator to indicate that CORS +# was evaluated, in case the decorator and extension are both applied +# to a view. +# TODO: Refactor these two flags down into one flag. +SANIC_CORS_EVALUATED = "_sanic_cors_e" +SANIC_CORS_SKIP_RESPONSE_MIDDLEWARE = "_sanic_cors_srm" + +# Strange, but this gets the type of a compiled regex, which is otherwise not +# exposed in a public API. +RegexObject = type(re.compile("")) +DEFAULT_OPTIONS = dict( + origins="*", + methods=ALL_METHODS, + allow_headers="*", + expose_headers=None, + supports_credentials=False, + max_age=None, + send_wildcard=False, + automatic_options=True, + vary_header=True, + resources=r"/*", + intercept_exceptions=True, + always_send=True, +) + + +def parse_resources(resources): + if isinstance(resources, dict): + # To make the API more consistent with the decorator, allow a + # resource of '*', which is not actually a valid regexp. + resources = [(re_fix(k), v) for k, v in resources.items()] + + # Sort by regex length to provide consistency of matching and + # to provide a proxy for specificity of match. E.G. longer + # regular expressions are tried first. + def pattern_length(pair): + maybe_regex, _ = pair + return len(get_regexp_pattern(maybe_regex)) + + return sorted(resources, key=pattern_length, reverse=True) + + elif isinstance(resources, str): + return [(re_fix(resources), {})] + + elif isinstance(resources, collections.abc.Iterable): + return [(re_fix(r), {}) for r in resources] + + # Type of compiled regex is not part of the public API. Test for this + # at runtime. + elif isinstance(resources, RegexObject): + return [(re_fix(resources), {})] + + else: + raise ValueError("Unexpected value for resources argument.") + + +def get_regexp_pattern(regexp): + """ + Helper that returns regexp pattern from given value. + + :param regexp: regular expression to stringify + :type regexp: _sre.SRE_Pattern or str + :returns: string representation of given regexp pattern + :rtype: str + """ + try: + return regexp.pattern + except AttributeError: + return str(regexp) + + +def get_cors_origins(options, request_origin): + origins = options.get("origins") + wildcard = r".*" in origins + + # If the Origin header is not present terminate this set of steps. + # The request is outside the scope of this specification.-- W3Spec + if request_origin: + LOG.debug("CORS request received with 'Origin' %s", request_origin) + + # If the allowed origins is an asterisk or 'wildcard', always match + if wildcard and options.get("send_wildcard"): + LOG.debug("Allowed origins are set to '*'. Sending wildcard CORS header.") + return ["*"] + # If the value of the Origin header is a case-sensitive match + # for any of the values in list of origins + elif try_match_any(request_origin, origins): + LOG.debug( + "The request's Origin header matches. Sending CORS headers.", + ) + # Add a single Access-Control-Allow-Origin header, with either + # the value of the Origin header or the string "*" as value. + # -- W3Spec + return [request_origin] + else: + LOG.debug( + "The request's Origin header does not match any of allowed origins." + ) + return None + + elif options.get("always_send"): + if wildcard: + # If wildcard is in the origins, even if 'send_wildcard' is False, + # simply send the wildcard. It is the most-likely to be correct + # thing to do (the only other option is to return nothing, which) + # pretty is probably not whawt you want if you specify origins as + # '*' + return ["*"] + else: + # Return all origins that are not regexes. + return sorted([o for o in origins if not probably_regex(o)]) + + # Terminate these steps, return the original request untouched. + else: + LOG.debug( + "The request did not contain an 'Origin' header. " + "This means the browser or client did not request CORS, ensure the Origin Header is set." + ) + return None + + +def get_allow_headers(options, acl_request_headers): + if acl_request_headers: + request_headers = [h.strip() for h in acl_request_headers.split(",")] + + # any header that matches in the allow_headers + matching_headers = filter( + lambda h: try_match_any(h, options.get("allow_headers")), request_headers + ) + + return ", ".join(sorted(matching_headers)) + + return None + + +def get_cors_headers(options, request_headers, request_method): + origins_to_set = get_cors_origins(options, request_headers.get("Origin")) + headers = CIMultiDict() + + if not origins_to_set: # CORS is not enabled for this route + return headers + + for origin in origins_to_set: + # TODO, with CIDict, with will only allow one origin + # With CIMultiDict it should work with multiple + headers[ACL_ORIGIN] = origin + + headers[ACL_EXPOSE_HEADERS] = options.get("expose_headers") + + if options.get("supports_credentials"): + headers[ACL_CREDENTIALS] = "true" # case sensative + + # This is a preflight request + # http://www.w3.org/TR/cors/#resource-preflight-requests + if request_method == "OPTIONS": + acl_request_method = request_headers.get(ACL_REQUEST_METHOD, "").upper() + + # If there is no Access-Control-Request-Method header or if parsing + # failed, do not set any additional headers + if acl_request_method and acl_request_method in options.get("methods"): + + # If method is not a case-sensitive match for any of the values in + # list of methods do not set any additional headers and terminate + # this set of steps. + headers[ACL_ALLOW_HEADERS] = get_allow_headers( + options, request_headers.get(ACL_REQUEST_HEADERS) + ) + headers[ACL_MAX_AGE] = str( + options.get("max_age") + ) # sanic cannot handle integers in header values. + headers[ACL_METHODS] = options.get("methods") + else: + LOG.info( + "The request's Access-Control-Request-Method header does not match allowed methods. " + "CORS headers will not be applied." + ) + + # http://www.w3.org/TR/cors/#resource-implementation + if options.get("vary_header"): + # Only set header if the origin returned will vary dynamically, + # i.e. if we are not returning an asterisk, and there are multiple + # origins that can be matched. + if headers[ACL_ORIGIN] == "*": + pass + elif ( + len(options.get("origins")) > 1 + or len(origins_to_set) > 1 + or any(map(probably_regex, options.get("origins"))) + ): + headers["Vary"] = "Origin" + + return CIMultiDict((k, v) for k, v in headers.items() if v) + + +def set_cors_headers(req, resp, req_context, options): + """ + Performs the actual evaluation of Sanic-CORS options and actually + modifies the response object. + + This function is used in the decorator, the CORS exception wrapper, + and the after_request callback + :param sanic.request.Request req: + + """ + # If CORS has already been evaluated via the decorator, skip + if req_context is not None: + evaluated = getattr(req_context, SANIC_CORS_EVALUATED, False) + if evaluated: + LOG.debug("CORS have been already evaluated, skipping") + return resp + + # `resp` can be None or [] in the case of using Websockets + # however this case should have been handled in the `extension` and `decorator` methods + # before getting here. This is a final failsafe check to prevent crashing + if not resp: + return None + + if resp.headers is None: + resp.headers = CIMultiDict() + + headers_to_set = get_cors_headers(options, req.headers, req.method) + + LOG.debug("Settings CORS headers: %s", str(headers_to_set)) + + for k, v in headers_to_set.items(): + try: + resp.headers.add(k, v) + except Exception as e2: + resp.headers[k] = v + return resp + + +def probably_regex(maybe_regex): + if isinstance(maybe_regex, RegexObject): + return True + else: + common_regex_chars = ["*", "\\", "]", "?"] + # Use common characters used in regular expressions as a proxy + # for if this string is in fact a regex. + return any((c in maybe_regex for c in common_regex_chars)) + + +def re_fix(reg): + """ + Replace the invalid regex r'*' with the valid, wildcard regex r'/.*' to + enable the CORS app extension to have a more user friendly api. + """ + return r".*" if reg == r"*" else reg + + +def try_match_any(inst, patterns): + return any(try_match(inst, pattern) for pattern in patterns) + + +def try_match(request_origin, maybe_regex): + """Safely attempts to match a pattern or string to a request origin.""" + if isinstance(maybe_regex, RegexObject): + return re.match(maybe_regex, request_origin) + elif probably_regex(maybe_regex): + return re.match(maybe_regex, request_origin, flags=re.IGNORECASE) + else: + try: + return request_origin.lower() == maybe_regex.lower() + except AttributeError: + return request_origin == maybe_regex + + +def get_cors_options(appInstance, *dicts): + """ + Compute CORS options for an application by combining the DEFAULT_OPTIONS, + the app's configuration-specified options and any dictionaries passed. The + last specified option wins. + """ + options = DEFAULT_OPTIONS.copy() + options.update(get_app_kwarg_dict(appInstance)) + if dicts: + for d in dicts: + options.update(d) + + return serialize_options(options) + + +def get_app_kwarg_dict(appInstance): + """Returns the dictionary of CORS specific app configurations.""" + # In order to support blueprints which do not have a config attribute + app_config = getattr(appInstance, "config", {}) + return dict( + (k.lower().replace("cors_", ""), app_config.get(k)) + for k in CONFIG_OPTIONS + if app_config.get(k) is not None + ) + + +def flexible_str(obj): + """ + A more flexible str function which intelligently handles stringifying + strings, lists and other iterables. The results are lexographically sorted + to ensure generated responses are consistent when iterables such as Set + are used. + """ + if obj is None: + return None + elif not isinstance(obj, str) and isinstance(obj, collections.abc.Iterable): + return ", ".join(str(item) for item in sorted(obj)) + else: + return str(obj) + + +def serialize_option(options_dict, key, upper=False): + if key in options_dict: + value = flexible_str(options_dict[key]) + options_dict[key] = value.upper() if upper else value + + +def ensure_iterable(inst): + """ + Wraps scalars or string types as a list, or returns the iterable instance. + """ + if isinstance(inst, str): + return [inst] + elif not isinstance(inst, collections.abc.Iterable): + return [inst] + else: + return inst + + +def sanitize_regex_param(param): + return [re_fix(x) for x in ensure_iterable(param)] + + +def serialize_options(opts): + """ + A helper method to serialize and processes the options dictionary. + """ + options = (opts or {}).copy() + + for key in opts.keys(): + if key not in DEFAULT_OPTIONS: + LOG.warning("Unknown option passed to Sanic-CORS: %s", key) + + # Ensure origins is a list of allowed origins with at least one entry. + options["origins"] = sanitize_regex_param(options.get("origins")) + options["allow_headers"] = sanitize_regex_param(options.get("allow_headers")) + + # This is expressly forbidden by the spec. Raise a value error so people + # don't get burned in production. + if ( + r".*" in options["origins"] + and options["supports_credentials"] + and options["send_wildcard"] + ): + raise ValueError( + "Cannot use supports_credentials in conjunction with" + "an origin string of '*'. See: " + "http://www.w3.org/TR/cors/#resource-requests" + ) + + serialize_option(options, "expose_headers") + serialize_option(options, "methods", upper=True) + + if isinstance(options.get("max_age"), timedelta): + options["max_age"] = str(int(options["max_age"].total_seconds())) + + return options diff --git a/backend/sanic_server/sanic_cors/decorator.py b/backend/sanic_server/sanic_cors/decorator.py new file mode 100644 index 000000000..d6c9c1f40 --- /dev/null +++ b/backend/sanic_server/sanic_cors/decorator.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" + decorator + ~~~~ + This unit exposes a single decorator which should be used to wrap a + Sanic route with. It accepts all parameters and options as + the CORS extension. + + :copyright: (c) 2021 by Ashley Sommer (based on flask-cors by Cory Dolphin). + :license: MIT, see LICENSE for more details. +""" + +from ..sanic_plugin_toolkit import SanicPluginRealm +from .core import * +from .extension import cors + + +def cross_origin(app, *args, **kwargs): + """ + This function is the decorator which is used to wrap a Sanic route with. + In the simplest case, simply use the default parameters to allow all + origins in what is the most permissive configuration. If this method + modifies state or performs authentication which may be brute-forced, you + should add some degree of protection, such as Cross Site Forgery + Request protection. + + :param origins: + The origin, or list of origins to allow requests from. + The origin(s) may be regular expressions, case-sensitive strings, + or else an asterisk + + Default : '*' + :type origins: list, string or regex + + :param methods: + The method or list of methods which the allowed origins are allowed to + access for non-simple requests. + + Default : [GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE] + :type methods: list or string + + :param expose_headers: + The header or list which are safe to expose to the API of a CORS API + specification. + + Default : None + :type expose_headers: list or string + + :param allow_headers: + The header or list of header field names which can be used when this + resource is accessed by allowed origins. The header(s) may be regular + expressions, case-sensitive strings, or else an asterisk. + + Default : '*', allow all headers + :type allow_headers: list, string or regex + + :param supports_credentials: + Allows users to make authenticated requests. If true, injects the + `Access-Control-Allow-Credentials` header in responses. This allows + cookies and credentials to be submitted across domains. + + :note: This option cannot be used in conjuction with a '*' origin + + Default : False + :type supports_credentials: bool + + :param max_age: + The maximum time for which this CORS request maybe cached. This value + is set as the `Access-Control-Max-Age` header. + + Default : None + :type max_age: timedelta, integer, string or None + + :param send_wildcard: If True, and the origins parameter is `*`, a wildcard + `Access-Control-Allow-Origin` header is sent, rather than the + request's `Origin` header. + + Default : False + :type send_wildcard: bool + + :param vary_header: + If True, the header Vary: Origin will be returned as per the W3 + implementation guidelines. + + Setting this header when the `Access-Control-Allow-Origin` is + dynamically generated (e.g. when there is more than one allowed + origin, and an Origin than '*' is returned) informs CDNs and other + caches that the CORS headers are dynamic, and cannot be cached. + + If False, the Vary header will never be injected or altered. + + Default : True + :type vary_header: bool + + :param automatic_options: + Only applies to the `cross_origin` decorator. If True, Sanic-CORS will + override Sanic's default OPTIONS handling to return CORS headers for + OPTIONS requests. + + Default : True + :type automatic_options: bool + + """ + _options = kwargs + _real_decorator = cors.decorate( + app, *args, run_middleware=False, with_context=False, **kwargs + ) + + def wrapped_decorator(f): + realm = SanicPluginRealm(app) # get the singleton from the app + try: + plugin = realm.register_plugin(cors, skip_reg=True) + except ValueError as e: + # this is normal, if this plugin has been registered previously + assert e.args and len(e.args) > 1 + plugin = e.args[1] + context = cors.get_context_from_realm(realm) + log = context.log + log( + logging.DEBUG, + "Enabled {:s} for cross_origin using options: {}".format( + str(f), str(_options) + ), + ) + return _real_decorator(f) + + return wrapped_decorator diff --git a/backend/sanic_server/sanic_cors/extension.py b/backend/sanic_server/sanic_cors/extension.py new file mode 100644 index 000000000..e2d943388 --- /dev/null +++ b/backend/sanic_server/sanic_cors/extension.py @@ -0,0 +1,488 @@ +# -*- coding: utf-8 -*- +""" + extension + ~~~~ + Sanic-CORS is a simple extension to Sanic allowing you to support cross + origin resource sharing (CORS) using a simple decorator. + + :copyright: (c) 2021 by Ashley Sommer (based on flask-cors by Cory Dolphin). + :license: MIT, see LICENSE for more details. +""" +import logging +from asyncio import iscoroutinefunction +from distutils.version import LooseVersion +from functools import partial, update_wrapper +from inspect import isawaitable + +from ..sanic import __version__ as sanic_version +from ..sanic import exceptions, response +from ..sanic.exceptions import MethodNotSupported, NotFound +from ..sanic.handlers import ErrorHandler +from ..sanic_plugin_toolkit import SanicPlugin +from .core import * + +SANIC_VERSION = LooseVersion(sanic_version) +SANIC_18_12_0 = LooseVersion("18.12.0") +SANIC_19_9_0 = LooseVersion("19.9.0") +SANIC_19_12_0 = LooseVersion("19.12.0") +SANIC_21_9_0 = LooseVersion("21.9.0") + + +USE_ASYNC_EXCEPTION_HANDLER = False + + +class CORS(SanicPlugin): + __slots__ = tuple() + """ + Initializes Cross Origin Resource sharing for the application. The + arguments are identical to :py:func:`cross_origin`, with the addition of a + `resources` parameter. The resources parameter defines a series of regular + expressions for resource paths to match and optionally, the associated + options to be applied to the particular resource. These options are + identical to the arguments to :py:func:`cross_origin`. + + The settings for CORS are determined in the following order + + 1. Resource level settings (e.g when passed as a dictionary) + 2. Keyword argument settings + 3. App level configuration settings (e.g. CORS_*) + 4. Default settings + + Note: as it is possible for multiple regular expressions to match a + resource path, the regular expressions are first sorted by length, + from longest to shortest, in order to attempt to match the most + specific regular expression. This allows the definition of a + number of specific resource options, with a wildcard fallback + for all other resources. + + :param resources: + The series of regular expression and (optionally) associated CORS + options to be applied to the given resource path. + + If the argument is a dictionary, it's keys must be regular expressions, + and the values must be a dictionary of kwargs, identical to the kwargs + of this function. + + If the argument is a list, it is expected to be a list of regular + expressions, for which the app-wide configured options are applied. + + If the argument is a string, it is expected to be a regular expression + for which the app-wide configured options are applied. + + Default : Match all and apply app-level configuration + + :type resources: dict, iterable or string + + :param origins: + The origin, or list of origins to allow requests from. + The origin(s) may be regular expressions, case-sensitive strings, + or else an asterisk + + Default : '*' + :type origins: list, string or regex + + :param methods: + The method or list of methods which the allowed origins are allowed to + access for non-simple requests. + + Default : [GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE] + :type methods: list or string + + :param expose_headers: + The header or list which are safe to expose to the API of a CORS API + specification. + + Default : None + :type expose_headers: list or string + + :param allow_headers: + The header or list of header field names which can be used when this + resource is accessed by allowed origins. The header(s) may be regular + expressions, case-sensitive strings, or else an asterisk. + + Default : '*', allow all headers + :type allow_headers: list, string or regex + + :param supports_credentials: + Allows users to make authenticated requests. If true, injects the + `Access-Control-Allow-Credentials` header in responses. This allows + cookies and credentials to be submitted across domains. + + :note: This option cannot be used in conjuction with a '*' origin + + Default : False + :type supports_credentials: bool + + :param max_age: + The maximum time for which this CORS request maybe cached. This value + is set as the `Access-Control-Max-Age` header. + + Default : None + :type max_age: timedelta, integer, string or None + + :param send_wildcard: If True, and the origins parameter is `*`, a wildcard + `Access-Control-Allow-Origin` header is sent, rather than the + request's `Origin` header. + + Default : False + :type send_wildcard: bool + + :param vary_header: + If True, the header Vary: Origin will be returned as per the W3 + implementation guidelines. + + Setting this header when the `Access-Control-Allow-Origin` is + dynamically generated (e.g. when there is more than one allowed + origin, and an Origin than '*' is returned) informs CDNs and other + caches that the CORS headers are dynamic, and cannot be cached. + + If False, the Vary header will never be injected or altered. + + Default : True + :type vary_header: bool + """ + + def __init__(self, *args, **kwargs): + if SANIC_18_12_0 > SANIC_VERSION: + raise RuntimeError( + "You cannot use this version of Sanic-CORS with " + "Sanic earlier than v18.12.0" + ) + super(CORS, self).__init__(*args, **kwargs) + + def on_before_registered(self, context, *args, **kwargs): + context._options = kwargs + if not CORS.on_before_registered.has_run: + # debug = partial(context.log, logging.DEBUG) + _ = _make_cors_request_middleware_function(self) + _ = _make_cors_response_middleware_function(self) + CORS.on_before_registered.has_run = True + + on_before_registered.has_run = False + + def on_registered(self, context, *args, **kwargs): + # this will need to be called more than once, for every app it is registered on. + self.init_app(context, *args, **kwargs) + + def init_app(self, context, *args, **kwargs): + app = context.app + log = context.log + _options = context._options + debug = partial(log, logging.DEBUG) + # The resources and options may be specified in the App Config, the CORS constructor + # or the kwargs to the call to init_app. + options = get_cors_options(app, _options, kwargs) + + # Flatten our resources into a list of the form + # (pattern_or_regexp, dictionary_of_options) + resources = parse_resources(options.get("resources")) + + # Compute the options for each resource by combining the options from + # the app's configuration, the constructor, the kwargs to init_app, and + # finally the options specified in the resources dictionary. + resources = [ + (pattern, get_cors_options(app, options, opts)) + for (pattern, opts) in resources + ] + context.options = options + context.resources = resources + # Create a human readable form of these resources by converting the compiled + # regular expressions into strings. + resources_human = dict( + [(get_regexp_pattern(pattern), opts) for (pattern, opts) in resources] + ) + debug("Configuring CORS with resources: {}".format(resources_human)) + if hasattr(app, "error_handler"): + cors_error_handler = CORSErrorHandler( + context, app.error_handler, fallback="auto" + ) + setattr(app, "error_handler", cors_error_handler) + else: + # Blueprints have no error_handler. Just skip error_handler initialisation + pass + + async def route_wrapper( + self, + route, + req, + context, + request_args, + request_kw, + *decorator_args, + **decorator_kw + ): + _ = decorator_kw.pop("with_context") # ignore this. + _options = decorator_kw + options = get_cors_options(context.app, _options) + if options.get("automatic_options", True) and req.method == "OPTIONS": + resp = response.HTTPResponse() + else: + resp = route(req, *request_args, **request_kw) + while isawaitable(resp): + resp = await resp + # resp can be `None` or `[]` if using Websockets + if not resp: + return None + try: + request_context = context.request[id(req)] + except (AttributeError, LookupError): + if SANIC_19_9_0 <= SANIC_VERSION: + request_context = req.ctx + else: + request_context = None + set_cors_headers(req, resp, request_context, options) + if request_context is not None: + setattr(request_context, SANIC_CORS_EVALUATED, "1") + else: + context.log( + logging.DEBUG, + "Cannot access a sanic request " + "context. Has request started? Is request ended?", + ) + return resp + + +def unapplied_cors_request_middleware(req, context): + if req.method == "OPTIONS": + try: + path = req.path + except AttributeError: + path = req.url + resources = context.resources + log = context.log + debug = partial(log, logging.DEBUG) + for res_regex, res_options in resources: + if res_options.get("automatic_options", True) and try_match( + path, res_regex + ): + debug( + "Request to '{:s}' matches CORS resource '{}'. " + "Using options: {}".format( + path, get_regexp_pattern(res_regex), res_options + ) + ) + resp = response.HTTPResponse() + + try: + request_context = context.request[id(req)] + except (AttributeError, LookupError): + if SANIC_19_9_0 <= SANIC_VERSION: + request_context = req.ctx + else: + request_context = None + context.log( + logging.DEBUG, + "Cannot access a sanic request " + "context. Has request started? Is request ended?", + ) + set_cors_headers(req, resp, request_context, res_options) + if request_context is not None: + setattr(request_context, SANIC_CORS_EVALUATED, "1") + return resp + else: + debug("No CORS rule matches") + + +async def unapplied_cors_response_middleware(req, resp, context): + log = context.log + debug = partial(log, logging.DEBUG) + # `resp` can be None or [] in the case of using Websockets + if not resp: + return False + try: + request_context = context.request[id(req)] + except (AttributeError, LookupError): + if SANIC_19_9_0 <= SANIC_VERSION: + request_context = req.ctx + else: + debug( + "Cannot find the request context. " + "Is request already finished? Is request not started?" + ) + request_context = None + if request_context is not None: + # If CORS headers are set in the CORS error handler + if getattr(request_context, SANIC_CORS_SKIP_RESPONSE_MIDDLEWARE, False): + debug("CORS was handled in the exception handler, skipping") + return False + + # If CORS headers are set in a view decorator, pass + elif getattr(request_context, SANIC_CORS_EVALUATED, False): + debug("CORS have been already evaluated, skipping") + return False + try: + path = req.path + except AttributeError: + path = req.url + + resources = context.resources + for res_regex, res_options in resources: + if try_match(path, res_regex): + debug( + "Request to '{}' matches CORS resource '{:s}'. Using options: {}".format( + path, get_regexp_pattern(res_regex), res_options + ) + ) + set_cors_headers(req, resp, request_context, res_options) + if request_context is not None: + setattr(request_context, SANIC_CORS_EVALUATED, "1") + break + else: + debug("No CORS rule matches") + + +def _make_cors_request_middleware_function(plugin): + applied_cors_request_middleware = plugin.middleware( + relative="pre", attach_to="request", with_context=True + )(unapplied_cors_request_middleware) + return applied_cors_request_middleware + + +def _make_cors_response_middleware_function(plugin): + applied_cors_response_middleware = plugin.middleware( + relative="post", attach_to="response", with_context=True + )(unapplied_cors_response_middleware) + return applied_cors_response_middleware + + +class CORSErrorHandler(ErrorHandler): + @classmethod + def _apply_cors_to_exception(cls, ctx, req, resp): + try: + path = req.path + except AttributeError: + path = req.url + if path is not None: + resources = ctx.resources + log = ctx.log + debug = partial(log, logging.DEBUG) + try: + request_context = ctx.request[id(req)] + except (AttributeError, LookupError): + if SANIC_19_9_0 <= SANIC_VERSION: + request_context = req.ctx + else: + request_context = None + for res_regex, res_options in resources: + if try_match(path, res_regex): + debug( + "Request to '{:s}' matches CORS resource '{}'. " + "Using options: {}".format( + path, get_regexp_pattern(res_regex), res_options + ) + ) + set_cors_headers(req, resp, request_context, res_options) + break + else: + debug("No CORS rule matches") + else: + pass + + def __new__(cls, *args, **kwargs): + self = super(CORSErrorHandler, cls).__new__(cls) + if USE_ASYNC_EXCEPTION_HANDLER: + self.response = self.async_response + else: + self.response = self.sync_response + return self + + def __init__(self, context, orig_handler, fallback="auto"): + if SANIC_21_9_0 <= SANIC_VERSION: + super(CORSErrorHandler, self).__init__(fallback=fallback) + else: + super(CORSErrorHandler, self).__init__() + self.orig_handler = orig_handler + self.ctx = context + + if SANIC_21_9_0 <= SANIC_VERSION: + + def add(self, exception, handler, route_names=None): + self.orig_handler.add(exception, handler, route_names=route_names) + + def lookup(self, exception, route_name=None): + return self.orig_handler.lookup(exception, route_name=route_name) + + else: + + def add(self, exception, handler): + self.orig_handler.add(exception, handler) + + def lookup(self, exception): + return self.orig_handler.lookup(exception) + + # wrap app's original exception response function + # so that error responses have proper CORS headers + @classmethod + def wrapper(cls, f, ctx, req, e): + opts = ctx.options + log = ctx.log + # get response from the original handler + if ( + req is not None + and SANIC_19_12_0 <= SANIC_VERSION + and isinstance(e, MethodNotSupported) + and req.method == "OPTIONS" + and opts.get("automatic_options", True) + ): + # A very specific set of requirments to trigger this kind of + # automatic-options resp + resp = response.HTTPResponse() + else: + do_await = iscoroutinefunction(f) + resp = f(req, e) + if do_await: + log( + logging.DEBUG, + "Found an async Exception handler response. " + "Cannot apply CORS to it. Passing it on.", + ) + return resp + # SanicExceptions are equiv to Flask Aborts, + # always apply CORS to them. + if (req is not None and resp is not None) and ( + isinstance(e, exceptions.SanicException) + or opts.get("intercept_exceptions", True) + ): + try: + cls._apply_cors_to_exception(ctx, req, resp) + except AttributeError: + # not sure why certain exceptions doesn't have + # an accompanying request + pass + if req is None: + return resp + # These exceptions have normal CORS middleware applied automatically. + # So set a flag to skip our manual application of the middleware. + try: + request_context = ctx.request[id(req)] + except (LookupError, AttributeError): + # On Sanic 19.12.0, a NotFound error can be thrown _before_ + # the request_context is set up. This is a fallback routine: + if SANIC_19_12_0 <= SANIC_VERSION and isinstance( + e, (NotFound, MethodNotSupported) + ): + # On sanic 19.9.0+ request is a dict, so we can add our + # flag directly to it. + request_context = req.ctx + else: + log( + logging.DEBUG, + "Cannot find the request context. Is request started? " + "Is request already finished?", + ) + request_context = None + if request_context is not None: + setattr(request_context, SANIC_CORS_SKIP_RESPONSE_MIDDLEWARE, "1") + return resp + + async def async_response(self, request, exception): + orig_resp_handler = self.orig_handler.response + return await self.wrapper(orig_resp_handler, self.ctx, request, exception) + + def sync_response(self, request, exception): + orig_resp_handler = self.orig_handler.response + return self.wrapper(orig_resp_handler, self.ctx, request, exception) + + +instance = cors = CORS() +__all__ = ["cors", "CORS"] diff --git a/backend/sanic_server/sanic_cors/version.py b/backend/sanic_server/sanic_cors/version.py new file mode 100644 index 000000000..cd7ca4980 --- /dev/null +++ b/backend/sanic_server/sanic_cors/version.py @@ -0,0 +1 @@ +__version__ = '1.0.1' diff --git a/backend/sanic_server/sanic_ext/LICENSE b/backend/sanic_server/sanic_ext/LICENSE new file mode 100644 index 000000000..1377edf4c --- /dev/null +++ b/backend/sanic_server/sanic_ext/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Channel Cat + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/backend/sanic_server/sanic_ext/__init__.py b/backend/sanic_server/sanic_ext/__init__.py new file mode 100644 index 000000000..84bda49b8 --- /dev/null +++ b/backend/sanic_server/sanic_ext/__init__.py @@ -0,0 +1,14 @@ +from ..sanic_ext.bootstrap import Extend +from ..sanic_ext.config import Config +from ..sanic_ext.extensions.http.cors import cors +from ..sanic_ext.extras.serializer.decorator import serializer +from ..sanic_ext.extras.validation.decorator import validate + +__version__ = "21.12.1" +__all__ = [ + "Config", + "Extend", + "cors", + "serializer", + "validate", +] diff --git a/backend/sanic_server/sanic_ext/bootstrap.py b/backend/sanic_server/sanic_ext/bootstrap.py new file mode 100644 index 000000000..434786dc3 --- /dev/null +++ b/backend/sanic_server/sanic_ext/bootstrap.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from string import ascii_lowercase +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Type, Union +from warnings import warn + +from ..sanic import Sanic, __version__ +from ..sanic.exceptions import SanicException +from ..sanic.log import logger +from ..sanic_ext.config import Config, add_fallback_config +from ..sanic_ext.extensions.base import Extension +from ..sanic_ext.extensions.http.extension import HTTPExtension +from ..sanic_ext.extensions.injection.extension import InjectionExtension +from ..sanic_ext.extensions.injection.registry import InjectionRegistry +from ..sanic_ext.utils.string import camel_to_snake + +MIN_SUPPORT = (21, 3, 2) + + +class Extend: + def __init__( + self, + app: Sanic, + *, + extensions: Optional[List[Type[Extension]]] = None, + built_in_extensions: bool = True, + config: Optional[Union[Config, Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + Ingress for instantiating sanic-ext + + :param app: Sanic application + :type app: Sanic + """ + if not isinstance(app, Sanic): + raise SanicException(f"Cannot apply SanicExt to {app.__class__.__name__}") + + sanic_version = tuple( + map(int, __version__.strip(ascii_lowercase).split(".", 3)[:3]) + ) + + if MIN_SUPPORT > sanic_version: + min_version = ".".join(map(str, MIN_SUPPORT)) + raise SanicException( + f"SanicExt only works with Sanic v{min_version} and above. " + f"It looks like you are running {__version__}." + ) + + self.app = app + self.extensions = [] + self._injection_registry: Optional[InjectionRegistry] = None + app._ext = self + app.ctx._dependencies = SimpleNamespace() + + if not isinstance(config, Config): + config = Config.from_dict(config or {}) + self.config = add_fallback_config(app, config, **kwargs) + + extensions = extensions or [] + if built_in_extensions: + extensions.extend( + [ + InjectionExtension, + HTTPExtension, + ] + ) + for extclass in extensions[::-1]: + extension = extclass(app, self.config) + extension._startup(self) + self.extensions.append(extension) + + def _display(self): + init_logs = ["Sanic Extensions:"] + for extension in self.extensions: + init_logs.append(f" > {extension.name} {extension.label()}") + + list(map(logger.info, init_logs)) + + def injection( + self, + type: Type, + constructor: Optional[Callable[..., Any]] = None, + ) -> None: + warn( + "The 'ext.injection' method has been deprecated and will be " + "removed in v22.6. Please use 'ext.add_dependency' instead.", + DeprecationWarning, + ) + self.add_dependency(type=type, constructor=constructor) + + def add_dependency( + self, + type: Type, + constructor: Optional[Callable[..., Any]] = None, + ) -> None: + if not self._injection_registry: + raise SanicException("Injection extension not enabled") + self._injection_registry.register(type, constructor) + + def dependency(self, obj: Any, name: Optional[str] = None) -> None: + if not name: + name = camel_to_snake(obj.__class__.__name__) + setattr(self.app.ctx._dependencies, name, obj) + + def getter(*_): + return obj + + self.add_dependency(obj.__class__, getter) diff --git a/backend/sanic_server/sanic_ext/config.py b/backend/sanic_server/sanic_ext/config.py new file mode 100644 index 000000000..1b5dc37ea --- /dev/null +++ b/backend/sanic_server/sanic_ext/config.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence + +from ..sanic import Sanic +from ..sanic.config import Config as SanicConfig + + +class Config(SanicConfig): + def __init__( + self, + cors: bool = True, + cors_allow_headers: str = "*", + cors_always_send: bool = True, + cors_automatic_options: bool = True, + cors_expose_headers: str = "", + cors_max_age: int = 5, + cors_methods: str = "", + cors_origins: str = "", + cors_send_wildcard: bool = False, + cors_supports_credentials: bool = False, + cors_vary_header: bool = True, + http_all_methods: bool = True, + http_auto_head: bool = True, + http_auto_options: bool = True, + http_auto_trace: bool = False, + oas: bool = True, + oas_autodoc: bool = True, + oas_ignore_head: bool = True, + oas_ignore_options: bool = True, + oas_path_to_redoc_html: Optional[str] = None, + oas_path_to_swagger_html: Optional[str] = None, + oas_ui_default: Optional[str] = "redoc", + oas_ui_redoc: bool = True, + oas_ui_swagger: bool = True, + oas_ui_swagger_version: str = "4.1.0", + oas_uri_to_config: str = "/swagger-config", + oas_uri_to_json: str = "/openapi.json", + oas_uri_to_redoc: str = "/redoc", + oas_uri_to_swagger: str = "/swagger", + oas_url_prefix: str = "/docs", + swagger_ui_configuration: Optional[Dict[str, Any]] = None, + trace_excluded_headers: Sequence[str] = ("authorization", "cookie"), + **kwargs, + ): + self.CORS = cors + self.CORS_ALLOW_HEADERS = cors_allow_headers + self.CORS_ALWAYS_SEND = cors_always_send + self.CORS_AUTOMATIC_OPTIONS = cors_automatic_options + self.CORS_EXPOSE_HEADERS = cors_expose_headers + self.CORS_MAX_AGE = cors_max_age + self.CORS_METHODS = cors_methods + self.CORS_ORIGINS = cors_origins + self.CORS_SEND_WILDCARD = cors_send_wildcard + self.CORS_SUPPORTS_CREDENTIALS = cors_supports_credentials + self.CORS_VARY_HEADER = cors_vary_header + self.HTTP_ALL_METHODS = http_all_methods + self.HTTP_AUTO_HEAD = http_auto_head + self.HTTP_AUTO_OPTIONS = http_auto_options + self.HTTP_AUTO_TRACE = http_auto_trace + self.OAS = oas + self.OAS_AUTODOC = oas_autodoc + self.OAS_IGNORE_HEAD = oas_ignore_head + self.OAS_IGNORE_OPTIONS = oas_ignore_options + self.OAS_PATH_TO_REDOC_HTML = oas_path_to_redoc_html + self.OAS_PATH_TO_SWAGGER_HTML = oas_path_to_swagger_html + self.OAS_UI_DEFAULT = oas_ui_default + self.OAS_UI_REDOC = oas_ui_redoc + self.OAS_UI_SWAGGER = oas_ui_swagger + self.OAS_UI_SWAGGER_VERSION = oas_ui_swagger_version + self.OAS_URI_TO_CONFIG = oas_uri_to_config + self.OAS_URI_TO_JSON = oas_uri_to_json + self.OAS_URI_TO_REDOC = oas_uri_to_redoc + self.OAS_URI_TO_SWAGGER = oas_uri_to_swagger + self.OAS_URL_PREFIX = oas_url_prefix + self.SWAGGER_UI_CONFIGURATION = swagger_ui_configuration or { + "apisSorter": "alpha", + "operationsSorter": "alpha", + "docExpansion": "full", + } + self.TRACE_EXCLUDED_HEADERS = trace_excluded_headers + + if isinstance(self.TRACE_EXCLUDED_HEADERS, str): + self.TRACE_EXCLUDED_HEADERS = tuple(self.TRACE_EXCLUDED_HEADERS.split(",")) + + self.load({key.upper(): value for key, value in kwargs.items()}) + + @classmethod + def from_dict(cls, mapping) -> Config: + return cls(**mapping) + + +def add_fallback_config( + app: Sanic, config: Optional[Config] = None, **kwargs +) -> Config: + if config is None: + config = Config(**kwargs) + + app.config.update( + {key: value for key, value in config.items() if key not in app.config} + ) + + return config diff --git a/backend/sanic_server/sanic_ext/exceptions.py b/backend/sanic_server/sanic_ext/exceptions.py new file mode 100644 index 000000000..fe3046a74 --- /dev/null +++ b/backend/sanic_server/sanic_ext/exceptions.py @@ -0,0 +1,9 @@ +from ..sanic.exceptions import SanicException + + +class ValidationError(SanicException): + status_code = 400 + + +class InitError(SanicException): + ... diff --git a/backend/sanic_server/sanic_ext/extensions/__init__.py b/backend/sanic_server/sanic_ext/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extensions/base.py b/backend/sanic_server/sanic_ext/extensions/base.py new file mode 100644 index 000000000..f6b819d41 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/base.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Type + +from ...sanic.app import Sanic +from ...sanic.exceptions import SanicException +from ...sanic_ext.config import Config +from ...sanic_ext.exceptions import InitError + + +class NoDuplicateDict(dict): # type: ignore + def __setitem__(self, key: Any, value: Any) -> None: + if key in self: + raise KeyError(f"Duplicate key: {key}") + return super().__setitem__(key, value) + + +class Extension(ABC): + _name_registry: Dict[str, Type[Extension]] = NoDuplicateDict() + _singleton = None + name: str + + def __new__(cls, *args, **kwargs): + if cls._singleton is None: + cls._singleton = super().__new__(cls) + cls._singleton._started = False + return cls._singleton + + def __init_subclass__(cls): + if not getattr(cls, "name", None) or not cls.name.isalpha(): + raise InitError( + "Extensions must be named, and may only contain " + "alphabetic characters" + ) + + if cls.name in cls._name_registry: + raise InitError(f'Extension "{cls.name}" already exists') + + cls._name_registry[cls.name] = cls + + def __init__(self, app: Sanic, config: Config) -> None: + self.app = app + self.config = config + + def _startup(self, bootstrap): + if self._started: + raise SanicException( + "Extension already started. Cannot start " + f"Extension:{self.name} multiple times." + ) + self.startup(bootstrap) + self._started = True + + @abstractmethod + def startup(self, bootstrap) -> None: + ... + + def label(self): + return "" diff --git a/backend/sanic_server/sanic_ext/extensions/http/__init__.py b/backend/sanic_server/sanic_ext/extensions/http/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extensions/http/cors.py b/backend/sanic_server/sanic_ext/extensions/http/cors.py new file mode 100644 index 000000000..b381f4c79 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/http/cors.py @@ -0,0 +1,386 @@ +import re +from dataclasses import dataclass +from datetime import timedelta +from types import SimpleNamespace +from typing import Any, FrozenSet, List, Optional, Tuple, Union + +from ....sanic import HTTPResponse, Request, Sanic +from ....sanic.exceptions import SanicException +from ....sanic.helpers import Default, _default +from ....sanic.log import logger + +WILDCARD_PATTERN = re.compile(r".*") +ORIGIN_HEADER = "access-control-allow-origin" +ALLOW_HEADERS_HEADER = "access-control-allow-headers" +ALLOW_METHODS_HEADER = "access-control-allow-methods" +EXPOSE_HEADER = "access-control-expose-headers" +CREDENTIALS_HEADER = "access-control-allow-credentials" +REQUEST_METHOD_HEADER = "access-control-request-method" +REQUEST_HEADERS_HEADER = "access-control-request-headers" +MAX_AGE_HEADER = "access-control-max-age" +VARY_HEADER = "vary" + + +@dataclass(frozen=True) +class CORSSettings: + allow_headers: FrozenSet[str] + allow_methods: FrozenSet[str] + allow_origins: Tuple[re.Pattern, ...] + always_send: bool + automatic_options: bool + expose_headers: FrozenSet[str] + max_age: str + send_wildcard: bool + supports_credentials: bool + + +def add_cors(app: Sanic) -> None: + _setup_cors_settings(app) + + @app.on_response + async def _add_cors_headers(request, response): + preflight = ( + request.app.ctx.cors.automatic_options + and request.method == "OPTIONS" + ) + + if preflight and not request.headers.get(REQUEST_METHOD_HEADER): + logger.info( + "No Access-Control-Request-Method header found on request. " + "CORS headers will not be applied." + ) + return + + _add_origin_header(request, response) + + if ORIGIN_HEADER not in response.headers: + return + + _add_expose_header(request, response) + _add_credentials_header(request, response) + _add_vary_header(request, response) + + if preflight: + _add_max_age_header(request, response) + _add_allow_header(request, response) + _add_methods_header(request, response) + + @app.before_server_start + async def _assign_cors_settings(app, _): + for group in app.router.groups.values(): + _cors = SimpleNamespace() + for route in group: + cors = getattr(route.handler, "__cors__", None) + if cors: + for key, value in cors.__dict__.items(): + setattr(_cors, key, value) + + for route in group: + route.ctx._cors = _cors + + +def cors( + *, + origin: Union[str, Default] = _default, + expose_headers: Union[List[str], Default] = _default, + allow_headers: Union[List[str], Default] = _default, + allow_methods: Union[List[str], Default] = _default, + supports_credentials: Union[bool, Default] = _default, + max_age: Union[str, int, timedelta, Default] = _default, +): + def decorator(f): + f.__cors__ = SimpleNamespace( + _cors_origin=origin, + _cors_expose_headers=expose_headers, + _cors_supports_credentials=supports_credentials, + _cors_allow_origins=( + _parse_allow_origins(origin) + if origin is not _default + else origin + ), + _cors_allow_headers=( + _parse_allow_headers(allow_headers) + if allow_headers is not _default + else allow_headers + ), + _cors_allow_methods=( + _parse_allow_methods(allow_methods) + if allow_methods is not _default + else allow_methods + ), + _cors_max_age=( + _parse_max_age(max_age) if max_age is not _default else max_age + ), + ) + return f + + return decorator + + +def _setup_cors_settings(app: Sanic) -> None: + if app.config.CORS_ORIGINS == "*" and app.config.CORS_SUPPORTS_CREDENTIALS: + raise SanicException( + "Cannot use supports_credentials in conjunction with " + "an origin string of '*'. See: " + "http://www.w3.org/TR/cors/#resource-requests" + ) + + allow_headers = _get_allow_headers(app) + allow_methods = _get_allow_methods(app) + allow_origins = _get_allow_origins(app) + expose_headers = _get_expose_headers(app) + max_age = _get_max_age(app) + + app.ctx.cors = CORSSettings( + allow_headers=allow_headers, + allow_methods=allow_methods, + allow_origins=allow_origins, + always_send=app.config.CORS_ALWAYS_SEND, + automatic_options=app.config.CORS_AUTOMATIC_OPTIONS, + expose_headers=expose_headers, + max_age=max_age, + send_wildcard=( + app.config.CORS_SEND_WILDCARD and WILDCARD_PATTERN in allow_origins + ), + supports_credentials=app.config.CORS_SUPPORTS_CREDENTIALS, + ) + + +def _get_from_cors_ctx(request: Request, key: str, default: Any = None): + if request.route: + value = getattr(request.route.ctx._cors, key, default) + if value is not _default: + return value + return default + + +def _add_origin_header(request: Request, response: HTTPResponse) -> None: + request_origin = request.headers.get("origin") + origin_value = "" + allow_origins = _get_from_cors_ctx( + request, + "_cors_allow_origins", + request.app.ctx.cors.allow_origins, + ) + fallback_origin = _get_from_cors_ctx( + request, + "_cors_origin", + request.app.config.CORS_ORIGINS, + ) + + if request_origin: + if request.app.ctx.cors.send_wildcard: + origin_value = "*" + else: + for pattern in allow_origins: + if pattern.match(request_origin): + origin_value = request_origin + break + elif request.app.ctx.cors.always_send: + if WILDCARD_PATTERN in allow_origins: + origin_value = "*" + else: + if isinstance(fallback_origin, str) and "," not in fallback_origin: + origin_value = fallback_origin + else: + origin_value = request.app.config.get("SERVER_NAME", "") + + if origin_value: + response.headers[ORIGIN_HEADER] = origin_value + + +def _add_expose_header(request: Request, response: HTTPResponse) -> None: + with_credentials = _is_request_with_credentials(request) + headers = None + expose_headers = _get_from_cors_ctx( + request, "_cors_expose_headers", request.app.ctx.cors.expose_headers + ) + # MDN: The value "*" only counts as a special wildcard value for requests + # without credentials (requests without HTTP cookies or HTTP + # authentication information). In requests with credentials, it is + # treated as the literal header name "*" without special semantics. + # Note that the Authorization header can't be wildcarded and always + # needs to be listed explicitly. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + if not with_credentials and "*" in expose_headers: + headers = ["*"] + elif expose_headers: + headers = expose_headers + + if headers: + response.headers[EXPOSE_HEADER] = ",".join(headers) + + +def _add_credentials_header(request: Request, response: HTTPResponse) -> None: + supports_credentials = _get_from_cors_ctx( + request, + "_cors_supports_credentials", + request.app.ctx.cors.supports_credentials, + ) + if supports_credentials: + response.headers[CREDENTIALS_HEADER] = "true" + + +def _add_allow_header(request: Request, response: HTTPResponse) -> None: + with_credentials = _is_request_with_credentials(request) + request_headers = set( + h.strip().lower() + for h in request.headers.get(REQUEST_HEADERS_HEADER, "").split(",") + ) + allow_headers = _get_from_cors_ctx( + request, "_cors_allow_headers", request.app.ctx.cors.allow_headers + ) + + # MDN: The value "*" only counts as a special wildcard value for requests + # without credentials (requests without HTTP cookies or HTTP + # authentication information). In requests with credentials, + # it is treated as the literal header name "*" without special semantics. + # Note that the Authorization header can't be wildcarded and always needs + # to be listed explicitly. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + if not with_credentials and "*" in allow_headers: + allow_headers = ["*"] + else: + allow_headers = request_headers & allow_headers + + if allow_headers: + response.headers[ALLOW_HEADERS_HEADER] = ",".join(allow_headers) + + +def _add_max_age_header(request: Request, response: HTTPResponse) -> None: + max_age = _get_from_cors_ctx( + request, "_cors_max_age", request.app.ctx.cors.max_age + ) + if max_age: + response.headers[MAX_AGE_HEADER] = max_age + + +def _add_methods_header(request: Request, response: HTTPResponse) -> None: + # MDN: The value "*" only counts as a special wildcard value for requests + # without credentials (requests without HTTP cookies or HTTP + # authentication information). In requests with credentials, it + # is treated as the literal method name "*" without + # special semantics. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + methods = None + with_credentials = _is_request_with_credentials(request) + allow_methods = _get_from_cors_ctx( + request, "_cors_allow_methods", request.app.ctx.cors.allow_methods + ) + + if not with_credentials and "*" in allow_methods: + methods = {"*"} + elif request.route: + group = request.app.router.groups.get(request.route.segments) + if group: + group_methods = {method.lower() for method in group.methods} + if allow_methods: + methods = group_methods & allow_methods + else: + methods = group_methods + + if methods: + response.headers[ALLOW_METHODS_HEADER] = ",".join(methods).upper() + + +def _add_vary_header(request: Request, response: HTTPResponse) -> None: + allow_origins = _get_from_cors_ctx( + request, + "_cors_allow_origins", + request.app.ctx.cors.allow_origins, + ) + if len(allow_origins) > 1: + response.headers[VARY_HEADER] = "origin" + + +def _get_allow_origins(app: Sanic) -> Tuple[re.Pattern, ...]: + origins = app.config.CORS_ORIGINS + return _parse_allow_origins(origins) + + +def _parse_allow_origins( + value: Union[str, re.Pattern] +) -> Tuple[re.Pattern, ...]: + origins: Optional[Union[List[str], List[re.Pattern]]] = None + if value and isinstance(value, str): + if value == "*": + origins = [WILDCARD_PATTERN] + else: + origins = value.split(",") + elif isinstance(value, re.Pattern): + origins = [value] + + return tuple( + pattern if isinstance(pattern, re.Pattern) else re.compile(pattern) + for pattern in (origins or []) + ) + + +def _get_expose_headers(app: Sanic) -> FrozenSet[str]: + expose_headers = ( + ( + app.config.CORS_EXPOSE_HEADERS + if isinstance( + app.config.CORS_EXPOSE_HEADERS, (list, set, frozenset, tuple) + ) + else app.config.CORS_EXPOSE_HEADERS.split(",") + ) + if app.config.CORS_EXPOSE_HEADERS + else tuple() + ) + return frozenset(header.lower() for header in expose_headers) + + +def _get_allow_headers(app: Sanic) -> FrozenSet[str]: + return _parse_allow_headers(app.config.CORS_ALLOW_HEADERS) + + +def _parse_allow_headers(value: str) -> FrozenSet[str]: + allow_headers = ( + ( + value + if isinstance( + value, + (list, set, frozenset, tuple), + ) + else value.split(",") + ) + if value + else tuple() + ) + return frozenset(header.lower() for header in allow_headers) + + +def _get_max_age(app: Sanic) -> str: + return _parse_max_age(app.config.CORS_MAX_AGE or "") + + +def _parse_max_age(value) -> str: + max_age = value or "" + if isinstance(max_age, timedelta): + max_age = str(int(max_age.total_seconds())) + return str(max_age) + + +def _get_allow_methods(app: Sanic) -> FrozenSet[str]: + return _parse_allow_methods(app.config.CORS_METHODS) + + +def _parse_allow_methods(value) -> FrozenSet[str]: + allow_methods = ( + ( + value + if isinstance( + value, + (list, set, frozenset, tuple), + ) + else value.split(",") + ) + if value + else tuple() + ) + return frozenset(method.lower() for method in allow_methods) + + +def _is_request_with_credentials(request: Request) -> bool: + return bool(request.headers.get("authorization") or request.cookies) diff --git a/backend/sanic_server/sanic_ext/extensions/http/extension.py b/backend/sanic_server/sanic_ext/extensions/http/extension.py new file mode 100644 index 000000000..b7b13b48e --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/http/extension.py @@ -0,0 +1,36 @@ +from ...exceptions import InitError +from ..base import Extension +from .cors import add_cors +from .methods import add_auto_handlers, add_http_methods + + +class HTTPExtension(Extension): + name = "http" + + def __init__(self, *args) -> None: + super().__init__(*args) + self.all_methods: bool = self.config.HTTP_ALL_METHODS + self.auto_head: bool = self.config.HTTP_AUTO_HEAD + self.auto_options: bool = self.config.HTTP_AUTO_OPTIONS + self.auto_trace: bool = self.config.HTTP_AUTO_TRACE + self.cors: bool = self.config.CORS + + def startup(self, _) -> None: + if self.all_methods: + add_http_methods(self.app, ["CONNECT", "TRACE"]) + + if self.auto_head or self.auto_options or self.auto_trace: + add_auto_handlers( + self.app, self.auto_head, self.auto_options, self.auto_trace + ) + + if self.cors: + add_cors(self.app) + else: + return + + if self.app.ctx.cors.automatic_options and not self.auto_options: + raise InitError( + "Configuration mismatch. If CORS_AUTOMATIC_OPTIONS is set to " + "True, then you must run SanicExt with auto_options=True" + ) diff --git a/backend/sanic_server/sanic_ext/extensions/http/methods.py b/backend/sanic_server/sanic_ext/extensions/http/methods.py new file mode 100644 index 000000000..ce4ddaa26 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/http/methods.py @@ -0,0 +1,120 @@ +from functools import partial +from inspect import isawaitable +from typing import Sequence, Union + +from sanic import Sanic +from sanic.constants import HTTPMethod +from sanic.exceptions import SanicException +from sanic.response import empty, raw + +from ...utils.route import clean_route_name +from ..openapi import openapi + + +def add_http_methods( + app: Sanic, methods: Sequence[Union[str, HTTPMethod]] +) -> None: + """ + Adds HTTP methods to an app + + :param app: Your Sanic app + :type app: Sanic + :param methods: The http methods being added, eg: CONNECT, TRACE + :type methods: Sequence[str] + """ + + app.router.ALLOWED_METHODS = tuple( + [*app.router.ALLOWED_METHODS, *methods] # type: ignore + ) + + +def add_auto_handlers( + app: Sanic, auto_head: bool, auto_options: bool, auto_trace: bool +) -> None: + if auto_trace and "TRACE" not in app.router.ALLOWED_METHODS: + raise SanicException( + "Cannot use apply(..., auto_trace=True) if TRACE is not an " + "allowed HTTP method. Make sure apply(..., all_http_methods=True) " + "has been set." + ) + + async def head_handler(request, get_handler, *args, **kwargs): + retval = get_handler(request, *args, **kwargs) + if isawaitable(retval): + retval = await retval + return retval + + async def options_handler(request, methods, *args, **kwargs): + resp = empty() + resp.headers["allow"] = ",".join([*methods, "OPTIONS"]) + return resp + + async def trace_handler(request): + cleaned_head = b"" + for line in request.head.split(b"\r\n"): + first_word, _ = line.split(b" ", 1) + + if ( + first_word.lower().replace(b":", b"").decode("utf-8") + not in request.app.config.TRACE_EXCLUDED_HEADERS + ): + cleaned_head += line + b"\r\n" + + message = "\r\n\r\n".join( + [part.decode("utf-8") for part in [cleaned_head, request.body]] + ) + return raw(message, content_type="message/http") + + @app.before_server_start + def _add_handlers(app, _): + nonlocal auto_head + nonlocal auto_options + + if auto_head: + app.router.reset() + for group in app.router.groups.values(): + if "GET" in group.methods and "HEAD" not in group.methods: + get_route = group.methods_index["GET"] + name = f"{get_route.name}_head" + app.add_route( + handler=openapi.definition( + summary=clean_route_name(get_route.name).title(), + description="Retrieve HEAD details", + )( + partial( + head_handler, get_handler=get_route.handler + ) + ), + uri=group.uri, + methods=["HEAD"], + strict_slashes=group.strict, + name=name, + ) + app.router.finalize() + + if auto_trace: + app.router.reset() + for group in app.router.groups.values(): + if "TRACE" not in group.methods: + app.add_route( + handler=trace_handler, + uri=group.uri, + methods=["TRACE"], + strict_slashes=group.strict, + ) + app.router.finalize() + + if auto_options: + app.router.reset() + for group in app.router.groups.values(): + if "OPTIONS" not in group.methods: + app.add_route( + handler=partial( + options_handler, methods=group.methods + ), + uri=group.uri, + methods=["OPTIONS"], + strict_slashes=group.strict, + name="_options", + ) + app.router.finalize() diff --git a/backend/sanic_server/sanic_ext/extensions/injection/__init__.py b/backend/sanic_server/sanic_ext/extensions/injection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extensions/injection/constructor.py b/backend/sanic_server/sanic_ext/extensions/injection/constructor.py new file mode 100644 index 000000000..5ae68cc01 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/injection/constructor.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from inspect import isawaitable +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Set, + Tuple, + Type, + get_type_hints, +) + +from sanic import Request +from sanic.exceptions import ServerError + +from sanic_ext.exceptions import InitError + +if TYPE_CHECKING: + from .registry import InjectionRegistry + + +class Constructor: + EXEMPT_ANNOTATIONS = (Request,) + + def __init__( + self, + func: Callable[..., Any], + ): + self.func = func + self.injections: Dict[str, Tuple[Type, Constructor]] = {} + self.pass_kwargs = False + + def __str__(self) -> str: + return f"<{self.__class__.__name__}:{self.func.__name__}>" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(func={self.func.__name__})>" + + async def __call__(self, request, **kwargs): + try: + args = await gather_args(self.injections, request, **kwargs) + if self.pass_kwargs: + args.update(kwargs) + retval = self.func(request, **args) + if isawaitable(retval): + retval = await retval + return retval + except TypeError as e: + raise ServerError( + "Failure to inject dependencies. Make sure that all " + f"dependencies for '{self.func.__name__}' have been " + "registered." + ) from e + + def prepare( + self, + injection_registry: InjectionRegistry, + allowed_types: Set[Type[object]], + ) -> None: + hints = get_type_hints(self.func) + hints.pop("return", None) + missing = [] + for param, annotation in hints.items(): + if annotation in allowed_types: + self.pass_kwargs = True + if ( + annotation not in self.EXEMPT_ANNOTATIONS + and annotation not in allowed_types + ): + dependency = injection_registry.get(annotation) + if not dependency: + missing.append((param, annotation)) + self.injections[param] = (annotation, dependency) + + if missing: + dependencies = "\n".join( + [f" - {param}: {annotation}" for param, annotation in missing] + ) + raise InitError( + "Unable to resolve dependencies for " + f"'{self.func.__name__}'. Could not find the following " + f"dependencies:\n{dependencies}.\nMake sure the dependencies " + "are declared using ext.injection. See " + "https://sanicframework.org/en/plugins/sanic-ext/injection." + "html#injecting-services for more details." + ) + + self.check_circular(set()) + + def check_circular( + self, + checked: Set[Type[object]], + ) -> None: + dependencies = set(self.injections.values()) + for dependency, constructor in dependencies: + if dependency in checked: + raise InitError( + "Circular dependency injection detected on " + f"'{self.func.__name__}'. Check dependencies of " + f"'{constructor.func.__name__}' which may contain " + f"circular dependency chain with {dependency}." + ) + checked.add(dependency) + constructor.check_circular(checked) + + +async def gather_args(injections, request, **kwargs) -> Dict[str, Any]: + return { + name: await do_cast(_type, constructor, request, **kwargs) + for name, (_type, constructor) in injections.items() + } + + +async def do_cast(_type, constructor, request, **kwargs): + cast = constructor if constructor else _type + args = [request] if constructor else [] + + retval = cast(*args, **kwargs) + if isawaitable(retval): + retval = await retval + return retval diff --git a/backend/sanic_server/sanic_ext/extensions/injection/extension.py b/backend/sanic_server/sanic_ext/extensions/injection/extension.py new file mode 100644 index 000000000..e4ab6f06b --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/injection/extension.py @@ -0,0 +1,15 @@ +from ..base import Extension +from .injector import add_injection +from .registry import InjectionRegistry + + +class InjectionExtension(Extension): + name = "injection" + + def startup(self, bootstrap) -> None: + self.registry = InjectionRegistry() + add_injection(self.app, self.registry) + bootstrap._injection_registry = self.registry + + def label(self): + return f"[{self.registry.length}]" diff --git a/backend/sanic_server/sanic_ext/extensions/injection/injector.py b/backend/sanic_server/sanic_ext/extensions/injection/injector.py new file mode 100644 index 000000000..c902acc37 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/injection/injector.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from inspect import getmembers, isclass, isfunction +from typing import Any, Callable, Dict, Optional, Tuple, Type, get_type_hints + +from sanic import Sanic +from sanic.constants import HTTP_METHODS + +from sanic_ext.extensions.injection.constructor import gather_args + +from .registry import InjectionRegistry, SignatureRegistry + + +def add_injection(app: Sanic, injection_registry: InjectionRegistry) -> None: + signature_registry = _setup_signature_registry(app, injection_registry) + + @app.after_server_start + async def finalize_injections(app: Sanic, _): + router_converters = set( + allowed[0] for allowed in app.router.regex_types.values() + ) + router_types = set() + for converter in router_converters: + if isclass(converter): + router_types.add(converter) + elif isfunction(converter): + hints = get_type_hints(converter) + if return_type := hints.get("return"): + router_types.add(return_type) + injection_registry.finalize(router_types) + + @app.signal("http.routing.after") + async def inject_kwargs(request, route, kwargs, **_): + nonlocal signature_registry + + for name in (route.name, f"{route.name}_{request.method.lower()}"): + injections = signature_registry.get(name) + if injections: + break + + if injections: + injected_args = await gather_args(injections, request, **kwargs) + request.match_info.update(injected_args) + + +def _http_method_predicate(member): + return isfunction(member) and member.__name__ in HTTP_METHODS + + +def _setup_signature_registry( + app: Sanic, + injection_registry: InjectionRegistry, +) -> SignatureRegistry: + registry = SignatureRegistry() + + @app.after_server_start + async def setup_signatures(app, _): + nonlocal registry + + for route in app.router.routes: + handlers = [(route.name, route.handler)] + viewclass = getattr(route.handler, "view_class", None) + if viewclass: + handlers = [ + (f"{route.name}_{name}", member) + for name, member in getmembers( + viewclass, _http_method_predicate + ) + ] + for name, handler in handlers: + try: + hints = get_type_hints(handler) + except TypeError: + continue + + injections: Dict[ + str, Tuple[Type, Optional[Callable[..., Any]]] + ] = { + param: ( + annotation, + injection_registry[annotation], + ) + for param, annotation in hints.items() + if annotation in injection_registry + } + registry.register(name, injections) + + return registry diff --git a/backend/sanic_server/sanic_ext/extensions/injection/registry.py b/backend/sanic_server/sanic_ext/extensions/injection/registry.py new file mode 100644 index 000000000..2858be18e --- /dev/null +++ b/backend/sanic_server/sanic_ext/extensions/injection/registry.py @@ -0,0 +1,59 @@ +from typing import Any, Callable, Dict, Optional, Tuple, Type + +from .constructor import Constructor + + +class InjectionRegistry: + def __init__(self): + self._registry: Dict[Type, Optional[Callable[..., Any]]] = {} + + def __getitem__(self, key): + return self._registry[key] + + def __str__(self) -> str: + return str(self._registry) + + def __contains__(self, other: Any): + return other in self._registry + + def get(self, key, default=None): + return self._registry.get(key, default) + + def register( + self, _type: Type, constructor: Optional[Callable[..., Any]] + ) -> None: + if constructor: + constructor = Constructor(constructor) + self._registry[_type] = constructor + + def finalize(self, allowed_types): + for constructor in self._registry.values(): + if isinstance(constructor, Constructor): + constructor.prepare(self, allowed_types) + + @property + def length(self): + return len(self._registry) + + +class SignatureRegistry: + def __init__(self): + self._registry: Dict[ + str, Dict[str, Tuple[Type, Optional[Callable[..., Any]]]] + ] = {} + + def __getitem__(self, key): + return self._registry[key] + + def __str__(self) -> str: + return str(self._registry) + + def get(self, key, default=None): + return self._registry.get(key, default) + + def register( + self, + route_name: str, + injections: Dict[str, Tuple[Type, Optional[Callable[..., Any]]]], + ) -> None: + self._registry[route_name] = injections diff --git a/backend/sanic_server/sanic_ext/extras/__init__.py b/backend/sanic_server/sanic_ext/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extras/serializer/__init__.py b/backend/sanic_server/sanic_ext/extras/serializer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extras/serializer/decorator.py b/backend/sanic_server/sanic_ext/extras/serializer/decorator.py new file mode 100644 index 000000000..5457d76f9 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/serializer/decorator.py @@ -0,0 +1,16 @@ +from functools import wraps +from inspect import isawaitable + + +def serializer(func, *, status: int = 200): + def decorator(f): + @wraps(f) + async def decorated_function(*args, **kwargs): + retval = f(*args, **kwargs) + if isawaitable(retval): + retval = await retval + return func(retval, status=status) + + return decorated_function + + return decorator diff --git a/backend/sanic_server/sanic_ext/extras/validation/__init__.py b/backend/sanic_server/sanic_ext/extras/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/extras/validation/check.py b/backend/sanic_server/sanic_ext/extras/validation/check.py new file mode 100644 index 000000000..a0cad2e56 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/validation/check.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import Any, Literal, NamedTuple, Optional, Tuple, Union + + +class Hint(NamedTuple): + hint: Any + model: bool + literal: bool + typed: bool + nullable: bool + origin: Optional[Any] + allowed: Tuple[Hint, ...] # type: ignore + + def validate( + self, value, schema, allow_multiple=False, allow_coerce=False + ): + if not self.typed: + if self.model: + return check_data( + self.hint, + value, + schema, + allow_multiple=allow_multiple, + allow_coerce=allow_coerce, + ) + if ( + allow_multiple + and isinstance(value, list) + and self.hint is not list + and len(value) == 1 + ): + value = value[0] + try: + _check_types(value, self.literal, self.hint) + except ValueError as e: + if allow_coerce: + if isinstance(value, list): + value = [self.hint(item) for item in value] + else: + value = self.hint(value) + _check_types(value, self.literal, self.hint) + else: + raise e + else: + _check_nullability(value, self.nullable, self.allowed, schema) + + if not self.nullable: + if self.origin in (Union, Literal): + value = _check_inclusion( + value, + self.allowed, + schema, + allow_multiple, + allow_coerce, + ) + elif self.origin is list: + value = _check_list( + value, + self.allowed, + self.hint, + schema, + allow_multiple, + allow_coerce, + ) + elif self.origin is dict: + value = _check_dict( + value, + self.allowed, + self.hint, + schema, + allow_multiple, + allow_coerce, + ) + + return value + + +def check_data(model, data, schema, allow_multiple=False, allow_coerce=False): + if not isinstance(data, dict): + raise TypeError(f"Value '{data}' is not a dict") + sig = schema[model.__name__]["sig"] + hints = schema[model.__name__]["hints"] + bound = sig.bind(**data) + bound.apply_defaults() + params = dict(zip(sig.parameters, bound.args)) + params.update(bound.kwargs) + + hydration_values = {} + try: + for key, value in params.items(): + hint = hints.get(key, Any) + hydration_values[key] = hint.validate( + value, + schema, + allow_multiple=allow_multiple, + allow_coerce=allow_coerce, + ) + except ValueError as e: + raise TypeError(e) + + return model(**hydration_values) + + +def _check_types(value, literal, expected): + if literal: + if expected is Any: + return + elif value != expected: + raise ValueError(f"Value '{value}' must be {expected}") + else: + if not isinstance(value, expected): + raise ValueError(f"Value '{value}' is not of type {expected}") + + +def _check_nullability(value, nullable, allowed, schema): + if not nullable and value is None: + raise ValueError("Value cannot be None") + if nullable and value is not None: + allowed[0].validate(value, schema) + + +def _check_inclusion(value, allowed, schema, allow_multiple, allow_coerce): + for option in allowed: + try: + return option.validate(value, schema, allow_multiple, allow_coerce) + except (ValueError, TypeError): + ... + + options = ", ".join([str(option.hint) for option in allowed]) + raise ValueError(f"Value '{value}' must be one of {options}") + + +def _check_list(value, allowed, hint, schema, allow_multiple, allow_coerce): + if isinstance(value, list): + try: + return [ + _check_inclusion( + item, allowed, schema, allow_multiple, allow_coerce + ) + for item in value + ] + except (ValueError, TypeError): + ... + raise ValueError(f"Value '{value}' must be a {hint}") + + +def _check_dict(value, allowed, hint, schema, allow_multiple, allow_coerce): + if isinstance(value, dict): + try: + return { + key: _check_inclusion( + item, allowed, schema, allow_multiple, allow_coerce + ) + for key, item in value.items() + } + except (ValueError, TypeError): + ... + raise ValueError(f"Value '{value}' must be a {hint}") diff --git a/backend/sanic_server/sanic_ext/extras/validation/decorator.py b/backend/sanic_server/sanic_ext/extras/validation/decorator.py new file mode 100644 index 000000000..c2867344e --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/validation/decorator.py @@ -0,0 +1,76 @@ +from functools import wraps +from inspect import isawaitable +from typing import Callable, Optional, Type, Union + +from sanic import Request + +from sanic_ext.exceptions import InitError + +from .setup import do_validation, generate_schema + + +def validate( + json: Optional[Union[Callable[[Request], bool], Type[object]]] = None, + form: Optional[Union[Callable[[Request], bool], Type[object]]] = None, + query: Optional[Union[Callable[[Request], bool], Type[object]]] = None, + body_argument: str = "body", + query_argument: str = "query", +): + + schemas = { + key: generate_schema(param) + for key, param in ( + ("json", json), + ("form", form), + ("query", query), + ) + } + + if json and form: + raise InitError("Cannot define both a form and json route validator") + + def decorator(f): + @wraps(f) + async def decorated_function(request: Request, *args, **kwargs): + + if schemas["json"]: + await do_validation( + model=json, + data=request.json, + schema=schemas["json"], + request=request, + kwargs=kwargs, + body_argument=body_argument, + allow_multiple=False, + allow_coerce=False, + ) + elif schemas["form"]: + await do_validation( + model=form, + data=request.form, + schema=schemas["form"], + request=request, + kwargs=kwargs, + body_argument=body_argument, + allow_multiple=True, + allow_coerce=False, + ) + elif schemas["query"]: + await do_validation( + model=query, + data=request.args, + schema=schemas["query"], + request=request, + kwargs=kwargs, + body_argument=query_argument, + allow_multiple=True, + allow_coerce=True, + ) + retval = f(request, *args, **kwargs) + if isawaitable(retval): + retval = await retval + return retval + + return decorated_function + + return decorator diff --git a/backend/sanic_server/sanic_ext/extras/validation/schema.py b/backend/sanic_server/sanic_ext/extras/validation/schema.py new file mode 100644 index 000000000..fe34d1248 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/validation/schema.py @@ -0,0 +1,76 @@ +from dataclasses import is_dataclass +from inspect import isclass, signature +from typing import ( # type: ignore + Dict, + Literal, + Union, + _GenericAlias, + get_args, + get_origin, + get_type_hints, +) + +from .check import Hint + + +def make_schema(agg, item): + if type(item) in (bool, str, int, float): + return agg + if isinstance(item, _GenericAlias) and (args := get_args(item)): + for arg in args: + make_schema(agg, arg) + elif item.__name__ not in agg and is_dataclass(item): + sig = signature(item) + hints = parse_hints(get_type_hints(item)) + + agg[item.__name__] = { + "sig": sig, + "hints": hints, + } + + for hint in hints.values(): + make_schema(agg, hint.hint) + + return agg + + +def parse_hints(hints) -> Dict[str, Hint]: + output: Dict[str, Hint] = { + name: parse_hint(hint) for name, hint in hints.items() + } + return output + + +def parse_hint(hint): + origin = None + literal = not isclass(hint) + nullable = False + typed = False + model = False + allowed = tuple() + + if is_dataclass(hint): + model = True + elif isinstance(hint, _GenericAlias): + typed = True + literal = False + origin = get_origin(hint) + args = get_args(hint) + nullable = origin == Union and type(None) in args + + if nullable: + allowed = (args[0],) + elif origin is dict: + allowed = (args[1],) + elif origin is list or origin is Literal or origin is Union: + allowed = args + + return Hint( + hint, + model, + literal, + typed, + nullable, + origin, + tuple([parse_hint(item) for item in allowed]), + ) diff --git a/backend/sanic_server/sanic_ext/extras/validation/setup.py b/backend/sanic_server/sanic_ext/extras/validation/setup.py new file mode 100644 index 000000000..2a8b86822 --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/validation/setup.py @@ -0,0 +1,79 @@ +from functools import partial +from inspect import isawaitable, isclass + +from sanic.log import logger + +from sanic_ext.exceptions import ValidationError + +from .schema import make_schema +from .validators import ( + _validate_annotations, + _validate_instance, + validate_body, +) + +try: + from pydantic import BaseModel + + PYDANTIC = True +except ImportError: + PYDANTIC = False + + +async def do_validation( + *, + model, + data, + schema, + request, + kwargs, + body_argument, + allow_multiple, + allow_coerce, +): + try: + logger.debug(f"Validating {request.path} using {model}") + if model is not None: + if isclass(model): + validator = _get_validator( + model, schema, allow_multiple, allow_coerce + ) + validation = validate_body(validator, model, data) + kwargs[body_argument] = validation + else: + validation = model( + request=request, data=data, handler_kwargs=kwargs + ) + if isawaitable(validation): + await validation + except TypeError as e: + raise ValidationError(e) + + +def generate_schema(param): + try: + if param is None or _is_pydantic(param): + return param + except TypeError: + ... + + return make_schema({}, param) if isclass(param) else param + + +def _is_pydantic(model): + is_pydantic = PYDANTIC and ( + issubclass(model, BaseModel) or hasattr(model, "__pydantic_model__") + ) + return is_pydantic + + +def _get_validator(model, schema, allow_multiple, allow_coerce): + if _is_pydantic(model): + return _validate_instance + + return partial( + _validate_annotations, + schema=schema, + allow_multiple=allow_multiple, + allow_coerce=allow_coerce, + ) diff --git a/backend/sanic_server/sanic_ext/extras/validation/validators.py b/backend/sanic_server/sanic_ext/extras/validation/validators.py new file mode 100644 index 000000000..76769bcff --- /dev/null +++ b/backend/sanic_server/sanic_ext/extras/validation/validators.py @@ -0,0 +1,36 @@ +from typing import Any, Callable, Dict, Tuple, Type + +from sanic_ext.exceptions import ValidationError + +from .check import check_data + +try: + from pydantic import ValidationError as PydanticValidationError + + VALIDATION_ERROR: Tuple[Type[Exception], ...] = ( + TypeError, + PydanticValidationError, + ) +except ImportError: + VALIDATION_ERROR = (TypeError,) + + +def validate_body( + validator: Callable[[Type[Any], Dict[str, Any]], Any], + model: Type[Any], + body: Dict[str, Any], +) -> Any: + try: + return validator(model, body) + except VALIDATION_ERROR as e: + raise ValidationError( + f"Invalid request body: {model.__name__}. Error: {e}" + ) + + +def _validate_instance(model, body): + return model(**body) + + +def _validate_annotations(model, body, schema, allow_multiple, allow_coerce): + return check_data(model, body, schema, allow_multiple, allow_coerce) diff --git a/backend/sanic_server/sanic_ext/utils/__init__.py b/backend/sanic_server/sanic_ext/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_ext/utils/route.py b/backend/sanic_server/sanic_ext/utils/route.py new file mode 100644 index 000000000..c4ef47b1f --- /dev/null +++ b/backend/sanic_server/sanic_ext/utils/route.py @@ -0,0 +1,118 @@ +import re + + +def clean_route_name(name: str) -> str: + parts = name.split(".", 1) + name = parts[-1] + for target in ("_", ".", " "): + name = name.replace(target, " ") + + return name.title() + + +def get_uri_filter(app): + """ + Return a filter function that takes a URI and returns whether it should + be filter out from the swagger documentation or not. + + Arguments: + app: The application to take `config.API_URI_FILTER` from. Possible + values for this config option are: `slash` (to keep URIs that + end with a `/`), `all` (to keep all URIs). All other values + default to keep all URIs that don't end with a `/`. + + Returns: + `True` if the URI should be *filtered out* from the swagger + documentation, and `False` if it should be kept in the documentation. + """ + choice = getattr(app.config, "API_URI_FILTER", None) + + if choice == "slash": + # Keep URIs that end with a /. + return lambda uri: not uri.endswith("/") + + if choice == "all": + # Keep all URIs. + return lambda uri: False + + # Keep URIs that don't end with a /, (special case: "/"). + return lambda uri: len(uri) > 1 and uri.endswith("/") + + +def remove_nulls(dictionary, deep=True): + """ + Removes all null values from a dictionary. + """ + return { + k: remove_nulls(v, deep) if deep and type(v) is dict else v + for k, v in dictionary.items() + if v is not None + } + + +def remove_nulls_from_kwargs(**kwargs): + return remove_nulls(kwargs, deep=False) + + +def get_blueprinted_routes(app): + for blueprint in app.blueprints.values(): + if not hasattr(blueprint, "routes"): + continue + + for route in blueprint.routes: + if hasattr(route.handler, "view_class"): + # before sanic 21.3, route.handler could be a number of + # different things, so have to type check + for http_method in route.methods: + _handler = getattr( + route.handler.view_class, http_method.lower(), None + ) + if _handler: + yield (blueprint.name, _handler) + else: + yield (blueprint.name, route.handler) + + +def get_all_routes(app, skip_prefix): + uri_filter = get_uri_filter(app) + + for group in app.router.groups.values(): + uri = f"/{group.path}" + + # prior to sanic 21.3 routes came in both forms + # (e.g. /test and /test/ ) + # after sanic 21.3 routes come in one form, + # with an attribute "strict", + # so we simulate that ourselves: + + uris = [uri] + if not group.strict and len(uri) > 1: + alt = uri[:-1] if uri.endswith("/") else f"{uri}/" + uris.append(alt) + + for uri in uris: + if uri_filter(uri): + continue + + if skip_prefix and group.raw_path.startswith( + skip_prefix.lstrip("/") + ): + continue + + for parameter in group.params.values(): + uri = re.sub( + f"<{parameter.name}.*?>", + f"{{{parameter.name}}}", + uri, + ) + + for route in group: + if route.name and "static" in route.name: + continue + + method_handlers = [ + (method, route.handler) for method in route.methods + ] + + _, name = route.name.split(".", 1) + yield (uri, name, route.params.values(), method_handlers) diff --git a/backend/sanic_server/sanic_ext/utils/string.py b/backend/sanic_server/sanic_ext/utils/string.py new file mode 100644 index 000000000..549f1d915 --- /dev/null +++ b/backend/sanic_server/sanic_ext/utils/string.py @@ -0,0 +1,12 @@ +import re + +CAMEL_TO_SNAKE_PATTERNS = ( + re.compile(r"(.)([A-Z][a-z]+)"), + re.compile(r"([a-z0-9])([A-Z])"), +) + + +def camel_to_snake(name: str) -> str: + for pattern in CAMEL_TO_SNAKE_PATTERNS: + name = pattern.sub(r"\1_\2", name) + return name.lower() diff --git a/backend/sanic_server/sanic_plugin_toolkit/LICENSE.txt b/backend/sanic_server/sanic_plugin_toolkit/LICENSE.txt new file mode 100644 index 000000000..ee7e0c468 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017-2021 Ashley Sommer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/backend/sanic_server/sanic_plugin_toolkit/__init__.py b/backend/sanic_server/sanic_plugin_toolkit/__init__.py new file mode 100644 index 000000000..59d68a6a5 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: latin-1 -*- +# this is ascii, no unicode in this document +from .plugin import SanicPlugin +from .realm import SanicPluginRealm + + +__version__ = '1.2.0' +__all__ = ["SanicPlugin", "SanicPluginRealm", "__version__"] diff --git a/backend/sanic_server/sanic_plugin_toolkit/config.py b/backend/sanic_server/sanic_plugin_toolkit/config.py new file mode 100644 index 000000000..a6383f0c8 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/config.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +Allows SPTK to parse a config file and automatically load defined plugins +""" + +import configparser +import importlib +import os + +import pkg_resources + + +def _find_config_file(filename): + abs = os.path.abspath(filename) + if os.path.isfile(abs): + return abs + raise FileNotFoundError(filename) + + +def _get_config_defaults(): + return {} + + +def _find_advertised_plugins(realm): + plugins = {} + for entrypoint in pkg_resources.iter_entry_points('sanic_plugins'): + if entrypoint.attrs: + attr = entrypoint.attrs[0] + else: + attr = None + name = entrypoint.name + try: + module = importlib.import_module(entrypoint.module_name) + except ImportError: + realm.error("Cannot import {}".format(entrypoint.module_name)) + continue + p_dict = {'name': name, 'module': module} + if attr: + try: + inst = getattr(module, attr) + except AttributeError: + realm.error("Cannot import {} from {}".format(attr, entrypoint.module_name)) + continue + p_dict['instance'] = inst + plugins[name] = p_dict + plugins[str(name).casefold()] = p_dict + return plugins + + +def _transform_option_dict(options): + parts = str(options).split(',') + args = [] + kwargs = {} + for part in parts: + if "=" in part: + kwparts = part.split('=', 1) + kwkey = kwparts[0] + val = kwparts[1] + else: + val = part + kwkey = None + + if val == "True": + val = True + elif val == "False": + val = False + elif val == "None": + val = None + elif '.' in val: + try: + f = float(val) + val = f + except ValueError: + pass + else: + try: + i = int(val) + val = i + except ValueError: + pass + if kwkey: + kwargs[kwkey] = val + else: + args.append(val) + args = tuple(args) + return args, kwargs + + +def _register_advertised_plugin(realm, app, plugin_def, *args, **kwargs): + name = plugin_def['name'] + realm.info("Found advertised plugin {}.".format(name)) + inst = plugin_def.get('instance', None) + if inst: + p = inst + else: + p = plugin_def['module'] + return realm.register_plugin(p, *args, **kwargs) + + +def _try_register_other_plugin(realm, app, plugin_name, *args, **kwargs): + try: + module = importlib.import_module(plugin_name) + except ImportError: + raise RuntimeError("Do not know how to register plugin: {}".format(plugin_name)) + return realm.register_plugin(module, *args, **kwargs) + + +def _register_plugins(realm, app, config_plugins): + advertised_plugins = _find_advertised_plugins(realm) + registered_plugins = {} + for plugin, options in config_plugins: + realm.info("Loading plugin: {}...".format(plugin)) + if options: + args, kwargs = _transform_option_dict(options) + else: + args = tuple() + kwargs = {} + p_fold = str(plugin).casefold() + if p_fold in advertised_plugins: + assoc = _register_advertised_plugin(realm, app, advertised_plugins[p_fold], *args, **kwargs) + else: + assoc = _try_register_other_plugin(realm, app, plugin, *args, **kwargs) + _p, reg = assoc + registered_plugins[reg.plugin_name] = assoc + return registered_plugins + + +def load_config_file(realm, app, filename): + """ + + :param realm: + :type realm: sanic_plugin_toolkit.SanicPluginRealm + :param app: + :type app: sanic.Sanic + :param filename: + :type filename: str + :return: + """ + location = _find_config_file(filename) + realm.info("Loading sanic_plugin_toolkit config file {}.".format(location)) + + defaults = _get_config_defaults() + parser = configparser.ConfigParser(defaults=defaults, allow_no_value=True, strict=False) + parser.read(location) + try: + config_plugins = parser.items('plugins') + except Exception as e: + raise e + # noinspection PyUnusedLocal + _ = _register_plugins(realm, app, config_plugins) # noqa: F841 + return diff --git a/backend/sanic_server/sanic_plugin_toolkit/context.py b/backend/sanic_server/sanic_plugin_toolkit/context.py new file mode 100644 index 000000000..21ab6991b --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/context.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +""" +This is the specialised dictionary that is used by Sanic Plugin Toolkit +to manage Context objects. It can be hierarchical, and it searches its +parents if it cannot find an item in its own dictionary. It can create its +own children. +""" + + +class HierDict(object): + """ + This is the specialised dictionary that is used by the Sanic Plugin Toolkit + to manage Context objects. It can be hierarchical, and it searches its + parents if it cannot find an item in its own dictionary. It can create its + own children. + """ + + __slots__ = ('_parent_hd', '_dict', '__weakref__') + + @classmethod + def _iter_slots(cls): + use_cls = cls + bases = cls.__bases__ + base_count = 0 + while True: + if use_cls.__slots__: + for _s in use_cls.__slots__: + yield _s + if base_count >= len(bases): + break + use_cls = bases[base_count] + base_count += 1 + return + + def _inner(self): + """ + :return: the internal dictionary + :rtype: dict + """ + return object.__getattribute__(self, '_dict') + + def __repr__(self): + _dict_repr = repr(self._inner()) + return "HierDict({:s})".format(_dict_repr) + + def __str__(self): + _dict_str = str(self._inner()) + return "HierDict({:s})".format(_dict_str) + + def __len__(self): + return len(self._inner()) + + def __setitem__(self, key, value): + # TODO: If key is in __slots__, ignore it and return + return self._inner().__setitem__(key, value) + + def __getitem__(self, item): + try: + return self._inner().__getitem__(item) + except KeyError as e1: + parents_searched = [self] + parent = self._parent_hd + while parent: + try: + return parent._inner().__getitem__(item) + except KeyError: + parents_searched.append(parent) + # noinspection PyProtectedMember + next_parent = parent._parent_hd + if next_parent in parents_searched: + raise RuntimeError("Recursive HierDict found!") + parent = next_parent + raise e1 + + def __delitem__(self, key): + self._inner().__delitem__(key) + + def __getattr__(self, item): + if item in self._iter_slots(): + return object.__getattribute__(self, item) + try: + return self.__getitem__(item) + except KeyError as e: + raise AttributeError(*e.args) + + def __setattr__(self, key, value): + if key in self._iter_slots(): + if key == '__weakref__': + if value is None: + return + else: + raise ValueError("Cannot set weakrefs on Context") + return object.__setattr__(self, key, value) + try: + return self.__setitem__(key, value) + except Exception as e: # pragma: no cover + # what exceptions can occur on setting an item? + raise e + + def __contains__(self, item): + return self._inner().__contains__(item) + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except (AttributeError, KeyError): + return default + + def set(self, key, value): + try: + return self.__setattr__(key, value) + except Exception as e: # pragma: no cover + raise e + + def items(self): + """ + A set-like read-only view HierDict's (K,V) tuples + :return: + :rtype: frozenset + """ + return self._inner().items() + + def keys(self): + """ + An object containing a view on the HierDict's keys + :return: + :rtype: tuple # using tuple to represent an immutable list + """ + return self._inner().keys() + + def values(self): + """ + An object containing a view on the HierDict's values + :return: + :rtype: tuple # using tuple to represent an immutable list + """ + return self._inner().values() + + def replace(self, key, value): + """ + If this HierDict doesn't already have this key, it sets + the value on a parent HierDict if that parent has the key, + otherwise sets the value on this HierDict. + :param key: + :param value: + :return: Nothing + :rtype: None + """ + if key in self._inner().keys(): + return self.__setitem__(key, value) + parents_searched = [self] + parent = self._parent_hd + while parent: + try: + if key in parent.keys(): + return parent.__setitem__(key, value) + except (KeyError, AttributeError): + pass + parents_searched.append(parent) + # noinspection PyProtectedMember + next_parent = parent._parent_context + if next_parent in parents_searched: + raise RuntimeError("Recursive HierDict found!") + parent = next_parent + return self.__setitem__(key, value) + + # noinspection PyPep8Naming + def update(self, E=None, **F): + """ + Update HierDict from dict/iterable E and F + :return: Nothing + :rtype: None + """ + if E is not None: + if hasattr(E, 'keys'): + for K in E: + self.replace(K, E[K]) + elif hasattr(E, 'items'): + for K, V in E.items(): + self.replace(K, V) + else: + for K, V in E: + self.replace(K, V) + for K in F: + self.replace(K, F[K]) + + def __new__(cls, parent, *args, **kwargs): + self = super(HierDict, cls).__new__(cls) + self._dict = dict(*args, **kwargs) + if parent is not None: + assert isinstance(parent, HierDict), "Parent context must be a valid initialised HierDict" + self._parent_hd = parent + else: + self._parent_hd = None + return self + + def __init__(self, *args, **kwargs): + args = list(args) + args.pop(0) # remove parent + super(HierDict, self).__init__() + + def __getstate__(self): + state_dict = {} + for s in HierDict.__slots__: + if s == "__weakref__": + continue + state_dict[s] = object.__getattribute__(self, s) + return state_dict + + def __setstate__(self, state): + for s, v in state.items(): + setattr(self, s, v) + + def __reduce__(self): + state_dict = self.__getstate__() + _ = state_dict.pop('_stk_realm', None) + parent_context = state_dict.pop('_parent_hd') + return (HierDict.__new__, (self.__class__, parent_context), state_dict) + + +class SanicContext(HierDict): + __slots__ = ('_stk_realm',) + + def __repr__(self): + _dict_repr = repr(self._inner()) + return "SanicContext({:s})".format(_dict_repr) + + def __str__(self): + _dict_str = str(self._inner()) + return "SanicContext({:s})".format(_dict_str) + + def create_child_context(self, *args, **kwargs): + return SanicContext(self._stk_realm, self, *args, **kwargs) + + def __new__(cls, stk_realm, parent, *args, **kwargs): + if parent is not None: + assert isinstance(parent, SanicContext), "Parent context must be a valid initialised SanicContext" + self = super(SanicContext, cls).__new__(cls, parent, *args, **kwargs) + self._stk_realm = stk_realm + return self + + def __init__(self, *args, **kwargs): + args = list(args) + # remove realm + _stk_realm = args.pop(0) # noqa: F841 + super(SanicContext, self).__init__(*args) + + def __getstate__(self): + state_dict = super(SanicContext, self).__getstate__() + for s in SanicContext.__slots__: + state_dict[s] = object.__getattribute__(self, s) + return state_dict + + def __reduce__(self): + state_dict = self.__getstate__() + realm = state_dict.pop('_stk_realm') + parent_context = state_dict.pop('_parent_hd') + return (SanicContext.__new__, (self.__class__, realm, parent_context), state_dict) + + def for_request(self, req): + # shortcut for context.request[id(req)] + requests_ctx = self.request + return requests_ctx[id(req)] if req else None diff --git a/backend/sanic_server/sanic_plugin_toolkit/plugin.py b/backend/sanic_server/sanic_plugin_toolkit/plugin.py new file mode 100644 index 000000000..ba0b12ffd --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/plugin.py @@ -0,0 +1,560 @@ +# -*- coding: utf-8 -*- +import importlib +from collections import defaultdict, deque, namedtuple +from distutils.version import LooseVersion +from functools import update_wrapper +from inspect import isawaitable +from typing import Type + +from ..sanic import Blueprint, Sanic +from ..sanic import __version__ as sanic_version + +SANIC_VERSION = LooseVersion(sanic_version) +SANIC_21_6_0 = LooseVersion("21.6.0") +SANIC_21_9_0 = LooseVersion("21.9.0") + +CRITICAL = 50 +ERROR = 40 +WARNING = 30 +INFO = 20 +DEBUG = 10 + +FutureMiddleware = namedtuple("FutureMiddleware", ["middleware", "args", "kwargs"]) +FutureRoute = namedtuple("FutureRoute", ["handler", "uri", "args", "kwargs"]) +FutureWebsocket = namedtuple("FutureWebsocket", ["handler", "uri", "args", "kwargs"]) +FutureStatic = namedtuple("FutureStatic", ["uri", "file_or_dir", "args", "kwargs"]) +FutureException = namedtuple("FutureException", ["handler", "exceptions", "kwargs"]) +PluginRegistration = namedtuple( + "PluginRegistration", ["realm", "plugin_name", "url_prefix"] +) +PluginAssociated = namedtuple("PluginAssociated", ["plugin", "reg"]) + + +class SanicPlugin(object): + __slots__ = ( + "registrations", + "_routes", + "_ws", + "_static", + "_middlewares", + "_exceptions", + "_listeners", + "_initialized", + "__weakref__", + ) + + AssociatedTuple: Type[object] = PluginAssociated + + # Decorator + def middleware(self, *args, **kwargs): + """Decorate and register middleware + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The middleware function to use as the decorator + :rtype: fn + """ + kwargs.setdefault("priority", 5) + kwargs.setdefault("relative", None) + kwargs.setdefault("attach_to", None) + kwargs.setdefault("with_context", False) + if len(args) == 1 and callable(args[0]): + middle_f = args[0] + self._middlewares.append( + FutureMiddleware(middle_f, args=tuple(), kwargs=kwargs) + ) + return middle_f + + def wrapper(middleware_f): + self._middlewares.append( + FutureMiddleware(middleware_f, args=args, kwargs=kwargs) + ) + return middleware_f + + return wrapper + + def exception(self, *args, **kwargs): + """Decorate and register an exception handler + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The exception function to use as the decorator + :rtype: fn + """ + if len(args) == 1 and callable(args[0]): + if isinstance(args[0], type) and issubclass(args[0], Exception): + pass + else: # pragma: no cover + raise RuntimeError( + "Cannot use the @exception decorator without arguments" + ) + + def wrapper(handler_f): + self._exceptions.append( + FutureException(handler_f, exceptions=args, kwargs=kwargs) + ) + return handler_f + + return wrapper + + def listener(self, event, *args, **kwargs): + """Create a listener from a decorated function. + :param event: Event to listen to. + :type event: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The function to use as the listener + :rtype: fn + """ + if len(args) == 1 and callable(args[0]): # pragma: no cover + raise RuntimeError("Cannot use the @listener decorator without arguments") + + def wrapper(listener_f): + if len(kwargs) > 0: + listener_f = (listener_f, kwargs) + self._listeners[event].append(listener_f) + return listener_f + + return wrapper + + def route(self, uri, *args, **kwargs): + """Create a plugin route from a decorated function. + :param uri: endpoint at which the route will be accessible. + :type uri: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The function to use as the decorator + :rtype: fn + """ + if len(args) == 0 and callable(uri): # pragma: no cover + raise RuntimeError("Cannot use the @route decorator without arguments.") + kwargs.setdefault("methods", frozenset({"GET"})) + kwargs.setdefault("host", None) + kwargs.setdefault("strict_slashes", False) + kwargs.setdefault("stream", False) + kwargs.setdefault("name", None) + kwargs.setdefault("version", None) + kwargs.setdefault("ignore_body", False) + kwargs.setdefault("websocket", False) + kwargs.setdefault("subprotocols", None) + kwargs.setdefault("unquote", False) + kwargs.setdefault("static", False) + if SANIC_21_6_0 <= SANIC_VERSION: + kwargs.setdefault("version_prefix", "/v") + if SANIC_21_9_0 <= SANIC_VERSION: + kwargs.setdefault("error_format", None) + + def wrapper(handler_f): + self._routes.append(FutureRoute(handler_f, uri, args, kwargs)) + return handler_f + + return wrapper + + def websocket(self, uri, *args, **kwargs): + """Create a websocket route from a decorated function + deprecated. now use @route("/path",websocket=True) + """ + kwargs["websocket"] = True + return self.route(uri, *args, **kwargs) + + def static(self, uri, file_or_directory, *args, **kwargs): + """Create a websocket route from a decorated function + :param uri: endpoint at which the socket endpoint will be accessible. + :type uri: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The function to use as the decorator + :rtype: fn + """ + + kwargs.setdefault("pattern", r"/?.+") + kwargs.setdefault("use_modified_since", True) + kwargs.setdefault("use_content_range", False) + kwargs.setdefault("stream_large_files", False) + kwargs.setdefault("name", "static") + kwargs.setdefault("host", None) + kwargs.setdefault("strict_slashes", None) + kwargs.setdefault("content_type", None) + + self._static.append(FutureStatic(uri, file_or_directory, args, kwargs)) + + def on_before_registered(self, context, *args, **kwargs): + pass + + def on_registered(self, context, reg, *args, **kwargs): + pass + + def find_plugin_registration(self, realm): + if isinstance(realm, PluginRegistration): + return realm + for reg in self.registrations: + (r, n, u) = reg + if r is not None and r == realm: + return reg + raise KeyError("Plugin registration not found") + + def first_plugin_context(self): + """Returns the context is associated with the first app this plugin was + registered on""" + # Note, because registrations are stored in a set, its not _really_ + # the first one, but whichever one it sees first in the set. + first_realm_reg = next(iter(self.registrations)) + return self.get_context_from_realm(first_realm_reg) + + def get_context_from_realm(self, realm): + rt = RuntimeError("Cannot use the plugin's Context before it is registered.") + if isinstance(realm, PluginRegistration): + reg = realm + else: + try: + reg = self.find_plugin_registration(realm) + except LookupError: + raise rt + (r, n, u) = reg + try: + return r.get_context(n) + except KeyError as k: + raise k + except AttributeError: + raise rt + + def get_app_from_realm_context(self, realm): + rt = RuntimeError( + "Cannot get the app from Realm before this plugin is registered on the Realm." + ) + if isinstance(realm, PluginRegistration): + reg = realm + else: + try: + reg = self.find_plugin_registration(realm) + except LookupError: + raise rt + context = self.get_context_from_realm(reg) + try: + app = context.app + except (LookupError, AttributeError): + raise rt + return app + + def resolve_url_for(self, realm, view_name, *args, **kwargs): + try: + reg = self.find_plugin_registration(realm) + except LookupError: + raise RuntimeError( + "Cannot resolve URL because this plugin is not registered on the PluginRealm." + ) + (realm, name, url_prefix) = reg + app = self.get_app_from_realm_context(reg) + if app is None: + return None + if isinstance(app, Blueprint): + self.warning( + "Cannot use url_for when plugin is registered on a Blueprint. Use `app.url_for` instead." + ) + return None + constructed_name = "{}.{}".format(name, view_name) + return app.url_for(constructed_name, *args, **kwargs) + + def log(self, realm, level, message, *args, **kwargs): + try: + reg = self.find_plugin_registration(realm) + except LookupError: + raise RuntimeError( + "Cannot log using this plugin, because this plugin is not registered on the Realm." + ) + context = self.get_context_from_realm(reg) + return context.log(level, message, *args, reg=self, **kwargs) + + def debug(self, message, *args, **kwargs): + return self.log(DEBUG, message, *args, **kwargs) + + def info(self, message, *args, **kwargs): + return self.log(INFO, message, *args, **kwargs) + + def warning(self, message, *args, **kwargs): + return self.log(WARNING, message, *args, **kwargs) + + def error(self, message, *args, **kwargs): + return self.log(ERROR, message, *args, **kwargs) + + def critical(self, message, *args, **kwargs): + return self.log(CRITICAL, message, *args, **kwargs) + + @classmethod + def decorate(cls, app, *args, run_middleware=False, with_context=False, **kwargs): + """ + This is a decorator that can be used to apply this plugin to a specific + route/view on your app, rather than the whole app. + :param app: + :type app: Sanic | Blueprint + :param args: + :type args: tuple(Any) + :param run_middleware: + :type run_middleware: bool + :param with_context: + :type with_context: bool + :param kwargs: + :param kwargs: dict(Any) + :return: the decorated route/view + :rtype: fn + """ + from ..sanic_plugin_toolkit.realm import SanicPluginRealm + + realm = SanicPluginRealm(app) # get the singleton from the app + try: + assoc = realm.register_plugin(cls, skip_reg=True) + except ValueError as e: + # this is normal, if this plugin has been registered previously + assert e.args and len(e.args) > 1 + assoc = e.args[1] + (plugin, reg) = assoc + # plugin may not actually be registered + inst = realm.get_plugin_inst(plugin) + # registered might be True, False or None at this point + regd = True if inst else None + if regd is True: + # middleware will be run on this route anyway, because the plugin + # is registered on the app. Turn it off on the route-level. + run_middleware = False + req_middleware = deque() + resp_middleware = deque() + if run_middleware: + for i, m in enumerate(plugin._middlewares): + attach_to = m.kwargs.pop("attach_to", "request") + priority = m.kwargs.pop("priority", 5) + with_context = m.kwargs.pop("with_context", False) + mw_handle_fn = m.middleware + if attach_to == "response": + relative = m.kwargs.pop("relative", "post") + if relative == "pre": + mw = ( + 0, + 0 - priority, + 0 - i, + mw_handle_fn, + with_context, + m.args, + m.kwargs, + ) + else: # relative = "post" + mw = ( + 1, + 0 - priority, + 0 - i, + mw_handle_fn, + with_context, + m.args, + m.kwargs, + ) + resp_middleware.append(mw) + else: # attach_to = "request" + relative = m.kwargs.pop("relative", "pre") + if relative == "post": + mw = ( + 1, + priority, + i, + mw_handle_fn, + with_context, + m.args, + m.kwargs, + ) + else: # relative = "pre" + mw = ( + 0, + priority, + i, + mw_handle_fn, + with_context, + m.args, + m.kwargs, + ) + req_middleware.append(mw) + + req_middleware = tuple(sorted(req_middleware)) + resp_middleware = tuple(sorted(resp_middleware)) + + def _decorator(f): + nonlocal realm, plugin, regd, run_middleware, with_context + nonlocal req_middleware, resp_middleware, args, kwargs + + async def wrapper(request, *a, **kw): + nonlocal realm, plugin, regd, run_middleware, with_context + nonlocal req_middleware, resp_middleware, f, args, kwargs + # the plugin was not registered on the app, it might be now + if regd is None: + _inst = realm.get_plugin_inst(plugin) + regd = _inst is not None + + context = plugin.get_context_from_realm(realm) + if run_middleware and not regd and len(req_middleware) > 0: + for ( + _a, + _p, + _i, + handler, + with_context, + args, + kwargs, + ) in req_middleware: + if with_context: + resp = handler(request, *args, context=context, **kwargs) + else: + resp = handler(request, *args, **kwargs) + if isawaitable(resp): + resp = await resp + if resp: + return + + response = await plugin.route_wrapper( + f, + request, + context, + a, + kw, + *args, + with_context=with_context, + **kwargs + ) + if isawaitable(response): + response = await response + if run_middleware and not regd and len(resp_middleware) > 0: + for ( + _a, + _p, + _i, + handler, + with_context, + args, + kwargs, + ) in resp_middleware: + if with_context: + _resp = handler( + request, response, *args, context=context, **kwargs + ) + else: + _resp = handler(request, response, *args, **kwargs) + if isawaitable(_resp): + _resp = await _resp + if _resp: + response = _resp + break + return response + + return update_wrapper(wrapper, f) + + return _decorator + + async def route_wrapper( + self, + route, + request, + context, + request_args, + request_kw, + *decorator_args, + with_context=None, + **decorator_kw + ): + """This is the function that is called when a route is decorated with + your plugin decorator. Context will normally be None, but the user + can pass use_context=True so the route will get the plugin + context + """ + # by default, do nothing, just run the wrapped function + if with_context: + resp = route(request, context, *request_args, **request_kw) + else: + resp = route(request, *request_args, **request_kw) + if isawaitable(resp): + resp = await resp + return resp + + def __new__(cls, *args, **kwargs): + # making a bold assumption here. + # Assuming that if a sanic plugin is initialized using + # `MyPlugin(app)`, then the user is attempting to do a legacy plugin + # instantiation, aka Flask-Style plugin instantiation. + if ( + args + and len(args) > 0 + and (isinstance(args[0], Sanic) or isinstance(args[0], Blueprint)) + ): + app = args[0] + try: + mod_name = cls.__module__ + mod = importlib.import_module(mod_name) + assert mod + except (ImportError, AssertionError): + raise RuntimeError( + "Failed attempting a legacy plugin instantiation. " + "Cannot find the module this plugin belongs to." + ) + # Get the sanic_plugin_toolkit singleton from this app + from ..sanic_plugin_toolkit.realm import SanicPluginRealm + + realm = SanicPluginRealm(app) + # catch cases like when the module is "__main__" or + # "__call__" or "__init__" + if mod_name.startswith("__"): + # In this case, we cannot use the module to register the + # plugin. Try to use the class method. + assoc = realm.register_plugin(cls, *args, **kwargs) + else: + assoc = realm.register_plugin(mod, *args, **kwargs) + return assoc + self = super(SanicPlugin, cls).__new__(cls) + try: + self._initialized # initialized may be True or Unknown + except AttributeError: + self._initialized = False + return self + + def is_registered_in_realm(self, check_realm): + for reg in self.registrations: + (realm, name, url) = reg + if realm is not None and realm == check_realm: + return True + return False + + def __init__(self, *args, **kwargs): + # Sometimes __init__ can be called twice. + # Ignore it on subsequent times + if self._initialized: + return + super(SanicPlugin, self).__init__(*args, **kwargs) + self._routes = [] + self._ws = [] + self._static = [] + self._middlewares = [] + self._exceptions = [] + self._listeners = defaultdict(list) + self.registrations = set() + self._initialized = True + + def __getstate__(self): + state_dict = {} + for s in SanicPlugin.__slots__: + state_dict[s] = getattr(self, s) + return state_dict + + def __setstate__(self, state): + for s, v in state.items(): + if s == "__weakref__": + if v is None: + continue + else: + raise NotImplementedError("Setting weakrefs on Plugin") + setattr(self, s, v) + + def __reduce__(self): + state_dict = self.__getstate__() + return SanicPlugin.__new__, (self.__class__,), state_dict diff --git a/backend/sanic_server/sanic_plugin_toolkit/plugins/__init__.py b/backend/sanic_server/sanic_plugin_toolkit/plugins/__init__.py new file mode 100644 index 000000000..25079c262 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/plugins/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# +from .contextualize import Contextualize, contextualize + + +__all__ = ('Contextualize', 'contextualize') diff --git a/backend/sanic_server/sanic_plugin_toolkit/plugins/contextualize.py b/backend/sanic_server/sanic_plugin_toolkit/plugins/contextualize.py new file mode 100644 index 000000000..557f6ad23 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/plugins/contextualize.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +from collections import namedtuple + +from ...sanic_plugin_toolkit import SanicPlugin +from ...sanic_plugin_toolkit.plugin import ( + SANIC_21_6_0, + SANIC_21_9_0, + SANIC_VERSION, + FutureMiddleware, + FutureRoute, +) + +ContextualizeAssociatedTuple = namedtuple( + "ContextualizeAssociatedTuple", ["plugin", "reg"] +) + + +class ContextualizeAssociated(ContextualizeAssociatedTuple): + __slots__ = () + + # Decorator + def middleware(self, *args, **kwargs): + """Decorate and register middleware + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The middleware function to use as the decorator + :rtype: fn + """ + kwargs.setdefault("priority", 5) + kwargs.setdefault("relative", None) + kwargs.setdefault("attach_to", None) + kwargs["with_context"] = True # This is the whole point of this plugin + plugin = self.plugin + reg = self.reg + + if len(args) == 1 and callable(args[0]): + middle_f = args[0] + return plugin._add_new_middleware(reg, middle_f, **kwargs) + + def wrapper(middle_f): + nonlocal plugin, reg + nonlocal args, kwargs + return plugin._add_new_middleware(reg, middle_f, *args, **kwargs) + + return wrapper + + def route(self, uri, *args, **kwargs): + """Create a plugin route from a decorated function. + :param uri: endpoint at which the route will be accessible. + :type uri: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The exception function to use as the decorator + :rtype: fn + """ + if len(args) == 0 and callable(uri): + raise RuntimeError("Cannot use the @route decorator without " "arguments.") + kwargs.setdefault("methods", frozenset({"GET"})) + kwargs.setdefault("host", None) + kwargs.setdefault("strict_slashes", False) + kwargs.setdefault("stream", False) + kwargs.setdefault("name", None) + kwargs.setdefault("version", None) + kwargs.setdefault("ignore_body", False) + kwargs.setdefault("websocket", False) + kwargs.setdefault("subprotocols", None) + kwargs.setdefault("unquote", False) + kwargs.setdefault("static", False) + if SANIC_21_6_0 <= SANIC_VERSION: + kwargs.setdefault("version_prefix", "/v") + if SANIC_21_9_0 <= SANIC_VERSION: + kwargs.setdefault("error_format", None) + kwargs["with_context"] = True # This is the whole point of this plugin + plugin = self.plugin + reg = self.reg + + def wrapper(handler_f): + nonlocal plugin, reg + nonlocal uri, args, kwargs + return plugin._add_new_route(reg, uri, handler_f, *args, **kwargs) + + return wrapper + + def listener(self, event, *args, **kwargs): + """Create a listener from a decorated function. + :param event: Event to listen to. + :type event: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The function to use as the listener + :rtype: fn + """ + if len(args) == 1 and callable(args[0]): + raise RuntimeError( + "Cannot use the @listener decorator without " "arguments" + ) + kwargs["with_context"] = True # This is the whole point of this plugin + plugin = self.plugin + reg = self.reg + + def wrapper(listener_f): + nonlocal plugin, reg + nonlocal event, args, kwargs + return plugin._add_new_listener(reg, event, listener_f, *args, **kwargs) + + return wrapper + + def websocket(self, uri, *args, **kwargs): + """Create a websocket route from a decorated function + # Deprecated. Use @contextualize.route("/path", websocket=True) + """ + + kwargs["websocket"] = True + kwargs["with_context"] = True # This is the whole point of this plugin + + return self.route(uri, *args, **kwargs) + + +class Contextualize(SanicPlugin): + __slots__ = () + + AssociatedTuple = ContextualizeAssociated + + def _add_new_middleware(self, reg, middle_f, *args, **kwargs): + # A user should never call this directly. + # it should be called only by the AssociatedTuple + assert reg in self.registrations + (realm, p_name, url_prefix) = reg + context = self.get_context_from_realm(reg) + # This is how we add a new middleware _after_ the plugin is registered + m = FutureMiddleware(middle_f, args, kwargs) + realm._register_middleware_helper(m, realm, self, context) + return middle_f + + def _add_new_route(self, reg, uri, handler_f, *args, **kwargs): + # A user should never call this directly. + # it should be called only by the AssociatedTuple + assert reg in self.registrations + (realm, p_name, url_prefix) = reg + context = self.get_context_from_realm(reg) + # This is how we add a new route _after_ the plugin is registered + r = FutureRoute(handler_f, uri, args, kwargs) + realm._register_route_helper(r, realm, self, context, p_name, url_prefix) + return handler_f + + def _add_new_listener(self, reg, event, listener_f, *args, **kwargs): + # A user should never call this directly. + # it should be called only by the AssociatedTuple + assert reg in self.registrations + (realm, p_name, url_prefix) = reg + context = self.get_context_from_realm(reg) + # This is how we add a new listener _after_ the plugin is registered + realm._plugin_register_listener( + event, listener_f, self, context, *args, **kwargs + ) + return listener_f + + # Decorator + def middleware(self, *args, **kwargs): + """Decorate and register middleware + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The middleware function to use as the decorator + :rtype: fn + """ + kwargs.setdefault("priority", 5) + kwargs.setdefault("relative", None) + kwargs.setdefault("attach_to", None) + kwargs["with_context"] = True # This is the whole point of this plugin + if len(args) == 1 and callable(args[0]): + middle_f = args[0] + return super(Contextualize, self).middleware(middle_f, **kwargs) + + def wrapper(middle_f): + nonlocal self, args, kwargs + return super(Contextualize, self).middleware(*args, **kwargs)(middle_f) + + return wrapper + + # Decorator + def route(self, uri, *args, **kwargs): + """Create a plugin route from a decorated function. + :param uri: endpoint at which the route will be accessible. + :type uri: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The exception function to use as the decorator + :rtype: fn + """ + if len(args) == 0 and callable(uri): + raise RuntimeError("Cannot use the @route decorator without arguments.") + kwargs.setdefault("methods", frozenset({"GET"})) + kwargs.setdefault("host", None) + kwargs.setdefault("strict_slashes", False) + kwargs.setdefault("stream", False) + kwargs.setdefault("name", None) + kwargs.setdefault("version", None) + kwargs.setdefault("ignore_body", False) + kwargs.setdefault("websocket", False) + kwargs.setdefault("subprotocols", None) + kwargs.setdefault("unquote", False) + kwargs.setdefault("static", False) + kwargs["with_context"] = True # This is the whole point of this plugin + + def wrapper(handler_f): + nonlocal self, uri, args, kwargs + return super(Contextualize, self).route(uri, *args, **kwargs)(handler_f) + + return wrapper + + # Decorator + def listener(self, event, *args, **kwargs): + """Create a listener from a decorated function. + :param event: Event to listen to. + :type event: str + :param args: captures all of the positional arguments passed in + :type args: tuple(Any) + :param kwargs: captures the keyword arguments passed in + :type kwargs: dict(Any) + :return: The exception function to use as the listener + :rtype: fn + """ + if len(args) == 1 and callable(args[0]): + raise RuntimeError("Cannot use the @listener decorator without arguments") + kwargs["with_context"] = True # This is the whole point of this plugin + + def wrapper(listener_f): + nonlocal self, event, args, kwargs + return super(Contextualize, self).listener(event, *args, **kwargs)( + listener_f + ) + + return wrapper + + def websocket(self, uri, *args, **kwargs): + """Create a websocket route from a decorated function + # Deprecated. Use @contextualize.route("/path",websocket=True) + """ + + kwargs["websocket"] = True + kwargs["with_context"] = True # This is the whole point of this plugin + + return self.route(uri, *args, **kwargs) + + def __init__(self, *args, **kwargs): + super(Contextualize, self).__init__(*args, **kwargs) + + +instance = contextualize = Contextualize() diff --git a/backend/sanic_server/sanic_plugin_toolkit/realm.py b/backend/sanic_server/sanic_plugin_toolkit/realm.py new file mode 100644 index 000000000..dc99ead05 --- /dev/null +++ b/backend/sanic_server/sanic_plugin_toolkit/realm.py @@ -0,0 +1,1251 @@ +# -*- coding: utf-8 -*- +import importlib +import re +import sys +from asyncio import CancelledError +from collections import deque +from distutils.version import LooseVersion +from functools import partial, update_wrapper +from inspect import isawaitable, ismodule +from typing import Any, Dict +from uuid import uuid1 + +from ..sanic import Blueprint, Sanic +from ..sanic import __version__ as sanic_version +from ..sanic.exceptions import ServerError +from ..sanic.log import logger +from ..sanic.models.futures import FutureException as SanicFutureException +from ..sanic.models.futures import FutureListener as SanicFutureListener +from ..sanic.models.futures import FutureMiddleware as SanicFutureMiddleware +from ..sanic.models.futures import FutureRoute as SanicFutureRoute +from ..sanic.models.futures import FutureStatic as SanicFutureStatic + +try: + from ..sanic.response import BaseHTTPResponse +except ImportError: + from ..sanic.response import HTTPResponse as BaseHTTPResponse + +from ..sanic_plugin_toolkit.config import load_config_file +from ..sanic_plugin_toolkit.context import SanicContext +from ..sanic_plugin_toolkit.plugin import PluginRegistration, SanicPlugin + +module = sys.modules[__name__] +CONSTS: Dict[str, Any] = dict() +CONSTS["APP_CONFIG_INSTANCE_KEY"] = APP_CONFIG_INSTANCE_KEY = "__SPTK_INSTANCE" +CONSTS["SPTK_LOAD_INI_KEY"] = SPTK_LOAD_INI_KEY = "SPTK_LOAD_INI" +CONSTS["SPTK_INI_FILE_KEY"] = SPTK_INI_FILE_KEY = "SPTK_INI_FILE" +CONSTS["SANIC_19_12_0"] = SANIC_19_12_0 = LooseVersion("19.12.0") +CONSTS["SANIC_20_12_1"] = SANIC_20_12_1 = LooseVersion("20.12.1") +CONSTS["SANIC_21_3_0"] = SANIC_21_3_0 = LooseVersion("21.3.0") + +# Currently installed sanic version in this environment +SANIC_VERSION = LooseVersion(sanic_version) + +CRITICAL = 50 +ERROR = 40 +WARNING = 30 +INFO = 20 +DEBUG = 10 + +to_snake_case_first_cap_re = re.compile("(.)([A-Z][a-z]+)") +to_snake_case_all_cap_re = re.compile("([a-z0-9])([A-Z])") + + +def to_snake_case(name): + """ + Simple helper function. + Changes PascalCase, camelCase, and CAPS_CASE to snake_case. + :param name: variable name to convert + :type name: str + :return: the name of the variable, converted to snake_case + :rtype: str + """ + s1 = to_snake_case_first_cap_re.sub(r"\1_\2", name) + return to_snake_case_all_cap_re.sub(r"\1_\2", s1).lower() + + +class SanicPluginRealm(object): + __slots__ = ( + "_running", + "_app", + "_plugin_names", + "_contexts", + "_pre_request_middleware", + "_post_request_middleware", + "_pre_response_middleware", + "_post_response_middleware", + "_cleanup_middleware", + "_loop", + "__weakref__", + ) + + def log(self, level, message, reg=None, *args, **kwargs): + if reg is not None: + (_, n, _) = reg + message = "{:s}: {:s}".format(str(n), str(message)) + return logger.log(level, message, *args, **kwargs) + + def debug(self, message, reg=None, *args, **kwargs): + return self.log(DEBUG, message=message, reg=reg, *args, **kwargs) + + def info(self, message, reg=None, *args, **kwargs): + return self.log(INFO, message=message, reg=reg, *args, **kwargs) + + def warning(self, message, reg=None, *args, **kwargs): + return self.log(WARNING, message=message, reg=reg, *args, **kwargs) + + def error(self, message, reg=None, *args, **kwargs): + return self.log(ERROR, message=message, reg=reg, *args, **kwargs) + + def critical(self, message, reg=None, *args, **kwargs): + return self.log(CRITICAL, message=message, reg=reg, *args, **kwargs) + + def url_for(self, view_name, *args, reg=None, **kwargs): + if reg is not None: + (_, name, url_prefix) = reg + view_name = "{}.{}".format(name, view_name) + app = self._app + if app is None: + return None + if isinstance(app, Blueprint): + bp = app + view_name = "{}.{}".format(app.name, view_name) + return [a.url_for(view_name, *args, **kwargs) for a in bp.apps] + return app.url_for(view_name, *args, **kwargs) + + def _get_realm_plugin(self, plugin): + if isinstance(plugin, str): + if plugin not in self._plugin_names: + self.warning("Cannot lookup that plugin by its name.") + return None + name = plugin + else: + reg = plugin.find_plugin_registration(self) + (_, name, _) = reg + _p_context = self._plugins_context + try: + _plugin_reg = _p_context[name] + except KeyError as k: + self.warning("Plugin not found!") + raise k + return _plugin_reg + + def get_plugin_inst(self, plugin): + _plugin_reg = self._get_realm_plugin(plugin) + try: + inst = _plugin_reg["instance"] + except KeyError: + self.warning("Plugin is not registered properly") + inst = None + return inst + + def get_plugin_assoc(self, plugin): + _plugin_reg = self._get_realm_plugin(plugin) + p = _plugin_reg["instance"] + reg = _plugin_reg["reg"] + associated_tuple = p.AssociatedTuple + return associated_tuple(p, reg) + + def register_plugin(self, plugin, *args, name=None, skip_reg=False, **kwargs): + assert not self._running, ( + "Cannot add, remove, or change plugins " + "after the App has started serving." + ) + assert plugin, ( + "Plugin must be a valid type! Do not pass in `None` " "or `False`" + ) + + if isinstance(plugin, type): + # We got passed in a Class. That's ok, we can handle this! + module_name = getattr(plugin, "__module__") + class_name = getattr(plugin, "__name__") + lower_class = to_snake_case(class_name) + try: + mod = importlib.import_module(module_name) + try: + plugin = getattr(mod, lower_class) + except AttributeError: + plugin = mod # try the module-based resolution next + except ImportError: + raise + + if ismodule(plugin): + # We got passed in a module. That's ok, we can handle this! + try: # look for '.instance' on the module + plugin = getattr(plugin, "instance") + assert plugin is not None + except (AttributeError, AssertionError): + # now look for the same name, + # like my_module.my_module on the module. + try: + plugin_module_name = getattr(plugin, "__name__") + assert plugin_module_name and len(plugin_module_name) > 0 + plugin_module_name = plugin_module_name.split(".")[-1] + plugin = getattr(plugin, plugin_module_name) + assert plugin is not None + except (AttributeError, AssertionError): + raise RuntimeError("Cannot import this module as a Sanic Plugin.") + + assert isinstance( + plugin, SanicPlugin + ), "Plugin must be derived from SanicPlugin" + if name is None: + try: + name = str(plugin.__class__.__name__) + assert name is not None + except (AttributeError, AssertionError, ValueError, KeyError): + logger.warning( + "Cannot determine a name for {}, using UUID.".format(repr(plugin)) + ) + name = str(uuid1(None, None)) + assert isinstance(name, str), "Plugin name must be a python unicode string!" + + associated_tuple = plugin.AssociatedTuple + + if name in self._plugin_names: # we're already registered on this Realm + reg = plugin.find_plugin_registration(self) + assoc = associated_tuple(plugin, reg) + raise ValueError("Plugin {:s} is already registered!".format(name), assoc) + if plugin.is_registered_in_realm(self): + raise RuntimeError( + "Plugin already shows it is registered to this " + "sanic_plugin_toolkit, maybe under a different name?" + ) + self._plugin_names.add(name) + shared_context = self.shared_context + self._contexts[name] = context = SanicContext( + self, shared_context, {"shared": shared_context} + ) + _p_context = self._plugins_context + _plugin_reg = _p_context.get(name, None) + if _plugin_reg is None: + _p_context[name] = _plugin_reg = _p_context.create_child_context() + _plugin_reg["name"] = name + _plugin_reg["context"] = context + if skip_reg: + dummy_reg = PluginRegistration( + realm=self, plugin_name=name, url_prefix=None + ) + context["log"] = partial(self.log, reg=dummy_reg) + context["url_for"] = partial(self.url_for, reg=dummy_reg) + plugin.registrations.add(dummy_reg) + # This indicates the plugin is not registered on the app + _plugin_reg["instance"] = None + _plugin_reg["reg"] = None + return associated_tuple(plugin, dummy_reg) + if _plugin_reg.get("instance", False): + raise RuntimeError( + "The plugin we are trying to register already " "has a known instance!" + ) + reg = self._register_helper( + plugin, context, *args, _realm=self, _plugin_name=name, **kwargs + ) + _plugin_reg["instance"] = plugin + _plugin_reg["reg"] = reg + return associated_tuple(plugin, reg) + + @staticmethod + def _register_exception_helper(e, _realm, plugin, context): + return ( + _realm._plugin_register_bp_exception( + e.handler, plugin, context, *e.exceptions, **e.kwargs + ) + if isinstance(_realm._app, Blueprint) + else _realm._plugin_register_app_exception( + e.handler, plugin, context, *e.exceptions, **e.kwargs + ) + ) + + @staticmethod + def _register_listener_helper(event, listener, _realm, plugin, context, **kwargs): + return ( + _realm._plugin_register_bp_listener( + event, listener, plugin, context, **kwargs + ) + if isinstance(_realm._app, Blueprint) + else _realm._plugin_register_app_listener( + event, listener, plugin, context, **kwargs + ) + ) + + @staticmethod + def _register_middleware_helper(m, _realm, plugin, context): + return _realm._plugin_register_middleware( + m.middleware, plugin, context, *m.args, **m.kwargs + ) + + @staticmethod + def _register_route_helper(r, _realm, plugin, context, _p_name, _url_prefix): + # Prepend the plugin URI prefix if available + uri = _url_prefix + r.uri if _url_prefix else r.uri + uri = uri[1:] if uri.startswith("//") else uri + # attach the plugin name to the handler so that it can be + # prefixed properly in the router + _app = _realm._app + handler_name = str(r.handler.__name__) + plugin_prefix = _p_name + "." + kwargs = r.kwargs + if isinstance(_app, Blueprint): + # blueprint always handles adding __blueprintname__ + # So we identify ourselves here a different way. + # r.handler.__name__ = "{}.{}".format(_p_name, handler_name) + if "name" not in kwargs or kwargs["name"] is None: + kwargs["name"] = plugin_prefix + handler_name + elif not kwargs["name"].startswith(plugin_prefix): + kwargs["name"] = plugin_prefix + kwargs["name"] + _realm._plugin_register_bp_route( + r.handler, plugin, context, uri, *r.args, **kwargs + ) + else: + if "name" not in kwargs or kwargs["name"] is None: + kwargs["name"] = plugin_prefix + handler_name + elif not kwargs["name"].startswith(plugin_prefix): + kwargs["name"] = plugin_prefix + kwargs["name"] + _realm._plugin_register_app_route( + r.handler, plugin, context, uri, *r.args, **kwargs + ) + + @staticmethod + def _register_static_helper(s, _realm, plugin, context, _p_name, _url_prefix): + # attach the plugin name to the static route so that it can be + # prefixed properly in the router + kwargs = s.kwargs + name = kwargs.pop("name", "static") + plugin_prefix = _p_name + "." + _app = _realm._app + if not name.startswith(plugin_prefix): + name = plugin_prefix + name + # Prepend the plugin URI prefix if available + uri = _url_prefix + s.uri if _url_prefix else s.uri + uri = uri[1:] if uri.startswith("//") else uri + kwargs["name"] = name + return ( + _realm._plugin_register_bp_static( + uri, s.file_or_dir, plugin, context, *s.args, **kwargs + ) + if isinstance(_app, Blueprint) + else _realm._plugin_register_app_static( + uri, s.file_or_dir, plugin, context, *s.args, **kwargs + ) + ) + + @staticmethod + def _register_helper( + plugin, + context, + *args, + _realm=None, + _plugin_name=None, + _url_prefix=None, + **kwargs, + ): + error_str = ( + "Plugin must be initialised using the " "Sanic Plugin Toolkit PluginRealm." + ) + assert _realm is not None, error_str + assert _plugin_name is not None, error_str + _app = _realm._app + assert _app is not None, error_str + + reg = PluginRegistration( + realm=_realm, plugin_name=_plugin_name, url_prefix=_url_prefix + ) + context["log"] = partial(_realm.log, reg=reg) + context["url_for"] = partial(_realm.url_for, reg=reg) + continue_flag = plugin.on_before_registered(context, *args, **kwargs) + if continue_flag is False: + return plugin + + # Routes + [ + _realm._register_route_helper( + r, _realm, plugin, context, _plugin_name, _url_prefix + ) + for r in plugin._routes + ] + + # Websocket routes + # These are deprecated and should be handled in the _routes_ list above. + [ + _realm._register_route_helper( + w, _realm, plugin, context, _plugin_name, _url_prefix + ) + for w in plugin._ws + ] + + # Static routes + [ + _realm._register_static_helper( + s, _realm, plugin, context, _plugin_name, _url_prefix + ) + for s in plugin._static + ] + + # Middleware + [ + _realm._register_middleware_helper(m, _realm, plugin, context) + for m in plugin._middlewares + ] + + # Exceptions + [ + _realm._register_exception_helper(e, _realm, plugin, context) + for e in plugin._exceptions + ] + + # Listeners + for event, listeners in plugin._listeners.items(): + for listener in listeners: + if isinstance(listener, tuple): + listener, lkw = listener + else: + lkw = {} + _realm._register_listener_helper( + event, listener, _realm, plugin, context, **lkw + ) + + # # this should only ever run once! + plugin.registrations.add(reg) + plugin.on_registered(context, reg, *args, **kwargs) + + return reg + + def _plugin_register_app_route( + self, + r_handler, + plugin, + context, + uri, + *args, + name=None, + with_context=False, + **kwargs, + ): + if with_context: + r_handler = update_wrapper(partial(r_handler, context=context), r_handler) + fr = SanicFutureRoute(r_handler, uri, name=name, **kwargs) + routes = self._app._apply_route(fr) + return routes + + def _plugin_register_bp_route( + self, + r_handler, + plugin, + context, + uri, + *args, + name=None, + with_context=False, + **kwargs, + ): + bp = self._app + if with_context: + r_handler = update_wrapper(partial(r_handler, context=context), r_handler) + # __blueprintname__ gets added in the register() routine + # When app is a blueprint, it doesn't register right away, it happens + # in the blueprint.register() routine. + r_handler = bp.route(uri, *args, name=name, **kwargs)(r_handler) + return r_handler + + def _plugin_register_app_static( + self, uri, file_or_dir, plugin, context, *args, **kwargs + ): + fs = SanicFutureStatic(uri, file_or_dir, **kwargs) + return self._app._apply_static(fs) + + def _plugin_register_bp_static( + self, uri, file_or_dir, plugin, context, *args, **kwargs + ): + bp = self._app + return bp.static(uri, file_or_dir, *args, **kwargs) + + def _plugin_register_app_exception( + self, handler, plugin, context, *exceptions, with_context=False, **kwargs + ): + if with_context: + handler = update_wrapper(partial(handler, context=context), handler) + fe = SanicFutureException(handler, list(exceptions)) + return self._app._apply_exception_handler(fe) + + def _plugin_register_bp_exception( + self, handler, plugin, context, *exceptions, with_context=False, **kwargs + ): + if with_context: + handler = update_wrapper(partial(handler, context=context), handler) + return self._app.exception(*exceptions)(handler) + + def _plugin_register_app_listener( + self, event, listener, plugin, context, *args, with_context=False, **kwargs + ): + if with_context: + listener = update_wrapper(partial(listener, context=context), listener) + fl = SanicFutureListener(listener, event) + return self._app._apply_listener(fl) + + def _plugin_register_bp_listener( + self, event, listener, plugin, context, *args, with_context=False, **kwargs + ): + if with_context: + listener = update_wrapper(partial(listener, context=context), listener) + bp = self._app + return bp.listener(event)(listener) + + def _plugin_register_middleware( + self, + middleware, + plugin, + context, + *args, + priority=5, + relative=None, + attach_to=None, + with_context=False, + **kwargs, + ): + assert isinstance(priority, int), "Priority must be an integer!" + assert 0 <= priority <= 9, ( + "Priority must be between 0 and 9 (inclusive), " + "0 is highest priority, 9 is lowest." + ) + assert isinstance( + plugin, SanicPlugin + ), "Plugin middleware only works with a plugin from SPTK." + if len(args) > 0 and isinstance(args[0], str) and attach_to is None: + # for backwards/sideways compatibility with Sanic, + # the first arg is interpreted as 'attach_to' + attach_to = args[0] + if with_context: + middleware = update_wrapper( + partial(middleware, context=context), middleware + ) + if attach_to is None or attach_to == "request": + insert_order = len(self._pre_request_middleware) + len( + self._post_request_middleware + ) + priority_middleware = (priority, insert_order, middleware) + if relative is None or relative == "pre": + # plugin request middleware default to pre-app middleware + self._pre_request_middleware.append(priority_middleware) + else: # post + assert ( + relative == "post" + ), "A request middleware must have relative = pre or post" + self._post_request_middleware.append(priority_middleware) + elif attach_to == "cleanup": + insert_order = len(self._cleanup_middleware) + priority_middleware = (priority, insert_order, middleware) + assert ( + relative is None + ), "A cleanup middleware cannot have relative pre or post" + self._cleanup_middleware.append(priority_middleware) + else: # response + assert ( + attach_to == "response" + ), "A middleware kind must be either request or response." + insert_order = len(self._post_response_middleware) + len( + self._pre_response_middleware + ) + # so they are sorted backwards + priority_middleware = (0 - priority, 0.0 - insert_order, middleware) + if relative is None or relative == "post": + # plugin response middleware default to post-app middleware + self._post_response_middleware.append(priority_middleware) + else: # pre + assert ( + relative == "pre" + ), "A response middleware must have relative = pre or post" + self._pre_response_middleware.append(priority_middleware) + return middleware + + @property + def _plugins_context(self): + try: + return self._contexts["_plugins"] + except (AttributeError, KeyError): + raise RuntimeError("PluginRealm does not have a valid plugins context!") + + @property + def shared_context(self): + try: + return self._contexts["shared"] + except (AttributeError, KeyError): + raise RuntimeError("PluginRealm does not have a valid shared context!") + + def get_context(self, context=None): + context = context or "shared" + try: + _context = self._contexts[context] + except KeyError: + logger.error("Context {:s} does not exist!") + return None + return _context + + def get_from_context(self, item, context=None): + context = context or "shared" + try: + _context = self._contexts[context] + except KeyError: + logger.warning( + "Context {:s} does not exist! Falling back to shared context".format( + context + ) + ) + _context = self._contexts["shared"] + return _context.__getitem__(item) + + def create_temporary_request_context(self, request): + request_hash = id(request) + shared_context = self.shared_context + shared_requests_dict = shared_context.get("request", False) + if not shared_requests_dict: + new_ctx = SanicContext(self, None, {"id": "shared request contexts"}) + shared_context["request"] = shared_requests_dict = new_ctx + shared_request_ctx = shared_requests_dict.get(request_hash, None) + if shared_request_ctx: + # Somehow, we've already created a temporary context for this request. + return shared_request_ctx + shared_requests_dict[request_hash] = shared_request_ctx = SanicContext( + self, + None, + { + "request": request, + "id": "shared request context for request {}".format(id(request)), + }, + ) + for name, _p in self._plugins_context.items(): + if not ( + isinstance(_p, SanicContext) + and "instance" in _p + and isinstance(_p["instance"], SanicPlugin) + ): + continue + if not ("context" in _p and isinstance(_p["context"], SanicContext)): + continue + _p_context = _p["context"] + if "request" not in _p_context: + _p_context["request"] = p_request = SanicContext( + self, None, {"id": "private request contexts"} + ) + else: + p_request = _p_context.request + p_request[request_hash] = SanicContext( + self, + None, + { + "request": request, + "id": "private request context for {} on request {}".format( + name, id(request) + ), + }, + ) + return shared_request_ctx + + def delete_temporary_request_context(self, request): + request_hash = id(request) + shared_context = self.shared_context + try: + _shared_requests_dict = shared_context["request"] + del _shared_requests_dict[request_hash] + except KeyError: + pass + for name, _p in self._plugins_context.items(): + if not ( + isinstance(_p, SanicContext) + and "instance" in _p + and isinstance(_p["instance"], SanicPlugin) + ): + continue + if not ("context" in _p and isinstance(_p["context"], SanicContext)): + continue + _p_context = _p["context"] + try: + _p_requests_dict = _p_context["request"] + del _p_requests_dict[request_hash] + except KeyError: + pass + + async def _handle_request( + self, real_handle, request, write_callback, stream_callback + ): + cancelled = False + try: + _ = await real_handle(request, write_callback, stream_callback) + except CancelledError as ce: + # We still want to run cleanup middleware, even if cancelled + cancelled = ce + except BaseException as be: + logger.error( + "SPTK caught an error that should have been caught by Sanic response handler." + ) + logger.error(str(be)) + raise + finally: + # noinspection PyUnusedLocal + _ = await self._run_cleanup_middleware(request) # noqa: F841 + if cancelled: + raise cancelled + + async def _handle_request_21_03(self, real_handle, request): + cancelled = False + try: + _ = await real_handle(request) + except CancelledError as ce: + # We still want to run cleanup middleware, even if cancelled + cancelled = ce + except BaseException as be: + logger.error( + "SPTK caught an error that should have been caught by Sanic response handler." + ) + logger.error(str(be)) + raise + finally: + # noinspection PyUnusedLocal + _ = await self._run_cleanup_middleware(request) # noqa: F841 + if cancelled: + raise cancelled + + def wrap_handle_request(self, app, new_handler=None): + if new_handler is None: + new_handler = self._handle_request + orig_handle_request = app.handle_request + return update_wrapper(partial(new_handler, orig_handle_request), new_handler) + + async def _run_request_middleware_18_12(self, request): + if not self._running: + raise ServerError( + "Toolkit processing a request before App server is started." + ) + self.create_temporary_request_context(request) + if self._pre_request_middleware: + for (_pri, _ins, middleware) in self._pre_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + if self._app.request_middleware: + for middleware in self._app.request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + if self._post_request_middleware: + for (_pri, _ins, middleware) in self._post_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + return None + + async def _run_request_middleware_19_12(self, request, request_name=None): + if not self._running: + # Test_mode is only present on Sanic 20.9+ + test_mode = getattr(self._app, "test_mode", False) + if self._app.asgi: + if test_mode: + # We're deliberately in Test Mode, we don't expect + # Server events to have been kicked off yet. + pass + else: + # An ASGI app can receive requests from HTTPX even if + # the app is not booted yet. + self.warning( + "Unexpected ASGI request. Forcing Toolkit " + "into running mode without a server." + ) + self._on_server_start(request.app, request.transport.loop) + elif test_mode: + self.warning( + "Unexpected test-mode request. Forcing Toolkit " + "into running mode without a server." + ) + self._on_server_start(request.app, request.transport.loop) + else: + raise RuntimeError( + "Sanic Plugin Toolkit received a request before Sanic server is started." + ) + self.create_temporary_request_context(request) + if self._pre_request_middleware: + for (_pri, _ins, middleware) in self._pre_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + app = self._app + named_middleware = app.named_request_middleware.get(request_name, deque()) + applicable_middleware = app.request_middleware + named_middleware + if applicable_middleware: + for middleware in applicable_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + if self._post_request_middleware: + for (_pri, _ins, middleware) in self._post_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + return None + + async def _run_request_middleware_21_03(self, request, request_name=None): + if not self._running: + test_mode = self._app.test_mode + if self._app.asgi: + if test_mode: + # We're deliberately in Test Mode, we don't expect + # Server events to have been kicked off yet. + pass + else: + # An ASGI app can receive requests from HTTPX even if + # the app is not booted yet. + self.warning( + "Unexpected ASGI request. Forcing Toolkit " + "into running mode without a server." + ) + self._on_server_start(request.app, request.transport.loop) + elif test_mode: + self.warning( + "Unexpected test-mode request. Forcing Toolkit " + "into running mode without a server." + ) + self._on_server_start(request.app, request.transport.loop) + else: + raise RuntimeError( + "Sanic Plugin Toolkit received a request before Sanic server is started." + ) + + shared_req_context = self.create_temporary_request_context(request) + realm_request_middleware_started = shared_req_context.get( + "realm_request_middleware_started", False + ) + if realm_request_middleware_started: + return None + shared_req_context["realm_request_middleware_started"] = True + if self._pre_request_middleware: + for (_pri, _ins, middleware) in self._pre_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + app = self._app + named_middleware = app.named_request_middleware.get(request_name, deque()) + applicable_middleware = app.request_middleware + named_middleware + # request.request_middleware_started is meant as a stop-gap solution + # until RFC 1630 is adopted + if applicable_middleware and not request.request_middleware_started: + request.request_middleware_started = True + for middleware in applicable_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + if self._post_request_middleware: + for (_pri, _ins, middleware) in self._post_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + return None + + async def _run_response_middleware_18_12(self, request, response): + if self._pre_response_middleware: + for (_pri, _ins, middleware) in self._pre_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + if self._app.response_middleware: + for middleware in self._app.response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + if self._post_response_middleware: + for (_pri, _ins, middleware) in self._post_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + return response + + async def _run_response_middleware_19_12( + self, request, response, request_name=None + ): + if self._pre_response_middleware: + for (_pri, _ins, middleware) in self._pre_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + app = self._app + named_middleware = app.named_response_middleware.get(request_name, deque()) + applicable_middleware = app.response_middleware + named_middleware + if applicable_middleware: + for middleware in applicable_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + if self._post_response_middleware: + for (_pri, _ins, middleware) in self._post_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + break + return response + + async def _run_response_middleware_21_03( + self, request, response, request_name=None + ): + if self._pre_response_middleware: + for (_pri, _ins, middleware) in self._pre_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break + app = self._app + named_middleware = app.named_response_middleware.get(request_name, deque()) + applicable_middleware = app.response_middleware + named_middleware + if applicable_middleware: + for middleware in applicable_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break + if self._post_response_middleware: + for (_pri, _ins, middleware) in self._post_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + if isinstance(response, BaseHTTPResponse): + response = request.stream.respond(response) + break + return response + + async def _run_cleanup_middleware(self, request): + return_this = None + if self._cleanup_middleware: + for (_pri, _ins, middleware) in self._cleanup_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return_this = response + break + self.delete_temporary_request_context(request) + return return_this + + def _on_server_start(self, app, loop): + if not isinstance(self._app, Blueprint): + assert self._app == app, ( + "Sanic Plugins Framework is not assigned to the correct " "Sanic App!" + ) + if self._running: + # during testing, this will be called _many_ times. + return # Ignore if this is already called. + self._loop = loop + + # sort and freeze these + self._pre_request_middleware = tuple(sorted(self._pre_request_middleware)) + self._post_request_middleware = tuple(sorted(self._post_request_middleware)) + self._pre_response_middleware = tuple(sorted(self._pre_response_middleware)) + self._post_response_middleware = tuple(sorted(self._post_response_middleware)) + self._cleanup_middleware = tuple(sorted(self._cleanup_middleware)) + self._running = True + + def _on_after_server_start(self, app, loop): + if not self._running: + # Missed before_server_start event + # Run startup now! + self._on_server_start(app, loop) + + async def _startup(self, app, real_startup): + _ = await real_startup() + # Patch app _after_ Touchup is done. + self._patch_app(app) + + def _patch_app(self, app): + # monkey patch the app! + + if SANIC_21_3_0 <= SANIC_VERSION: + app.handle_request = self.wrap_handle_request( + app, self._handle_request_21_03 + ) + app._run_request_middleware = self._run_request_middleware_21_03 + app._run_response_middleware = self._run_response_middleware_21_03 + setattr(app.ctx, APP_CONFIG_INSTANCE_KEY, self) + else: + if SANIC_19_12_0 <= SANIC_VERSION: + app.handle_request = self.wrap_handle_request(app) + app._run_request_middleware = self._run_request_middleware_19_12 + app._run_response_middleware = self._run_response_middleware_19_12 + else: + app.handle_request = self.wrap_handle_request(app) + app._run_request_middleware = self._run_request_middleware_18_12 + app._run_response_middleware = self._run_response_middleware_18_12 + app.config[APP_CONFIG_INSTANCE_KEY] = self + + def _patch_blueprint(self, bp): + # monkey patch the blueprint! + # Caveat! We cannot take over the sanic middleware runner when + # app is a blueprint. We will do this a different way. + _spf = self + + async def run_bp_pre_request_mw(request): + nonlocal _spf + _spf.create_temporary_request_context(request) + if _spf._pre_request_middleware: + for (_pri, _ins, middleware) in _spf._pre_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + + async def run_bp_post_request_mw(request): + nonlocal _spf + if _spf._post_request_middleware: + for (_pri, _ins, middleware) in _spf._post_request_middleware: + response = middleware(request) + if isawaitable(response): + response = await response + if response: + return response + + async def run_bp_pre_response_mw(request, response): + nonlocal _spf + altered = False + if _spf._pre_response_middleware: + for (_pri, _ins, middleware) in _spf._pre_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + altered = True + break + if altered: + return response + + async def run_bp_post_response_mw(request, response): + nonlocal _spf + altered = False + if _spf._post_response_middleware: + for (_pri, _ins, middleware) in _spf._post_response_middleware: + _response = middleware(request, response) + if isawaitable(_response): + _response = await _response + if _response: + response = _response + altered = True + break + if self._cleanup_middleware: + for (_pri, _ins, middleware) in self._cleanup_middleware: + response2 = middleware(request) + if isawaitable(response2): + response2 = await response2 + if response2: + break + _spf.delete_temporary_request_context(request) + if altered: + return response + + def bp_register(bp_self, orig_register, app, options): + # from ..sanic.blueprints import FutureMiddleware as BPFutureMW + pre_request = SanicFutureMiddleware(run_bp_pre_request_mw, "request") + post_request = SanicFutureMiddleware(run_bp_post_request_mw, "request") + pre_response = SanicFutureMiddleware(run_bp_pre_response_mw, "response") + post_response = SanicFutureMiddleware(run_bp_post_response_mw, "response") + # this order is very important. Don't change it. It is correct. + bp_self._future_middleware.insert(0, post_response) + bp_self._future_middleware.insert(0, pre_request) + bp_self._future_middleware.append(post_request) + bp_self._future_middleware.append(pre_response) + + orig_register(app, options) + + if SANIC_21_3_0 <= SANIC_VERSION: + _slots = list(Blueprint.__fake_slots__) + _slots.extend(["register"]) + Sanic.__fake_slots__ = tuple(_slots) + bp.register = update_wrapper( + partial(bp_register, bp, bp.register), bp.register + ) + setattr(bp.ctx, APP_CONFIG_INSTANCE_KEY, self) + else: + bp.register = update_wrapper( + partial(bp_register, bp, bp.register), bp.register + ) + setattr(bp, APP_CONFIG_INSTANCE_KEY, self) + + @classmethod + def _recreate(cls, app): + self = super(SanicPluginRealm, cls).__new__(cls) + self._running = False + self._app = app + self._loop = None + self._plugin_names = set() + # these deques get replaced with frozen tuples at runtime + self._pre_request_middleware = deque() + self._post_request_middleware = deque() + self._pre_response_middleware = deque() + self._post_response_middleware = deque() + self._cleanup_middleware = deque() + self._contexts = SanicContext(self, None) + self._contexts["shared"] = SanicContext(self, None, {"app": app}) + self._contexts["_plugins"] = SanicContext( + self, None, {"sanic_plugin_toolkit": self} + ) + return self + + def __new__(cls, app, *args, **kwargs): + assert app, "Plugin Realm must be given a valid Sanic App to work with." + assert isinstance(app, Sanic) or isinstance(app, Blueprint), ( + "PluginRealm only works with Sanic Apps or Blueprints. " + "Please pass in an app instance to the Realm constructor." + ) + # An app _must_ only have one sanic_plugin_toolkit instance associated with it. + # If there is already one registered on the app, return that one. + try: + instance = getattr(app.ctx, APP_CONFIG_INSTANCE_KEY) + assert isinstance( + instance, cls + ), "This app is already registered to a different type of Sanic Plugin Realm!" + return instance + except (AttributeError, LookupError): + # App doesn't have .ctx or key is not present + try: + instance = app.config[APP_CONFIG_INSTANCE_KEY] + assert isinstance( + instance, cls + ), "This app is already registered to a different type of Sanic Plugin Realm!" + return instance + except AttributeError: # app must then be a blueprint + try: + instance = getattr(app, APP_CONFIG_INSTANCE_KEY) + assert isinstance( + instance, cls + ), "This Blueprint is already registered to a different type of Sanic Plugin Realm!" + return instance + except AttributeError: + pass + except LookupError: + pass + self = cls._recreate(app) + if isinstance(app, Blueprint): + bp = app + self._patch_blueprint(bp) + bp.listener("before_server_start")(self._on_server_start) + bp.listener("after_server_start")(self._on_after_server_start) + else: + if hasattr(Sanic, "__fake_slots__"): + _slots = list(Sanic.__fake_slots__) + _slots.extend( + [ + "_startup", + "handle_request", + "_run_request_middleware", + "_run_response_middleware", + ] + ) + Sanic.__fake_slots__ = tuple(_slots) + if hasattr(app, "_startup"): + # We can wrap startup, to patch _after_ Touchup is done + app._startup = update_wrapper( + partial(self._startup, app, app._startup), app._startup + ) + else: + self._patch_app(app) + app.listener("before_server_start")(self._on_server_start) + app.listener("after_server_start")(self._on_after_server_start) + config = getattr(app, "config", None) + if config: + load_ini = config.get(SPTK_LOAD_INI_KEY, True) + if load_ini: + ini_file = config.get(SPTK_INI_FILE_KEY, "sptk.ini") + try: + load_config_file(self, app, ini_file) + except FileNotFoundError: + pass + return self + + def __init__(self, *args, **kwargs): + args = list(args) # tuple is not mutable. Change it to a list. + if len(args) > 0: + args.pop(0) # remove 'app' arg + assert ( + self._app and self._contexts + ), "Sanic Plugin Realm was not initialized correctly." + assert len(args) < 1, "Unexpected arguments passed to the Sanic Plugin Realm." + assert ( + len(kwargs) < 1 + ), "Unexpected keyword arguments passed to the SanicPluginRealm." + super(SanicPluginRealm, self).__init__(*args, **kwargs) + + def __getstate__(self): + if self._running: + raise RuntimeError( + "Cannot call __getstate__ on an SPTK app that is already running." + ) + state_dict = {} + for s in SanicPluginRealm.__slots__: + if s in ("_running", "_loop"): + continue + state_dict[s] = getattr(self, s) + return state_dict + + def __setstate__(self, state): + running = getattr(self, "_running", False) + if running: + raise RuntimeError( + "Cannot call __setstate__ on an SPTK app that is already running." + ) + for s, v in state.items(): + if s in ("_running", "_loop"): + continue + if s == "__weakref__": + if v is None: + continue + else: + raise NotImplementedError("Setting weakrefs on SPTK PluginRealm") + setattr(self, s, v) + + def __reduce__(self): + if self._running: + raise RuntimeError( + "Cannot pickle a SPTK PluginRealm App after it has started running!" + ) + state_dict = self.__getstate__() + app = state_dict.pop("_app") + return SanicPluginRealm._recreate, (app,), state_dict diff --git a/backend/sanic_server/sanic_routing/LICENSE b/backend/sanic_server/sanic_routing/LICENSE new file mode 100644 index 000000000..3baabded1 --- /dev/null +++ b/backend/sanic_server/sanic_routing/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Sanic Community Organization + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/backend/sanic_server/sanic_routing/__init__.py b/backend/sanic_server/sanic_routing/__init__.py new file mode 100644 index 000000000..a64dc391b --- /dev/null +++ b/backend/sanic_server/sanic_routing/__init__.py @@ -0,0 +1,6 @@ +from .group import RouteGroup +from .route import Route +from .router import BaseRouter + +__version__ = "21.12.0" +__all__ = ("BaseRouter", "Route", "RouteGroup") diff --git a/backend/sanic_server/sanic_routing/exceptions.py b/backend/sanic_server/sanic_routing/exceptions.py new file mode 100644 index 000000000..9b9b1d379 --- /dev/null +++ b/backend/sanic_server/sanic_routing/exceptions.py @@ -0,0 +1,47 @@ +from typing import Optional, Set + + +class BaseException(Exception): + ... + + +class NotFound(BaseException): + def __init__( + self, + message: str = "Not Found", + path: Optional[str] = None, + ): + super().__init__(message) + self.path = path + + +class BadMethod(BaseException): + ... + + +class NoMethod(BaseException): + def __init__( + self, + message: str = "Method does not exist", + method: Optional[str] = None, + allowed_methods: Optional[Set[str]] = None, + ): + super().__init__(message) + self.method = method + self.allowed_methods = allowed_methods + + +class FinalizationError(BaseException): + ... + + +class InvalidUsage(BaseException): + ... + + +class RouteExists(BaseException): + ... + + +class ParameterNameConflicts(BaseException): + ... diff --git a/backend/sanic_server/sanic_routing/group.py b/backend/sanic_server/sanic_routing/group.py new file mode 100644 index 000000000..057dd00af --- /dev/null +++ b/backend/sanic_server/sanic_routing/group.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import FrozenSet, List, Optional, Sequence, Tuple + +from ..sanic_routing.route import Requirements, Route +from ..sanic_routing.utils import Immutable +from .exceptions import InvalidUsage, RouteExists + + +class RouteGroup: + methods_index: Immutable + passthru_properties = ( + "labels", + "params", + "parts", + "path", + "pattern", + "raw_path", + "regex", + "router", + "segments", + "strict", + "unquote", + "uri", + ) + + #: The _reconstructed_ path after the Route has been normalized. + #: Does not contain preceding ``/`` (see also + #: :py:attr:`uri`) + path: str + + #: A regex version of the :py:attr:`~sanic_routing.route.Route.path` + pattern: Optional[str] + + #: Whether the route requires regular expression evaluation + regex: bool + + #: The raw version of the path exploded (see also + #: :py:attr:`segments`) + parts: Tuple[str, ...] + + #: Same as :py:attr:`parts` except + #: generalized so that any dynamic parts do not + #: include param keys since they have no impact on routing. + segments: Tuple[str, ...] + + #: Whether the route should be matched with strict evaluation + strict: bool + + #: Whether the route should be unquoted after matching if (for example) it + #: is suspected to contain non-URL friendly characters + unquote: bool + + #: Since :py:attr:`path` does NOT + #: include a preceding '/', this adds it back. + uri: str + + def __init__(self, *routes) -> None: + if len(set(route.parts for route in routes)) > 1: + raise InvalidUsage("Cannot group routes with differing paths") + + if any(routes[-1].strict != route.strict for route in routes): + raise InvalidUsage("Cannot group routes with differing strictness") + + route_list = list(routes) + route_list.pop() + + self._routes = routes + self.pattern_idx = 0 + + def __str__(self): + display = ( + f"path={self.path or self.router.delimiter} len={len(self.routes)}" + ) + return f"<{self.__class__.__name__}: {display}>" + + def __repr__(self) -> str: + return str(self) + + def __iter__(self): + return iter(self.routes) + + def __getitem__(self, key): + return self.routes[key] + + def __getattr__(self, key): + # There are a number of properties that all of the routes in the group + # share in common. We pass thrm through to make them available + # on the RouteGroup, and then cache them so that they are permanent. + if key in self.passthru_properties: + value = getattr(self[0], key) + setattr(self, key, value) + return value + + raise AttributeError(f"RouteGroup has no '{key}' attribute") + + def finalize(self): + self.methods_index = Immutable( + { + method: route + for route in self._routes + for method in route.methods + } + ) + + def reset(self): + self.methods_index = dict(self.methods_index) + + def merge( + self, group: RouteGroup, overwrite: bool = False, append: bool = False + ) -> None: + """ + The purpose of merge is to group routes with the same path, but + declarared individually. In other words to group these: + + .. code-block:: python + + @app.get("/path/to") + def handler1(...): + ... + + @app.post("/path/to") + def handler2(...): + ... + + The other main purpose is to look for conflicts and + raise ``RouteExists`` + + A duplicate route is when: + 1. They have the same path and any overlapping methods; AND + 2. If they have requirements, they are the same + + :param group: Incoming route group + :type group: RouteGroup + :param overwrite: whether to allow an otherwise duplicate route group + to overwrite the existing, if ``True`` will not raise exception + on duplicates, defaults to False + :type overwrite: bool, optional + :param append: whether to allow an otherwise duplicate route group to + append its routes to the existing route group, defaults to False + :type append: bool, optional + :raises RouteExists: Raised when there is a duplicate + """ + _routes = list(self._routes) + for other_route in group.routes: + for current_route in self: + if ( + current_route == other_route + or ( + current_route.requirements + and not other_route.requirements + ) + or ( + not current_route.requirements + and other_route.requirements + ) + ) and not append: + if not overwrite: + raise RouteExists( + f"Route already registered: {self.raw_path} " + f"[{','.join(self.methods)}]" + ) + else: + _routes.append(other_route) + self._routes = tuple(_routes) + + @property + def depth(self) -> int: + """ + The number of parts in :py:attr:`parts` + """ + return len(self[0].parts) + + @property + def dynamic_path(self) -> bool: + return any( + (param.label == "path") or ("/" in param.label) + for param in self.params.values() + ) + + @property + def methods(self) -> FrozenSet[str]: + """""" + return frozenset( + [method for route in self for method in route.methods] + ) + + @property + def routes(self) -> Sequence[Route]: + return self._routes + + @property + def requirements(self) -> List[Requirements]: + return [route.requirements for route in self if route.requirements] diff --git a/backend/sanic_server/sanic_routing/line.py b/backend/sanic_server/sanic_routing/line.py new file mode 100644 index 000000000..c3acb8637 --- /dev/null +++ b/backend/sanic_server/sanic_routing/line.py @@ -0,0 +1,17 @@ +class Line: + TAB = " " + + def __init__( + self, + src: str, + indent: int, + offset: int = 0, + render: bool = True, + ) -> None: + self.src = src + self.indent = indent + self.offset = offset + self.render = render + + def __str__(self): + return (self.TAB * self.indent) + self.src + "\n" diff --git a/backend/sanic_server/sanic_routing/patterns.py b/backend/sanic_server/sanic_routing/patterns.py new file mode 100644 index 000000000..696be1536 --- /dev/null +++ b/backend/sanic_server/sanic_routing/patterns.py @@ -0,0 +1,50 @@ +import re +import uuid +from datetime import date, datetime + + +def parse_date(d) -> date: + return datetime.strptime(d, "%Y-%m-%d").date() + + +def alpha(param: str) -> str: + if not param.isalpha(): + raise ValueError(f"Value {param} contains non-alphabetic chracters") + return param + + +def slug(param: str) -> str: + if not REGEX_TYPES["slug"][1].match(param): + raise ValueError(f"Value {param} does not match the slug format") + return param + + +REGEX_PARAM_NAME = re.compile(r"^<([a-zA-Z_][a-zA-Z0-9_]*)(?::(.*))?>$") + +# Predefined path parameter types. The value is a tuple consisteing of a +# callable and a compiled regular expression. +# The callable should: +# 1. accept a string input +# 2. cast the string to desired type +# 3. raise ValueError if it cannot +# The regular expression is generally NOT used. Unless the path is forced +# to use regex patterns. +REGEX_TYPES = { + "str": (str, re.compile(r"^[^/]+$")), + "slug": (slug, re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")), + "alpha": (alpha, re.compile(r"^[A-Za-z]+$")), + "path": (str, re.compile(r"^[^/]?.*?$")), + "float": (float, re.compile(r"^-?(?:\d+(?:\.\d*)?|\.\d+)$")), + "int": (int, re.compile(r"^-?\d+$")), + "ymd": ( + parse_date, + re.compile(r"^([12]\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01]))$"), + ), + "uuid": ( + uuid.UUID, + re.compile( + r"^[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-" + r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}$" + ), + ), +} diff --git a/backend/sanic_server/sanic_routing/py.typed b/backend/sanic_server/sanic_routing/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/backend/sanic_server/sanic_routing/route.py b/backend/sanic_server/sanic_routing/route.py new file mode 100644 index 000000000..f2ff97281 --- /dev/null +++ b/backend/sanic_server/sanic_routing/route.py @@ -0,0 +1,343 @@ +import re +import typing as t +from collections import namedtuple +from types import SimpleNamespace +from warnings import warn + +from .exceptions import InvalidUsage, ParameterNameConflicts +from .utils import Immutable, parts_to_path, path_to_parts + +ParamInfo = namedtuple( + "ParamInfo", + ("name", "raw_path", "label", "cast", "pattern", "regex", "priority"), +) + + +class Requirements(Immutable): + def __hash__(self): + return hash(frozenset(self.items())) + + +class Route: + __slots__ = ( + "_params", + "_raw_path", + "ctx", + "handler", + "labels", + "methods", + "name", + "overloaded", + "params", + "parts", + "path", + "pattern", + "regex", + "requirements", + "router", + "static", + "strict", + "unquote", + ) + + #: A container for route meta-data + ctx: SimpleNamespace + #: The route handler + handler: t.Callable[..., t.Any] + #: The HTTP methods that the route can handle + methods: t.FrozenSet[str] + #: The route name, either generated or as defined in the route definition + name: str + #: The raw version of the path exploded (see also + #: :py:attr:`~sanic_routing.route.Route.segments`) + parts: t.Tuple[str, ...] + #: The _reconstructed_ path after the Route has been normalized. + #: Does not contain preceding ``/`` (see also + #: :py:attr:`~sanic_routing.route.Route.uri`) + path: str + #: A regex version of the :py:attr:`~sanic_routing.route.Route.path` + pattern: t.Optional[str] + #: Whether the route requires regular expression evaluation + regex: bool + #: A representation of the non-path route requirements + requirements: Requirements + #: When ``True``, the route does not have any dynamic path parameters + static: bool + #: Whether the route should be matched with strict evaluation + strict: bool + #: Whether the route should be unquoted after matching if (for example) it + #: is suspected to contain non-URL friendly characters + unquote: bool + + def __init__( + self, + router, + raw_path: str, + name: str, + handler: t.Callable[..., t.Any], + methods: t.Union[t.Sequence[str], t.FrozenSet[str]], + requirements: t.Dict[str, t.Any] = None, + strict: bool = False, + unquote: bool = False, + static: bool = False, + regex: bool = False, + overloaded: bool = False, + ): + self.router = router + self.name = name + self.handler = handler # type: ignore + self.methods = frozenset(methods) + self.requirements = Requirements(requirements or {}) + + self.ctx = SimpleNamespace() + + self._params: t.Dict[int, ParamInfo] = {} + self._raw_path = raw_path + + # Main goal is to do some normalization. Any dynamic segments + # that are missing a type are rewritten with str type + ingested_path = self._ingest_path(raw_path) + + # By passing the path back and forth to deconstruct and reconstruct + # we can normalize it and make sure we are dealing consistently + parts = path_to_parts(ingested_path, self.router.delimiter) + self.path = parts_to_path(parts, delimiter=self.router.delimiter) + self.parts = parts + self.static = static + self.regex = regex + self.overloaded = overloaded + self.pattern = None + self.strict: bool = strict + self.unquote: bool = unquote + self.labels: t.Optional[t.List[str]] = None + + self._setup_params() + + def __str__(self): + display = ( + f"name={self.name} path={self.path or self.router.delimiter}" + if self.name and self.name != self.path + else f"path={self.path or self.router.delimiter}" + ) + return f"<{self.__class__.__name__}: {display}>" + + def __repr__(self) -> str: + return str(self) + + def __eq__(self, other) -> bool: + if not isinstance(other, self.__class__): + return False + + # Equality specifically uses self.segments and not self.parts. + # In general, these properties are nearly identical. + # self.segments is generalized and only displays dynamic param types + # and self.parts has both the param key and the param type. + # In this test, we use the & operator so that we create a union and a + # positive equality if there is one or more overlaps in the methods. + return bool( + ( + self.segments, + self.requirements, + ) + == ( + other.segments, + other.requirements, + ) + and (self.methods & other.methods) + ) + + def _ingest_path(self, path): + segments = [] + for part in path.split(self.router.delimiter): + if part.startswith("<") and ":" not in part: + name = part[1:-1] + part = f"<{name}:str>" + segments.append(part) + return self.router.delimiter.join(segments) + + def _setup_params(self): + key_path = parts_to_path( + path_to_parts(self.raw_path, self.router.delimiter), + self.router.delimiter, + ) + if not self.static: + parts = path_to_parts(key_path, self.router.delimiter) + for idx, part in enumerate(parts): + if part.startswith("<"): + ( + name, + label, + _type, + pattern, + ) = self.parse_parameter_string(part[1:-1]) + self.add_parameter( + idx, name, key_path, label, _type, pattern + ) + + def add_parameter( + self, + idx: int, + name: str, + raw_path: str, + label: str, + cast: t.Type, + pattern=None, + ): + if pattern and isinstance(pattern, str): + if not pattern.startswith("^"): + pattern = f"^{pattern}" + if not pattern.endswith("$"): + pattern = f"{pattern}$" + + pattern = re.compile(pattern) + + is_regex = label not in self.router.regex_types + priority = ( + 0 + if is_regex + else list(self.router.regex_types.keys()).index(label) + ) + self._params[idx] = ParamInfo( + name, raw_path, label, cast, pattern, is_regex, priority + ) + + def _finalize_params(self): + params = dict(self._params) + label_pairs = set([(param.name, idx) for idx, param in params.items()]) + labels = [item[0] for item in label_pairs] + if len(labels) != len(set(labels)): + raise ParameterNameConflicts( + f"Duplicate named parameters in: {self._raw_path}" + ) + self.labels = labels + self.params = dict( + sorted(params.items(), key=lambda param: self._sorting(param[1])) + ) + + def _compile_regex(self): + components = [] + + for part in self.parts: + if part.startswith("<"): + name, *_, pattern = self.parse_parameter_string(part) + if not isinstance(pattern, str): + pattern = pattern.pattern.strip("^$") + compiled = re.compile(pattern) + if compiled.groups == 1: + if compiled.groupindex: + if list(compiled.groupindex)[0] != name: + raise InvalidUsage( + f"Named group ({list(compiled.groupindex)[0]})" + f" must match your named parameter ({name})" + ) + components.append(pattern) + else: + if pattern.count("(") > 1: + raise InvalidUsage( + f"Could not compile pattern {pattern}. " + "Try using a named group instead: " + f"'(?P<{name}>your_matching_group)'" + ) + beginning, end = pattern.split("(") + components.append(f"{beginning}(?P<{name}>{end}") + elif compiled.groups > 1: + raise InvalidUsage(f"Invalid matching pattern {pattern}") + else: + components.append(f"(?P<{name}>{pattern})") + else: + components.append(part) + + self.pattern = self.router.delimiter + self.router.delimiter.join( + components + ) + + def finalize(self): + self._finalize_params() + if self.regex: + self._compile_regex() + self.requirements = Immutable(self.requirements) + + def reset(self): + self.requirements = dict(self.requirements) + + @property + def defined_params(self): + return self._params + + @property + def raw_path(self): + """ + The raw path from the route definition + """ + return self._raw_path + + @property + def segments(self) -> t.Tuple[str, ...]: + """ + Same as :py:attr:`~sanic_routing.route.Route.parts` except + generalized so that any dynamic parts do not + include param keys since they have no impact on routing. + """ + return tuple( + f"<__dynamic__:{self._params[idx].label}>" + if idx in self._params + else segment + for idx, segment in enumerate(self.parts) + ) + + @property + def uri(self): + """ + Since :py:attr:`~sanic_routing.route.Route.path` does NOT + include a preceding '/', this adds it back. + """ + return f"{self.router.delimiter}{self.path}" + + def _sorting(self, item) -> int: + try: + return list(self.router.regex_types.keys()).index(item.label) + except ValueError: + return len(list(self.router.regex_types.keys())) + + def parse_parameter_string(self, parameter_string: str): + """Parse a parameter string into its constituent name, type, and + pattern + + For example:: + + parse_parameter_string('')` -> + ('param_one', '[A-z]', , '[A-z]') + + :param parameter_string: String to parse + :return: tuple containing + (parameter_name, parameter_type, parameter_pattern) + """ + # We could receive NAME or NAME:PATTERN + parameter_string = parameter_string.strip("<>") + name = parameter_string + label = "str" + if ":" in parameter_string: + name, label = parameter_string.split(":", 1) + if not name: + raise ValueError( + f"Invalid parameter syntax: {parameter_string}" + ) + if label == "string": + warn( + "Use of 'string' as a path parameter type is deprected, " + "and will be removed in Sanic v21.12. " + f"Instead, use <{name}:str>.", + DeprecationWarning, + ) + elif label == "number": + warn( + "Use of 'number' as a path parameter type is deprected, " + "and will be removed in Sanic v21.12. " + f"Instead, use <{name}:float>.", + DeprecationWarning, + ) + + default = (str, label) + # Pull from pre-configured types + _type, pattern = self.router.regex_types.get(label, default) + return name, label, _type, pattern diff --git a/backend/sanic_server/sanic_routing/router.py b/backend/sanic_server/sanic_routing/router.py new file mode 100644 index 000000000..897b23f49 --- /dev/null +++ b/backend/sanic_server/sanic_routing/router.py @@ -0,0 +1,578 @@ +import ast +import sys +import typing as t +from abc import ABC, abstractmethod +from re import Pattern +from types import SimpleNamespace +from warnings import warn + +from ..sanic_routing.group import RouteGroup +from .exceptions import BadMethod, FinalizationError, InvalidUsage, NoMethod, NotFound +from .line import Line +from .patterns import REGEX_TYPES +from .route import Route +from .tree import Node, Tree +from .utils import parts_to_path, path_to_parts + +# The below functions might be called by the compiled source code, and +# therefore should be made available here by import +import re # noqa isort:skip +from datetime import datetime # noqa isort:skip +from urllib.parse import unquote # noqa isort:skip +from uuid import UUID # noqa isort:skip +from .patterns import parse_date, alpha, slug # noqa isort:skip + + +class BaseRouter(ABC): + DEFAULT_METHOD = "BASE" + ALLOWED_METHODS: t.Tuple[str, ...] = tuple() + + def __init__( + self, + delimiter: str = "/", + exception: t.Type[NotFound] = NotFound, + method_handler_exception: t.Type[NoMethod] = NoMethod, + route_class: t.Type[Route] = Route, + group_class: t.Type[RouteGroup] = RouteGroup, + stacking: bool = False, + cascade_not_found: bool = False, + ) -> None: + self._find_route = None + self._matchers = None + self.static_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} + self.dynamic_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} + self.regex_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} + self.name_index: t.Dict[str, Route] = {} + self.delimiter = delimiter + self.exception = exception + self.method_handler_exception = method_handler_exception + self.route_class = route_class + self.group_class = group_class + self.tree = Tree(router=self) + self.finalized = False + self.stacking = stacking + self.ctx = SimpleNamespace() + self.cascade_not_found = cascade_not_found + + self.regex_types = {**REGEX_TYPES} + + @abstractmethod + def get(self, **kwargs): + ... + + def resolve( + self, + path: str, + *, + method: t.Optional[str] = None, + orig: t.Optional[str] = None, + extra: t.Optional[t.Dict[str, str]] = None, + ) -> t.Tuple[Route, t.Callable[..., t.Any], t.Dict[str, t.Any]]: + try: + route, param_basket = self.find_route( + path, + method, + self, + {"__params__": {}, "__matches__": {}}, + extra, + ) + except (NotFound, NoMethod) as e: + # If we did not find the route, we might need to try routing one + # more time to handle strict_slashes + if path.endswith(self.delimiter): + return self.resolve( + path=path[:-1], + method=method, + orig=path, + extra=extra, + ) + raise self.exception(str(e), path=path) + + if isinstance(route, RouteGroup): + try: + route = route.methods_index[method] + except KeyError: + raise self.method_handler_exception( + f"Method '{method}' not found on {route}", + method=method, + allowed_methods=route.methods, + ) + + # Regex routes evaluate and can extract params directly. They are set + # on param_basket["__params__"] + params = param_basket["__params__"] + if not params: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. + params = { + param.name: param_basket["__matches__"][idx] + for idx, param in route.params.items() + } + + # Double check that if we made a match it is not a false positive + # because of strict_slashes + if route.strict and orig and orig[-1] != route.path[-1]: + raise self.exception("Path not found", path=path) + + if method not in route.methods: + raise self.method_handler_exception( + f"Method '{method}' not found on {route}", + method=method, + allowed_methods=route.methods, + ) + + return route, route.handler, params + + def add( + self, + path: str, + handler: t.Callable, + methods: t.Optional[t.Union[t.Sequence[str], t.FrozenSet[str], str]] = None, + name: t.Optional[str] = None, + requirements: t.Optional[t.Dict[str, t.Any]] = None, + strict: bool = False, + unquote: bool = False, # noqa + overwrite: bool = False, + append: bool = False, + ) -> Route: + # Can add a route with overwrite, or append, not both. + # - overwrite: if matching path exists, replace it + # - append: if matching path exists, append handler to it + if overwrite and append: + raise FinalizationError( + "Cannot add a route with both overwrite and append equal " "to True" + ) + if not methods: + methods = [self.DEFAULT_METHOD] + + if hasattr(methods, "__iter__") and not isinstance(methods, frozenset): + methods = frozenset(methods) + elif isinstance(methods, str): + methods = frozenset([methods]) + + if self.ALLOWED_METHODS and any( + method not in self.ALLOWED_METHODS for method in methods + ): + bad = [method for method in methods if method not in self.ALLOWED_METHODS] + raise BadMethod( + f"Bad method: {bad}. Must be one of: {self.ALLOWED_METHODS}" + ) + + if self.finalized: + raise FinalizationError("Cannot finalize router more than once.") + + static = "<" not in path and requirements is None + regex = self._is_regex(path) + + # There are generally three pools of routes on the router: + # - those that are static patterns with not matching + # - those that have one or more dynamic parts, but NO regex + # - those that have one or more dynamic parts, with at least one regex + if regex: + routes = self.regex_routes + elif static: + routes = self.static_routes + else: + routes = self.dynamic_routes + + # Only URL encode the static parts of the path + path = parts_to_path(path_to_parts(path, self.delimiter), self.delimiter) + + # We need to clean off the delimiters are the beginning, and maybe the + # end, depending upon whether we are in strict mode + strip = path.lstrip if strict else path.strip + path = strip(self.delimiter) + route = self.route_class( + self, + path, + name or "", + handler=handler, + methods=methods, + requirements=requirements, + strict=strict, + unquote=unquote, + static=static, + regex=regex, + ) + group = self.group_class(route) + + # Catch the scenario where a route is overloaded with and + # and without requirements, first as dynamic then as static + if static and route.segments in self.dynamic_routes: + routes = self.dynamic_routes + + # Catch the reverse scenario where a route is overload first as static + # and then as dynamic + if not static and route.segments in self.static_routes: + existing_group = self.static_routes.pop(route.segments) + group.merge(existing_group, overwrite, append) + + else: + if route.segments in routes: + existing_group = routes[route.segments] + group.merge(existing_group, overwrite, append) + + routes[route.segments] = group + + if name: + self.name_index[name] = route + + group.finalize() + + return route + + def register_pattern( + self, label: str, cast: t.Callable[[str], t.Any], pattern: Pattern + ): + """ + Add a custom parameter type to the router. The cast shoud raise a + ValueError if it is an incorrect type. The order of registration is + important if it is possible that a single value could pass multiple + pattern types. Therefore, patterns are tried in the REVERSE order of + registration. All custom patterns will be evaluated before any built-in + patterns. + + :param label: The parts that is used to signify the type: example + + :type label: str + :param cast: The callable that casts the value to the desired type, or + fails trying + :type cast: t.Callable[[str], t.Any] + :param pattern: A regular expression that could also match the path + segment + :type pattern: Pattern + """ + if not isinstance(label, str): + raise InvalidUsage( + "When registering a pattern, label must be a " + f"string, not label={label}" + ) + if not callable(cast): + raise InvalidUsage( + "When registering a pattern, cast must be a " + f"callable, not cast={cast}" + ) + if not isinstance(pattern, str): + raise InvalidUsage( + "When registering a pattern, pattern must be a " + f"string, not pattern={pattern}" + ) + + globals()[cast.__name__] = cast + self.regex_types[label] = (cast, pattern) + + def finalize(self, do_compile: bool = True, do_optimize: bool = False): + """ + After all routes are added, we can put everything into a final state + and build the routing dource + + :param do_compile: Whether to compile the source, mainly a debugging + tool, defaults to True + :type do_compile: bool, optional + :param do_optimize: Experimental feature that uses AST module to make + some optimizations, defaults to False + :type do_optimize: bool, optional + :raises FinalizationError: Cannot finalize if there are no routes, or + the router has already been finalized (can call reset() to undo it) + """ + if self.finalized: + raise FinalizationError("Cannot finalize router more than once.") + if not self.routes: + raise FinalizationError("Cannot finalize with no routes defined.") + self.finalized = True + + for group in ( + list(self.static_routes.values()) + + list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + ): + group.finalize() + for route in group.routes: + route.finalize() + + # Evaluates all of the paths and arranges them into a hierarchichal + # tree of nodes + self._generate_tree() + + # Renders the source code + self._render(do_compile, do_optimize) + + def reset(self): + self.finalized = False + self.tree = Tree(router=self) + self._find_route = None + + for group in ( + list(self.static_routes.values()) + + list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + ): + group.reset() + for route in group.routes: + route.reset() + + def _get_non_static_non_path_groups( + self, has_dynamic_path: bool + ) -> t.List[RouteGroup]: + """ + Paths that have some matching params (includes dynamic and regex), + but excludes any routes with a or delimiter in its regex. + This is because those special cases need to be evaluated seperately. + Anything else can be evaluated in the node tree. + + :param has_dynamic_path: Whether the path catches a path, or path-like + type + :type has_dynamic_path: bool + :return: list of routes that have no path, but do need matching + :rtype: List[RouteGroup] + """ + return sorted( + [ + group + for group in list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + if group.dynamic_path is has_dynamic_path + ], + key=lambda x: x.depth, + reverse=True, + ) + + def _generate_tree(self) -> None: + self.tree.generate(self._get_non_static_non_path_groups(False)) + self.tree.finalize() + + def _render(self, do_compile: bool = True, do_optimize: bool = False) -> None: + # Initial boilerplate for the function source + src = [ + Line("def find_route(path, method, router, basket, extra):", 0), + Line("parts = tuple(path[1:].split(router.delimiter))", 1), + ] + delayed = [] + + # Add static path matching + if self.static_routes: + # TODO: + # - future improvement would be to decide which option to use + # at runtime based upon the makeup of the router since this + # potentially has an impact on performance + src += [ + Line("try:", 1), + Line( + "group = router.static_routes[parts]", + 2, + ), + Line("basket['__raw_path__'] = path", 2), + Line("return group, basket", 2), + Line("except KeyError:", 1), + Line("pass", 2), + ] + # src += [ + # Line("if parts in router.static_routes:", 1), + # Line("route = router.static_routes[parts]", 2), + # Line("basket['__raw_path__'] = route.path", 2), + # Line("return route, basket", 2), + # ] + # src += [ + # Line("if path in router.static_routes:", 1), + # Line("route = router.static_routes.get(path)", 2), + # Line("basket['__raw_path__'] = route.path", 2), + # Line("return route, basket", 2), + # ] + + # Add in pre-compiled regular expressions so they do not need to + # compile at run time + if self.regex_routes: + routes = sorted( + self.regex_routes.values(), + key=lambda route: len(route.parts), + reverse=True, + ) + delayed.append(Line("matchers = [", 0)) + for idx, group in enumerate(routes): + group.pattern_idx = idx + delayed.append(Line(f"re.compile(r'^{group.pattern}$'),", 1)) + delayed.append(Line("]", 0)) + + # Generate all the dynamic code + if self.dynamic_routes or self.regex_routes: + src += [Line("num = len(parts)", 1)] + src += self.tree.render() + + # Inject regex matching that could not be in the tree + for group in self._get_non_static_non_path_groups(True): + route_container = "regex_routes" if group.regex else "dynamic_routes" + route_idx: t.Union[str, int] = 0 + holder: t.List[Line] = [] + + if len(group.routes) > 1: + route_idx = "route_idx" + Node._inject_method_check(holder, 2, group) + + src.extend( + [ + Line( + ( + "match = router.matchers" + f"[{group.pattern_idx}].match(path)" + ), + 1, + ), + Line("if match:", 1), + *holder, + Line("basket['__params__'] = match.groupdict()", 2), + Line( + ( + f"return router.{route_container}" + f"[{group.segments}][{route_idx}], basket" + ), + 2, + ), + ] + ) + + src.append(Line("raise NotFound", 1)) + src.extend(delayed) + + self.find_route_src = "".join(map(str, filter(lambda x: x.render, src))) + if do_compile: + try: + syntax_tree = ast.parse(self.find_route_src) + + if do_optimize: + self._optimize(syntax_tree.body[0]) + + if sys.version_info.major == 3 and sys.version_info.minor >= 9: + # This is purely a convenience thing. Python 3.9 added this + # feature, so it allows us to see exactly how the + # interpreter will see the code after compiling and any + # optimizing. + setattr( + self, + "find_route_src_compiled", + ast.unparse(syntax_tree), # type: ignore + ) + + # Sometimes there may be missing meta data, so we add it back + # before compiling + ast.fix_missing_locations(syntax_tree) + + compiled_src = compile( + syntax_tree, + "", + "exec", + ) + except SyntaxError as se: + syntax_error = ( + f"Line {se.lineno}: {se.msg}\n{se.text}" + f"{' '*max(0,int(se.offset or 0)-1) + '^'}" + ) + raise FinalizationError( + f"Cannot compile route AST:\n{self.find_route_src}" + f"\n{syntax_error}" + ) + ctx: t.Dict[t.Any, t.Any] = {} + exec(compiled_src, None, ctx) + self._find_route = ctx["find_route"] + self._matchers = ctx.get("matchers") + + @property + def find_route(self): + return self._find_route + + @property + def matchers(self): + return self._matchers + + @property + def groups(self): + return { + **self.static_routes, + **self.dynamic_routes, + **self.regex_routes, + } + + @property + def routes(self): + return tuple([route for group in self.groups.values() for route in group]) + + def _optimize(self, node) -> None: + warn( + "Router AST optimization is an experimental only feature. " + "Results may vary from unoptimized code." + ) + if hasattr(node, "body"): + for child in node.body: + self._optimize(child) + + # concatenate nested single if blocks + # EXAMPLE: + # if parts[1] == "foo": + # if num > 3: + # BECOMES: + # if parts[1] == 'foo' and num > 3: + # Testing has shown that further recursion does not actually + # produce any faster results. + if self._is_lone_if(node) and self._is_lone_if(node.body[0]): + current = node.body[0] + nested = node.body[0].body[0] + + values: t.List[t.Any] = [] + for test in [current.test, nested.test]: + if isinstance(test, ast.Compare): + values.append(test) + elif isinstance(test, ast.BoolOp) and isinstance(test.op, ast.And): + values.extend(test.values) + else: + ... + combined = ast.BoolOp(op=ast.And(), values=values) + + current.test = combined + current.body = nested.body + + # Look for identical successive if blocks + # EXAMPLE: + # if num == 5: + # foo1() + # if num == 5: + # foo2() + # BECOMES: + # if num == 5: + # foo1() + # foo2() + if ( + all(isinstance(child, ast.If) for child in node.body) + # TODO: create func to peoperly compare equality of conditions + # and len({child.test for child in node.body}) + and len(node.body) > 1 + ): + first, *rem = node.body + for item in rem: + first.body.extend(item.body) + + node.body = [first] + + if hasattr(node, "orelse"): + for child in node.orelse: + self._optimize(child) + + @staticmethod + def _is_lone_if(node): + return len(node.body) == 1 and isinstance(node.body[0], ast.If) + + def _is_regex(self, path: str): + parts = path_to_parts(path, self.delimiter) + + def requires(part): + if not part.startswith("<") or ":" not in part: + return False + + _, pattern_type = part[1:-1].split(":", 1) + + return ( + part.endswith(":path>") + or self.delimiter in part + or pattern_type not in self.regex_types + ) + + return any(requires(part) for part in parts) diff --git a/backend/sanic_server/sanic_routing/tree.py b/backend/sanic_server/sanic_routing/tree.py new file mode 100644 index 000000000..23c0e52ff --- /dev/null +++ b/backend/sanic_server/sanic_routing/tree.py @@ -0,0 +1,473 @@ +import typing as t +from logging import getLogger + +from .group import RouteGroup +from .line import Line +from .patterns import REGEX_PARAM_NAME + +logger = getLogger("sanic.root") + + +class Node: + def __init__( + self, + part: str = "", + root: bool = False, + parent=None, + router=None, + param=None, + ) -> None: + self.root = root + self.part = part + self.parent = parent + self.param = param + self._children: t.Dict[str, "Node"] = {} + self.children: t.Dict[str, "Node"] = {} + self.level = 0 + self.base_indent = 0 + self.offset = 0 + self.groups: t.List[RouteGroup] = [] + self.dynamic = False + self.first = False + self.last = False + self.children_basketed = False + self.children_param_injected = False + self.has_deferred = False + self.equality_check = False + self.unquote = False + self.router = router + + def __str__(self) -> str: + internals = ", ".join( + f"{prop}={getattr(self, prop)}" + for prop in ["part", "level", "groups", "dynamic"] + if getattr(self, prop) or prop in ["level"] + ) + return f"" + + def __repr__(self) -> str: + return str(self) + + @property + def ident(self) -> str: + prefix = ( + f"{self.parent.ident}." + if self.parent and not self.parent.root + else "" + ) + return f"{prefix}{self.idx}" + + @property + def idx(self) -> int: + if not self.parent: + return 1 + return list(self.parent.children.keys()).index(self.part) + 1 + + def finalize_children(self): + """ + Sort the children (if any), and set properties for easy checking + # they are at the beginning or end of the line. + """ + self.children = { + k: v for k, v in sorted(self._children.items(), key=self._sorting) + } + if self.children: + keys = list(self.children.keys()) + self.children[keys[0]].first = True + self.children[keys[-1]].last = True + + for child in self.children.values(): + child.finalize_children() + + def display(self) -> None: + """ + Visual display of the tree of nodes + """ + logger.info(" " * 4 * self.level + str(self)) + for child in self.children.values(): + child.display() + + def render(self) -> t.Tuple[t.List[Line], t.List[Line]]: + # output - code injected into the source as it is being + # called/evaluated + # delayed - code that is injected after you do all of its children + # first + # final - code that is injected at the very end of all rendering + src: t.List[Line] = [] + delayed: t.List[Line] = [] + final: t.List[Line] = [] + + if not self.root: + src, delayed, final = self.to_src() + for child in self.children.values(): + o, f = child.render() + src += o + final += f + return src + delayed, final + + def to_src(self) -> t.Tuple[t.List[Line], t.List[Line], t.List[Line]]: + siblings = self.parent.children if self.parent else {} + first_sibling: t.Optional[Node] = None + + if not self.first: + first_sibling = next(iter(siblings.values())) + + self.base_indent = ( + bool(self.level >= 1 or self.first) + self.parent.base_indent + if self.parent + else 0 + ) + + indent = self.base_indent + + # See render() docstring for definition of these three sequences + delayed: t.List[Line] = [] + final: t.List[Line] = [] + src: t.List[Line] = [] + + # Some cleanup to make code easier to read + src.append(Line("", indent)) + src.append(Line(f"# node={self.ident} // part={self.part}", indent)) + + level = self.level + idx = level - 1 + + return_bump = not self.dynamic + + operation = ">" + conditional = "if" + + # The "equality_check" is when we do a "==" operation to check + # that the incoming path is the same length as a particular target. + # Since this could take place in a few different locations, we need + # to be able to track if it has been set. + if self.groups: + operation = "==" if self.level == self.parent.depth else ">=" + self.equality_check = operation == "==" + + src.append( + Line( + f"{conditional} num {operation} {level}: # CHECK 1", + indent, + ) + ) + indent += 1 + + if self.dynamic: + # Injects code to try casting a segment to all POTENTIAL types that + # the defined routes could catch in this location + self._inject_param_check(src, indent, idx) + indent += 1 + + else: + if ( + not self.equality_check + and self.groups + and not self.first + and first_sibling + ): + self.equality_check = first_sibling.equality_check + + # Maybe try and sneak an equality check in? + if_stmt = "if" + len_check = ( + f" and num == {self.level}" + if not self.children and not self.equality_check + else "" + ) + + self.equality_check |= bool(len_check) + + src.append( + Line( + f'{if_stmt} parts[{idx}] == "{self.part}"{len_check}:' + " # CHECK 4", + indent, + ) + ) + self.base_indent += 1 + + # Get ready to return some handlers + if self.groups: + return_indent = indent + return_bump + route_idx: t.Union[int, str] = 0 + location = delayed + + # Do any missing equality_check + if not self.equality_check: + # If if we have not done an equality check and there are + # children nodes, then we know there is a CHECK 1 + # for the children that starts at the same level, and will + # be an exclusive conditional to what is being evaluated here. + # Therefore, we can use elif + # example: + # if num == 7: # CHECK 1 + # child_node_stuff + # elif num == 6: # CHECK 5 + # current_node_stuff + conditional = "elif" if self.children else "if" + operation = "==" + location.append( + Line( + f"{conditional} num {operation} {level}: # CHECK 5", + return_indent, + ) + ) + return_indent += 1 + + for group in sorted(self.groups, key=self._group_sorting): + group_bump = 0 + + # If the route had some requirements, let's make sure we check + # them in the source + if group.requirements: + route_idx = "route_idx" + self._inject_requirements( + location, return_indent + group_bump, group + ) + + # This is for any inline regex routes. It sould not include, + # path or path-like routes. + if group.regex: + self._inject_regex( + location, return_indent + group_bump, group + ) + group_bump += 1 + + # Since routes are grouped, we need to know which to select + # Inside the compiled source, we keep track so we know which + # handler to assign this to + if route_idx == 0 and len(group.routes) > 1: + route_idx = "route_idx" + self._inject_method_check( + location, return_indent + group_bump, group + ) + + # The return.kingdom + self._inject_return( + location, return_indent + group_bump, route_idx, group + ) + + return src, delayed, final + + def add_child(self, child: "Node") -> None: + self._children[child.part] = child + + def _inject_param_check(self, location, indent, idx): + """ + Try and cast relevant path segments. + """ + lines = [ + Line("try:", indent), + Line( + f"basket['__matches__'][{idx}] = " + f"{self.param.cast.__name__}(parts[{idx}])", + indent + 1, + ), + Line("except ValueError:", indent), + Line("pass", indent + 1), + Line("else:", indent), + ] + if self.unquote: + lines.append( + Line( + f"basket['__matches__'][{idx}] = " + f"unquote(basket['__matches__'][{idx}])", + indent + 1, + ) + ) + self.base_indent += 1 + + location.extend(lines) + + @staticmethod + def _inject_method_check(location, indent, group): + """ + Sometimes we need to check the routing methods inside the generated src + """ + for i, route in enumerate(group.routes): + if_stmt = "if" if i == 0 else "elif" + location.extend( + [ + Line( + f"{if_stmt} method in {route.methods}:", + indent, + ), + Line(f"route_idx = {i}", indent + 1), + ] + ) + location.extend( + [ + Line("else:", indent), + Line("raise NoMethod", indent + 1), + ] + ) + + def _inject_return(self, location, indent, route_idx, group): + """ + The return statement for the node if needed + """ + routes = "regex_routes" if group.regex else "dynamic_routes" + route_return = "" if group.router.stacking else f"[{route_idx}]" + location.extend( + [ + Line(f"# Return {self.ident}", indent), + Line( + ( + f"return router.{routes}[{group.segments}]" + f"{route_return}, basket" + ), + indent, + ), + ] + ) + + def _inject_requirements(self, location, indent, group): + """ + Check any extra checks needed for a route. In path routing, for exampe, + this is used for matching vhosts. + """ + for k, route in enumerate(group): + conditional = "if" if k == 0 else "elif" + location.extend( + [ + Line( + ( + f"{conditional} extra == {route.requirements} " + f"and method in {route.methods}:" + ), + indent, + ), + Line((f"route_idx = {k}"), indent + 1), + ] + ) + + location.extend( + [ + Line(("else:"), indent), + Line(("raise NotFound"), indent + 1), + ] + ) + + def _inject_regex(self, location, indent, group): + """ + For any path matching that happens in the course of the tree (anything + that has a path matching----or similar matching with regex + delimiter) + """ + location.extend( + [ + Line( + ( + "match = router.matchers" + f"[{group.pattern_idx}].match(path)" + ), + indent, + ), + Line("if match:", indent), + Line( + "basket['__params__'] = match.groupdict()", + indent + 1, + ), + ] + ) + + def _sorting(self, item) -> t.Tuple[bool, bool, int, int, int, bool, str]: + """ + Primarily use to sort nodes to determine the order of the nested tree + """ + key, child = item + type_ = 0 + if child.dynamic: + type_ = child.param.priority + + return ( + bool(child.groups), + child.dynamic, + type_ * -1, + child.depth * -1, + len(child._children), + not bool( + child.groups and any(group.regex for group in child.groups) + ), + key, + ) + + def _group_sorting(self, item) -> t.Tuple[int, ...]: + """ + When multiple RouteGroups terminate on the same node, we want to + evaluate them based upon the priority of the param matching types + """ + + def get_type(segment): + type_ = 0 + if segment.startswith("<"): + key = segment[1:-1] + if ":" in key: + key, param_type = key.split(":", 1) + try: + type_ = list(self.router.regex_types.keys()).index( + param_type + ) + except ValueError: + type_ = len(list(self.router.regex_types.keys())) + return type_ * -1 + + segments = tuple(map(get_type, item.parts)) + return segments + + @property + def depth(self): + if not self._children: + return self.level + return max(child.depth for child in self._children.values()) + + +class Tree: + def __init__(self, router) -> None: + self.root = Node(root=True, router=router) + self.root.level = 0 + self.router = router + + def generate(self, groups: t.Iterable[RouteGroup]) -> None: + """ + Arrange RouteGroups into hierarchical nodes and arrange them into + a tree + """ + for group in groups: + current = self.root + for level, part in enumerate(group.parts): + param = None + dynamic = part.startswith("<") + if dynamic: + if not REGEX_PARAM_NAME.match(part): + raise ValueError(f"Invalid declaration: {part}") + part = f"__dynamic__:{group.params[level].label}" + param = group.params[level] + if part not in current._children: + child = Node( + part=part, + parent=current, + router=self.router, + param=param, + ) + child.dynamic = dynamic + current.add_child(child) + current = current._children[part] + current.level = level + 1 + + current.groups.append(group) + current.unquote = current.unquote or group.unquote + + def display(self) -> None: + """ + Debug tool to output visual of the tree + """ + self.root.display() + + def render(self) -> t.List[Line]: + o, f = self.root.render() + return o + f + + def finalize(self): + self.root.finalize_children() diff --git a/backend/sanic_server/sanic_routing/utils.py b/backend/sanic_server/sanic_routing/utils.py new file mode 100644 index 000000000..c19ece416 --- /dev/null +++ b/backend/sanic_server/sanic_routing/utils.py @@ -0,0 +1,80 @@ +import re +from urllib.parse import quote, unquote + +from .patterns import REGEX_PARAM_NAME + + +class Immutable(dict): + def __setitem__(self, *args): + raise TypeError("Cannot change immutable dict") + + def __delitem__(self, *args): + raise TypeError("Cannot change immutable dict") + + +def parse_parameter_basket(route, basket, raw_path=None): + params = {} + if basket: + for idx, value in basket.items(): + for p in route.params[idx]: + if not raw_path or p.raw_path == raw_path: + if not p.regex: + raw_path = p.raw_path + params[p.name] = p.cast(value) + break + elif p.pattern.search(value): + raw_path = p.raw_path + if "(" in p.pattern: + groups = p.pattern.match(value) + value = groups.group(1) + params[p.name] = p.cast(value) + break + + if raw_path: + raise ValueError("Invalid parameter") + + if raw_path and not params[p.name]: + raise ValueError("Invalid parameter") + + if route.unquote: + for p in route.params[idx]: + if isinstance(params[p.name], str): + params[p.name] = unquote(params[p.name]) + + if raw_path is None: + raise ValueError("Invalid parameter") + return params, raw_path + + +def path_to_parts(path, delimiter="/"): + r""" + OK > /foo//bar/ + OK > /foo/ + OK > /foo/txt)>/ + OK > /foo// + OK > /foo//txt)d> + NOT OK > /foo/txt)d>/ + """ + path = unquote(path.lstrip(delimiter)) + delimiter = re.escape(delimiter) + return tuple( + part if part.startswith("<") else quote(part) + for part in re.split(rf"{delimiter}(?=[^>]*(?:<(?") + except AttributeError: + raise ValueError(f"Invalid declaration: {part}") + else: + path.append(part) + return delimiter.join(path) diff --git a/frontend/.gitignore b/frontend/.gitignore deleted file mode 100644 index 8e3a10669..000000000 --- a/frontend/.gitignore +++ /dev/null @@ -1,89 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -lerna-debug.log* - -# Diagnostic reports (https://nodejs.org/api/report.html) -report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json - -# Runtime data -pids -*.pid -*.seed -*.pid.lock -.DS_Store - -# Directory for instrumented libs generated by jscoverage/JSCover -lib-cov - -# Coverage directory used by tools like istanbul -coverage -*.lcov - -# nyc test coverage -.nyc_output - -# node-waf configuration -.lock-wscript - -# Compiled binary addons (https://nodejs.org/api/addons.html) -build/Release - -# Dependency directories -node_modules/ -jspm_packages/ - -# TypeScript v1 declaration files -typings/ - -# TypeScript cache -*.tsbuildinfo - -# Optional npm cache directory -.npm - -# Optional eslint cache -.eslintcache - -# Optional REPL history -.node_repl_history - -# Output of 'npm pack' -*.tgz - -# Yarn Integrity file -.yarn-integrity - -# dotenv environment variables file -.env -.env.test - -# parcel-bundler cache (https://parceljs.org/) -.cache - -# next.js build output -.next - -# nuxt.js build output -.nuxt - -# vuepress build output -.vuepress/dist - -# Serverless directories -.serverless/ - -# FuseBox cache -.fusebox/ - -# DynamoDB Local files -.dynamodb/ - -# Webpack -.webpack/ - -# Electron-Forge -out/ diff --git a/frontend/package-lock.json b/package-lock.json similarity index 98% rename from frontend/package-lock.json rename to package-lock.json index d6c992cba..e72ac8c0b 100644 --- a/frontend/package-lock.json +++ b/package-lock.json @@ -26,6 +26,7 @@ "image-webpack-loader": "^8.0.1", "lodash": "^4.17.21", "meow": "^10.1.1", + "os-utils": "^0.0.14", "portastic": "^1.0.1", "react": "^17.0.2", "react-dom": "^17.0.2", @@ -59,6 +60,7 @@ "eslint": "^7.32.0", "eslint-config-airbnb": "^18.2.1", "node-loader": "^2.0.0", + "semver-regex": ">=3.1.3", "style-loader": "^3.2.1" } }, @@ -2101,7 +2103,6 @@ "version": "1.13.0", "resolved": "https://registry.npmjs.org/@electron/get/-/get-1.13.0.tgz", "integrity": "sha512-+SjZhRuRo+STTO1Fdhzqnv9D2ZhjxXP6egsJ9kiO8dtP68cDx7dFCwWi64dlMQV7sWcfW1OYCW4wviEBzmRsfQ==", - "dev": true, "dependencies": { "debug": "^4.1.1", "env-paths": "^2.2.0", @@ -2123,7 +2124,6 @@ "version": "8.1.0", "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-8.1.0.tgz", "integrity": "sha512-yhlQgA6mnOJUKOsRUFsgJdQCvkKhcz8tlZG5HBQfReYZy46OwLcY+Zia0mtdHsOo9y/hP+CxMN0TU9QxoOtG4g==", - "dev": true, "dependencies": { "graceful-fs": "^4.2.0", "jsonfile": "^4.0.0", @@ -2137,7 +2137,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-4.0.0.tgz", "integrity": "sha1-h3Gq4HmbZAdrdmQPygWPnBDjPss=", - "dev": true, "optionalDependencies": { "graceful-fs": "^4.1.6" } @@ -2146,7 +2145,6 @@ "version": "6.3.0", "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", - "dev": true, "bin": { "semver": "bin/semver.js" } @@ -2155,7 +2153,6 @@ "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", "integrity": "sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg==", - "dev": true, "engines": { "node": ">= 4.0.0" } @@ -2747,7 +2744,6 @@ "version": "0.14.0", "resolved": "https://registry.npmjs.org/@sindresorhus/is/-/is-0.14.0.tgz", "integrity": "sha512-9NET910DNaIPngYnLLPeg+Ogzqsi9uM4mSboU5y6p8S5DzMTVEsJZrawi+BoDNUVBa2DhJqQYUFvMDfgU062LQ==", - "dev": true, "engines": { "node": ">=6" } @@ -2764,7 +2760,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/@szmarczak/http-timer/-/http-timer-1.1.2.tgz", "integrity": "sha512-XIB2XbzHTN6ieIjfIMV9hlVcfPU26s2vafYWQcZHWXHOxiaRZYEDKEwdl129Zyg50+foYV2jCgtrqSA6qNuNSA==", - "dev": true, "dependencies": { "defer-to-connect": "^1.0.1" }, @@ -3693,15 +3688,6 @@ "node": ">=6.0" } }, - "node_modules/array-find-index": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/array-find-index/-/array-find-index-1.0.2.tgz", - "integrity": "sha1-3wEKoSh+Fku9pvlyOwqWoexBh6E=", - "optional": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/array-flatten": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", @@ -4806,7 +4792,6 @@ "version": "3.1.4", "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.1.4.tgz", "integrity": "sha512-3hx0kwU3uzG6ReQ3pnaFQPSktpBw6RHN3/ivDKEuU8g1XSfafowyvDnadjv1xp8IZqhtSukxlwv9bF6FhX8m0w==", - "dev": true, "optional": true }, "node_modules/brace-expansion": { @@ -4895,7 +4880,6 @@ "version": "0.2.13", "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-0.2.13.tgz", "integrity": "sha1-DTM+PwDqxQqhRUq9MO+MKl2ackI=", - "devOptional": true, "engines": { "node": "*" } @@ -4968,7 +4952,6 @@ "version": "6.1.0", "resolved": "https://registry.npmjs.org/cacheable-request/-/cacheable-request-6.1.0.tgz", "integrity": "sha512-Oj3cAGPCqOZX7Rz64Uny2GYAZNliQSqfbePrgAQ1wKAihYmCUnraBtJtKcGR4xz7wF+LoJC+ssFZvv5BgF9Igg==", - "dev": true, "dependencies": { "clone-response": "^1.0.2", "get-stream": "^5.1.0", @@ -4986,7 +4969,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-2.0.0.tgz", "integrity": "sha512-tqNXrS78oMOE73NMxK4EMLQsQowWf8jKooH9g7xPavRT706R6bkQJ6DY2Te7QukaZsulxa30wQ7bk0pm4XiHmA==", - "dev": true, "engines": { "node": ">=8" } @@ -5303,7 +5285,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/clone-response/-/clone-response-1.0.2.tgz", "integrity": "sha1-0dyXOSAxTfZ/vrlCI7TuNQI56Ws=", - "devOptional": true, "dependencies": { "mimic-response": "^1.0.0" } @@ -5461,7 +5442,6 @@ "version": "1.6.2", "resolved": "https://registry.npmjs.org/concat-stream/-/concat-stream-1.6.2.tgz", "integrity": "sha512-27HBghJxjiZtIk3Ycvn/4kbJk/1uZuJFfuPEns6LaEvpvG1f0hTea8lilrouyo9mVc2GWdcEZ8OLoGmSADlrCw==", - "dev": true, "engines": [ "node >= 0.8" ], @@ -5476,7 +5456,6 @@ "version": "2.3.7", "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.7.tgz", "integrity": "sha512-Ebho8K4jIbHAxnuxi7o42OrZgF/ZTNcsZj6nRKyUmkhLFq8CHItp/fy6hQZuZmP/n3yZ9VBUbp4zz/mX8hmYPw==", - "dev": true, "dependencies": { "core-util-is": "~1.0.0", "inherits": "~2.0.3", @@ -5491,7 +5470,6 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", - "dev": true, "dependencies": { "safe-buffer": "~5.1.0" } @@ -5706,7 +5684,6 @@ "version": "3.17.2", "resolved": "https://registry.npmjs.org/core-js/-/core-js-3.17.2.tgz", "integrity": "sha512-XkbXqhcXeMHPRk2ItS+zQYliAMilea2euoMsnpRRdDad6b2VY6CQQcwz1K8AnWesfw4p165RzY0bTnr3UrbYiA==", - "dev": true, "hasInstallScript": true, "optional": true, "funding": { @@ -5729,8 +5706,7 @@ "node_modules/core-util-is": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", - "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", - "devOptional": true + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==" }, "node_modules/cosmiconfig": { "version": "6.0.0", @@ -5944,18 +5920,6 @@ "dev": true, "optional": true }, - "node_modules/currently-unhandled": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/currently-unhandled/-/currently-unhandled-0.4.1.tgz", - "integrity": "sha1-mI3zP+qxke95mmE2nddsF635V+o=", - "optional": true, - "dependencies": { - "array-find-index": "^1.0.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/cwebp-bin": { "version": "5.1.0", "resolved": "https://registry.npmjs.org/cwebp-bin/-/cwebp-bin-5.1.0.tgz", @@ -6178,7 +6142,6 @@ "version": "3.3.0", "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-3.3.0.tgz", "integrity": "sha1-gKTdMjdIOEv6JICDYirt7Jgq3/M=", - "devOptional": true, "dependencies": { "mimic-response": "^1.0.0" }, @@ -6407,8 +6370,7 @@ "node_modules/defer-to-connect": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/defer-to-connect/-/defer-to-connect-1.1.3.tgz", - "integrity": "sha512-0ISdNousHvZT2EiFlZeZAHBUvSxmKswVCEf8hW7KWgG4a8MVEu/3Vb6uWYozkjylyCxe0JBIiRB1jV45S70WVQ==", - "dev": true + "integrity": "sha512-0ISdNousHvZT2EiFlZeZAHBUvSxmKswVCEf8hW7KWgG4a8MVEu/3Vb6uWYozkjylyCxe0JBIiRB1jV45S70WVQ==" }, "node_modules/define-lazy-prop": { "version": "2.0.0", @@ -6423,7 +6385,7 @@ "version": "1.1.3", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.3.tgz", "integrity": "sha512-3MqfYKj2lLzdMSf8ZIZE/V+Zuy+BgD6f164e8K2w7dgnpKArBDerGYpM46IYYcjnkdPNMjPk9A6VFB8+3SKlXQ==", - "dev": true, + "devOptional": true, "dependencies": { "object-keys": "^1.0.12" }, @@ -6505,7 +6467,7 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==", - "dev": true + "devOptional": true }, "node_modules/detect-node-es": { "version": "1.1.0", @@ -6814,8 +6776,7 @@ "node_modules/duplexer3": { "version": "0.1.4", "resolved": "https://registry.npmjs.org/duplexer3/-/duplexer3-0.1.4.tgz", - "integrity": "sha1-7gHdHKwO08vH/b6jfcCo8c4ALOI=", - "devOptional": true + "integrity": "sha1-7gHdHKwO08vH/b6jfcCo8c4ALOI=" }, "node_modules/ecc-jsbn": { "version": "0.1.2", @@ -6837,7 +6798,6 @@ "version": "15.0.0", "resolved": "https://registry.npmjs.org/electron/-/electron-15.0.0.tgz", "integrity": "sha512-LlBjN5nCJoC7EDrgfDQwEGSGSAo/o08nSP5uJxN2m+ZtNA69SxpnWv4yPgo1K08X/iQPoGhoZu6C8LYYlk1Dtg==", - "dev": true, "hasInstallScript": true, "dependencies": { "@electron/get": "^1.13.0", @@ -7740,14 +7700,12 @@ "node_modules/electron/node_modules/@types/node": { "version": "14.17.15", "resolved": "https://registry.npmjs.org/@types/node/-/node-14.17.15.tgz", - "integrity": "sha512-D1sdW0EcSCmNdLKBGMYb38YsHUS6JcM7yQ6sLQ9KuZ35ck7LYCKE7kYFHOO59ayFOY3zobWVZxf4KXhYHcHYFA==", - "dev": true + "integrity": "sha512-D1sdW0EcSCmNdLKBGMYb38YsHUS6JcM7yQ6sLQ9KuZ35ck7LYCKE7kYFHOO59ayFOY3zobWVZxf4KXhYHcHYFA==" }, "node_modules/electron/node_modules/debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", - "dev": true, "dependencies": { "ms": "2.0.0" } @@ -7756,7 +7714,6 @@ "version": "1.7.0", "resolved": "https://registry.npmjs.org/extract-zip/-/extract-zip-1.7.0.tgz", "integrity": "sha512-xoh5G1W/PB0/27lXgMQyIhP5DSY/LhoCsOyZgb+6iMmRtCwVBo55uKaMoEYrDCKQhWvqEip5ZPKAc6eFNyf/MA==", - "dev": true, "dependencies": { "concat-stream": "^1.6.2", "debug": "^2.6.9", @@ -7771,7 +7728,6 @@ "version": "0.5.5", "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.5.tgz", "integrity": "sha512-NKmAlESf6jMGym1++R0Ra7wvhV+wFW63FaSOFPwRahvea0gMUcGUhVeAg/0BC0wiv9ih5NYPB1Wn1UEI1/L+xQ==", - "dev": true, "dependencies": { "minimist": "^1.2.5" }, @@ -7782,8 +7738,7 @@ "node_modules/electron/node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", - "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=", - "dev": true + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" }, "node_modules/emoji-regex": { "version": "8.0.0", @@ -7803,7 +7758,7 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=", - "dev": true, + "devOptional": true, "engines": { "node": ">= 0.8" } @@ -7835,7 +7790,6 @@ "version": "1.4.4", "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", - "devOptional": true, "dependencies": { "once": "^1.4.0" } @@ -7877,7 +7831,6 @@ "version": "2.2.1", "resolved": "https://registry.npmjs.org/env-paths/-/env-paths-2.2.1.tgz", "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==", - "dev": true, "engines": { "node": ">=6" } @@ -7954,7 +7907,6 @@ "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==", - "dev": true, "optional": true }, "node_modules/escalade": { @@ -9097,7 +9049,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/fd-slicer/-/fd-slicer-1.1.0.tgz", "integrity": "sha1-JcfInLH5B3+IkbvmHY85Dq4lbx4=", - "devOptional": true, "dependencies": { "pend": "~1.2.0" } @@ -9286,6 +9237,15 @@ "node": ">=6" } }, + "node_modules/find-versions/node_modules/semver-regex": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-2.0.0.tgz", + "integrity": "sha512-mUdIBBvdn0PLOeP3TEkMH7HHeUP3GjsXCwKarjv/kGmUFOYg1VqEemKhoQpWMu6X2I8kHeuVdGibLGkVK+/5Qw==", + "optional": true, + "engines": { + "node": ">=6" + } + }, "node_modules/flat-cache": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.0.4.tgz", @@ -9765,7 +9725,6 @@ "version": "5.2.0", "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.2.0.tgz", "integrity": "sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==", - "devOptional": true, "dependencies": { "pump": "^3.0.0" }, @@ -9871,7 +9830,6 @@ "version": "2.2.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-2.2.0.tgz", "integrity": "sha512-+20KpaW6DDLqhG7JDiJpD1JvNvb8ts+TNl7BPOYcURqCrXqnN1Vf+XVOrkKJAFPqfX+oEhsdzOj1hLWkBTdNJg==", - "dev": true, "optional": true, "dependencies": { "boolean": "^3.0.1", @@ -9932,7 +9890,6 @@ "version": "2.7.1", "resolved": "https://registry.npmjs.org/global-tunnel-ng/-/global-tunnel-ng-2.7.1.tgz", "integrity": "sha512-4s+DyciWBV0eK148wqXxcmVAbFVPqtc3sEtUE/GTQfuU80rySLcMhUmHKSHI7/LDj8q0gDYI1lIhRRB7ieRAqg==", - "dev": true, "optional": true, "dependencies": { "encodeurl": "^1.0.2", @@ -9956,7 +9913,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/globalthis/-/globalthis-1.0.2.tgz", "integrity": "sha512-ZQnSFO1la8P7auIOQECnm0sSuoMeaSq0EEdXMBFF2QJO4uNcwbyhSgG3MruWNbFTqCLmxVwGOl7LZ9kASvHdeQ==", - "dev": true, "optional": true, "dependencies": { "define-properties": "^1.1.3" @@ -9992,7 +9948,6 @@ "version": "9.6.0", "resolved": "https://registry.npmjs.org/got/-/got-9.6.0.tgz", "integrity": "sha512-R7eWptXuGYxwijs0eV+v3o6+XH1IqVK8dJOEecQfTmkncw9AV4dcw/Dhxi8MdlqPthxxpZyizMzyg8RTmEsG+Q==", - "dev": true, "dependencies": { "@sindresorhus/is": "^0.14.0", "@szmarczak/http-timer": "^1.1.2", @@ -10014,7 +9969,6 @@ "version": "4.1.0", "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-4.1.0.tgz", "integrity": "sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w==", - "dev": true, "dependencies": { "pump": "^3.0.0" }, @@ -10217,7 +10171,7 @@ "version": "2.8.9", "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", - "devOptional": true + "dev": true }, "node_modules/hotkeys-js": { "version": "3.8.7", @@ -10341,8 +10295,7 @@ "node_modules/http-cache-semantics": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.0.tgz", - "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==", - "dev": true + "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==" }, "node_modules/http-deceiver": { "version": "1.2.7", @@ -11594,12 +11547,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/is-utf8": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/is-utf8/-/is-utf8-0.2.1.tgz", - "integrity": "sha1-Sw2hRCEE0bM2NA6AeX6GXPOffXI=", - "optional": true - }, "node_modules/is-windows": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", @@ -11624,8 +11571,7 @@ "node_modules/isarray": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", - "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=", - "devOptional": true + "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" }, "node_modules/isbinaryfile": { "version": "3.0.3", @@ -11753,8 +11699,7 @@ "node_modules/json-buffer": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.0.tgz", - "integrity": "sha1-Wx85evx11ne96Lz8Dkfh+aPZqJg=", - "devOptional": true + "integrity": "sha1-Wx85evx11ne96Lz8Dkfh+aPZqJg=" }, "node_modules/json-parse-better-errors": { "version": "1.0.2", @@ -11787,7 +11732,7 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz", "integrity": "sha1-Epai1Y/UXxmg9s4B1lcB4sc1tus=", - "dev": true + "devOptional": true }, "node_modules/json5": { "version": "2.2.0", @@ -11856,7 +11801,6 @@ "version": "3.1.0", "resolved": "https://registry.npmjs.org/keyv/-/keyv-3.1.0.tgz", "integrity": "sha512-9ykJ/46SN/9KPM/sichzQ7OvXyGDYKGTaDlKMGCAlg2UK8KRy4jb0d8sFc+0Tt0YYnThq8X2RZgCg74RPxgcVA==", - "dev": true, "dependencies": { "json-buffer": "3.0.0" } @@ -12095,19 +12039,6 @@ "loose-envify": "cli.js" } }, - "node_modules/loud-rejection": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/loud-rejection/-/loud-rejection-1.6.0.tgz", - "integrity": "sha1-W0b4AUft7leIcPCG0Eghz5mOVR8=", - "optional": true, - "dependencies": { - "currently-unhandled": "^0.4.1", - "signal-exit": "^3.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/lower-case": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/lower-case/-/lower-case-2.0.2.tgz", @@ -12121,7 +12052,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-1.0.1.tgz", "integrity": "sha512-G2Lj61tXDnVFFOi8VZds+SoQjtQC3dgokKdDG2mTm1tx4m50NUHBOZSBwQQHyy0V12A0JTG4icfZQH+xPyh8VA==", - "devOptional": true, "engines": { "node": ">=0.10.0" } @@ -12145,40 +12075,27 @@ } }, "node_modules/lpad-align/node_modules/camelcase": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-2.1.1.tgz", - "integrity": "sha1-fB0W1nmhu+WcoCys7PsBHiAfWh8=", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-1.2.1.tgz", + "integrity": "sha1-m7UwTS4LVmmLLHWLCKPqqdqlijk=", "optional": true, "engines": { "node": ">=0.10.0" } }, "node_modules/lpad-align/node_modules/camelcase-keys": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-2.1.0.tgz", - "integrity": "sha1-MIvur/3ygRkFHvodkyITyRuPkuc=", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-1.0.0.tgz", + "integrity": "sha1-vRoRv5sxoc5JNJOpMN4aC69K1+w=", "optional": true, "dependencies": { - "camelcase": "^2.0.0", + "camelcase": "^1.0.1", "map-obj": "^1.0.0" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/find-up": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-1.1.2.tgz", - "integrity": "sha1-ay6YIrGizgpgq2TWEOzK1TyyTQ8=", - "optional": true, - "dependencies": { - "path-exists": "^2.0.0", - "pinkie-promise": "^2.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/lpad-align/node_modules/indent-string": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-2.1.0.tgz", @@ -12191,22 +12108,6 @@ "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/load-json-file": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/load-json-file/-/load-json-file-1.1.0.tgz", - "integrity": "sha1-lWkFcI1YtLq0wiYbBPWfMcmTdMA=", - "optional": true, - "dependencies": { - "graceful-fs": "^4.1.2", - "parse-json": "^2.2.0", - "pify": "^2.0.0", - "pinkie-promise": "^2.0.0", - "strip-bom": "^2.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/lpad-align/node_modules/map-obj": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/map-obj/-/map-obj-1.0.1.tgz", @@ -12217,97 +12118,57 @@ } }, "node_modules/lpad-align/node_modules/meow": { - "version": "3.7.0", - "resolved": "https://registry.npmjs.org/meow/-/meow-3.7.0.tgz", - "integrity": "sha1-cstmi0JSKCkKu/qFaJJYcwioAfs=", + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/meow/-/meow-3.3.0.tgz", + "integrity": "sha1-+Hd/0Ntn9z0d4b7uCMl8hmXvxu0=", "optional": true, "dependencies": { - "camelcase-keys": "^2.0.0", - "decamelize": "^1.1.2", - "loud-rejection": "^1.0.0", - "map-obj": "^1.0.1", - "minimist": "^1.1.3", - "normalize-package-data": "^2.3.4", - "object-assign": "^4.0.1", - "read-pkg-up": "^1.0.1", - "redent": "^1.0.0", - "trim-newlines": "^1.0.0" + "camelcase-keys": "^1.0.0", + "indent-string": "^1.1.0", + "minimist": "^1.1.0", + "object-assign": "^3.0.0" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/path-exists": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-2.1.0.tgz", - "integrity": "sha1-D+tsZPD8UY2adU3V77YscCJ2H0s=", + "node_modules/lpad-align/node_modules/meow/node_modules/indent-string": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-1.2.2.tgz", + "integrity": "sha1-25m8xYPrarux5I3LsZmamGBBy2s=", "optional": true, "dependencies": { - "pinkie-promise": "^2.0.0" + "get-stdin": "^4.0.1", + "minimist": "^1.1.0", + "repeating": "^1.1.0" }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/lpad-align/node_modules/path-type": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/path-type/-/path-type-1.1.0.tgz", - "integrity": "sha1-WcRPfuSR2nBNpBXaWkBwuk+P5EE=", - "optional": true, - "dependencies": { - "graceful-fs": "^4.1.2", - "pify": "^2.0.0", - "pinkie-promise": "^2.0.0" + "bin": { + "indent-string": "cli.js" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/pify": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", - "integrity": "sha1-7RQaasBDqEnqWISY59yosVMw6Qw=", - "optional": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/lpad-align/node_modules/read-pkg": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-1.1.0.tgz", - "integrity": "sha1-9f+qXs0pyzHAR0vKfXVra7KePyg=", + "node_modules/lpad-align/node_modules/meow/node_modules/repeating": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/repeating/-/repeating-1.1.3.tgz", + "integrity": "sha1-PUEUIYh3U3SU+X93+Xhfq4EPpKw=", "optional": true, "dependencies": { - "load-json-file": "^1.0.0", - "normalize-package-data": "^2.3.2", - "path-type": "^1.0.0" + "is-finite": "^1.0.0" }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/lpad-align/node_modules/read-pkg-up": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/read-pkg-up/-/read-pkg-up-1.0.1.tgz", - "integrity": "sha1-nWPBMnbAZZGNV/ACpX9AobZD+wI=", - "optional": true, - "dependencies": { - "find-up": "^1.0.0", - "read-pkg": "^1.0.0" + "bin": { + "repeating": "cli.js" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/redent": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/redent/-/redent-1.0.0.tgz", - "integrity": "sha1-z5Fqsf1fHxbfsggi3W7H9zDCr94=", + "node_modules/lpad-align/node_modules/object-assign": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-3.0.0.tgz", + "integrity": "sha1-m+3VygiXlJvKR+f/QIBi1Un1h/I=", "optional": true, - "dependencies": { - "indent-string": "^2.1.0", - "strip-indent": "^1.0.1" - }, "engines": { "node": ">=0.10.0" } @@ -12324,42 +12185,6 @@ "node": ">=0.10.0" } }, - "node_modules/lpad-align/node_modules/strip-bom": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-2.0.0.tgz", - "integrity": "sha1-YhmoVhZSBJHzV4i9vxRHqZx+aw4=", - "optional": true, - "dependencies": { - "is-utf8": "^0.2.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/lpad-align/node_modules/strip-indent": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-1.0.1.tgz", - "integrity": "sha1-DHlipq3vp7vUrDZkYKY4VSrhoKI=", - "optional": true, - "dependencies": { - "get-stdin": "^4.0.1" - }, - "bin": { - "strip-indent": "cli.js" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/lpad-align/node_modules/trim-newlines": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/trim-newlines/-/trim-newlines-1.0.0.tgz", - "integrity": "sha1-WIeWa7WCpFA6QetST301ARgVphM=", - "optional": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/lru-cache": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", @@ -12464,7 +12289,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==", - "dev": true, "optional": true, "dependencies": { "escape-string-regexp": "^4.0.0" @@ -12801,7 +12625,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-1.0.1.tgz", "integrity": "sha512-j5EctnkH7amfV/q5Hgmoal1g2QHFJRraOtmx0JpIqkxhBhI/lJSl1nMpQ45hVarwNETOoWEimndZ4QK0RHxuxQ==", - "devOptional": true, "engines": { "node": ">=4" } @@ -13404,7 +13227,7 @@ "version": "2.5.0", "resolved": "https://registry.npmjs.org/normalize-package-data/-/normalize-package-data-2.5.0.tgz", "integrity": "sha512-/5CMN3T0R4XTj4DcGaexo+roZSdSFW/0AOOTROrjxzCG1wrWXEsGbRKevjlIL+ZDE4sZlJr5ED4YW0yqmkK+eA==", - "devOptional": true, + "dev": true, "dependencies": { "hosted-git-info": "^2.1.4", "resolve": "^1.10.0", @@ -13416,7 +13239,7 @@ "version": "5.7.1", "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", - "devOptional": true, + "dev": true, "bin": { "semver": "bin/semver" } @@ -13434,7 +13257,6 @@ "version": "4.5.1", "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-4.5.1.tgz", "integrity": "sha512-9UZCFRHQdNrfTpGg8+1INIg93B6zE0aXMVFkw1WFwvO4SlZywU6aLg5Of0Ap/PgcbSw4LNxvMWXMeugwMCX0AA==", - "dev": true, "engines": { "node": ">=8" } @@ -13576,7 +13398,7 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", "integrity": "sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==", - "dev": true, + "devOptional": true, "engines": { "node": ">= 0.4" } @@ -13816,6 +13638,14 @@ "node": ">=0.10.0" } }, + "node_modules/os-utils": { + "version": "0.0.14", + "resolved": "https://registry.npmjs.org/os-utils/-/os-utils-0.0.14.tgz", + "integrity": "sha1-KeURaXsZgrjGJ3Ihdf45eX72QVY=", + "engines": { + "node": "*" + } + }, "node_modules/ow": { "version": "0.17.0", "resolved": "https://registry.npmjs.org/ow/-/ow-0.17.0.tgz", @@ -13847,7 +13677,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/p-cancelable/-/p-cancelable-1.1.0.tgz", "integrity": "sha512-s73XxOZ4zpt1edZYZzvhqFa6uvQc1vwUa0K0BdtIZgQMAJj9IbebH+JkgKZc9h+B05PKHLOTl4ajG1BmNrVZlw==", - "dev": true, "engines": { "node": ">=6" } @@ -14050,7 +13879,7 @@ "version": "2.2.0", "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-2.2.0.tgz", "integrity": "sha1-9ID0BDTvgHQfhGkJn43qGPVaTck=", - "devOptional": true, + "dev": true, "dependencies": { "error-ex": "^1.2.0" }, @@ -14156,8 +13985,7 @@ "node_modules/pend": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", - "integrity": "sha1-elfrVQpng/kRUzH89GY9XI4AelA=", - "devOptional": true + "integrity": "sha1-elfrVQpng/kRUzH89GY9XI4AelA=" }, "node_modules/performance-now": { "version": "2.1.0", @@ -14561,7 +14389,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/prepend-http/-/prepend-http-2.0.0.tgz", "integrity": "sha1-6SQ0v6XqjBn0HN/UAddBo8gZ2Jc=", - "devOptional": true, "engines": { "node": ">=4" } @@ -14691,14 +14518,12 @@ "node_modules/process-nextick-args": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", - "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", - "devOptional": true + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" }, "node_modules/progress": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", - "dev": true, "engines": { "node": ">=0.4.0" } @@ -14777,7 +14602,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", - "devOptional": true, "dependencies": { "end-of-stream": "^1.1.0", "once": "^1.3.1" @@ -15464,7 +15288,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/responselike/-/responselike-1.0.2.tgz", "integrity": "sha1-kYcg7ztjHFZCvgaPFa3lpG9Loec=", - "devOptional": true, "dependencies": { "lowercase-keys": "^1.0.0" } @@ -15519,7 +15342,6 @@ "version": "2.15.4", "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz", "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==", - "dev": true, "optional": true, "dependencies": { "boolean": "^3.0.1", @@ -15706,16 +15528,18 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz", "integrity": "sha1-De4hahyUGrN+nvsXiPavxf9VN/w=", - "dev": true, "optional": true }, "node_modules/semver-regex": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-2.0.0.tgz", - "integrity": "sha512-mUdIBBvdn0PLOeP3TEkMH7HHeUP3GjsXCwKarjv/kGmUFOYg1VqEemKhoQpWMu6X2I8kHeuVdGibLGkVK+/5Qw==", - "optional": true, + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-4.0.2.tgz", + "integrity": "sha512-xyuBZk1XYqQkB687hMQqrCP+J9bdJSjPpZwdmmNjyxKW1K3LDXxqxw91Egaqkh/yheBIVtKPt4/1eybKVdCx3g==", + "dev": true, "engines": { - "node": ">=6" + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/semver-truncate": { @@ -15788,7 +15612,6 @@ "version": "7.0.1", "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz", "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==", - "dev": true, "optional": true, "dependencies": { "type-fest": "^0.13.1" @@ -16173,7 +15996,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.2.tgz", "integrity": "sha512-VE0SOVEHCk7Qc8ulkWw3ntAzXuqf7S2lvwQaDLRnUeIEaKNQJzV6BwmLKhOqT61aGhfUMrXeaBk+oDGCzvhcug==", - "dev": true, "optional": true }, "node_modules/squeak": { @@ -16524,7 +16346,6 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/sumchecker/-/sumchecker-3.0.1.tgz", "integrity": "sha512-MvjXzkz/BOfyVDkG0oFOtBxHX2u3gKbMHIF/dXblZsgD3BWOFLmHovIpZY7BykJdAjcqRCBi1WYBNdEC9yI7vg==", - "dev": true, "dependencies": { "debug": "^4.1.0" }, @@ -17075,7 +16896,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/to-readable-stream/-/to-readable-stream-1.0.0.tgz", "integrity": "sha512-Iq25XBt6zD5npPhlLVXGFN3/gyR2/qODcKNNyTMd4vbm39HUaOiAM4PMq0eMVC/Tkxz+Zjdsc55g9yyz+Yq00Q==", - "dev": true, "engines": { "node": ">=6" } @@ -17194,7 +17014,6 @@ "version": "0.0.6", "resolved": "https://registry.npmjs.org/tunnel/-/tunnel-0.0.6.tgz", "integrity": "sha512-1h/Lnq9yajKY2PEbBadPXj3VxsDDu844OnaAo52UVmIzIvwwtBPIuNvkjuzBlTWpfJyUbG3ez0KSBibQkj4ojg==", - "dev": true, "optional": true, "engines": { "node": ">=0.6.11 <=0.7.0 || >=0.7.3" @@ -17244,7 +17063,6 @@ "version": "0.13.1", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==", - "dev": true, "optional": true, "engines": { "node": ">=10" @@ -17269,8 +17087,7 @@ "node_modules/typedarray": { "version": "0.0.6", "resolved": "https://registry.npmjs.org/typedarray/-/typedarray-0.0.6.tgz", - "integrity": "sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=", - "dev": true + "integrity": "sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=" }, "node_modules/unbox-primitive": { "version": "1.0.1", @@ -17361,7 +17178,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/url-parse-lax/-/url-parse-lax-3.0.0.tgz", "integrity": "sha1-FrXK/Afb42dsGxmZF3gj1lA6yww=", - "devOptional": true, "dependencies": { "prepend-http": "^2.0.0" }, @@ -17617,8 +17433,7 @@ "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=", - "devOptional": true + "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=" }, "node_modules/utila": { "version": "0.4.0", @@ -18337,7 +18152,6 @@ "version": "2.10.0", "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-2.10.0.tgz", "integrity": "sha1-x+sXyT4RLLEIb6bY5R+wZnt5pfk=", - "devOptional": true, "dependencies": { "buffer-crc32": "~0.2.3", "fd-slicer": "~1.1.0" @@ -19840,7 +19654,6 @@ "version": "1.13.0", "resolved": "https://registry.npmjs.org/@electron/get/-/get-1.13.0.tgz", "integrity": "sha512-+SjZhRuRo+STTO1Fdhzqnv9D2ZhjxXP6egsJ9kiO8dtP68cDx7dFCwWi64dlMQV7sWcfW1OYCW4wviEBzmRsfQ==", - "dev": true, "requires": { "debug": "^4.1.1", "env-paths": "^2.2.0", @@ -19857,7 +19670,6 @@ "version": "8.1.0", "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-8.1.0.tgz", "integrity": "sha512-yhlQgA6mnOJUKOsRUFsgJdQCvkKhcz8tlZG5HBQfReYZy46OwLcY+Zia0mtdHsOo9y/hP+CxMN0TU9QxoOtG4g==", - "dev": true, "requires": { "graceful-fs": "^4.2.0", "jsonfile": "^4.0.0", @@ -19868,7 +19680,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-4.0.0.tgz", "integrity": "sha1-h3Gq4HmbZAdrdmQPygWPnBDjPss=", - "dev": true, "requires": { "graceful-fs": "^4.1.6" } @@ -19876,14 +19687,12 @@ "semver": { "version": "6.3.0", "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", - "dev": true + "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==" }, "universalify": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", - "integrity": "sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg==", - "dev": true + "integrity": "sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg==" } } }, @@ -20353,8 +20162,7 @@ "@sindresorhus/is": { "version": "0.14.0", "resolved": "https://registry.npmjs.org/@sindresorhus/is/-/is-0.14.0.tgz", - "integrity": "sha512-9NET910DNaIPngYnLLPeg+Ogzqsi9uM4mSboU5y6p8S5DzMTVEsJZrawi+BoDNUVBa2DhJqQYUFvMDfgU062LQ==", - "dev": true + "integrity": "sha512-9NET910DNaIPngYnLLPeg+Ogzqsi9uM4mSboU5y6p8S5DzMTVEsJZrawi+BoDNUVBa2DhJqQYUFvMDfgU062LQ==" }, "@swiftcarrot/color-fns": { "version": "3.2.0", @@ -20368,7 +20176,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/@szmarczak/http-timer/-/http-timer-1.1.2.tgz", "integrity": "sha512-XIB2XbzHTN6ieIjfIMV9hlVcfPU26s2vafYWQcZHWXHOxiaRZYEDKEwdl129Zyg50+foYV2jCgtrqSA6qNuNSA==", - "dev": true, "requires": { "defer-to-connect": "^1.0.1" } @@ -21211,12 +21018,6 @@ "@babel/runtime-corejs3": "^7.10.2" } }, - "array-find-index": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/array-find-index/-/array-find-index-1.0.2.tgz", - "integrity": "sha1-3wEKoSh+Fku9pvlyOwqWoexBh6E=", - "optional": true - }, "array-flatten": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", @@ -22100,7 +21901,6 @@ "version": "3.1.4", "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.1.4.tgz", "integrity": "sha512-3hx0kwU3uzG6ReQ3pnaFQPSktpBw6RHN3/ivDKEuU8g1XSfafowyvDnadjv1xp8IZqhtSukxlwv9bF6FhX8m0w==", - "dev": true, "optional": true }, "brace-expansion": { @@ -22161,8 +21961,7 @@ "buffer-crc32": { "version": "0.2.13", "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-0.2.13.tgz", - "integrity": "sha1-DTM+PwDqxQqhRUq9MO+MKl2ackI=", - "devOptional": true + "integrity": "sha1-DTM+PwDqxQqhRUq9MO+MKl2ackI=" }, "buffer-fill": { "version": "1.0.0", @@ -22223,7 +22022,6 @@ "version": "6.1.0", "resolved": "https://registry.npmjs.org/cacheable-request/-/cacheable-request-6.1.0.tgz", "integrity": "sha512-Oj3cAGPCqOZX7Rz64Uny2GYAZNliQSqfbePrgAQ1wKAihYmCUnraBtJtKcGR4xz7wF+LoJC+ssFZvv5BgF9Igg==", - "dev": true, "requires": { "clone-response": "^1.0.2", "get-stream": "^5.1.0", @@ -22237,8 +22035,7 @@ "lowercase-keys": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-2.0.0.tgz", - "integrity": "sha512-tqNXrS78oMOE73NMxK4EMLQsQowWf8jKooH9g7xPavRT706R6bkQJ6DY2Te7QukaZsulxa30wQ7bk0pm4XiHmA==", - "dev": true + "integrity": "sha512-tqNXrS78oMOE73NMxK4EMLQsQowWf8jKooH9g7xPavRT706R6bkQJ6DY2Te7QukaZsulxa30wQ7bk0pm4XiHmA==" } } }, @@ -22474,7 +22271,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/clone-response/-/clone-response-1.0.2.tgz", "integrity": "sha1-0dyXOSAxTfZ/vrlCI7TuNQI56Ws=", - "devOptional": true, "requires": { "mimic-response": "^1.0.0" } @@ -22604,7 +22400,6 @@ "version": "1.6.2", "resolved": "https://registry.npmjs.org/concat-stream/-/concat-stream-1.6.2.tgz", "integrity": "sha512-27HBghJxjiZtIk3Ycvn/4kbJk/1uZuJFfuPEns6LaEvpvG1f0hTea8lilrouyo9mVc2GWdcEZ8OLoGmSADlrCw==", - "dev": true, "requires": { "buffer-from": "^1.0.0", "inherits": "^2.0.3", @@ -22616,7 +22411,6 @@ "version": "2.3.7", "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.7.tgz", "integrity": "sha512-Ebho8K4jIbHAxnuxi7o42OrZgF/ZTNcsZj6nRKyUmkhLFq8CHItp/fy6hQZuZmP/n3yZ9VBUbp4zz/mX8hmYPw==", - "dev": true, "requires": { "core-util-is": "~1.0.0", "inherits": "~2.0.3", @@ -22631,7 +22425,6 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", - "dev": true, "requires": { "safe-buffer": "~5.1.0" } @@ -22808,7 +22601,6 @@ "version": "3.17.2", "resolved": "https://registry.npmjs.org/core-js/-/core-js-3.17.2.tgz", "integrity": "sha512-XkbXqhcXeMHPRk2ItS+zQYliAMilea2euoMsnpRRdDad6b2VY6CQQcwz1K8AnWesfw4p165RzY0bTnr3UrbYiA==", - "dev": true, "optional": true }, "core-js-pure": { @@ -22821,8 +22613,7 @@ "core-util-is": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", - "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", - "devOptional": true + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==" }, "cosmiconfig": { "version": "6.0.0", @@ -22962,15 +22753,6 @@ "dev": true, "optional": true }, - "currently-unhandled": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/currently-unhandled/-/currently-unhandled-0.4.1.tgz", - "integrity": "sha1-mI3zP+qxke95mmE2nddsF635V+o=", - "optional": true, - "requires": { - "array-find-index": "^1.0.1" - } - }, "cwebp-bin": { "version": "5.1.0", "resolved": "https://registry.npmjs.org/cwebp-bin/-/cwebp-bin-5.1.0.tgz", @@ -23150,7 +22932,6 @@ "version": "3.3.0", "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-3.3.0.tgz", "integrity": "sha1-gKTdMjdIOEv6JICDYirt7Jgq3/M=", - "devOptional": true, "requires": { "mimic-response": "^1.0.0" } @@ -23309,8 +23090,7 @@ "defer-to-connect": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/defer-to-connect/-/defer-to-connect-1.1.3.tgz", - "integrity": "sha512-0ISdNousHvZT2EiFlZeZAHBUvSxmKswVCEf8hW7KWgG4a8MVEu/3Vb6uWYozkjylyCxe0JBIiRB1jV45S70WVQ==", - "dev": true + "integrity": "sha512-0ISdNousHvZT2EiFlZeZAHBUvSxmKswVCEf8hW7KWgG4a8MVEu/3Vb6uWYozkjylyCxe0JBIiRB1jV45S70WVQ==" }, "define-lazy-prop": { "version": "2.0.0", @@ -23322,7 +23102,7 @@ "version": "1.1.3", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.3.tgz", "integrity": "sha512-3MqfYKj2lLzdMSf8ZIZE/V+Zuy+BgD6f164e8K2w7dgnpKArBDerGYpM46IYYcjnkdPNMjPk9A6VFB8+3SKlXQ==", - "dev": true, + "devOptional": true, "requires": { "object-keys": "^1.0.12" } @@ -23383,7 +23163,7 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==", - "dev": true + "devOptional": true }, "detect-node-es": { "version": "1.1.0", @@ -23630,8 +23410,7 @@ "duplexer3": { "version": "0.1.4", "resolved": "https://registry.npmjs.org/duplexer3/-/duplexer3-0.1.4.tgz", - "integrity": "sha1-7gHdHKwO08vH/b6jfcCo8c4ALOI=", - "devOptional": true + "integrity": "sha1-7gHdHKwO08vH/b6jfcCo8c4ALOI=" }, "ecc-jsbn": { "version": "0.1.2", @@ -23653,7 +23432,6 @@ "version": "15.0.0", "resolved": "https://registry.npmjs.org/electron/-/electron-15.0.0.tgz", "integrity": "sha512-LlBjN5nCJoC7EDrgfDQwEGSGSAo/o08nSP5uJxN2m+ZtNA69SxpnWv4yPgo1K08X/iQPoGhoZu6C8LYYlk1Dtg==", - "dev": true, "requires": { "@electron/get": "^1.13.0", "@types/node": "^14.6.2", @@ -23663,14 +23441,12 @@ "@types/node": { "version": "14.17.15", "resolved": "https://registry.npmjs.org/@types/node/-/node-14.17.15.tgz", - "integrity": "sha512-D1sdW0EcSCmNdLKBGMYb38YsHUS6JcM7yQ6sLQ9KuZ35ck7LYCKE7kYFHOO59ayFOY3zobWVZxf4KXhYHcHYFA==", - "dev": true + "integrity": "sha512-D1sdW0EcSCmNdLKBGMYb38YsHUS6JcM7yQ6sLQ9KuZ35ck7LYCKE7kYFHOO59ayFOY3zobWVZxf4KXhYHcHYFA==" }, "debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", - "dev": true, "requires": { "ms": "2.0.0" } @@ -23679,7 +23455,6 @@ "version": "1.7.0", "resolved": "https://registry.npmjs.org/extract-zip/-/extract-zip-1.7.0.tgz", "integrity": "sha512-xoh5G1W/PB0/27lXgMQyIhP5DSY/LhoCsOyZgb+6iMmRtCwVBo55uKaMoEYrDCKQhWvqEip5ZPKAc6eFNyf/MA==", - "dev": true, "requires": { "concat-stream": "^1.6.2", "debug": "^2.6.9", @@ -23691,7 +23466,6 @@ "version": "0.5.5", "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.5.tgz", "integrity": "sha512-NKmAlESf6jMGym1++R0Ra7wvhV+wFW63FaSOFPwRahvea0gMUcGUhVeAg/0BC0wiv9ih5NYPB1Wn1UEI1/L+xQ==", - "dev": true, "requires": { "minimist": "^1.2.5" } @@ -23699,8 +23473,7 @@ "ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", - "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=", - "dev": true + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" } } }, @@ -24410,7 +24183,7 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=", - "dev": true + "devOptional": true }, "encoding": { "version": "0.1.13", @@ -24438,7 +24211,6 @@ "version": "1.4.4", "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", - "devOptional": true, "requires": { "once": "^1.4.0" } @@ -24470,8 +24242,7 @@ "env-paths": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/env-paths/-/env-paths-2.2.1.tgz", - "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==", - "dev": true + "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==" }, "err-code": { "version": "2.0.3", @@ -24533,7 +24304,6 @@ "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==", - "dev": true, "optional": true }, "escalade": { @@ -25445,7 +25215,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/fd-slicer/-/fd-slicer-1.1.0.tgz", "integrity": "sha1-JcfInLH5B3+IkbvmHY85Dq4lbx4=", - "devOptional": true, "requires": { "pend": "~1.2.0" } @@ -25584,6 +25353,14 @@ "optional": true, "requires": { "semver-regex": "^2.0.0" + }, + "dependencies": { + "semver-regex": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-2.0.0.tgz", + "integrity": "sha512-mUdIBBvdn0PLOeP3TEkMH7HHeUP3GjsXCwKarjv/kGmUFOYg1VqEemKhoQpWMu6X2I8kHeuVdGibLGkVK+/5Qw==", + "optional": true + } } }, "flat-cache": { @@ -25990,7 +25767,6 @@ "version": "5.2.0", "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.2.0.tgz", "integrity": "sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==", - "devOptional": true, "requires": { "pump": "^3.0.0" } @@ -26065,7 +25841,6 @@ "version": "2.2.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-2.2.0.tgz", "integrity": "sha512-+20KpaW6DDLqhG7JDiJpD1JvNvb8ts+TNl7BPOYcURqCrXqnN1Vf+XVOrkKJAFPqfX+oEhsdzOj1hLWkBTdNJg==", - "dev": true, "optional": true, "requires": { "boolean": "^3.0.1", @@ -26116,7 +25891,6 @@ "version": "2.7.1", "resolved": "https://registry.npmjs.org/global-tunnel-ng/-/global-tunnel-ng-2.7.1.tgz", "integrity": "sha512-4s+DyciWBV0eK148wqXxcmVAbFVPqtc3sEtUE/GTQfuU80rySLcMhUmHKSHI7/LDj8q0gDYI1lIhRRB7ieRAqg==", - "dev": true, "optional": true, "requires": { "encodeurl": "^1.0.2", @@ -26134,7 +25908,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/globalthis/-/globalthis-1.0.2.tgz", "integrity": "sha512-ZQnSFO1la8P7auIOQECnm0sSuoMeaSq0EEdXMBFF2QJO4uNcwbyhSgG3MruWNbFTqCLmxVwGOl7LZ9kASvHdeQ==", - "dev": true, "optional": true, "requires": { "define-properties": "^1.1.3" @@ -26158,7 +25931,6 @@ "version": "9.6.0", "resolved": "https://registry.npmjs.org/got/-/got-9.6.0.tgz", "integrity": "sha512-R7eWptXuGYxwijs0eV+v3o6+XH1IqVK8dJOEecQfTmkncw9AV4dcw/Dhxi8MdlqPthxxpZyizMzyg8RTmEsG+Q==", - "dev": true, "requires": { "@sindresorhus/is": "^0.14.0", "@szmarczak/http-timer": "^1.1.2", @@ -26177,7 +25949,6 @@ "version": "4.1.0", "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-4.1.0.tgz", "integrity": "sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w==", - "dev": true, "requires": { "pump": "^3.0.0" } @@ -26332,7 +26103,7 @@ "version": "2.8.9", "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", - "devOptional": true + "dev": true }, "hotkeys-js": { "version": "3.8.7", @@ -26434,8 +26205,7 @@ "http-cache-semantics": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.0.tgz", - "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==", - "dev": true + "integrity": "sha512-carPklcUh7ROWRK7Cv27RPtdhYhUsela/ue5/jKzjegVvXDqM2ILE9Q2BGn9JZJh1g87cp56su/FgQSzcWS8cQ==" }, "http-deceiver": { "version": "1.2.7", @@ -27362,12 +27132,6 @@ "integrity": "sha512-knxG2q4UC3u8stRGyAVJCOdxFmv5DZiRcdlIaAQXAbSfJya+OhopNotLQrstBhququ4ZpuKbDc/8S6mgXgPFPw==", "dev": true }, - "is-utf8": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/is-utf8/-/is-utf8-0.2.1.tgz", - "integrity": "sha1-Sw2hRCEE0bM2NA6AeX6GXPOffXI=", - "optional": true - }, "is-windows": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", @@ -27386,8 +27150,7 @@ "isarray": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", - "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=", - "devOptional": true + "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" }, "isbinaryfile": { "version": "3.0.3", @@ -27490,8 +27253,7 @@ "json-buffer": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.0.tgz", - "integrity": "sha1-Wx85evx11ne96Lz8Dkfh+aPZqJg=", - "devOptional": true + "integrity": "sha1-Wx85evx11ne96Lz8Dkfh+aPZqJg=" }, "json-parse-better-errors": { "version": "1.0.2", @@ -27524,7 +27286,7 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz", "integrity": "sha1-Epai1Y/UXxmg9s4B1lcB4sc1tus=", - "dev": true + "devOptional": true }, "json5": { "version": "2.2.0", @@ -27576,7 +27338,6 @@ "version": "3.1.0", "resolved": "https://registry.npmjs.org/keyv/-/keyv-3.1.0.tgz", "integrity": "sha512-9ykJ/46SN/9KPM/sichzQ7OvXyGDYKGTaDlKMGCAlg2UK8KRy4jb0d8sFc+0Tt0YYnThq8X2RZgCg74RPxgcVA==", - "dev": true, "requires": { "json-buffer": "3.0.0" } @@ -27777,16 +27538,6 @@ "js-tokens": "^3.0.0 || ^4.0.0" } }, - "loud-rejection": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/loud-rejection/-/loud-rejection-1.6.0.tgz", - "integrity": "sha1-W0b4AUft7leIcPCG0Eghz5mOVR8=", - "optional": true, - "requires": { - "currently-unhandled": "^0.4.1", - "signal-exit": "^3.0.0" - } - }, "lower-case": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/lower-case/-/lower-case-2.0.2.tgz", @@ -27799,8 +27550,7 @@ "lowercase-keys": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-1.0.1.tgz", - "integrity": "sha512-G2Lj61tXDnVFFOi8VZds+SoQjtQC3dgokKdDG2mTm1tx4m50NUHBOZSBwQQHyy0V12A0JTG4icfZQH+xPyh8VA==", - "devOptional": true + "integrity": "sha512-G2Lj61tXDnVFFOi8VZds+SoQjtQC3dgokKdDG2mTm1tx4m50NUHBOZSBwQQHyy0V12A0JTG4icfZQH+xPyh8VA==" }, "lpad-align": { "version": "1.1.2", @@ -27815,31 +27565,21 @@ }, "dependencies": { "camelcase": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-2.1.1.tgz", - "integrity": "sha1-fB0W1nmhu+WcoCys7PsBHiAfWh8=", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-1.2.1.tgz", + "integrity": "sha1-m7UwTS4LVmmLLHWLCKPqqdqlijk=", "optional": true }, "camelcase-keys": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-2.1.0.tgz", - "integrity": "sha1-MIvur/3ygRkFHvodkyITyRuPkuc=", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-1.0.0.tgz", + "integrity": "sha1-vRoRv5sxoc5JNJOpMN4aC69K1+w=", "optional": true, "requires": { - "camelcase": "^2.0.0", + "camelcase": "^1.0.1", "map-obj": "^1.0.0" } }, - "find-up": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-1.1.2.tgz", - "integrity": "sha1-ay6YIrGizgpgq2TWEOzK1TyyTQ8=", - "optional": true, - "requires": { - "path-exists": "^2.0.0", - "pinkie-promise": "^2.0.0" - } - }, "indent-string": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-2.1.0.tgz", @@ -27849,19 +27589,6 @@ "repeating": "^2.0.0" } }, - "load-json-file": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/load-json-file/-/load-json-file-1.1.0.tgz", - "integrity": "sha1-lWkFcI1YtLq0wiYbBPWfMcmTdMA=", - "optional": true, - "requires": { - "graceful-fs": "^4.1.2", - "parse-json": "^2.2.0", - "pify": "^2.0.0", - "pinkie-promise": "^2.0.0", - "strip-bom": "^2.0.0" - } - }, "map-obj": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/map-obj/-/map-obj-1.0.1.tgz", @@ -27869,80 +27596,45 @@ "optional": true }, "meow": { - "version": "3.7.0", - "resolved": "https://registry.npmjs.org/meow/-/meow-3.7.0.tgz", - "integrity": "sha1-cstmi0JSKCkKu/qFaJJYcwioAfs=", - "optional": true, - "requires": { - "camelcase-keys": "^2.0.0", - "decamelize": "^1.1.2", - "loud-rejection": "^1.0.0", - "map-obj": "^1.0.1", - "minimist": "^1.1.3", - "normalize-package-data": "^2.3.4", - "object-assign": "^4.0.1", - "read-pkg-up": "^1.0.1", - "redent": "^1.0.0", - "trim-newlines": "^1.0.0" - } - }, - "path-exists": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-2.1.0.tgz", - "integrity": "sha1-D+tsZPD8UY2adU3V77YscCJ2H0s=", - "optional": true, - "requires": { - "pinkie-promise": "^2.0.0" - } - }, - "path-type": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/path-type/-/path-type-1.1.0.tgz", - "integrity": "sha1-WcRPfuSR2nBNpBXaWkBwuk+P5EE=", + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/meow/-/meow-3.3.0.tgz", + "integrity": "sha1-+Hd/0Ntn9z0d4b7uCMl8hmXvxu0=", "optional": true, "requires": { - "graceful-fs": "^4.1.2", - "pify": "^2.0.0", - "pinkie-promise": "^2.0.0" + "camelcase-keys": "^1.0.0", + "indent-string": "^1.1.0", + "minimist": "^1.1.0", + "object-assign": "^3.0.0" + }, + "dependencies": { + "indent-string": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-1.2.2.tgz", + "integrity": "sha1-25m8xYPrarux5I3LsZmamGBBy2s=", + "optional": true, + "requires": { + "get-stdin": "^4.0.1", + "minimist": "^1.1.0", + "repeating": "^1.1.0" + } + }, + "repeating": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/repeating/-/repeating-1.1.3.tgz", + "integrity": "sha1-PUEUIYh3U3SU+X93+Xhfq4EPpKw=", + "optional": true, + "requires": { + "is-finite": "^1.0.0" + } + } } }, - "pify": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", - "integrity": "sha1-7RQaasBDqEnqWISY59yosVMw6Qw=", + "object-assign": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-3.0.0.tgz", + "integrity": "sha1-m+3VygiXlJvKR+f/QIBi1Un1h/I=", "optional": true }, - "read-pkg": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-1.1.0.tgz", - "integrity": "sha1-9f+qXs0pyzHAR0vKfXVra7KePyg=", - "optional": true, - "requires": { - "load-json-file": "^1.0.0", - "normalize-package-data": "^2.3.2", - "path-type": "^1.0.0" - } - }, - "read-pkg-up": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/read-pkg-up/-/read-pkg-up-1.0.1.tgz", - "integrity": "sha1-nWPBMnbAZZGNV/ACpX9AobZD+wI=", - "optional": true, - "requires": { - "find-up": "^1.0.0", - "read-pkg": "^1.0.0" - } - }, - "redent": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/redent/-/redent-1.0.0.tgz", - "integrity": "sha1-z5Fqsf1fHxbfsggi3W7H9zDCr94=", - "optional": true, - "requires": { - "indent-string": "^2.1.0", - "strip-indent": "^1.0.1" - } - }, "repeating": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/repeating/-/repeating-2.0.1.tgz", @@ -27951,30 +27643,6 @@ "requires": { "is-finite": "^1.0.0" } - }, - "strip-bom": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-2.0.0.tgz", - "integrity": "sha1-YhmoVhZSBJHzV4i9vxRHqZx+aw4=", - "optional": true, - "requires": { - "is-utf8": "^0.2.0" - } - }, - "strip-indent": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-1.0.1.tgz", - "integrity": "sha1-DHlipq3vp7vUrDZkYKY4VSrhoKI=", - "optional": true, - "requires": { - "get-stdin": "^4.0.1" - } - }, - "trim-newlines": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/trim-newlines/-/trim-newlines-1.0.0.tgz", - "integrity": "sha1-WIeWa7WCpFA6QetST301ARgVphM=", - "optional": true } } }, @@ -28053,7 +27721,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", "integrity": "sha512-OkeDaAZ/bQCxeFAozM55PKcKU0yJMPGifLwV4Qgjitu+5MoAfSQN4lsLJeXZ1b8w0x+/Emda6MZgXS1jvsapng==", - "dev": true, "optional": true, "requires": { "escape-string-regexp": "^4.0.0" @@ -28282,8 +27949,7 @@ "mimic-response": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-1.0.1.tgz", - "integrity": "sha512-j5EctnkH7amfV/q5Hgmoal1g2QHFJRraOtmx0JpIqkxhBhI/lJSl1nMpQ45hVarwNETOoWEimndZ4QK0RHxuxQ==", - "devOptional": true + "integrity": "sha512-j5EctnkH7amfV/q5Hgmoal1g2QHFJRraOtmx0JpIqkxhBhI/lJSl1nMpQ45hVarwNETOoWEimndZ4QK0RHxuxQ==" }, "min-document": { "version": "2.19.0", @@ -28791,7 +28457,7 @@ "version": "2.5.0", "resolved": "https://registry.npmjs.org/normalize-package-data/-/normalize-package-data-2.5.0.tgz", "integrity": "sha512-/5CMN3T0R4XTj4DcGaexo+roZSdSFW/0AOOTROrjxzCG1wrWXEsGbRKevjlIL+ZDE4sZlJr5ED4YW0yqmkK+eA==", - "devOptional": true, + "dev": true, "requires": { "hosted-git-info": "^2.1.4", "resolve": "^1.10.0", @@ -28803,7 +28469,7 @@ "version": "5.7.1", "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", - "devOptional": true + "dev": true } } }, @@ -28816,8 +28482,7 @@ "normalize-url": { "version": "4.5.1", "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-4.5.1.tgz", - "integrity": "sha512-9UZCFRHQdNrfTpGg8+1INIg93B6zE0aXMVFkw1WFwvO4SlZywU6aLg5Of0Ap/PgcbSw4LNxvMWXMeugwMCX0AA==", - "dev": true + "integrity": "sha512-9UZCFRHQdNrfTpGg8+1INIg93B6zE0aXMVFkw1WFwvO4SlZywU6aLg5Of0Ap/PgcbSw4LNxvMWXMeugwMCX0AA==" }, "npm-conf": { "version": "1.1.3", @@ -28928,7 +28593,7 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", "integrity": "sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==", - "dev": true + "devOptional": true }, "object.assign": { "version": "4.1.2", @@ -29100,6 +28765,11 @@ "integrity": "sha1-u+Z0BseaqFxc/sdm/lc0VV36EnQ=", "dev": true }, + "os-utils": { + "version": "0.0.14", + "resolved": "https://registry.npmjs.org/os-utils/-/os-utils-0.0.14.tgz", + "integrity": "sha1-KeURaXsZgrjGJ3Ihdf45eX72QVY=" + }, "ow": { "version": "0.17.0", "resolved": "https://registry.npmjs.org/ow/-/ow-0.17.0.tgz", @@ -29120,8 +28790,7 @@ "p-cancelable": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/p-cancelable/-/p-cancelable-1.1.0.tgz", - "integrity": "sha512-s73XxOZ4zpt1edZYZzvhqFa6uvQc1vwUa0K0BdtIZgQMAJj9IbebH+JkgKZc9h+B05PKHLOTl4ajG1BmNrVZlw==", - "dev": true + "integrity": "sha512-s73XxOZ4zpt1edZYZzvhqFa6uvQc1vwUa0K0BdtIZgQMAJj9IbebH+JkgKZc9h+B05PKHLOTl4ajG1BmNrVZlw==" }, "p-defer": { "version": "1.0.0", @@ -29266,7 +28935,7 @@ "version": "2.2.0", "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-2.2.0.tgz", "integrity": "sha1-9ID0BDTvgHQfhGkJn43qGPVaTck=", - "devOptional": true, + "dev": true, "requires": { "error-ex": "^1.2.0" } @@ -29347,8 +29016,7 @@ "pend": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", - "integrity": "sha1-elfrVQpng/kRUzH89GY9XI4AelA=", - "devOptional": true + "integrity": "sha1-elfrVQpng/kRUzH89GY9XI4AelA=" }, "performance-now": { "version": "2.1.0", @@ -29653,8 +29321,7 @@ "prepend-http": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/prepend-http/-/prepend-http-2.0.0.tgz", - "integrity": "sha1-6SQ0v6XqjBn0HN/UAddBo8gZ2Jc=", - "devOptional": true + "integrity": "sha1-6SQ0v6XqjBn0HN/UAddBo8gZ2Jc=" }, "pretty-bytes": { "version": "1.0.4", @@ -29747,14 +29414,12 @@ "process-nextick-args": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", - "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", - "devOptional": true + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" }, "progress": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", - "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", - "dev": true + "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==" }, "progress-stream": { "version": "1.2.0", @@ -29824,7 +29489,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", - "devOptional": true, "requires": { "end-of-stream": "^1.1.0", "once": "^1.3.1" @@ -30317,7 +29981,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/responselike/-/responselike-1.0.2.tgz", "integrity": "sha1-kYcg7ztjHFZCvgaPFa3lpG9Loec=", - "devOptional": true, "requires": { "lowercase-keys": "^1.0.0" } @@ -30356,7 +30019,6 @@ "version": "2.15.4", "resolved": "https://registry.npmjs.org/roarr/-/roarr-2.15.4.tgz", "integrity": "sha512-CHhPh+UNHD2GTXNYhPWLnU8ONHdI+5DI+4EYIAOaiD63rHeYlZvyh8P+in5999TTSFgUYuKUAjzRI4mdh/p+2A==", - "dev": true, "optional": true, "requires": { "boolean": "^3.0.1", @@ -30510,14 +30172,13 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/semver-compare/-/semver-compare-1.0.0.tgz", "integrity": "sha1-De4hahyUGrN+nvsXiPavxf9VN/w=", - "dev": true, "optional": true }, "semver-regex": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-2.0.0.tgz", - "integrity": "sha512-mUdIBBvdn0PLOeP3TEkMH7HHeUP3GjsXCwKarjv/kGmUFOYg1VqEemKhoQpWMu6X2I8kHeuVdGibLGkVK+/5Qw==", - "optional": true + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/semver-regex/-/semver-regex-4.0.2.tgz", + "integrity": "sha512-xyuBZk1XYqQkB687hMQqrCP+J9bdJSjPpZwdmmNjyxKW1K3LDXxqxw91Egaqkh/yheBIVtKPt4/1eybKVdCx3g==", + "dev": true }, "semver-truncate": { "version": "1.1.2", @@ -30586,7 +30247,6 @@ "version": "7.0.1", "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-7.0.1.tgz", "integrity": "sha512-8I8TjW5KMOKsZQTvoxjuSIa7foAwPWGOts+6o7sgjz41/qMD9VQHEDxi6PBvK2l0MXUmqZyNpUK+T2tQaaElvw==", - "dev": true, "optional": true, "requires": { "type-fest": "^0.13.1" @@ -30909,7 +30569,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.2.tgz", "integrity": "sha512-VE0SOVEHCk7Qc8ulkWw3ntAzXuqf7S2lvwQaDLRnUeIEaKNQJzV6BwmLKhOqT61aGhfUMrXeaBk+oDGCzvhcug==", - "dev": true, "optional": true }, "squeak": { @@ -31169,7 +30828,6 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/sumchecker/-/sumchecker-3.0.1.tgz", "integrity": "sha512-MvjXzkz/BOfyVDkG0oFOtBxHX2u3gKbMHIF/dXblZsgD3BWOFLmHovIpZY7BykJdAjcqRCBi1WYBNdEC9yI7vg==", - "dev": true, "requires": { "debug": "^4.1.0" } @@ -31605,8 +31263,7 @@ "to-readable-stream": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/to-readable-stream/-/to-readable-stream-1.0.0.tgz", - "integrity": "sha512-Iq25XBt6zD5npPhlLVXGFN3/gyR2/qODcKNNyTMd4vbm39HUaOiAM4PMq0eMVC/Tkxz+Zjdsc55g9yyz+Yq00Q==", - "dev": true + "integrity": "sha512-Iq25XBt6zD5npPhlLVXGFN3/gyR2/qODcKNNyTMd4vbm39HUaOiAM4PMq0eMVC/Tkxz+Zjdsc55g9yyz+Yq00Q==" }, "to-regex-range": { "version": "5.0.1", @@ -31699,7 +31356,6 @@ "version": "0.0.6", "resolved": "https://registry.npmjs.org/tunnel/-/tunnel-0.0.6.tgz", "integrity": "sha512-1h/Lnq9yajKY2PEbBadPXj3VxsDDu844OnaAo52UVmIzIvwwtBPIuNvkjuzBlTWpfJyUbG3ez0KSBibQkj4ojg==", - "dev": true, "optional": true }, "tunnel-agent": { @@ -31740,7 +31396,6 @@ "version": "0.13.1", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==", - "dev": true, "optional": true }, "type-is": { @@ -31756,8 +31411,7 @@ "typedarray": { "version": "0.0.6", "resolved": "https://registry.npmjs.org/typedarray/-/typedarray-0.0.6.tgz", - "integrity": "sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=", - "dev": true + "integrity": "sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=" }, "unbox-primitive": { "version": "1.0.1", @@ -31847,7 +31501,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/url-parse-lax/-/url-parse-lax-3.0.0.tgz", "integrity": "sha1-FrXK/Afb42dsGxmZF3gj1lA6yww=", - "devOptional": true, "requires": { "prepend-http": "^2.0.0" } @@ -32027,8 +31680,7 @@ "util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=", - "devOptional": true + "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=" }, "utila": { "version": "0.4.0", @@ -32557,7 +32209,6 @@ "version": "2.10.0", "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-2.10.0.tgz", "integrity": "sha1-x+sXyT4RLLEIb6bY5R+wZnt5pfk=", - "devOptional": true, "requires": { "buffer-crc32": "~0.2.3", "fd-slicer": "~1.1.0" diff --git a/frontend/package.json b/package.json similarity index 93% rename from frontend/package.json rename to package.json index fa289172f..6dd346a3c 100644 --- a/frontend/package.json +++ b/package.json @@ -1,12 +1,12 @@ { "name": "chainner", "productName": "chaiNNer", - "version": "0.0.1", - "description": "A modular interactive flowchart based image processing GUI", + "version": "0.1.0", + "description": "A flowchart based image processing GUI", "main": ".webpack/main", "scripts": { "start": "electron-forge start", - "dev": "concurrently \"electron-forge start\" \"nodemon ../backend/run.py 8000\"", + "dev": "concurrently \"nodemon ./backend/run.py 8000\" \"electron-forge start\"", "package": "electron-forge package", "make": "electron-forge make", "publish": "electron-forge publish", @@ -22,7 +22,7 @@ "forge": { "packagerConfig": { "executableName": "chainner", - "extraResource": "../backend/", + "extraResource": "./backend/", "icon": "./src/public/icons/cross_platform/icon" }, "publishers": [ @@ -42,8 +42,7 @@ { "name": "@electron-forge/maker-squirrel", "config": { - "name": "chainner", - "noMsi": false + "name": "chainner" } }, { @@ -111,6 +110,7 @@ "eslint": "^7.32.0", "eslint-config-airbnb": "^18.2.1", "node-loader": "^2.0.0", + "semver-regex": ">=3.1.3", "style-loader": "^3.2.1" }, "dependencies": { @@ -131,6 +131,7 @@ "image-webpack-loader": "^8.0.1", "lodash": "^4.17.21", "meow": "^10.1.1", + "os-utils": "^0.0.14", "portastic": "^1.0.1", "react": "^17.0.2", "react-dom": "^17.0.2", @@ -146,4 +147,4 @@ "use-http": "^1.0.26", "uuid": "^3.4.0" } -} \ No newline at end of file +} diff --git a/frontend/src/app.jsx b/src/app.jsx similarity index 100% rename from frontend/src/app.jsx rename to src/app.jsx diff --git a/frontend/src/assets/NumPy Logo.svg b/src/assets/NumPy Logo.svg similarity index 100% rename from frontend/src/assets/NumPy Logo.svg rename to src/assets/NumPy Logo.svg diff --git a/frontend/src/assets/OpenCV Logo.svg b/src/assets/OpenCV Logo.svg similarity index 100% rename from frontend/src/assets/OpenCV Logo.svg rename to src/assets/OpenCV Logo.svg diff --git a/frontend/src/assets/PyTorch Logo.svg b/src/assets/PyTorch Logo.svg similarity index 100% rename from frontend/src/assets/PyTorch Logo.svg rename to src/assets/PyTorch Logo.svg diff --git a/frontend/src/components/CustomIcons.jsx b/src/components/CustomIcons.jsx similarity index 100% rename from frontend/src/components/CustomIcons.jsx rename to src/components/CustomIcons.jsx diff --git a/frontend/src/components/DependencyManager.jsx b/src/components/DependencyManager.jsx similarity index 81% rename from frontend/src/components/DependencyManager.jsx rename to src/components/DependencyManager.jsx index b3391616b..ad1355b5a 100644 --- a/frontend/src/components/DependencyManager.jsx +++ b/src/components/DependencyManager.jsx @@ -52,17 +52,15 @@ const DependencyManager = ({ isOpen, onClose }) => { }, [isNvidiaAvailable]); useEffect(async () => { - const fullGpuInfo = await ipcRenderer.invoke('get-gpu-info'); - const gpuNames = fullGpuInfo?.controllers.map((gpu) => gpu.model); - setGpuInfo(gpuNames); - // Check if gpu string contains any nvidia-specific terms - const nvidiaGpu = gpuNames.find( - (gpu) => gpu.toLowerCase().split(' ').some( - (item) => ['nvidia', 'geforce', 'gtx', 'rtx'].includes(item), - ), - ); - setNvidiaGpuName(nvidiaGpu); - setIsNvidiaAvailable(!!nvidiaGpu); + const hasNvidia = await ipcRenderer.invoke('get-has-nvidia'); + if (hasNvidia) { + setNvidiaGpuName(await ipcRenderer.invoke('get-gpu-name')); + setIsNvidiaAvailable(await ipcRenderer.invoke('get-has-nvidia')); + } else { + const fullGpuInfo = await ipcRenderer.invoke('get-gpu-info'); + const gpuNames = fullGpuInfo?.controllers.map((gpu) => gpu.model); + setGpuInfo(gpuNames); + } }, []); useEffect(() => { @@ -72,7 +70,7 @@ const DependencyManager = ({ isOpen, onClose }) => { ...deps, pythonVersion: pKeys.version, }); - exec(`${pKeys.pip} list`, (error, stdout, stderr) => { + exec(`${pKeys.python} -m pip list`, (error, stdout, stderr) => { if (error) { setIsLoadingPipList(false); return; @@ -108,7 +106,7 @@ const DependencyManager = ({ isOpen, onClose }) => { useEffect(async () => { if (pipList && Object.keys(pipList).length) { setIsCheckingUpdates(true); - exec(`${pythonKeys.pip} list --outdated`, (error, stdout, stderr) => { + exec(`${pythonKeys.python} -m pip list --outdated`, (error, stdout, stderr) => { if (error) { console.log(error, stderr); setIsCheckingUpdates(false); @@ -131,17 +129,10 @@ const DependencyManager = ({ isOpen, onClose }) => { } }, [depChanged]); - const installPackage = (installCommand) => { + const runPipCommand = (args) => { setShellOutput(''); setIsRunningShell(true); - const args = installCommand.split(' '); - const installer = args.shift(); - let command = ''; - if (installer === 'pip') { - command = spawn(pythonKeys.pip, args); - } else { - command = spawn(installer, args); - } + const command = spawn(pythonKeys.python, ['-m', 'pip', ...args]); let outputString = ''; @@ -164,66 +155,27 @@ const DependencyManager = ({ isOpen, onClose }) => { }); }; - const updatePackage = (packageName) => { - setShellOutput(''); - setIsRunningShell(true); - const command = spawn(pythonKeys.pip, ['install', '--upgrade', packageName]); - - let outputString = ''; - - command.stdout.on('data', (data) => { - outputString += String(data); - setShellOutput(outputString); - }); - - command.stderr.on('data', (data) => { - setShellOutput(data); - }); - - command.on('error', (error) => { - setShellOutput(error); - }); + const installPackage = (installCommand) => { + const args = installCommand.split(' '); + const installer = args.shift(); + if (installer === 'pip') { + runPipCommand(args); + } + }; - command.on('close', (code) => { - console.log(`child process exited with code ${code}`); - setIsRunningShell(false); - }); + const updatePackage = (packageName) => { + runPipCommand(['install', '--upgrade', packageName]); }; const uninstallPackage = (packageName) => { const packageDep = availableDeps.find( (dep) => dep.name === packageName || dep.packageName === packageName, ); - setShellOutput(''); - setIsRunningShell(true); const args = packageDep.installCommand.split(' '); const installer = args.shift(); - let command = ''; if (installer === 'pip') { - command = spawn(pythonKeys.pip, ['uninstall', '-y', packageDep.packageName]); - } else { - return; + runPipCommand(['uninstall', '-y', packageDep.packageName]); } - - let outputString = ''; - - command.stdout.on('data', (data) => { - outputString += String(data); - setShellOutput(outputString); - }); - - command.stderr.on('data', (data) => { - setShellOutput(data); - }); - - command.on('error', (error) => { - setShellOutput(error); - }); - - command.on('close', (code) => { - console.log(`child process exited with code ${code}`); - setIsRunningShell(false); - }); }; useEffect(() => { @@ -251,19 +203,6 @@ const DependencyManager = ({ isOpen, onClose }) => { {`Python (${deps.pythonVersion})`} - {/* - - */} {isLoadingPipList ? : availableDeps.map((dep) => ( diff --git a/frontend/src/components/Header.jsx b/src/components/Header.jsx similarity index 78% rename from frontend/src/components/Header.jsx rename to src/components/Header.jsx index 308180b40..040360ece 100644 --- a/frontend/src/components/Header.jsx +++ b/src/components/Header.jsx @@ -5,19 +5,20 @@ import { import { AlertDialog, AlertDialogBody, AlertDialogCloseButton, AlertDialogContent, AlertDialogFooter, - AlertDialogHeader, AlertDialogOverlay, Box, - Button, Flex, Heading, HStack, - IconButton, Image, Menu, MenuButton, MenuItem, MenuList, - Portal, Spacer, Tag, useColorMode, useDisclosure, + AlertDialogHeader, AlertDialogOverlay, Box, Button, CircularProgress, + CircularProgressLabel, Flex, Heading, HStack, IconButton, + Image, Menu, MenuButton, MenuItem, MenuList, + Portal, Spacer, Tag, Tooltip, useColorMode, useColorModeValue, useDisclosure, } from '@chakra-ui/react'; -import { ipcRenderer } from 'electron'; +import { clipboard, ipcRenderer } from 'electron'; import React, { memo, useContext, useEffect, useState, } from 'react'; import { IoPause, IoPlay, IoStop } from 'react-icons/io5'; import useFetch from 'use-http'; import { GlobalContext } from '../helpers/GlobalNodeState.jsx'; -import useInterval from '../helpers/useInterval.js'; +import useInterval from '../helpers/hooks/useInterval.js'; +import useSystemUsage from '../helpers/hooks/useSystemUsage.js'; import logo from '../public/icons/png/256x256.png'; import DependencyManager from './DependencyManager.jsx'; import SettingsModal from './SettingsModal.jsx'; @@ -51,6 +52,7 @@ const Header = () => { const [running, setRunning] = useState(false); const { post, error, response: res } = useFetch(`http://localhost:${ipcRenderer.sendSync('get-port')}/run`, { cachePolicy: 'no-cache', + timeout: 0, }); const { post: checkPost, error: checkError, response: checkRes } = useFetch(`http://localhost:${ipcRenderer.sendSync('get-port')}/check`, { @@ -97,48 +99,50 @@ const Header = () => { const [isNvidiaAvailable, setIsNvidiaAvailable] = useState(false); const [nvidiaGpuIndex, setNvidiaGpuIndex] = useState(null); - const [vramUsage, setVramUsage] = useState(0); - const [ramUsage, setRamUsage] = useState(0); - const [cpuUsage, setCpuUsage] = useState(0); + // const [vramUsage, setVramUsage] = useState(0); + // const [ramUsage, setRamUsage] = useState(0); + // const [cpuUsage, setCpuUsage] = useState(0); const [hasCheckedGPU, setHasCheckedGPU] = useState(false); - const checkSysInfo = async () => { - const { gpu, ram, cpu } = await ipcRenderer.invoke('get-live-sys-info'); - - const vramCheck = (index) => { - const gpuInfo = gpu.controllers[index]; - const usage = Number(((gpuInfo?.memoryUsed || 0) / (gpuInfo?.memoryTotal || 1)) * 100); - setVramUsage(usage); - }; - if (!hasCheckedGPU) { - const gpuNames = gpu?.controllers.map((g) => g.model); - // Check if gpu string contains any nvidia-specific terms - const nvidiaGpu = gpuNames.find( - (g) => g.toLowerCase().split(' ').some( - (item) => ['nvidia', 'geforce', 'gtx', 'rtx'].includes(item), - ), - ); - setNvidiaGpuIndex(gpuNames.indexOf(nvidiaGpu)); - setIsNvidiaAvailable(!!nvidiaGpu); - setHasCheckedGPU(true); - if (nvidiaGpu) { - vramCheck(gpuNames.indexOf(nvidiaGpu)); - } - } - // if (isNvidiaAvailable && gpu) { - // vramCheck(nvidiaGpuIndex); - // } - // if (ram) { - // const usage = Number(((ram.used || 0) / (ram.total || 1)) * 100); - // setRamUsage(usage); - // } - // if (cpu) { - // setCpuUsage(cpu.currentLoad); - // } - }; + // const checkSysInfo = async () => { + // const { gpu, ram, cpu } = await ipcRenderer.invoke('get-live-sys-info'); - useEffect(async () => { - await checkSysInfo(); - }, []); + // const vramCheck = (index) => { + // const gpuInfo = gpu.controllers[index]; + // const usage = Number(((gpuInfo?.memoryUsed || 0) / (gpuInfo?.memoryTotal || 1)) * 100); + // setVramUsage(usage); + // }; + // if (!hasCheckedGPU) { + // const gpuNames = gpu?.controllers.map((g) => g.model); + // // Check if gpu string contains any nvidia-specific terms + // const nvidiaGpu = gpuNames.find( + // (g) => g.toLowerCase().split(' ').some( + // (item) => ['nvidia', 'geforce', 'gtx', 'rtx'].includes(item), + // ), + // ); + // setNvidiaGpuIndex(gpuNames.indexOf(nvidiaGpu)); + // setIsNvidiaAvailable(!!nvidiaGpu); + // setHasCheckedGPU(true); + // if (nvidiaGpu) { + // vramCheck(gpuNames.indexOf(nvidiaGpu)); + // } + // } + // if (isNvidiaAvailable && gpu) { + // vramCheck(nvidiaGpuIndex); + // } + // if (ram) { + // const usage = Number(((ram.used || 0) / (ram.total || 1)) * 100); + // setRamUsage(usage); + // } + // if (cpu) { + // setCpuUsage(cpu.currentLoad); + // } + // }; + + const { cpuUsage, ramUsage, vramUsage } = useSystemUsage(2500); + + // useEffect(async () => { + // await checkSysInfo(); + // }, []); // useInterval(async () => { // await checkSysInfo(); @@ -158,7 +162,7 @@ const Header = () => { if (invalidNodes.length === 0) { try { const data = convertToUsableFormat(); - const response = post({ + const response = await post({ data, isCpu, isFp16: isFp16 && !isCpu, @@ -166,12 +170,12 @@ const Header = () => { resolutionY: monitor?.resolutionY || 1080, }); console.log(response); - // if (!res.ok) { - // setErrorMessage(response.exception); - // onErrorOpen(); - // unAnimateEdges(); - // setRunning(false); - // } + if (!res.ok) { + setErrorMessage(response.exception); + onErrorOpen(); + unAnimateEdges(); + setRunning(false); + } } catch (err) { setErrorMessage(err.exception); onErrorOpen(); @@ -232,7 +236,7 @@ const Header = () => { chaiNNer - Pre-Alpha + Alpha {`v${appVersion}`} @@ -244,7 +248,7 @@ const Header = () => { - {/* + { VRAM - */} + } variant="outline" size="md"> @@ -335,6 +339,12 @@ const Header = () => { {errorMessage} + diff --git a/frontend/src/components/NodeSelectorPanel.jsx b/src/components/NodeSelectorPanel.jsx similarity index 100% rename from frontend/src/components/NodeSelectorPanel.jsx rename to src/components/NodeSelectorPanel.jsx diff --git a/frontend/src/components/ReactFlowBox.jsx b/src/components/ReactFlowBox.jsx similarity index 96% rename from frontend/src/components/ReactFlowBox.jsx rename to src/components/ReactFlowBox.jsx index f06f91e49..516082c55 100644 --- a/frontend/src/components/ReactFlowBox.jsx +++ b/src/components/ReactFlowBox.jsx @@ -3,6 +3,7 @@ import { Box, } from '@chakra-ui/react'; +import log from 'electron-log'; import React, { createContext, useCallback, useContext, } from 'react'; @@ -37,6 +38,7 @@ const ReactFlowBox = ({ }; const onDrop = (event) => { + log.info('dropped'); event.preventDefault(); const reactFlowBounds = wrapperRef.current.getBoundingClientRect(); @@ -48,6 +50,7 @@ const ReactFlowBox = ({ const category = event.dataTransfer.getData('application/reactflow/category'); const offsetX = event.dataTransfer.getData('application/reactflow/offsetX'); const offsetY = event.dataTransfer.getData('application/reactflow/offsetY'); + log.info(type, inputs, outputs, category); const position = reactFlowInstance.project({ x: event.clientX - reactFlowBounds.left - offsetX, diff --git a/frontend/src/components/SettingsModal.jsx b/src/components/SettingsModal.jsx similarity index 100% rename from frontend/src/components/SettingsModal.jsx rename to src/components/SettingsModal.jsx diff --git a/frontend/src/components/inputs/DirectoryInput.jsx b/src/components/inputs/DirectoryInput.jsx similarity index 100% rename from frontend/src/components/inputs/DirectoryInput.jsx rename to src/components/inputs/DirectoryInput.jsx diff --git a/frontend/src/components/inputs/DropdownInput.jsx b/src/components/inputs/DropDownInput.jsx similarity index 100% rename from frontend/src/components/inputs/DropdownInput.jsx rename to src/components/inputs/DropDownInput.jsx diff --git a/frontend/src/components/inputs/FileInput.jsx b/src/components/inputs/FileInput.jsx similarity index 100% rename from frontend/src/components/inputs/FileInput.jsx rename to src/components/inputs/FileInput.jsx diff --git a/frontend/src/components/inputs/GenericInput.jsx b/src/components/inputs/GenericInput.jsx similarity index 100% rename from frontend/src/components/inputs/GenericInput.jsx rename to src/components/inputs/GenericInput.jsx diff --git a/frontend/src/components/inputs/InputContainer.jsx b/src/components/inputs/InputContainer.jsx similarity index 100% rename from frontend/src/components/inputs/InputContainer.jsx rename to src/components/inputs/InputContainer.jsx diff --git a/frontend/src/components/inputs/NumberInput.jsx b/src/components/inputs/NumberInput.jsx similarity index 100% rename from frontend/src/components/inputs/NumberInput.jsx rename to src/components/inputs/NumberInput.jsx diff --git a/frontend/src/components/inputs/SliderInput.jsx b/src/components/inputs/SliderInput.jsx similarity index 72% rename from frontend/src/components/inputs/SliderInput.jsx rename to src/components/inputs/SliderInput.jsx index 75438a9d6..b49b22ea5 100644 --- a/frontend/src/components/inputs/SliderInput.jsx +++ b/src/components/inputs/SliderInput.jsx @@ -3,7 +3,8 @@ import { Slider, SliderFilledTrack, SliderThumb, SliderTrack, } from '@chakra-ui/react'; -import React, { memo, useContext } from 'react'; +import React, { memo, useContext, useState } from 'react'; +import { useDebouncedCallback } from 'use-debounce'; import getAccentColor from '../../helpers/getNodeAccentColors.js'; import { GlobalContext } from '../../helpers/GlobalNodeState.jsx'; import InputContainer from './InputContainer.jsx'; @@ -14,11 +15,15 @@ const SliderInput = memo(({ const { id } = data; const { useInputData, useNodeLock } = useContext(GlobalContext); const [input, setInput] = useInputData(id, index); + const [sliderValue, setSliderValue] = useState(input ?? def); const [isLocked] = useNodeLock(id); - const handleChange = (number) => { - setInput(number); - }; + const handleChange = useDebouncedCallback( + (number) => { + setInput(number); + }, + 1000, + ); return ( @@ -27,8 +32,8 @@ const SliderInput = memo(({ min={min} max={max} step={1} - onChange={handleChange} - value={input ?? def} + onChange={(v) => { handleChange(v); setSliderValue(v); }} + value={sliderValue ?? def} isDisabled={isLocked} > diff --git a/frontend/src/components/inputs/TextInput.jsx b/src/components/inputs/TextInput.jsx similarity index 100% rename from frontend/src/components/inputs/TextInput.jsx rename to src/components/inputs/TextInput.jsx diff --git a/frontend/src/components/inputs/previews/ImagePreview.jsx b/src/components/inputs/previews/ImagePreview.jsx similarity index 100% rename from frontend/src/components/inputs/previews/ImagePreview.jsx rename to src/components/inputs/previews/ImagePreview.jsx diff --git a/frontend/src/components/outputs/GenericOutput.jsx b/src/components/outputs/GenericOutput.jsx similarity index 100% rename from frontend/src/components/outputs/GenericOutput.jsx rename to src/components/outputs/GenericOutput.jsx diff --git a/frontend/src/components/outputs/ImageOutput.jsx b/src/components/outputs/ImageOutput.jsx similarity index 100% rename from frontend/src/components/outputs/ImageOutput.jsx rename to src/components/outputs/ImageOutput.jsx diff --git a/frontend/src/components/outputs/OutputContainer.jsx b/src/components/outputs/OutputContainer.jsx similarity index 100% rename from frontend/src/components/outputs/OutputContainer.jsx rename to src/components/outputs/OutputContainer.jsx diff --git a/src/downloads.json b/src/downloads.json new file mode 100644 index 000000000..edaea62b2 --- /dev/null +++ b/src/downloads.json @@ -0,0 +1,7 @@ +{ + "python": { + "linux": "https://github.com/indygreg/python-build-standalone/releases/download/20211017/cpython-3.9.7-x86_64-unknown-linux-gnu-install_only-20211017T1616.tar.gz", + "macos": "https://github.com/indygreg/python-build-standalone/releases/download/20211017/cpython-3.9.7-x86_64-apple-darwin-install_only-20211017T1616.tar.gz", + "windows": "https://github.com/indygreg/python-build-standalone/releases/download/20211017/cpython-3.9.7-x86_64-pc-windows-msvc-shared-install_only-20211017T1616.tar.gz" + } +} \ No newline at end of file diff --git a/frontend/src/global.css b/src/global.css similarity index 100% rename from frontend/src/global.css rename to src/global.css diff --git a/frontend/src/helpers/CustomEdge.jsx.jsx b/src/helpers/CustomEdge.jsx.jsx similarity index 100% rename from frontend/src/helpers/CustomEdge.jsx.jsx rename to src/helpers/CustomEdge.jsx.jsx diff --git a/frontend/src/helpers/GlobalNodeState.jsx b/src/helpers/GlobalNodeState.jsx similarity index 91% rename from frontend/src/helpers/GlobalNodeState.jsx rename to src/helpers/GlobalNodeState.jsx index 5286e2478..ddf51af04 100644 --- a/frontend/src/helpers/GlobalNodeState.jsx +++ b/src/helpers/GlobalNodeState.jsx @@ -10,15 +10,16 @@ import { } from 'react-flow-renderer'; import { useHotkeys } from 'react-hotkeys-hook'; import { v4 as uuidv4 } from 'uuid'; -import useLocalStorage from './useLocalStorage.js'; -import useUndoHistory from './useMultipleUndoHistory.js'; -import useSessionStorage from './useSessionStorage.js'; +import useLocalStorage from './hooks/useLocalStorage.js'; +import useUndoHistory from './hooks/useMultipleUndoHistory.js'; +import useSessionStorage from './hooks/useSessionStorage.js'; +import { migrate } from './migrations.js'; export const GlobalContext = createContext({}); const createUniqueId = () => uuidv4(); -export const GlobalProvider = ({ children }) => { +export const GlobalProvider = ({ children, nodeTypes }) => { const [nodes, setNodes] = useState([]); const [edges, setEdges] = useState([]); const [reactFlowInstance, setReactFlowInstance] = useState(null); @@ -39,15 +40,26 @@ export const GlobalProvider = ({ children }) => { const { transform } = useZoomPanHelper(); - const dumpStateToJSON = () => { - const output = JSON.stringify(reactFlowInstanceRfi); + const dumpStateToJSON = async () => { + const output = JSON.stringify({ + version: await ipcRenderer.invoke('get-app-version'), + content: reactFlowInstanceRfi, + timestamp: new Date(), + }); return output; }; const setStateFromJSON = (savedData, loadPosition = false) => { if (savedData) { - setNodes(savedData.elements.filter((element) => isNode(element)) || []); - setEdges(savedData.elements.filter((element) => isEdge(element)) || []); + const nodeTypesArr = Object.keys(nodeTypes); + const validNodes = savedData.elements.filter( + (element) => isNode(element) && nodeTypesArr.includes(element.type), + ) || []; + setNodes(validNodes); + setEdges(savedData.elements.filter((element) => isEdge(element) && ( + validNodes.some((el) => el.id === element.target) + && validNodes.some((el) => el.id === element.source) + )) || []); if (loadPosition) { const [x = 0, y = 0] = savedData.position; transform({ x, y, zoom: savedData.zoom || 0 }); @@ -83,7 +95,7 @@ export const GlobalProvider = ({ children }) => { }; const performSave = useCallback(async () => { - const json = dumpStateToJSON(); + const json = await dumpStateToJSON(); if (savePath) { ipcRenderer.invoke('file-save-json', json, savePath); } else { @@ -112,8 +124,15 @@ export const GlobalProvider = ({ children }) => { if (!loadedFromCli) { const contents = await ipcRenderer.invoke('get-cli-open'); if (contents) { - setStateFromJSON(contents, true); - setLoadedFromCli(true); + const { version, content } = contents; + if (version) { + const upgraded = migrate(version, content); + setStateFromJSON(upgraded, true); + } else { + // Legacy files + const upgraded = migrate(null, content); + setStateFromJSON(upgraded, true); + } } } }, []); @@ -121,8 +140,16 @@ export const GlobalProvider = ({ children }) => { // Register Open File event handler useEffect(() => { ipcRenderer.on('file-open', (event, json, openedFilePath) => { + const { version, content } = json; setSavePath(openedFilePath); - setStateFromJSON(json, true); + if (version) { + const upgraded = migrate(version, content); + setStateFromJSON(upgraded, true); + } else { + // Legacy files + const upgraded = migrate(null, json); + setStateFromJSON(upgraded, true); + } }); return () => { @@ -133,7 +160,7 @@ export const GlobalProvider = ({ children }) => { // Register Save/Save-As event handlers useEffect(() => { ipcRenderer.on('file-save-as', async () => { - const json = dumpStateToJSON(); + const json = await dumpStateToJSON(); const savedAsPath = await ipcRenderer.invoke('file-save-as-json', json, savePath); setSavePath(savedAsPath); }); diff --git a/frontend/src/helpers/createNodeTypes.jsx b/src/helpers/createNodeTypes.jsx similarity index 91% rename from frontend/src/helpers/createNodeTypes.jsx rename to src/helpers/createNodeTypes.jsx index 47a141cd7..abf08002c 100644 --- a/frontend/src/helpers/createNodeTypes.jsx +++ b/src/helpers/createNodeTypes.jsx @@ -12,7 +12,7 @@ import React, { memo, useContext } from 'react'; import { MdMoreHoriz } from 'react-icons/md'; import { IconFactory } from '../components/CustomIcons.jsx'; import DirectoryInput from '../components/inputs/DirectoryInput.jsx'; -import DropDownInput from '../components/inputs/DropdownInput.jsx'; +import DropDownInput from '../components/inputs/DropDownInput.jsx'; import FileInput from '../components/inputs/FileInput.jsx'; import GenericInput from '../components/inputs/GenericInput.jsx'; import NumberInput from '../components/inputs/NumberInput.jsx'; @@ -263,39 +263,6 @@ const UsableNode = ({ data, selected }) => ( ); -// export const createUsableNode = (category, node) => { -// const id = createUniqueId(); -// return ( -// -// -// - -// -// INPUTS -// -// {createUsableInputs(category, node, id)} - -// -// OUTPUTS -// -// {createUsableOutputs(category, node, id)} - -// -// -// -// ); -// }; - -// function RepresentativeNode({ data }) { -// return ( -// -// -// -// -// -// ); -// } - export const createRepresentativeNode = (category, node) => ( diff --git a/frontend/src/helpers/dependencies.js b/src/helpers/dependencies.js similarity index 100% rename from frontend/src/helpers/dependencies.js rename to src/helpers/dependencies.js diff --git a/frontend/src/helpers/getNodeAccentColors.js b/src/helpers/getNodeAccentColors.js similarity index 100% rename from frontend/src/helpers/getNodeAccentColors.js rename to src/helpers/getNodeAccentColors.js diff --git a/frontend/src/helpers/useInterval.js b/src/helpers/hooks/useInterval.js similarity index 100% rename from frontend/src/helpers/useInterval.js rename to src/helpers/hooks/useInterval.js diff --git a/frontend/src/helpers/useLocalStorage.js b/src/helpers/hooks/useLocalStorage.js similarity index 100% rename from frontend/src/helpers/useLocalStorage.js rename to src/helpers/hooks/useLocalStorage.js diff --git a/frontend/src/helpers/useMultipleUndoHistory.js b/src/helpers/hooks/useMultipleUndoHistory.js similarity index 100% rename from frontend/src/helpers/useMultipleUndoHistory.js rename to src/helpers/hooks/useMultipleUndoHistory.js diff --git a/frontend/src/helpers/usePrevious.js b/src/helpers/hooks/usePrevious.js similarity index 100% rename from frontend/src/helpers/usePrevious.js rename to src/helpers/hooks/usePrevious.js diff --git a/frontend/src/helpers/useSessionStorage.js b/src/helpers/hooks/useSessionStorage.js similarity index 100% rename from frontend/src/helpers/useSessionStorage.js rename to src/helpers/hooks/useSessionStorage.js diff --git a/src/helpers/hooks/useSystemUsage.js b/src/helpers/hooks/useSystemUsage.js new file mode 100644 index 000000000..2eecd24e7 --- /dev/null +++ b/src/helpers/hooks/useSystemUsage.js @@ -0,0 +1,39 @@ +import { ipcRenderer } from 'electron'; +import os from 'os-utils'; +import { useEffect, useMemo, useState } from 'react'; +import useInterval from './useInterval'; + +const useSystemUsage = (delay) => { + const [cpuUsage, setCpuUsage] = useState(0); + const [ramUsage, setRamUsage] = useState(0); + const [vramUsage, setVramUsage] = useState(0); + + useEffect(async () => { + // We set this up on mount, letting the main process handle it + // By doing it this way we avoid spawning multiple smi shells + await ipcRenderer.invoke('setup-vram-checker-process', delay); + }, []); + + useInterval(async () => { + // RAM + const totalMem = os.totalmem(); + const usedMem = os.freemem(); + const ramPercent = Number((usedMem / totalMem) * 100).toFixed(1); + setRamUsage(ramPercent); + + // CPU + os.cpuUsage((value) => { + setCpuUsage(value * 100); + }); + + // GPU/VRAM + const vramPercent = await ipcRenderer.invoke('get-vram-usage'); + setVramUsage(vramPercent); + }, delay); + + return useMemo(() => ({ + cpuUsage, ramUsage, vramUsage, + }), [cpuUsage, ramUsage, vramUsage]); +}; + +export default useSystemUsage; diff --git a/frontend/src/helpers/useUndoHistory.js b/src/helpers/hooks/useUndoHistory.js similarity index 100% rename from frontend/src/helpers/useUndoHistory.js rename to src/helpers/hooks/useUndoHistory.js diff --git a/src/helpers/migrations.js b/src/helpers/migrations.js new file mode 100644 index 000000000..335154636 --- /dev/null +++ b/src/helpers/migrations.js @@ -0,0 +1,30 @@ +/* eslint-disable import/prefer-default-export */ +import semver from 'semver'; + +const preAlpha = (data) => { + const newData = { ...data }; + const newElements = newData.elements.map((element) => { + const newElement = { ...element }; + if (newElement.type === 'ESRGAN::Load') { + newElement.type = 'Model::AutoLoad'; + newElement.data.type = 'Model::AutoLoad'; + } else if (newElement.type === 'ESRGAN::Run') { + newElement.type = 'Image::Upscale'; + newElement.data.type = 'Image::Upscale'; + } + return newElement; + }); + newData.elements = newElements; + return newData; +}; + +export const migrate = (version, data) => { + let convertedData = data; + + // Legacy files + if (!version || semver.lt(version, '0.1.0')) { + convertedData = preAlpha(convertedData); + } + + return convertedData; +}; diff --git a/src/helpers/nvidiaSmi.js b/src/helpers/nvidiaSmi.js new file mode 100644 index 000000000..67708f26e --- /dev/null +++ b/src/helpers/nvidiaSmi.js @@ -0,0 +1,49 @@ +// Borrowed and modified from https://github.com/sebhildebrandt/systeminformation/blob/master/lib/graphics.js + +// TODO: Convert this to a useNvidiaSmi hook +// Could auto-check gpu before letting you run whatever command? + +// Actually, should probably get the nvidia-smi path in the main process and use ipc to grab it. +// either that, or call this from the global state and pass the path/keyword into the hook. +// If getting in main, do it during splash screen dep check. +// Then in stuff like the dependency manager i can just use ipc to get the gpu name and isNvidia + +import fs from 'fs'; +import os from 'os'; + +let nvidiaSmiPath; + +// Best approximation of what drive windows is installed on +const homePath = os.homedir(); +const WINDIR = homePath ? `${homePath.charAt(0)}:\\Windows` : 'C:\\Windows'; + +export const getNvidiaSmi = () => { + if (nvidiaSmiPath) { + return nvidiaSmiPath; + } + + if (os.platform() === 'win32') { + try { + const basePath = `${WINDIR}\\System32\\DriverStore\\FileRepository`; + // find all directories that have an nvidia-smi.exe file + const candidateDirs = fs.readdirSync(basePath).filter((dir) => fs.readdirSync([basePath, dir].join('/')).includes('nvidia-smi.exe')); + // use the directory with the most recently created nvidia-smi.exe file + const targetDir = candidateDirs.reduce((prevDir, currentDir) => { + const previousNvidiaSmi = fs.statSync([basePath, prevDir, 'nvidia-smi.exe'].join('/')); + const currentNvidiaSmi = fs.statSync([basePath, currentDir, 'nvidia-smi.exe'].join('/')); + return (previousNvidiaSmi.ctimeMs > currentNvidiaSmi.ctimeMs) ? prevDir : currentDir; + }); + + if (targetDir) { + nvidiaSmiPath = [basePath, targetDir, 'nvidia-smi.exe'].join('/'); + } + } catch (e) { + // idk + } + } else if (os.platform() === 'linux') { + nvidiaSmiPath = 'nvidia-smi'; + } + return nvidiaSmiPath; +}; + +export const getSmiQuery = (delay) => `-lms ${delay} --query-gpu=name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory --format=csv,noheader,nounits`; diff --git a/frontend/src/helpers/shadeColor.js b/src/helpers/shadeColor.js similarity index 100% rename from frontend/src/helpers/shadeColor.js rename to src/helpers/shadeColor.js diff --git a/frontend/src/index.html b/src/index.html similarity index 100% rename from frontend/src/index.html rename to src/index.html diff --git a/frontend/src/index.jsx b/src/index.jsx similarity index 100% rename from frontend/src/index.jsx rename to src/index.jsx diff --git a/frontend/src/main.js b/src/main.js similarity index 83% rename from frontend/src/main.js rename to src/main.js index 8b898c2ef..7ae187f96 100644 --- a/frontend/src/main.js +++ b/src/main.js @@ -1,15 +1,15 @@ -import { execSync, spawn, spawnSync } from 'child_process'; +import { execSync, spawn } from 'child_process'; import { app, BrowserWindow, dialog, ipcMain, Menu, shell, } from 'electron'; import log from 'electron-log'; import { readFile, writeFile } from 'fs/promises'; -import hasbin from 'hasbin'; // import { readdir } from 'fs/promises'; import path from 'path'; import portastic from 'portastic'; import semver from 'semver'; import { currentLoad, graphics, mem } from 'systeminformation'; +import { getNvidiaSmi } from './helpers/nvidiaSmi'; // log.transports.file.resolvePath = () => path.join(app.getAppPath(), 'logs/main.log'); // eslint-disable-next-line max-len @@ -25,7 +25,6 @@ let port = 8000; const pythonKeys = { python: 'python', - pip: 'pip', }; // Handle creating/removing shortcuts on Windows when installing/uninstalling. @@ -109,7 +108,7 @@ const registerEventHandlers = () => { ipcMain.handle('get-app-version', async () => app.getVersion()); }; -const getValidPort = async () => { +const getValidPort = async (splashWindow) => { log.info('Attempting to check for a port...'); const ports = await portastic.find({ min: 8000, @@ -117,7 +116,7 @@ const getValidPort = async () => { }); if (!ports || ports.length === 0) { log.warn('An open port could not be found'); - splash.hide(); + splashWindow.hide(); const messageBoxOptions = { type: 'error', title: 'No open port', @@ -134,51 +133,41 @@ const getValidPort = async () => { }); }; -const checkPythonVersion = (pythonBin) => { - const { stdout } = spawnSync(pythonBin, ['--version'], { - stdio: 'pipe', - encoding: 'utf-8', - }); - log.info(`Python version (raw): ${stdout}`); - const { version: pythonVersion } = semver.coerce(stdout); - log.info(`Python version (semver): ${pythonVersion}`); - const hasValidPythonVersion = semver.gt(pythonVersion, '3.7.0') && semver.lt(pythonVersion, '3.10.0'); - return { pythonVersion, hasValidPythonVersion }; +const getPythonVersion = (pythonBin) => { + try { + const stdout = execSync(`${pythonBin} --version`).toString(); + log.info(`Python version (raw): ${stdout}`); + const { version } = semver.coerce(stdout); + log.info(`Python version (semver): ${version}`); + return version; + } catch (error) { + return null; + } }; -const checkPythonEnv = async () => { - log.info('Attempting to check Python env...'); +const checkPythonVersion = (version) => semver.gt(version, '3.7.0') && semver.lt(version, '3.10.0'); - // Check first for standard 'python' keyword - let pythonBin = hasbin.sync('python') ? 'python' : null; - let pipBin = hasbin.sync('pip') ? 'pip' : null; - let validPythonVersion = null; - log.info(`(Hasbin) Python binary: ${pythonBin}`, `Pip binary: ${pipBin}`); - if (pythonBin) { - const { pythonVersion, hasValidPythonVersion } = checkPythonVersion(pythonBin); - if (pythonVersion && hasValidPythonVersion) { - validPythonVersion = pythonVersion; - } - } +const checkPythonEnv = async (splashWindow) => { + log.info('Attempting to check Python env...'); - // If 'python' not available or not right version, check 'python3' - if (!pythonBin || !validPythonVersion) { - pythonBin = hasbin.sync('python3') ? 'python3' : null; - pipBin = hasbin.sync('pip3') ? 'pip3' : null; - log.info(`(Hasbin) Python3 binary: ${pythonBin}`, `Pip3 binary: ${pipBin}`); - if (pythonBin) { - const { pythonVersion, hasValidPythonVersion } = checkPythonVersion(pythonBin); - if (pythonVersion && hasValidPythonVersion) { - validPythonVersion = pythonVersion; - } - } + const pythonVersion = getPythonVersion('python'); + const python3Version = getPythonVersion('python3'); + let validPythonVersion; + let pythonBin; + + if (pythonVersion && checkPythonVersion(pythonVersion)) { + validPythonVersion = pythonVersion; + pythonBin = 'python'; + } else if (python3Version && checkPythonVersion(python3Version)) { + validPythonVersion = python3Version; + pythonBin = 'python3'; } - log.info(`Final Python binary: ${pythonBin}`, `Final Pip binary: ${pipBin}`); + log.info(`Final Python binary: ${pythonBin}`); if (!pythonBin) { log.warn('Python binary not found'); - splash.hide(); + splashWindow.hide(); const messageBoxOptions = { type: 'error', title: 'Python not installed', @@ -197,13 +186,12 @@ const checkPythonEnv = async () => { if (pythonBin) { pythonKeys.python = pythonBin; - pythonKeys.pip = pipBin; pythonKeys.version = validPythonVersion; log.info({ pythonKeys }); } if (!validPythonVersion) { - splash.hide(); + splashWindow.hide(); const messageBoxOptions = { type: 'error', title: 'Python version invalid', @@ -226,26 +214,56 @@ const checkPythonEnv = async () => { }); }; -const checkPythonDeps = async () => { +const checkPythonDeps = async (splashWindow) => { log.info('Attempting to check Python deps...'); try { - let pipList = execSync(`${pythonKeys.pip} list`); + let pipList = execSync(`${pythonKeys.python} -m pip list`); pipList = String(pipList).split('\n').map((pkg) => pkg.replace(/\s+/g, ' ').split(' ')); const hasSanic = pipList.some((pkg) => pkg[0] === 'sanic'); const hasSanicCors = pipList.some((pkg) => pkg[0] === 'Sanic-Cors'); if (!hasSanic || !hasSanicCors) { log.info('Sanic not found. Installing sanic...'); - splash.webContents.send('installing-deps'); - execSync(`${pythonKeys.pip} install sanic Sanic-Cors`); + splashWindow.webContents.send('installing-deps'); + execSync(`${pythonKeys.python} -m pip install sanic Sanic-Cors`); } } catch (error) { console.error(error); } }; +const checkNvidiaSmi = async () => { + const nvidiaSmi = getNvidiaSmi(); + ipcMain.handle('get-smi', () => nvidiaSmi); + if (nvidiaSmi) { + const [gpu] = execSync(`${nvidiaSmi} --query-gpu=name --format=csv,noheader,nounits ${process.platform === 'linux' ? ' 2>/dev/null' : ''}`).toString().split('\n'); + ipcMain.handle('get-has-nvidia', () => true); + ipcMain.handle('get-gpu-name', () => gpu.trim()); + let vramChecker; + ipcMain.handle('setup-vram-checker-process', (event, delay) => { + if (!vramChecker) { + vramChecker = spawn(nvidiaSmi, `-lms ${delay} --query-gpu=name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory --format=csv,noheader,nounits`.split(' ')); + } + + vramChecker.stdout.on('data', (data) => { + ipcMain.removeHandler('get-vram-usage'); + ipcMain.handle('get-vram-usage', () => { + const [, vramTotal, vramUsed] = String(data).split('\n')[0].split(', '); + const usage = (Number(vramUsed) / Number(vramTotal)) * 100; + return usage; + }); + }); + }); + } else { + ipcMain.handle('get-has-nvidia', () => false); + ipcMain.handle('get-gpu-name', () => null); + ipcMain.handle('setup-vram-checker-process', () => null); + ipcMain.handle('get-vram-usage', () => null); + } +}; + const spawnBackend = async () => { log.info('Attempting to spawn backend...'); - const backendPath = app.isPackaged ? path.join(process.resourcesPath, 'backend', 'run.py') : '../backend/run.py'; + const backendPath = app.isPackaged ? path.join(process.resourcesPath, 'backend', 'run.py') : './backend/run.py'; const backend = spawn(pythonKeys.python, [backendPath, port]); backend.stdout.on('data', (data) => { @@ -313,6 +331,7 @@ const doSplashScreenChecks = async () => new Promise((resolve) => { splash.webContents.send('checking-deps'); await checkPythonDeps(splash); + await checkNvidiaSmi(); splash.webContents.send('spawning-backend'); await spawnBackend(); diff --git a/frontend/src/pages/main.jsx b/src/pages/main.jsx similarity index 100% rename from frontend/src/pages/main.jsx rename to src/pages/main.jsx diff --git a/src/public/chaiNNer screenshot.png b/src/public/chaiNNer screenshot.png new file mode 100644 index 000000000..f6c38233b Binary files /dev/null and b/src/public/chaiNNer screenshot.png differ diff --git a/frontend/src/public/icons/cross_platform/icon.icns b/src/public/icons/cross_platform/icon.icns similarity index 100% rename from frontend/src/public/icons/cross_platform/icon.icns rename to src/public/icons/cross_platform/icon.icns diff --git a/frontend/src/public/icons/cross_platform/icon.ico b/src/public/icons/cross_platform/icon.ico similarity index 100% rename from frontend/src/public/icons/cross_platform/icon.ico rename to src/public/icons/cross_platform/icon.ico diff --git a/frontend/src/public/icons/cross_platform/icon.png b/src/public/icons/cross_platform/icon.png similarity index 100% rename from frontend/src/public/icons/cross_platform/icon.png rename to src/public/icons/cross_platform/icon.png diff --git a/frontend/src/public/icons/mac/icon.icns b/src/public/icons/mac/icon.icns similarity index 100% rename from frontend/src/public/icons/mac/icon.icns rename to src/public/icons/mac/icon.icns diff --git a/frontend/src/public/icons/png/1024x1024.png b/src/public/icons/png/1024x1024.png similarity index 100% rename from frontend/src/public/icons/png/1024x1024.png rename to src/public/icons/png/1024x1024.png diff --git a/frontend/src/public/icons/png/128x128.png b/src/public/icons/png/128x128.png similarity index 100% rename from frontend/src/public/icons/png/128x128.png rename to src/public/icons/png/128x128.png diff --git a/frontend/src/public/icons/png/16x16.png b/src/public/icons/png/16x16.png similarity index 100% rename from frontend/src/public/icons/png/16x16.png rename to src/public/icons/png/16x16.png diff --git a/frontend/src/public/icons/png/24x24.png b/src/public/icons/png/24x24.png similarity index 100% rename from frontend/src/public/icons/png/24x24.png rename to src/public/icons/png/24x24.png diff --git a/frontend/src/public/icons/png/256x256.png b/src/public/icons/png/256x256.png similarity index 100% rename from frontend/src/public/icons/png/256x256.png rename to src/public/icons/png/256x256.png diff --git a/frontend/src/public/icons/png/32x32.png b/src/public/icons/png/32x32.png similarity index 100% rename from frontend/src/public/icons/png/32x32.png rename to src/public/icons/png/32x32.png diff --git a/frontend/src/public/icons/png/48x48.png b/src/public/icons/png/48x48.png similarity index 100% rename from frontend/src/public/icons/png/48x48.png rename to src/public/icons/png/48x48.png diff --git a/frontend/src/public/icons/png/512x512.png b/src/public/icons/png/512x512.png similarity index 100% rename from frontend/src/public/icons/png/512x512.png rename to src/public/icons/png/512x512.png diff --git a/frontend/src/public/icons/png/64x64.png b/src/public/icons/png/64x64.png similarity index 100% rename from frontend/src/public/icons/png/64x64.png rename to src/public/icons/png/64x64.png diff --git a/frontend/src/public/icons/win/icon.ico b/src/public/icons/win/icon.ico similarity index 100% rename from frontend/src/public/icons/win/icon.ico rename to src/public/icons/win/icon.ico diff --git a/frontend/src/renderer.js b/src/renderer.js similarity index 100% rename from frontend/src/renderer.js rename to src/renderer.js diff --git a/frontend/src/splash.html b/src/splash.html similarity index 100% rename from frontend/src/splash.html rename to src/splash.html diff --git a/frontend/src/splash.jsx b/src/splash.jsx similarity index 100% rename from frontend/src/splash.jsx rename to src/splash.jsx index ba9318d2a..a6d1b6aff 100644 --- a/frontend/src/splash.jsx +++ b/src/splash.jsx @@ -8,8 +8,6 @@ import ReactDOM from 'react-dom'; // eslint-disable-next-line import/extensions import './global.css'; -ReactDOM.render(, document.getElementById('root')); - const Splash = () => { const [status, setStatus] = useState('Loading...'); @@ -58,4 +56,6 @@ const Splash = () => { ); }; +ReactDOM.render(, document.getElementById('root')); + export default Splash; diff --git a/frontend/src/splash_renderer.js b/src/splash_renderer.js similarity index 100% rename from frontend/src/splash_renderer.js rename to src/splash_renderer.js diff --git a/frontend/src/theme.js b/src/theme.js similarity index 100% rename from frontend/src/theme.js rename to src/theme.js diff --git a/frontend/webpack.main.config.js b/webpack.main.config.js similarity index 100% rename from frontend/webpack.main.config.js rename to webpack.main.config.js diff --git a/frontend/webpack.renderer.config.js b/webpack.renderer.config.js similarity index 100% rename from frontend/webpack.renderer.config.js rename to webpack.renderer.config.js diff --git a/frontend/webpack.rules.js b/webpack.rules.js similarity index 100% rename from frontend/webpack.rules.js rename to webpack.rules.js