diff --git a/backend/nodes/image_nodes.py b/backend/nodes/image_nodes.py index 1e974df6b..55c20001f 100644 --- a/backend/nodes/image_nodes.py +++ b/backend/nodes/image_nodes.py @@ -30,6 +30,7 @@ def __init__(self): # IntegerOutput("Height"), # IntegerOutput("Width"), # IntegerOutput("Channels"), + TextOutput("Image Name"), ] self.icon = "BsFillImageFill" self.sub = "Input & Output" @@ -54,14 +55,16 @@ def run(self, path: str) -> np.ndarray: try: dtype_max = np.iinfo(img.dtype).max except: - logger.info("img dtype is not an int") + logger.debug("img dtype is not an int") img = img.astype("float32") / dtype_max h, w = img.shape[:2] c = img.shape[2] if img.ndim > 2 else 1 - return img, h, w, c + # return img, h, w, c + basename = os.path.splitext(os.path.basename(path))[0] + return img, basename @NodeFactory.register("Image", "Save Image") @@ -248,6 +251,92 @@ def run( return result +@NodeFactory.register("Image (Utility)", "Overlay Images") +class ImOverlay(NodeBase): + """OpenCV transparency overlay node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Overlay transparent images on base image." + self.inputs = [ + ImageInput("Base"), + ImageInput("Overlay A"), + SliderInput("Opacity A", default=50, min=1, max=99), + ImageInput("Overlay B ", optional=True), + SliderInput("Opacity B", default=50, min=1, max=99, optional=True), + ] + self.outputs = [ImageOutput()] + self.icon = "BsLayersHalf" + self.sub = "Miscellaneous" + + def run( + self, + base: np.ndarray = None, + ov1: np.ndarray = None, + op1: int = 50, + ov2: np.ndarray = None, + op2: int = 50, + ) -> np.ndarray: + """Overlay transparent images on base image""" + + # Convert to 0.0-1.0 range + op1 = int(op1) / 100 + op2 = int(op2) / 100 + + imgs = [] + max_h, max_w, max_c = 0, 0, 1 + for img in base, ov1, ov2: + if img is not None and type(img) not in (int, str): + h, w = img.shape[:2] + if img.ndim == 2: # len(img.shape) needs to be 3, grayscale len only 2 + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + c = img.shape[2] + max_h = max(h, max_h) + max_w = max(w, max_w) + max_c = max(c, max_c) + imgs.append(img) + else: + assert ( + base.shape[0] >= max_h and base.shape[1] >= max_w + ), "Base must be largest image." + + # Expand channels if necessary + channel_fixed_imgs = [] + for img in imgs: + c = img.shape[2] + fixed_img = img + if c < max_c: + h, w = img.shape[:2] + temp_img = np.ones((h, w, max_c)) + temp_img[:, :, :c] = fixed_img + fixed_img = temp_img + channel_fixed_imgs.append(fixed_img.astype("float32")) + imgout = channel_fixed_imgs[0] + imgs = channel_fixed_imgs[1:] + + center_x = imgout.shape[1] // 2 + center_y = imgout.shape[0] // 2 + for img, op in zip(imgs, (op1, op2)): + h, w = img.shape[:2] + + # Center overlay + x_offset = center_x - (w // 2) + y_offset = center_y - (h // 2) + + cv2.addWeighted( + imgout[y_offset : y_offset + h, x_offset : x_offset + w], + 1 - op, + img, + op, + 0, + img, + ) + imgout[y_offset : y_offset + h, x_offset : x_offset + w] = img + + return imgout + + @NodeFactory.register("Image (Utility)", "Change Colorspace") class ColorConvertNode(NodeBase): """OpenCV color conversion node""" @@ -445,7 +534,9 @@ def run( for img in im1, im2, im3, im4: if img is not None and type(img) != str: h, w = img.shape[:2] - c = img.shape[2] or 1 + if img.ndim == 2: # len(img.shape) needs to be 3, grayscale len only 2 + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + c = img.shape[2] max_h = max(h, max_h) max_w = max(w, max_w) max_c = max(c, max_c) @@ -551,7 +642,7 @@ class BlurNode(NodeBase): def __init__(self): """Constructor""" super().__init__() - self.description = "Blur an image" + self.description = "Apply blur to an image" self.inputs = [ ImageInput(), IntegerInput("Amount X"), @@ -564,20 +655,77 @@ def __init__(self): def run( self, img: np.ndarray, - amountX: int, - amountY: int, + amount_x: int, + amount_y: int, # sigma: int, ) -> np.ndarray: """Adjusts the blur of an image""" - # ksize=(math.floor(int(amountX)/2)*2+1,math.floor(int(amountY)/2)*2+1) - # img=cv2.GaussianBlur(img,ksize,int(sigma)) - ksize = (int(amountX), int(amountY)) + ksize = (int(amount_x), int(amount_y)) for __i in range(16): img = cv2.blur(img, ksize) return img +@NodeFactory.register("Image (Effect)", "Gaussian Blur") +class GaussianBlurNode(NodeBase): + """OpenCV Gaussian Blur Node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Apply Gaussian Blur to an image" + self.inputs = [ + ImageInput(), + IntegerInput("Amount X"), + IntegerInput("Amount Y"), + ] + self.outputs = [ImageOutput()] + self.icon = "MdBlurOn" + self.sub = "Adjustment" + + def run( + self, + img: np.ndarray, + amount_x: str, + amount_y: str, + ) -> np.ndarray: + """Adjusts the sharpening of an image""" + blurred = cv2.GaussianBlur( + img, (0, 0), sigmaX=float(amount_x), sigmaY=float(amount_y) + ) + + return blurred + + +@NodeFactory.register("Image (Effect)", "Sharpen") +class SharpenNode(NodeBase): + """OpenCV Sharpen Node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Apply sharpening to an image" + self.inputs = [ + ImageInput(), + IntegerInput("Amount"), + ] + self.outputs = [ImageOutput()] + self.icon = "MdBlurOff" + self.sub = "Adjustment" + + def run( + self, + img: np.ndarray, + amount: int, + ) -> np.ndarray: + """Adjusts the sharpening of an image""" + blurred = cv2.GaussianBlur(img, (0, 0), float(amount)) + img = cv2.addWeighted(img, 2.0, blurred, -1.0, 0) + + return img + + @NodeFactory.register("Image (Effect)", "Shift") class ShiftNode(NodeBase): """OpenCV Shift Node""" @@ -598,24 +746,24 @@ def __init__(self): def run( self, img: np.ndarray, - amountX: int, - amountY: int, + amount_x: int, + amount_y: int, ) -> np.ndarray: """Adjusts the position of an image""" num_rows, num_cols = img.shape[:2] - translation_matrix = np.float32([[1, 0, amountX], [0, 1, amountY]]) + translation_matrix = np.float32([[1, 0, amount_x], [0, 1, amount_y]]) img = cv2.warpAffine(img, translation_matrix, (num_cols, num_rows)) return img @NodeFactory.register("Image (Utility)", "Split Channels") class ChannelSplitRGBANode(NodeBase): - """NumPy Splitter node""" + """NumPy Splitter node""" def __init__(self): """Constructor""" super().__init__() - self.description = "Split numpy image channels into separate channels. Typically used for splitting off an alpha (transparency) layer." + self.description = "Split image channels into separate channels. Typically used for splitting off an alpha (transparency) layer." self.inputs = [ImageInput()] self.outputs = [ ImageOutput("Blue Channel"), @@ -625,7 +773,7 @@ def __init__(self): ] self.icon = "MdCallSplit" - self.sub = "Miscellaneous" + self.sub = "Splitting & Merging" def run(self, img: np.ndarray) -> np.ndarray: """Split a multi-channel image into separate channels""" @@ -651,6 +799,41 @@ def run(self, img: np.ndarray) -> np.ndarray: return out +@NodeFactory.register("Image (Utility)", "Split Transparency") +class TransparencySplitNode(NodeBase): + """Transparency-specific Splitter node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = ( + "Split image channels into RGB and Alpha (transparency) channels." + ) + self.inputs = [ImageInput()] + self.outputs = [ + ImageOutput("RGB Channels"), + ImageOutput("Alpha Channel"), + ] + + self.icon = "MdCallSplit" + self.sub = "Splitting & Merging" + + def run(self, img: np.ndarray) -> np.ndarray: + """Split a multi-channel image into separate channels""" + if img.ndim == 2: + logger.debug("Expanding image channels") + img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(4, 3))) + # Pad with solid alpha channel if needed (i.e three channel image) + elif img.shape[2] == 3: + logger.debug("Expanding image channels") + img = np.dstack((img, np.full(img.shape[:-1], 1.0))) + + rgb = img[:, :, :3] + alpha = img[:, :, 3] + + return rgb, alpha + + @NodeFactory.register("Image (Utility)", "Merge Channels") class ChannelMergeRGBANode(NodeBase): """NumPy Merger node""" @@ -658,7 +841,7 @@ class ChannelMergeRGBANode(NodeBase): def __init__(self): """Constructor""" super().__init__() - self.description = "Merge numpy channels together into a <= 4 channel image. Typically used for combining an image with an alpha layer." + self.description = "Merge image channels together into a <= 4 channel image. Typically used for combining an image with an alpha layer." self.inputs = [ ImageInput("Channel(s) A"), ImageInput("Channel(s) B", optional=True), @@ -668,7 +851,7 @@ def __init__(self): self.outputs = [ImageOutput()] self.icon = "MdCallMerge" - self.sub = "Miscellaneous" + self.sub = "Splitting & Merging" def run( self, @@ -710,6 +893,54 @@ def run( return img +@NodeFactory.register("Image (Utility)", "Merge Transparency") +class TransparencyMergeNode(NodeBase): + """Transparency-specific Merge node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Merge RGB and Alpha (transparency) image channels into 4-channel RGBA channels." + self.inputs = [ImageInput("RGB Channels"), ImageInput("Alpha Channel")] + self.outputs = [ImageOutput()] + + self.icon = "MdCallMerge" + self.sub = "Splitting & Merging" + + def run(self, rgb: np.ndarray, a: np.ndarray) -> np.ndarray: + """Combine separate channels into a multi-chanel image""" + + start_shape = rgb.shape[:2] + logger.info(start_shape) + + for im in rgb, a: + if im is not None: + logger.info(im.shape[:2]) + assert ( + im.shape[:2] == start_shape + ), "All images to be merged must be the same resolution" + + if rgb.ndim == 2: + rgb = cv2.merge((rgb, rgb, rgb)) + elif rgb.ndim > 2 and rgb.shape[2] == 2: + rgb = cv2.merge( + (rgb, np.zeros((rgb.shape[0], rgb.shape[1], 1), dtype=rgb.dtype)) + ) + elif rgb.shape[2] > 3: + rgb = rgb[:, :, :3] + + if a.ndim > 2: + a = a[:, :, 0] + a = np.expand_dims(a, axis=2) + + imgs = [rgb, a] + for img in imgs: + logger.info(img.shape) + img = np.concatenate(imgs, axis=2) + + return img + + @NodeFactory.register("Image (Utility)", "Crop (Offsets)") class CropNode(NodeBase): """NumPy Crop node""" @@ -819,3 +1050,53 @@ def run( result = img[top : h - bottom, left : w - right] return result + + +@NodeFactory.register("Image (Utility)", "Add Caption") +class CaptionNode(NodeBase): + """Caption node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Add a caption to an image." + self.inputs = [ + ImageInput(), + TextInput("Caption"), + ] + self.outputs = [ImageOutput()] + + self.icon = "MdVideoLabel" + self.sub = "Miscellaneous" + + def run(self, img: np.ndarray, caption: str) -> np.ndarray: + """Add caption an image""" + + font = cv2.FONT_HERSHEY_SIMPLEX + font_size = 1 + font_thickness = 1 + + textsize = cv2.getTextSize(caption, font, font_size, font_thickness) + logger.info(textsize) + textsize = textsize[0] + + caption_height = textsize[1] + 20 + + img = cv2.copyMakeBorder( + img, 0, caption_height, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0, 255) + ) + + text_x = math.floor((img.shape[1] - textsize[0]) / 2) + text_y = math.ceil(img.shape[0] - ((caption_height - textsize[1]) / 2)) + + cv2.putText( + img, + caption, + (text_x, text_y), + font, + font_size, + color=(255, 255, 255, 255), + thickness=font_thickness, + lineType=cv2.LINE_AA, + ) + return img diff --git a/backend/nodes/ncnn_nodes.py b/backend/nodes/ncnn_nodes.py index 56bcb29d9..4c02b7f1e 100644 --- a/backend/nodes/ncnn_nodes.py +++ b/backend/nodes/ncnn_nodes.py @@ -95,44 +95,94 @@ def __init__(self): self.icon = "NCNN" self.sub = "NCNN" - def run(self, net_tuple: tuple, img: np.ndarray) -> np.ndarray: + def upscale(self, img: np.ndarray, net: tuple, input_name: str, output_name: str): dtype_max = 1 try: dtype_max = np.iinfo(img.dtype).max except: - logger.info("img dtype is not an int") + logger.debug("img dtype is not an int") img = (img.astype("float32") / dtype_max * 255).astype( np.uint8 ) # don't ask lol - # ncnn only supports 3 apparently - in_nc = 3 - gray = False - if img.ndim == 2: - gray = True - logger.warn("Expanding image channels") - img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) - # Remove extra channels if too many (i.e three channel image, single channel model) - elif img.shape[2] > in_nc: - logger.warn("Truncating image channels") - img = img[:, :, :in_nc] - # Pad with solid alpha channel if needed (i.e three channel image, four channel model) - elif img.shape[2] == 3 and in_nc == 4: - logger.warn("Expanding image channels") - img = np.dstack((img, np.full(img.shape[:-1], 1.0))) - - param_path, bin_path, input_name, output_name, net = net_tuple - # Try/except block to catch errors try: + vkdev = ncnn.get_gpu_device(0) + blob_vkallocator = ncnn.VkBlobAllocator(vkdev) + staging_vkallocator = ncnn.VkStagingAllocator(vkdev) output, _ = ncnn_auto_split_process( - img, net, input_name=input_name, output_name=output_name + img, + net, + input_name=input_name, + output_name=output_name, + blob_vkallocator=blob_vkallocator, + staging_vkallocator=staging_vkallocator, ) + # blob_vkallocator.clear() # this slows stuff down + # staging_vkallocator.clear() # as does this # net.clear() # don't do this, it makes chaining break - if gray: - output = np.average(output, axis=2) - return np.clip(output.astype(np.float32) / 255, 0, 1) + return output except Exception as e: logger.error(e) raise RuntimeError("An unexpected error occurred during NCNN processing.") + + def run(self, net_tuple: tuple, img: np.ndarray) -> np.ndarray: + + h, w = img.shape[:2] + c = img.shape[2] if len(img.shape) > 2 else 1 + + param_path, bin_path, input_name, output_name, net = net_tuple + + # ncnn only supports 3 apparently + in_nc = 3 + + # TODO: This can prob just be a shared function tbh + # Transparency hack (white/black background difference alpha) + if in_nc == 3 and c == 4: + # Ignore single-color alpha + unique = np.unique(img[:, :, 3]) + if len(unique) == 1: + logger.info("Single color alpha channel, ignoring.") + output = self.upscale(img[:, :, :3], net, input_name, output_name) + output = np.dstack( + (output, np.full(output.shape[:-1], (unique[0] * 255))) + ) + output = np.clip(output.astype(np.float32) / 255, 0, 1) + else: + img1 = np.copy(img[:, :, :3]) + img2 = np.copy(img[:, :, :3]) + for c in range(3): + img1[:, :, c] *= img[:, :, 3] + img2[:, :, c] = (img2[:, :, c] - 1) * img[:, :, 3] + 1 + + output1 = self.upscale(img1, net, input_name, output_name) + output2 = self.upscale(img2, net, input_name, output_name) + output1 = np.clip(output1.astype(np.float32) / 255, 0, 1) + output2 = np.clip(output2.astype(np.float32) / 255, 0, 1) + alpha = 1 - np.mean(output2 - output1, axis=2) + output = np.dstack((output1, alpha)) + else: + gray = False + if img.ndim == 2: + gray = True + logger.debug("Expanding image channels") + img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) + # Remove extra channels if too many (i.e three channel image, single channel model) + elif img.shape[2] > in_nc: + logger.warn("Truncating image channels") + img = img[:, :, :in_nc] + # Pad with solid alpha channel if needed (i.e three channel image, four channel model) + elif img.shape[2] == 3 and in_nc == 4: + logger.debug("Expanding image channels") + img = np.dstack((img, np.full(img.shape[:-1], 1.0))) + output = self.upscale(img, net, input_name, output_name) + + if gray: + output = np.average(output, axis=2) + + output = output.astype(np.float32) / 255 + + output = np.clip(output, 0, 1) + + return output diff --git a/backend/nodes/properties/inputs/generic_inputs.py b/backend/nodes/properties/inputs/generic_inputs.py index e1b2d549f..4cae4079c 100644 --- a/backend/nodes/properties/inputs/generic_inputs.py +++ b/backend/nodes/properties/inputs/generic_inputs.py @@ -1,21 +1,26 @@ from typing import Dict, List -def DropDownInput(input_type: str, label: str, options: List[str]) -> Dict: +def DropDownInput( + input_type: str, label: str, options: List[str], optional: bool = False +) -> Dict: """Input for a dropdown""" return { "type": f"dropdown::{input_type}", "label": label, "options": options, + "optional": optional, } -def TextInput(label: str) -> Dict: +def TextInput(label: str, has_handle=True, max_length=None, optional=False) -> Dict: """Input for arbitrary text""" return { "type": "text::any", "label": label, - "hasHandle": True, + "hasHandle": has_handle, + "maxLength": max_length, + "optional": optional, } @@ -54,6 +59,25 @@ def OddIntegerInput(label: str) -> Dict: } +def BoundedIntegerInput( + label: str, + minimum: int = 0, + maximum: int = 100, + default: int = 50, + optional: bool = False, +) -> Dict: + """Bounded input for integer number""" + return { + "type": "number::integer", + "label": label, + "min": minimum, + "max": maximum, + "def": default, + "hasHandle": True, + "optional": optional, + } + + def BoundlessIntegerInput(label: str) -> Dict: """Input for integer number""" return { @@ -66,7 +90,9 @@ def BoundlessIntegerInput(label: str) -> Dict: } -def SliderInput(label: str, min: int, max: int, default: int) -> Dict: +def SliderInput( + label: str, min: int, max: int, default: int, optional: bool = False +) -> Dict: """Input for integer number via slider""" return { "type": "number::slider", @@ -74,6 +100,7 @@ def SliderInput(label: str, min: int, max: int, default: int) -> Dict: "min": min, "max": max, "def": default, + "optional": optional, } @@ -133,4 +160,5 @@ def StackOrientationDropdown() -> Dict: "value": "vertical", }, ], + optional=True, ) diff --git a/backend/nodes/properties/outputs/generic_outputs.py b/backend/nodes/properties/outputs/generic_outputs.py index 24bf7582b..fe74826d8 100644 --- a/backend/nodes/properties/outputs/generic_outputs.py +++ b/backend/nodes/properties/outputs/generic_outputs.py @@ -15,3 +15,12 @@ def IntegerOutput(label: str) -> Dict: "type": "number::integer", "label": label, } + + +def TextOutput(label: str) -> Dict: + """Output for arbitrary text""" + return { + "type": "text::any", + "label": label, + "hasHandle": True, + } diff --git a/backend/nodes/pytorch_nodes.py b/backend/nodes/pytorch_nodes.py index 852e364e6..ebca40412 100644 --- a/backend/nodes/pytorch_nodes.py +++ b/backend/nodes/pytorch_nodes.py @@ -63,7 +63,7 @@ def __init__(self): super().__init__() self.description = "Load PyTorch state dict file (.pth) into an auto-detected supported model architecture. Supports most variations of the RRDB architecture (ESRGAN, Real-ESRGAN, RealSR, BSRGAN, SPSR) and Real-ESRGAN's SRVGG architecture." self.inputs = [PthFileInput()] - self.outputs = [ModelOutput()] + self.outputs = [ModelOutput(), TextOutput("Model Name")] self.icon = "PyTorch" self.sub = "Input & Output" @@ -82,7 +82,9 @@ def run(self, path: str) -> Any: model.eval() model = model.to(torch.device(os.environ["device"])) - return model + basename = os.path.splitext(os.path.basename(path))[0] + + return model, basename @NodeFactory.register("PyTorch", "Upscale Image") @@ -100,9 +102,32 @@ def __init__(self): self.icon = "PyTorch" self.sub = "Processing" + def upscale(self, img: np.ndarray, model: torch.nn.Module, scale: int): + # Borrowed from iNNfer + logger.info("Converting image to tensor") + img_tensor = np2tensor(img, change_range=True) + if os.environ["isFp16"] == "True": + model = model.half() + logger.info("Upscaling image") + t_out, _ = auto_split_process( + img_tensor, + model, + scale, + ) + del img_tensor, model + logger.info("Converting tensor to image") + img_out = tensor2np(t_out.detach(), change_range=False, imtype=np.float32) + logger.info("Done upscaling") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + del t_out + return img_out + def run(self, model: torch.nn.Module, img: np.ndarray) -> np.ndarray: """Upscales an image with a pretrained model""" + torch.load + check_env() logger.info(f"Upscaling image...") @@ -128,46 +153,49 @@ def run(self, model: torch.nn.Module, img: np.ndarray) -> np.ndarray: # The frontend should type-validate this enough where it shouldn't be needed, # But I want to be extra safe - # # Add extra channels if not enough (i.e single channel img, three channel model) - gray = False - if img.ndim == 2: - gray = True - logger.warn("Expanding image channels") - img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) - # Remove extra channels if too many (i.e three channel image, single channel model) - elif img.shape[2] > in_nc: - logger.warn("Truncating image channels") - img = img[:, :, :in_nc] - # Pad with solid alpha channel if needed (i.e three channel image, four channel model) - elif img.shape[2] == 3 and in_nc == 4: - logger.warn("Expanding image channels") - img = np.dstack((img, np.full(img.shape[:-1], 1.0))) + # Transparency hack (white/black background difference alpha) + if in_nc == 3 and c == 4: + # Ignore single-color alpha + unique = np.unique(img[:, :, 3]) + if len(unique) == 1: + logger.info("Single color alpha channel, ignoring.") + output = self.upscale(img[:, :, :3], model, model.scale) + output = np.dstack((output, np.full(output.shape[:-1], unique[0]))) + else: + img1 = np.copy(img[:, :, :3]) + img2 = np.copy(img[:, :, :3]) + for c in range(3): + img1[:, :, c] *= img[:, :, 3] + img2[:, :, c] = (img2[:, :, c] - 1) * img[:, :, 3] + 1 + + output1 = self.upscale(img1, model, model.scale) + output2 = self.upscale(img2, model, model.scale) + alpha = 1 - np.mean(output2 - output1, axis=2) + output = np.dstack((output1, alpha)) + else: + # # Add extra channels if not enough (i.e single channel img, three channel model) + gray = False + if img.ndim == 2: + gray = True + logger.debug("Expanding image channels") + img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3))) + # Remove extra channels if too many (i.e three channel image, single channel model) + elif img.shape[2] > in_nc: + logger.warn("Truncating image channels") + img = img[:, :, :in_nc] + # Pad with solid alpha channel if needed (i.e three channel image, four channel model) + elif img.shape[2] == 3 and in_nc == 4: + logger.debug("Expanding image channels") + img = np.dstack((img, np.full(img.shape[:-1], 1.0))) - # Borrowed from iNNfer - logger.info("Converting image to tensor") - img_tensor = np2tensor(img, change_range=True) - if os.environ["isFp16"] == "True": - model = model.half() - logger.info("Upscaling image") - t_out, _ = auto_split_process( - img_tensor, - model, - scale, - ) - del img_tensor, model - logger.info("Converting tensor to image") - img_out = tensor2np(t_out.detach(), change_range=False, imtype=np.float32) - logger.info("Done upscaling") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - del t_out + output = self.upscale(img, model, model.scale) - if gray: - img_out = np.average(img_out, axis=2).astype("float32") + if gray: + output = np.average(output, axis=2).astype("float32") - img_out = np.clip(img_out, 0, 1) + output = np.clip(output, 0, 1) - return img_out + return output @NodeFactory.register("PyTorch", "Interpolate Models") @@ -247,17 +275,17 @@ def __init__(self): """Constructor""" super().__init__() self.description = "Save a PyTorch model to specified directory." - self.inputs = [StateDictInput(), DirectoryInput(), TextInput("Model Name")] + self.inputs = [ModelInput(), DirectoryInput(), TextInput("Model Name")] self.outputs = [] self.icon = "PyTorch" self.sub = "Input & Output" - def run(self, model: OrderedDict(), directory: str, name: str) -> bool: + def run(self, model: torch.nn.Module, directory: str, name: str) -> bool: fullFile = f"{name}.pth" fullPath = os.path.join(directory, fullFile) logger.info(f"Writing model to path: {fullPath}") - status = torch.save(model, fullPath) + status = torch.save(model.state, fullPath) return status diff --git a/backend/nodes/utility_nodes.py b/backend/nodes/utility_nodes.py index d735f2c17..78179d1bc 100644 --- a/backend/nodes/utility_nodes.py +++ b/backend/nodes/utility_nodes.py @@ -57,3 +57,34 @@ def run(self, text: str) -> None: # return in1 / in2 # elif op == "pow": # return in1 ** in2 + + +@NodeFactory.register("Utility", "Text Append") +class TextAppendNode(NodeBase): + """Text Append node""" + + def __init__(self): + """Constructor""" + super().__init__() + self.description = "Perform mathematical operations on numbers." + self.inputs = [ + TextInput("Separator", has_handle=False, max_length=3), + TextInput("Text A"), + TextInput("Text B", optional=True), + TextInput("Text C", optional=True), + TextInput("Text D", optional=True), + ] + self.outputs = [TextOutput("Output Text")] + self.icon = "MdTextFields" + self.sub = "Text" + + def run( + self, + separator: str, + str1: str, + str2: str = None, + str3: str = None, + str4: str = None, + ) -> int: + strings = [x for x in [str1, str2, str3, str4] if x != "" and x is not None] + return separator.join(strings) diff --git a/backend/nodes/utils/ncnn_auto_split.py b/backend/nodes/utils/ncnn_auto_split.py index a3563c546..f9b7a3126 100644 --- a/backend/nodes/utils/ncnn_auto_split.py +++ b/backend/nodes/utils/ncnn_auto_split.py @@ -12,7 +12,7 @@ def fix_dtype_range(img): try: dtype_max = np.iinfo(img.dtype).max except: - logger.info("img dtype is not an int") + logger.debug("img dtype is not an int") img = (np.clip(img.astype("float32") / dtype_max, 0, 1) * 255).astype(np.uint8) return img @@ -27,6 +27,8 @@ def ncnn_auto_split_process( current_depth: int = 1, input_name: str = "data", output_name: str = "output", + blob_vkallocator=None, + staging_vkallocator=None, ) -> Tuple[Tensor, int]: # Original code: https://github.com/JoeyBallentine/ESRGAN/blob/master/utils/dataops.py @@ -38,9 +40,6 @@ def ncnn_auto_split_process( # Attempt to upscale if unknown depth or if reached known max depth if max_depth is None or max_depth == current_depth: ex = net.create_extractor() - vkdev = ncnn.get_gpu_device(0) - blob_vkallocator = ncnn.VkBlobAllocator(vkdev) - staging_vkallocator = ncnn.VkStagingAllocator(vkdev) ex.set_blob_vkallocator(blob_vkallocator) ex.set_workspace_vkallocator(blob_vkallocator) ex.set_staging_vkallocator(staging_vkallocator) @@ -59,9 +58,9 @@ def ncnn_auto_split_process( _, mat_out = ex.extract(output_name) result = fix_dtype_range(np.array(mat_out).transpose(1, 2, 0)) del ex, mat_in, mat_out - # Clear VRAM - blob_vkallocator.clear() - staging_vkallocator.clear() + # # Clear VRAM + # blob_vkallocator.clear() + # staging_vkallocator.clear() return result, current_depth except Exception as e: # Check to see if its actually the NCNN out of memory error @@ -69,7 +68,7 @@ def ncnn_auto_split_process( # clear VRAM blob_vkallocator.clear() staging_vkallocator.clear() - del ex, vkdev + del ex gc.collect() # Re-raise the exception if not an OOM error else: @@ -94,6 +93,8 @@ def ncnn_auto_split_process( net, overlap=overlap, current_depth=current_depth + 1, + blob_vkallocator=blob_vkallocator, + staging_vkallocator=staging_vkallocator, ) top_right_rlt, _ = ncnn_auto_split_process( top_right, @@ -101,6 +102,8 @@ def ncnn_auto_split_process( overlap=overlap, max_depth=depth, current_depth=current_depth + 1, + blob_vkallocator=blob_vkallocator, + staging_vkallocator=staging_vkallocator, ) bottom_left_rlt, _ = ncnn_auto_split_process( bottom_left, @@ -108,6 +111,8 @@ def ncnn_auto_split_process( overlap=overlap, max_depth=depth, current_depth=current_depth + 1, + blob_vkallocator=blob_vkallocator, + staging_vkallocator=staging_vkallocator, ) bottom_right_rlt, _ = ncnn_auto_split_process( bottom_right, @@ -115,6 +120,8 @@ def ncnn_auto_split_process( overlap=overlap, max_depth=depth, current_depth=current_depth + 1, + blob_vkallocator=blob_vkallocator, + staging_vkallocator=staging_vkallocator, ) tl_h, _ = top_left.shape[:2] diff --git a/package-lock.json b/package-lock.json index 0127dff37..61deebfe5 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "chainner", - "version": "0.3.1", + "version": "0.4.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "chainner", - "version": "0.3.1", + "version": "0.4.0", "license": "GPLv3", "dependencies": { "@chakra-ui/icons": "^1.0.15", @@ -48,9 +48,11 @@ "@electron-forge/maker-zip": "^6.0.0-beta.60", "@electron-forge/plugin-webpack": "^6.0.0-beta.60", "@electron-forge/publisher-github": "*", + "@pmmmwh/react-refresh-webpack-plugin": "^0.5.4", "@vercel/webpack-asset-relocator-loader": "^1.7.0", "babel-loader": "^8.2.2", "concurrently": "^6.5.1", + "cross-env": "^7.0.3", "css-loader": "^6.2.0", "electron": "^17.0.0", "electron-installer-common": "^0.10.3", @@ -59,6 +61,7 @@ "file-loader": "^6.2.0", "image-webpack-loader": "^8.1.0", "node-loader": "^2.0.0", + "react-refresh": "^0.11.0", "semver-regex": ">=3.1.3", "style-loader": "^3.2.1" } @@ -3207,6 +3210,97 @@ "@octokit/openapi-types": "^11.2.0" } }, + "node_modules/@pmmmwh/react-refresh-webpack-plugin": { + "version": "0.5.4", + "resolved": "https://registry.npmjs.org/@pmmmwh/react-refresh-webpack-plugin/-/react-refresh-webpack-plugin-0.5.4.tgz", + "integrity": "sha512-zZbZeHQDnoTlt2AF+diQT0wsSXpvWiaIOZwBRdltNFhG1+I3ozyaw7U/nBiUwyJ0D+zwdXp0E3bWOl38Ag2BMw==", + "dev": true, + "dependencies": { + "ansi-html-community": "^0.0.8", + "common-path-prefix": "^3.0.0", + "core-js-pure": "^3.8.1", + "error-stack-parser": "^2.0.6", + "find-up": "^5.0.0", + "html-entities": "^2.1.0", + "loader-utils": "^2.0.0", + "schema-utils": "^3.0.0", + "source-map": "^0.7.3" + }, + "engines": { + "node": ">= 10.13" + }, + "peerDependencies": { + "@types/webpack": "4.x || 5.x", + "react-refresh": ">=0.10.0 <1.0.0", + "sockjs-client": "^1.4.0", + "type-fest": ">=0.17.0 <3.0.0", + "webpack": ">=4.43.0 <6.0.0", + "webpack-dev-server": "3.x || 4.x", + "webpack-hot-middleware": "2.x", + "webpack-plugin-serve": "0.x || 1.x" + }, + "peerDependenciesMeta": { + "@types/webpack": { + "optional": true + }, + "sockjs-client": { + "optional": true + }, + "type-fest": { + "optional": true + }, + "webpack-dev-server": { + "optional": true + }, + "webpack-hot-middleware": { + "optional": true + }, + "webpack-plugin-serve": { + "optional": true + } + } + }, + "node_modules/@pmmmwh/react-refresh-webpack-plugin/node_modules/loader-utils": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.2.tgz", + "integrity": "sha512-TM57VeHptv569d/GKh6TAYdzKblwDNiumOdkFnejjD0XwTH87K90w3O7AiJRqdQoXygvi1VQTJTLGhJl7WqA7A==", + "dev": true, + "dependencies": { + "big.js": "^5.2.2", + "emojis-list": "^3.0.0", + "json5": "^2.1.2" + }, + "engines": { + "node": ">=8.9.0" + } + }, + "node_modules/@pmmmwh/react-refresh-webpack-plugin/node_modules/schema-utils": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.1.1.tgz", + "integrity": "sha512-Y5PQxS4ITlC+EahLuXaY86TXfR7Dc5lw294alXOq86JAHCihAIZfqv8nNCWvaEJvaC51uN9hbLGeV0cFBdH+Fw==", + "dev": true, + "dependencies": { + "@types/json-schema": "^7.0.8", + "ajv": "^6.12.5", + "ajv-keywords": "^3.5.2" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/@pmmmwh/react-refresh-webpack-plugin/node_modules/source-map": { + "version": "0.7.3", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.3.tgz", + "integrity": "sha512-CkCj6giN3S+n9qrYiBTX5gystlENnRW5jZeNLHpe6aue+SrHcG5VYwujhW9s4dY31mEGsxBDrHR6oI69fTXsaQ==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, "node_modules/@popperjs/core": { "version": "2.11.2", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.2.tgz", @@ -5682,6 +5776,12 @@ "node": ">= 6" } }, + "node_modules/common-path-prefix": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/common-path-prefix/-/common-path-prefix-3.0.0.tgz", + "integrity": "sha512-QE33hToZseCH3jS0qN96O/bSh3kaw/h+Tq7ngyY9eWDUnTlTNUyqfqvCXioLe5Na5jFsL78ra/wuBU4iuEgd4w==", + "dev": true + }, "node_modules/commondir": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", @@ -5994,7 +6094,6 @@ "integrity": "sha512-12VZfFIu+wyVbBebyHmRTuEE/tZrB4tJToWcwAMcsp3h4+sHR+fMJWbKpYiCRWlhFBq+KNyO8rIV9rTkeVmznQ==", "dev": true, "hasInstallScript": true, - "peer": true, "funding": { "type": "opencollective", "url": "https://opencollective.com/core-js" @@ -6020,6 +6119,24 @@ "node": ">=8" } }, + "node_modules/cross-env": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz", + "integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==", + "dev": true, + "dependencies": { + "cross-spawn": "^7.0.1" + }, + "bin": { + "cross-env": "src/bin/cross-env.js", + "cross-env-shell": "src/bin/cross-env-shell.js" + }, + "engines": { + "node": ">=10.14", + "npm": ">=6", + "yarn": ">=1" + } + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -8214,6 +8331,15 @@ "is-arrayish": "^0.2.1" } }, + "node_modules/error-stack-parser": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.0.7.tgz", + "integrity": "sha512-chLOW0ZGRf4s8raLrDxa5sdkvPec5YdvwbFnqJme4rk0rFajP8mPtrDL1+I+CwrQDCjswDA5sREX7jYQDQs9vA==", + "dev": true, + "dependencies": { + "stackframe": "^1.1.1" + } + }, "node_modules/es-abstract": { "version": "1.19.1", "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.19.1.tgz", @@ -15277,6 +15403,15 @@ "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" }, + "node_modules/react-refresh": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz", + "integrity": "sha512-F27qZr8uUqwhWZboondsPx8tnC3Ct3SxZA3V5WyEvujRyyNv0VYPhoBg1gZ8/MV5tubQp76Trw8lTv9hzRBa+A==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/read-pkg": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-2.0.0.tgz", @@ -16527,6 +16662,12 @@ "dev": true, "optional": true }, + "node_modules/stackframe": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/stackframe/-/stackframe-1.2.1.tgz", + "integrity": "sha512-h88QkzREN/hy8eRdyNhhsO7RSJ5oyTqxxmmn0dzBIMUclZsjpfmrsg81vp8mjjAs2vAZ72nyWxRUwSwmh0e4xg==", + "dev": true + }, "node_modules/statuses": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", @@ -20897,6 +21038,53 @@ "@octokit/openapi-types": "^11.2.0" } }, + "@pmmmwh/react-refresh-webpack-plugin": { + "version": "0.5.4", + "resolved": "https://registry.npmjs.org/@pmmmwh/react-refresh-webpack-plugin/-/react-refresh-webpack-plugin-0.5.4.tgz", + "integrity": "sha512-zZbZeHQDnoTlt2AF+diQT0wsSXpvWiaIOZwBRdltNFhG1+I3ozyaw7U/nBiUwyJ0D+zwdXp0E3bWOl38Ag2BMw==", + "dev": true, + "requires": { + "ansi-html-community": "^0.0.8", + "common-path-prefix": "^3.0.0", + "core-js-pure": "^3.8.1", + "error-stack-parser": "^2.0.6", + "find-up": "^5.0.0", + "html-entities": "^2.1.0", + "loader-utils": "^2.0.0", + "schema-utils": "^3.0.0", + "source-map": "^0.7.3" + }, + "dependencies": { + "loader-utils": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.2.tgz", + "integrity": "sha512-TM57VeHptv569d/GKh6TAYdzKblwDNiumOdkFnejjD0XwTH87K90w3O7AiJRqdQoXygvi1VQTJTLGhJl7WqA7A==", + "dev": true, + "requires": { + "big.js": "^5.2.2", + "emojis-list": "^3.0.0", + "json5": "^2.1.2" + } + }, + "schema-utils": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.1.1.tgz", + "integrity": "sha512-Y5PQxS4ITlC+EahLuXaY86TXfR7Dc5lw294alXOq86JAHCihAIZfqv8nNCWvaEJvaC51uN9hbLGeV0cFBdH+Fw==", + "dev": true, + "requires": { + "@types/json-schema": "^7.0.8", + "ajv": "^6.12.5", + "ajv-keywords": "^3.5.2" + } + }, + "source-map": { + "version": "0.7.3", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.3.tgz", + "integrity": "sha512-CkCj6giN3S+n9qrYiBTX5gystlENnRW5jZeNLHpe6aue+SrHcG5VYwujhW9s4dY31mEGsxBDrHR6oI69fTXsaQ==", + "dev": true + } + } + }, "@popperjs/core": { "version": "2.11.2", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.2.tgz", @@ -22959,6 +23147,12 @@ "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", "dev": true }, + "common-path-prefix": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/common-path-prefix/-/common-path-prefix-3.0.0.tgz", + "integrity": "sha512-QE33hToZseCH3jS0qN96O/bSh3kaw/h+Tq7ngyY9eWDUnTlTNUyqfqvCXioLe5Na5jFsL78ra/wuBU4iuEgd4w==", + "dev": true + }, "commondir": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", @@ -23203,8 +23397,7 @@ "version": "3.21.1", "resolved": "https://registry.npmjs.org/core-js-pure/-/core-js-pure-3.21.1.tgz", "integrity": "sha512-12VZfFIu+wyVbBebyHmRTuEE/tZrB4tJToWcwAMcsp3h4+sHR+fMJWbKpYiCRWlhFBq+KNyO8rIV9rTkeVmznQ==", - "dev": true, - "peer": true + "dev": true }, "core-util-is": { "version": "1.0.3", @@ -23223,6 +23416,15 @@ "yaml": "^1.7.2" } }, + "cross-env": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz", + "integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==", + "dev": true, + "requires": { + "cross-spawn": "^7.0.1" + } + }, "cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -24874,6 +25076,15 @@ "is-arrayish": "^0.2.1" } }, + "error-stack-parser": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.0.7.tgz", + "integrity": "sha512-chLOW0ZGRf4s8raLrDxa5sdkvPec5YdvwbFnqJme4rk0rFajP8mPtrDL1+I+CwrQDCjswDA5sREX7jYQDQs9vA==", + "dev": true, + "requires": { + "stackframe": "^1.1.1" + } + }, "es-abstract": { "version": "1.19.1", "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.19.1.tgz", @@ -30356,6 +30567,12 @@ "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" }, + "react-refresh": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz", + "integrity": "sha512-F27qZr8uUqwhWZboondsPx8tnC3Ct3SxZA3V5WyEvujRyyNv0VYPhoBg1gZ8/MV5tubQp76Trw8lTv9hzRBa+A==", + "dev": true + }, "read-pkg": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-2.0.0.tgz", @@ -31372,6 +31589,12 @@ "dev": true, "optional": true }, + "stackframe": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/stackframe/-/stackframe-1.2.1.tgz", + "integrity": "sha512-h88QkzREN/hy8eRdyNhhsO7RSJ5oyTqxxmmn0dzBIMUclZsjpfmrsg81vp8mjjAs2vAZ72nyWxRUwSwmh0e4xg==", + "dev": true + }, "statuses": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", diff --git a/package.json b/package.json index 3c0a892fa..b6ec78f10 100644 --- a/package.json +++ b/package.json @@ -1,18 +1,18 @@ { "name": "chainner", "productName": "chaiNNer", - "version": "0.3.1", + "version": "0.4.0", "description": "A flowchart based image processing GUI", "main": ".webpack/main", "scripts": { "start": "electron-forge start", "dev": "concurrently \"nodemon ./backend/run.py 8000\" \"electron-forge start -- --no-backend\"", - "package": "electron-forge package", - "make": "electron-forge make", - "make-linux-zip": "electron-forge make --targets @electron-forge/maker-zip --platform linux", - "make-win-zip": "electron-forge make --targets @electron-forge/maker-zip --platform win32", - "make-mac-zip": "electron-forge make --targets @electron-forge/maker-zip --platform darwin", - "publish": "electron-forge publish", + "package": "cross-env NODE_ENV=production electron-forge package", + "make": "cross-env NODE_ENV=production electron-forge make", + "make-linux-zip": "cross-env NODE_ENV=production electron-forge make --targets @electron-forge/maker-zip --platform linux", + "make-win-zip": "cross-env NODE_ENV=production electron-forge make --targets @electron-forge/maker-zip --platform win32", + "make-mac-zip": "cross-env NODE_ENV=production electron-forge make --targets @electron-forge/maker-zip --platform darwin", + "publish": "cross-env NODE_ENV=production electron-forge publish", "lint": "eslint . --fix" }, "keywords": [], @@ -116,9 +116,11 @@ "@electron-forge/maker-zip": "^6.0.0-beta.60", "@electron-forge/plugin-webpack": "^6.0.0-beta.60", "@electron-forge/publisher-github": "*", + "@pmmmwh/react-refresh-webpack-plugin": "^0.5.4", "@vercel/webpack-asset-relocator-loader": "^1.7.0", "babel-loader": "^8.2.2", "concurrently": "^6.5.1", + "cross-env": "^7.0.3", "css-loader": "^6.2.0", "electron": "^17.0.0", "electron-installer-common": "^0.10.3", @@ -127,6 +129,7 @@ "file-loader": "^6.2.0", "image-webpack-loader": "^8.1.0", "node-loader": "^2.0.0", + "react-refresh": "^0.11.0", "semver-regex": ">=3.1.3", "style-loader": "^3.2.1" }, diff --git a/src/components/Header.jsx b/src/components/Header.jsx index f43469b8a..2ef87e361 100644 --- a/src/components/Header.jsx +++ b/src/components/Header.jsx @@ -213,7 +213,12 @@ const Header = ({ port }) => { - + { - + { {vramUsage && ( - + { const { useInputData } = useContext(GlobalContext); const [input, setInput] = useInputData(id, index); + console.log('🚀 ~ file: SliderInput.jsx ~ line 17 ~ input', input); const [sliderValue, setSliderValue] = useState(input ?? def); const [showTooltip, setShowTooltip] = useState(false); @@ -38,6 +40,7 @@ const SliderInput = memo(({ onMouseEnter={() => setShowTooltip(true)} onMouseLeave={() => setShowTooltip(false)} isDisabled={isLocked} + focusThumbOnChange={false} > @@ -49,6 +52,9 @@ const SliderInput = memo(({ placement="top" isOpen={showTooltip} label={`${sliderValue}%`} + borderRadius={8} + py={1} + px={2} > @@ -58,6 +64,28 @@ const SliderInput = memo(({ > {max} + { + setInput(Math.min(Math.max(v, min), max)); + }} + draggable={false} + className="nodrag" + disabled={isLocked} + step={1} + size="xs" + > + + + + + + ); }); diff --git a/src/components/inputs/TextAreaInput.jsx b/src/components/inputs/TextAreaInput.jsx index 6ce75170f..400e4b588 100644 --- a/src/components/inputs/TextAreaInput.jsx +++ b/src/components/inputs/TextAreaInput.jsx @@ -11,7 +11,9 @@ const TextAreaInput = memo(({ const [input, setInput] = useInputData(id, index); useEffect(() => { - setInput(''); + if (!input) { + setInput(''); + } }, []); const handleChange = (event) => { diff --git a/src/components/inputs/TextInput.jsx b/src/components/inputs/TextInput.jsx index ff7182e2c..4bbed649d 100644 --- a/src/components/inputs/TextInput.jsx +++ b/src/components/inputs/TextInput.jsx @@ -1,32 +1,46 @@ /* eslint-disable import/extensions */ /* eslint-disable react/prop-types */ import { Input } from '@chakra-ui/react'; -import React, { memo, useContext, useEffect } from 'react'; +import React, { + memo, useContext, useEffect, useState, +} from 'react'; +import { useDebouncedCallback } from 'use-debounce'; import { GlobalContext } from '../../helpers/GlobalNodeState.jsx'; const TextInput = memo(({ - label, id, index, isLocked, + label, id, index, isLocked, maxLength, }) => { - const { useInputData } = useContext(GlobalContext); + const { useInputData, useNodeLock } = useContext(GlobalContext); const [input, setInput] = useInputData(id, index); + const [tempText, setTempText] = useState(''); + const [, , isInputLocked] = useNodeLock(id, index); useEffect(() => { - setInput(''); + if (!input) { + setInput(''); + } else { + setTempText(input); + } }, []); - const handleChange = (event) => { - const text = event.target.value; + const handleChange = useDebouncedCallback((event) => { + let text = event.target.value; + text = maxLength ? text.slice(0, maxLength) : text; setInput(text); - }; + }, 500); return ( { + setTempText(event.target.value); + handleChange(event); + }} draggable={false} className="nodrag" - disabled={isLocked} + disabled={isLocked || isInputLocked} + maxLength={maxLength} /> ); }); diff --git a/src/components/inputs/previews/ImagePreview.jsx b/src/components/inputs/previews/ImagePreview.jsx index 646aff644..fc8d6dc37 100644 --- a/src/components/inputs/previews/ImagePreview.jsx +++ b/src/components/inputs/previews/ImagePreview.jsx @@ -3,9 +3,15 @@ import { Center, HStack, Image, Spinner, Tag, VStack, } from '@chakra-ui/react'; +import { constants } from 'fs'; +import { access } from 'fs/promises'; import { Image as ImageJS } from 'image-js'; import React, { memo, useEffect, useState } from 'react'; +const checkFileExists = (file) => new Promise((resolve) => access(file, constants.F_OK) + .then(() => resolve(true)) + .catch(() => resolve(false))); + const getColorMode = (img) => { if (!img) { return '?'; @@ -32,8 +38,10 @@ export default memo(({ (async () => { if (path) { setIsLoading(true); - const loadedImg = await ImageJS.load(path); - setImg(loadedImg); + if (await checkFileExists(path)) { + const loadedImg = await ImageJS.load(path); + setImg(loadedImg); + } setIsLoading(false); } })(); diff --git a/src/components/node/NodeFooter.jsx b/src/components/node/NodeFooter.jsx index a530d79a6..921d317a4 100644 --- a/src/components/node/NodeFooter.jsx +++ b/src/components/node/NodeFooter.jsx @@ -15,7 +15,7 @@ import { MdMoreHoriz } from 'react-icons/md'; import { GlobalContext } from '../../helpers/GlobalNodeState.jsx'; const NodeFooter = ({ - id, validity, isLocked, toggleLock, + id, validity, isLocked, toggleLock, accentColor, }) => { const { removeNodeById, duplicateNode, clearNode, @@ -23,10 +23,16 @@ const NodeFooter = ({ const [isValid, invalidReason] = validity; + const iconShade = useColorModeValue('gray.400', 'gray.800'); + const validShade = useColorModeValue('gray.900', 'gray.100'); + // const invalidShade = useColorModeValue('red.200', 'red.900'); + const invalidShade = useColorModeValue('red.400', 'red.600'); + // const iconShade = useColorModeValue('gray.400', 'gray.800'); + return (
- toggleLock()} cursor="pointer" /> + toggleLock()} cursor="pointer" />
-
- +
+
+ +
@@ -45,7 +68,7 @@ const NodeFooter = ({
- +
diff --git a/src/helpers/GlobalNodeState.jsx b/src/helpers/GlobalNodeState.jsx index c270743c2..67e4d8383 100644 --- a/src/helpers/GlobalNodeState.jsx +++ b/src/helpers/GlobalNodeState.jsx @@ -2,11 +2,11 @@ /* eslint-disable react/prop-types */ import { ipcRenderer } from 'electron'; import React, { - createContext, useCallback, useEffect, useMemo, useState + createContext, useCallback, useEffect, useMemo, useState, } from 'react'; import { getOutgoers, - isEdge, isNode, removeElements as rfRemoveElements, useZoomPanHelper + isEdge, isNode, removeElements as rfRemoveElements, useZoomPanHelper, } from 'react-flow-renderer'; import { useHotkeys } from 'react-hotkeys-hook'; import { v4 as uuidv4 } from 'uuid'; @@ -60,7 +60,8 @@ export const GlobalProvider = ({ (element) => isNode(element), ); const validNodes = justNodes.filter( - (node) => availableNodes[node.data.category][node.data.type], + (node) => availableNodes[node.data.category] + && availableNodes[node.data.category][node.data.type], ) || []; if (justNodes.length !== validNodes.length) { await ipcRenderer.invoke( @@ -488,7 +489,7 @@ export const GlobalProvider = ({ }; let isInputLocked = false; - if (index) { + if (index !== undefined && index !== null) { const edge = edges.find((e) => e.target === id && String(e.targetHandle.split('-').slice(-1)) === String(index)); isInputLocked = !!edge; } diff --git a/src/helpers/checkNodeValidity.js b/src/helpers/checkNodeValidity.js index 96597b418..2b19ed9be 100644 --- a/src/helpers/checkNodeValidity.js +++ b/src/helpers/checkNodeValidity.js @@ -1,3 +1,5 @@ +/* eslint-disable max-len */ +// TODO: This file is a monstrosity, I need to make it so inputs are done by name and not by index const checkNodeValidity = ({ id, inputs, inputData, edges, }) => { @@ -9,18 +11,26 @@ const checkNodeValidity = ({ // Check to make sure the node has all the data it should based on the schema. // Compares the schema against the connections and the entered data const nonOptionalInputs = inputs.filter((input) => !input.optional); - const emptyInputs = Object.entries(inputData).filter(([key, value]) => nonOptionalInputs.includes(key) && (value === '' || value === undefined || value === null)).map(([key]) => String(key)); - // eslint-disable-next-line max-len - const isMissingInputs = nonOptionalInputs.length > Object.keys(inputData).length + filteredEdges.length; + // Grabs all the indexes of the inputs that the connections are targeting + const edgeTargetIndexes = edges.filter((edge) => edge.target === id).map((edge) => edge.targetHandle.split('-').slice(-1)[0]); + // Finds all empty inputs + const emptyInputs = Object.entries(inputData) + .filter( + ([key, value]) => Object.keys(nonOptionalInputs).includes(String(key)) + && (value === '' || value === undefined || value === null) + && !edgeTargetIndexes.includes(String(key)), + ) + .map(([key]) => String(key)); + const enteredOptionalInputs = inputs.filter((input, i) => input.optional && Object.keys(inputData).map((index) => String(index)).includes(String(i))); + const filteredInputDataKeys = Object.entries(inputData).filter(([key, value]) => !edgeTargetIndexes.includes(String(key)) && value !== '').map((([key]) => key)); + const isMissingInputs = nonOptionalInputs.length + enteredOptionalInputs.length > filteredInputDataKeys.length + filteredEdges.length; if (isMissingInputs || emptyInputs.length > 0) { - // Grabs all the indexes of the inputs that the connections are targeting - const edgeTargetIndexes = edges.filter((edge) => edge.target === id).map((edge) => edge.targetHandle.split('-').slice(-1)[0]); // Grab all inputs that do not have data or a connected edge const missingInputs = nonOptionalInputs.filter( - (input, i) => !Object.keys(inputData).includes(String(i)) - && !edgeTargetIndexes.includes(String(i)), + (input, i) => !edgeTargetIndexes.includes(String(i)) + && (!Object.keys(inputData).includes(String(i)) + || emptyInputs.includes(String(i))), ); - // TODO: This fails to output the missing inputs when a node is connected to another return [false, `Missing required input data: ${missingInputs.map((input) => input.label).join(', ')}`]; } return [true, '']; diff --git a/src/helpers/dependencies.js b/src/helpers/dependencies.js index 850746f58..69889b521 100644 --- a/src/helpers/dependencies.js +++ b/src/helpers/dependencies.js @@ -16,5 +16,5 @@ export default (isNvidiaAvailable) => [{ }, { name: 'NCNN', packageName: 'ncnn-vulkan', - version: '2022.3.2', + version: '2022.3.14', }]; diff --git a/src/main.js b/src/main.js index 8006f2afc..bdf7370d7 100644 --- a/src/main.js +++ b/src/main.js @@ -6,6 +6,7 @@ import log from 'electron-log'; import { access, readFile, writeFile, } from 'fs/promises'; +import https from 'https'; import os from 'os'; import path from 'path'; import portfinder from 'portfinder'; @@ -17,11 +18,42 @@ import { downloadPython, extractPython, installSanic } from './setupIntegratedPy const exec = util.promisify(_exec); +// Handle creating/removing shortcuts on Windows when installing/uninstalling. +if (require('electron-squirrel-startup')) { // eslint-disable-line global-require + app.quit(); +} + // log.transports.file.resolvePath = () => path.join(app.getAppPath(), 'logs/main.log'); // eslint-disable-next-line max-len log.transports.file.resolvePath = (variables) => path.join(variables.electronDefaultDir, variables.fileName); log.transports.file.level = 'info'; +log.catchErrors({ + showDialog: false, + onError(error, versions, submitIssue) { + dialog.showMessageBox({ + title: 'An error occurred', + message: error.message, + detail: error.stack, + type: 'error', + buttons: ['Ignore', 'Report', 'Exit'], + }) + .then((result) => { + if (result.response === 1) { + submitIssue('https://github.com/joeyballentine/chaiNNer/issues/new', { + title: `Error report for ${versions.app}`, + body: `Error:\n\`\`\`${error.stack}\n\`\`\`\nOS: ${versions.os}`, + }); + return; + } + + if (result.response === 2) { + app.quit(); + } + }); + }, +}); + let gpuInfo; process.env.ELECTRON_DISABLE_SECURITY_WARNINGS = 'true'; @@ -32,9 +64,58 @@ const pythonKeys = { python: 'python', }; -// Handle creating/removing shortcuts on Windows when installing/uninstalling. -if (require('electron-squirrel-startup')) { // eslint-disable-line global-require - app.quit(); +// Check for update +if (app.isPackaged) { + const options = { + hostname: 'api.github.com', + path: '/repos/joeyballentine/chaiNNer/releases', + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'User-Agent': 'chaiNNer', + }, + }; + const req = https.request(options, (res) => { + let response = ''; + + res.on('data', (data) => { + response += String(data); + }); + + res.on('close', async () => { + try { + const releases = JSON.parse(response); + const gtVersions = releases.filter( + (v) => semver.gt(semver.coerce(v.tag_name), app.getVersion()), + ); + if (gtVersions.length > 0) { + const sorted = gtVersions.sort((a, b) => semver.gt(a, b)); + const latestVersion = sorted[0]; + const releaseUrl = latestVersion.html_url; + const latestVersionNum = semver.coerce(latestVersion.tag_name); + const buttonResult = dialog.showMessageBoxSync(BrowserWindow.getFocusedWindow(), { + type: 'info', + title: 'An update is available for chaiNNer!', + message: `Version ${latestVersionNum} is available for download from GitHub.`, + buttons: [`Get version ${latestVersionNum}`, 'Ok'], + defaultId: 1, + }); + if (buttonResult === 0) { + await shell.openExternal(releaseUrl); + app.exit(); + } + } + } catch (error) { + log.error(error); + } + }); + }); + + req.on('error', (error) => { + log.error(error); + }); + + req.end(); } if (app.isPackaged) { diff --git a/webpack.renderer.config.js b/webpack.renderer.config.js index e5f6a64c5..5a75cdc40 100644 --- a/webpack.renderer.config.js +++ b/webpack.renderer.config.js @@ -1,5 +1,12 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +const ReactRefreshWebpackPlugin = require('@pmmmwh/react-refresh-webpack-plugin'); + const rules = require('./webpack.rules'); +const isDevelopment = process.env.NODE_ENV !== 'production'; + +console.log(`\nbuilding in ${isDevelopment ? 'development' : 'production'} mode`); + rules.push({ test: /\.css$/, use: [{ loader: 'style-loader' }, { loader: 'css-loader' }], @@ -10,4 +17,9 @@ module.exports = { module: { rules, }, + mode: isDevelopment ? 'development' : 'production', + devServer: { + hot: isDevelopment, + }, + plugins: [isDevelopment && new ReactRefreshWebpackPlugin()].filter(Boolean), }; diff --git a/webpack.rules.js b/webpack.rules.js index 173d76d9c..0222b933c 100644 --- a/webpack.rules.js +++ b/webpack.rules.js @@ -1,13 +1,18 @@ +const isDevelopment = process.env.NODE_ENV !== 'production'; + +console.log(`\nbuilding in ${isDevelopment ? 'development' : 'production'} mode`); + module.exports = [ // ... existing loader config ... { test: /\.jsx?$/, + exclude: /node_modules/, use: { - loader: 'babel-loader', + loader: require.resolve('babel-loader'), options: { exclude: /node_modules/, presets: ['@babel/preset-react'], - cacheDirectory: true, + plugins: [isDevelopment && require.resolve('react-refresh/babel')].filter(Boolean), }, }, },