From 47ea98996e32bc97f0b27d929bceb9f424c9fe09 Mon Sep 17 00:00:00 2001 From: PaulMagos Date: Tue, 28 Nov 2023 00:23:31 +0100 Subject: [PATCH] update --- django_storage_supabase/__init__.py | 1 + django_storage_supabase/base.py | 24 ++++ django_storage_supabase/compress.py | 49 +++++++ django_storage_supabase/supabase.py | 210 ++++++++++++++++++++++++++++ django_storage_supabase/utils.py | 124 ++++++++++++++++ 5 files changed, 408 insertions(+) create mode 100644 django_storage_supabase/__init__.py create mode 100644 django_storage_supabase/base.py create mode 100644 django_storage_supabase/compress.py create mode 100644 django_storage_supabase/supabase.py create mode 100644 django_storage_supabase/utils.py diff --git a/django_storage_supabase/__init__.py b/django_storage_supabase/__init__.py new file mode 100644 index 0000000..b851473 --- /dev/null +++ b/django_storage_supabase/__init__.py @@ -0,0 +1 @@ +from .supabase import * diff --git a/django_storage_supabase/base.py b/django_storage_supabase/base.py new file mode 100644 index 0000000..f919b78 --- /dev/null +++ b/django_storage_supabase/base.py @@ -0,0 +1,24 @@ +from django.core.exceptions import ImproperlyConfigured +from django.core.files.storage import Storage + + +class BaseStorage(Storage): + def __init__(self, **settings): + default_settings = self.get_default_settings() + + for name, value in default_settings.items(): + if not hasattr(self, name): + setattr(self, name, value) + + for name, value in settings.items(): + if name not in default_settings: + raise ImproperlyConfigured( + "Invalid setting '{}' for {}".format( + name, + self.__class__.__name__, + ) + ) + setattr(self, name, value) + + def get_default_settings(self): + return {} diff --git a/django_storage_supabase/compress.py b/django_storage_supabase/compress.py new file mode 100644 index 0000000..bec50d0 --- /dev/null +++ b/django_storage_supabase/compress.py @@ -0,0 +1,49 @@ +import io +import zlib +from gzip import GzipFile +from typing import Optional + +from django_storage_supabase.utils import to_bytes + + +class GzipCompressionWrapper(io.RawIOBase): + """Wrapper for compressing file contents on the fly.""" + + def __init__(self, raw, level=zlib.Z_BEST_COMPRESSION): + super().__init__() + self.raw = raw + self.compress = zlib.compressobj(level=level, wbits=31) + self.leftover = bytearray() + + @staticmethod + def readable(): + return True + + def readinto(self, buf: bytearray) -> Optional[int]: + size = len(buf) + while len(self.leftover) < size: + chunk = to_bytes(self.raw.read(size)) + if not chunk: + if self.compress: + self.leftover += self.compress.flush(zlib.Z_FINISH) + self.compress = None + break + self.leftover += self.compress.compress(chunk) + if len(self.leftover) == 0: + return 0 + output = self.leftover[:size] + size = len(output) + buf[:size] = output + self.leftover = self.leftover[size:] + return size + + +class CompressStorageMixin: + def _compress_content(self, content): + """Gzip a given string content.""" + return GzipCompressionWrapper(content) + + +class CompressedFileMixin: + def _decompress_file(self, mode, file, mtime=0.0): + return GzipFile(mode=mode, fileobj=file, mtime=mtime) diff --git a/django_storage_supabase/supabase.py b/django_storage_supabase/supabase.py new file mode 100644 index 0000000..93b99ba --- /dev/null +++ b/django_storage_supabase/supabase.py @@ -0,0 +1,210 @@ +# Supabase storage class for Django pluggable storage system. +# Author: Joel Lee +# License: MIT +__all__ = ["SupabaseStorage"] +import posixpath + +# Add below to settings.py: +# SUPABASE_ACCESS_TOKEN = 'YourOauthToken' +# SUPABASE_URL = "https:" +# SUPABASE_ROOT_PATH = '/dir/' +from io import BytesIO +from shutil import copyfileobj +from tempfile import SpooledTemporaryFile +import os +from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation +from django.core.files.base import File +from django.utils import timezone +from django.utils.deconstruct import deconstructible +import supabase + +from django_storage_supabase.base import BaseStorage +from django_storage_supabase.compress import CompressedFileMixin, CompressStorageMixin + +from .utils import ( + check_location, + clean_name, + get_available_overwrite_name, + safe_join, + setting, +) + + +@deconstructible +class SupabaseFile(CompressedFileMixin, File): + """The default file object used by the Supabase Storage. + + Parameters + ---------- + CompressedFileMixin : [type] + [description] + File : [type] + [description] + """ + + def __init__(self, name, storage): + self._storage_client = None + self.name = name + self._file = None + self._storage = storage + + def _get_file(self): + if self._file is None: + self._file = SpooledTemporaryFile() + response = self._storage_client.download(self.name) + # TODO: Modify Supabase-py to return response so we can check status == 200 before trying the op + with BytesIO(response) as file_content: + copyfileobj(file_content, self._file) + self._file.seek(0) + return self._file + + def _set_file(self, value): + self._file = value + + file = property(_get_file, _set_file) + + +@deconstructible +class SupabaseStorage(CompressStorageMixin, BaseStorage): + def __init__(self, **settings): + super().__init__(**settings) + self._bucket = None + self._client = None + self.location = "" + check_location(self) + print("Supabase Storage") + + def _normalize_name(self, name): + """ + Normalizes the name so that paths like /path/to/ignored/../something.txt + work. We check to make sure that the path pointed to is not outside + the directory specified by the LOCATION setting. + """ + try: + return safe_join(self.location, name) + except ValueError: + raise SuspiciousOperation("Attempted access to '%s' denied." % name) + + def _open(self, name): + remote_file = SupabaseFile(self._clean_name(name), self) + return remote_file + + def _save(self, name, content): + content.open() + # TODO: Get content.read() to become a file + self.bucket.upload(self._clean_name(name), content.read()) + content.close() + return name + + @property + def client(self): + if self._client is None: + settings = self.get_default_settings() + supabase_url, supabase_api_key = settings.get( + "supabase_url" + ), settings.get("supabase_api_key") + if bool(supabase_url) ^ bool(supabase_api_key): + raise ImproperlyConfigured( + "Both SUPABASE_URL and SUPABASE_API_KEY must be " + "provided together." + ) + self._client = supabase.create_client(supabase_url, supabase_api_key).storage + + return self._client + + @property + def bucket(self): + """ + Get the current bucket. If there is no current bucket object + create it. + """ + if self._bucket is None: + self._bucket = self.client.from_(self.bucket_name) + return self._bucket + + def get_valid_name(self, name): + # TODO: Add valid name checks + return name + + def get_default_settings(self): + # Return Access token and URL + return { + "supabase_url": setting("SUPABASE_URL"), + "supabase_api_key": setting("SUPABASE_API_KEY"), + "file_overwrite": setting('SUPABASE_STORAGE_FILE_OVERWRITE', True), + 'bucket_name': setting("SUPABASE_STORAGE_BUCKET_NAME") + } + + def listdir(self, name: str): + name = self._normalize_name(clean_name(name)) + # For bucket.list_blobs and logic below name needs to end in / + # but for the root path "" we leave it as an empty string + if name and not name.endswith("/"): + name += "/" + + directory_contents = self._bucket.list(path=name) + + files = [] + dirs = [] + for entry in directory_contents: + if entry.get("metadata"): + files.append(entry["name"]) + else: + dirs.append(entry["name"]) + + return dirs, files + + def delete(self, name: str): + name = self._normalize_name(clean_name(name)) + try: + self._bucket.remove(name) + except Exception as e: + pass + + def exists(self, name: str): + name = self._normalize_name(clean_name(name)) + return bool(self.bucket.list(name)) + + def get_accessed_time(self, name: str): + name = self._normalize_name(clean_name(name)) + accessed = self._bucket.list(name)[0]["accessed_at"] + return accessed if setting("USE_TZ") else timezone.make_naive(accessed) + + def get_available_name(self, name, max_length=None): + name = clean_name(name) + if self.file_overwrite: + return get_available_overwrite_name(name, max_length) + + return super().get_available_name(name, max_length) + + def get_created_time(self, name: str): + name = self._normalize_name(clean_name(name)) + created = self._bucket.list(name)[0]["created_at"] + return created if setting("USE_TZ") else timezone.make_naive(created) + + def get_modified_time(self, name: str): + name = self._normalize_name(clean_name(name)) + updated = self._bucket.list(name)[0]["updated_at"] + return updated if setting("USE_TZ") else timezone.make_naive(updated) + + def _clean_name(self, name: str) -> str: + """ + Cleans the name so that Windows style paths work + """ + # Normalize Windows style paths + clean_name = posixpath.normpath(name).replace("\\", "/") + + # os.path.normpath() can strip trailing slashes so we implement + # a workaround here. + if name.endswith("/") and not clean_name.endswith("/"): + # Add a trailing slash as it was stripped. + clean_name += "/" + return clean_name + + def size(self, name: str) -> int: + name = self._normalize_name(clean_name(name)) + return int(self._bucket.list(name)[0]["metadata"]["size"]) + + def url(self, name: str) -> str: + name = self._normalize_name(clean_name(name)) + return self._bucket.get_public_url(name) diff --git a/django_storage_supabase/utils.py b/django_storage_supabase/utils.py new file mode 100644 index 0000000..73ac292 --- /dev/null +++ b/django_storage_supabase/utils.py @@ -0,0 +1,124 @@ +import os +import posixpath + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured, SuspiciousFileOperation +from django.utils.encoding import force_bytes + + +def to_bytes(content): + """Wrap Django's force_bytes to pass through bytearrays.""" + if isinstance(content, bytearray): + return content + + return force_bytes(content) + + +def setting(name, default=None): + + """ + Helper function to get a Django setting by name. If setting doesn't exists + it will return a default. + :param name: Name of setting + :type name: str + :param default: Value if setting is unfound + :returns: Setting's value + """ + return getattr(settings, name, default) + + +def clean_name(name): + """ + Cleans the name so that Windows style paths work + """ + # Normalize Windows style paths + clean_name = posixpath.normpath(name).replace("\\", "/") + + # os.path.normpath() can strip trailing slashes so we implement + # a workaround here. + if name.endswith("/") and not clean_name.endswith("/"): + # Add a trailing slash as it was stripped. + clean_name = clean_name + "/" + + # Given an empty string, os.path.normpath() will return ., which we don't want + if clean_name == ".": + clean_name = "" + + return clean_name + + +def safe_join(base, *paths): + """ + A version of django.utils._os.safe_join for S3 paths. + Joins one or more path components to the base path component + intelligently. Returns a normalized version of the final path. + The final path must be located inside of the base path component + (otherwise a ValueError is raised). + Paths outside the base path indicate a possible security + sensitive operation. + """ + base_path = base + base_path = base_path.rstrip("/") + paths = [p for p in paths] + + final_path = base_path + "/" + for path in paths: + _final_path = posixpath.normpath(posixpath.join(final_path, path)) + # posixpath.normpath() strips the trailing /. Add it back. + if path.endswith("/") or _final_path + "/" == final_path: + _final_path += "/" + final_path = _final_path + if final_path == base_path: + final_path += "/" + + # Ensure final_path starts with base_path and that the next character after + # the base path is /. + base_path_len = len(base_path) + if not final_path.startswith(base_path) or final_path[base_path_len] != "/": + raise ValueError( + "the joined path is located outside of the base path" " component" + ) + + return final_path.lstrip("/") + + +def check_location(storage): + if storage.location.startswith("/"): + correct = storage.location.lstrip("/") + raise ImproperlyConfigured( + "{}.location cannot begin with a leading slash. Found '{}'. Use '{}' instead.".format( + storage.__class__.__name__, + storage.location, + correct, + ) + ) + + +def lookup_env(names): + """ + Look up for names in environment. Returns the first element + found. + """ + for name in names: + value = os.environ.get(name) + if value: + return value + + +def get_available_overwrite_name(name, max_length): + if max_length is None or len(name) <= max_length: + return name + + # Adapted from Django + dir_name, file_name = os.path.split(name) + file_root, file_ext = os.path.splitext(file_name) + truncation = len(name) - max_length + + file_root = file_root[:-truncation] + if not file_root: + raise SuspiciousFileOperation( + 'Storage tried to truncate away entire filename "%s". ' + "Please make sure that the corresponding file field " + 'allows sufficient "max_length".' % name + ) + return os.path.join(dir_name, f"{file_root}{file_ext}")