Skip to content

Commit

Permalink
Merge pull request #55 from SmartReports/dev
Browse files Browse the repository at this point in the history
update
  • Loading branch information
PaulMagos authored Nov 27, 2023
2 parents e4323f8 + 47ea989 commit e1dd15d
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 0 deletions.
1 change: 1 addition & 0 deletions django_storage_supabase/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .supabase import *
24 changes: 24 additions & 0 deletions django_storage_supabase/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.core.exceptions import ImproperlyConfigured
from django.core.files.storage import Storage


class BaseStorage(Storage):
def __init__(self, **settings):
default_settings = self.get_default_settings()

for name, value in default_settings.items():
if not hasattr(self, name):
setattr(self, name, value)

for name, value in settings.items():
if name not in default_settings:
raise ImproperlyConfigured(
"Invalid setting '{}' for {}".format(
name,
self.__class__.__name__,
)
)
setattr(self, name, value)

def get_default_settings(self):
return {}
49 changes: 49 additions & 0 deletions django_storage_supabase/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import io
import zlib
from gzip import GzipFile
from typing import Optional

from django_storage_supabase.utils import to_bytes


class GzipCompressionWrapper(io.RawIOBase):
"""Wrapper for compressing file contents on the fly."""

def __init__(self, raw, level=zlib.Z_BEST_COMPRESSION):
super().__init__()
self.raw = raw
self.compress = zlib.compressobj(level=level, wbits=31)
self.leftover = bytearray()

@staticmethod
def readable():
return True

def readinto(self, buf: bytearray) -> Optional[int]:
size = len(buf)
while len(self.leftover) < size:
chunk = to_bytes(self.raw.read(size))
if not chunk:
if self.compress:
self.leftover += self.compress.flush(zlib.Z_FINISH)
self.compress = None
break
self.leftover += self.compress.compress(chunk)
if len(self.leftover) == 0:
return 0
output = self.leftover[:size]
size = len(output)
buf[:size] = output
self.leftover = self.leftover[size:]
return size


class CompressStorageMixin:
def _compress_content(self, content):
"""Gzip a given string content."""
return GzipCompressionWrapper(content)


class CompressedFileMixin:
def _decompress_file(self, mode, file, mtime=0.0):
return GzipFile(mode=mode, fileobj=file, mtime=mtime)
210 changes: 210 additions & 0 deletions django_storage_supabase/supabase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Supabase storage class for Django pluggable storage system.
# Author: Joel Lee <[email protected]>
# License: MIT
__all__ = ["SupabaseStorage"]
import posixpath

# Add below to settings.py:
# SUPABASE_ACCESS_TOKEN = 'YourOauthToken'
# SUPABASE_URL = "https:<your-supabase-id>"
# SUPABASE_ROOT_PATH = '/dir/'
from io import BytesIO
from shutil import copyfileobj
from tempfile import SpooledTemporaryFile
import os
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.core.files.base import File
from django.utils import timezone
from django.utils.deconstruct import deconstructible
import supabase

from django_storage_supabase.base import BaseStorage
from django_storage_supabase.compress import CompressedFileMixin, CompressStorageMixin

from .utils import (
check_location,
clean_name,
get_available_overwrite_name,
safe_join,
setting,
)


@deconstructible
class SupabaseFile(CompressedFileMixin, File):
"""The default file object used by the Supabase Storage.
Parameters
----------
CompressedFileMixin : [type]
[description]
File : [type]
[description]
"""

def __init__(self, name, storage):
self._storage_client = None
self.name = name
self._file = None
self._storage = storage

def _get_file(self):
if self._file is None:
self._file = SpooledTemporaryFile()
response = self._storage_client.download(self.name)
# TODO: Modify Supabase-py to return response so we can check status == 200 before trying the op
with BytesIO(response) as file_content:
copyfileobj(file_content, self._file)
self._file.seek(0)
return self._file

def _set_file(self, value):
self._file = value

file = property(_get_file, _set_file)


@deconstructible
class SupabaseStorage(CompressStorageMixin, BaseStorage):
def __init__(self, **settings):
super().__init__(**settings)
self._bucket = None
self._client = None
self.location = ""
check_location(self)
print("Supabase Storage")

def _normalize_name(self, name):
"""
Normalizes the name so that paths like /path/to/ignored/../something.txt
work. We check to make sure that the path pointed to is not outside
the directory specified by the LOCATION setting.
"""
try:
return safe_join(self.location, name)
except ValueError:
raise SuspiciousOperation("Attempted access to '%s' denied." % name)

def _open(self, name):
remote_file = SupabaseFile(self._clean_name(name), self)
return remote_file

def _save(self, name, content):
content.open()
# TODO: Get content.read() to become a file
self.bucket.upload(self._clean_name(name), content.read())
content.close()
return name

@property
def client(self):
if self._client is None:
settings = self.get_default_settings()
supabase_url, supabase_api_key = settings.get(
"supabase_url"
), settings.get("supabase_api_key")
if bool(supabase_url) ^ bool(supabase_api_key):
raise ImproperlyConfigured(
"Both SUPABASE_URL and SUPABASE_API_KEY must be "
"provided together."
)
self._client = supabase.create_client(supabase_url, supabase_api_key).storage

return self._client

@property
def bucket(self):
"""
Get the current bucket. If there is no current bucket object
create it.
"""
if self._bucket is None:
self._bucket = self.client.from_(self.bucket_name)
return self._bucket

def get_valid_name(self, name):
# TODO: Add valid name checks
return name

def get_default_settings(self):
# Return Access token and URL
return {
"supabase_url": setting("SUPABASE_URL"),
"supabase_api_key": setting("SUPABASE_API_KEY"),
"file_overwrite": setting('SUPABASE_STORAGE_FILE_OVERWRITE', True),
'bucket_name': setting("SUPABASE_STORAGE_BUCKET_NAME")
}

def listdir(self, name: str):
name = self._normalize_name(clean_name(name))
# For bucket.list_blobs and logic below name needs to end in /
# but for the root path "" we leave it as an empty string
if name and not name.endswith("/"):
name += "/"

directory_contents = self._bucket.list(path=name)

files = []
dirs = []
for entry in directory_contents:
if entry.get("metadata"):
files.append(entry["name"])
else:
dirs.append(entry["name"])

return dirs, files

def delete(self, name: str):
name = self._normalize_name(clean_name(name))
try:
self._bucket.remove(name)
except Exception as e:
pass

def exists(self, name: str):
name = self._normalize_name(clean_name(name))
return bool(self.bucket.list(name))

def get_accessed_time(self, name: str):
name = self._normalize_name(clean_name(name))
accessed = self._bucket.list(name)[0]["accessed_at"]
return accessed if setting("USE_TZ") else timezone.make_naive(accessed)

def get_available_name(self, name, max_length=None):
name = clean_name(name)
if self.file_overwrite:
return get_available_overwrite_name(name, max_length)

return super().get_available_name(name, max_length)

def get_created_time(self, name: str):
name = self._normalize_name(clean_name(name))
created = self._bucket.list(name)[0]["created_at"]
return created if setting("USE_TZ") else timezone.make_naive(created)

def get_modified_time(self, name: str):
name = self._normalize_name(clean_name(name))
updated = self._bucket.list(name)[0]["updated_at"]
return updated if setting("USE_TZ") else timezone.make_naive(updated)

def _clean_name(self, name: str) -> str:
"""
Cleans the name so that Windows style paths work
"""
# Normalize Windows style paths
clean_name = posixpath.normpath(name).replace("\\", "/")

# os.path.normpath() can strip trailing slashes so we implement
# a workaround here.
if name.endswith("/") and not clean_name.endswith("/"):
# Add a trailing slash as it was stripped.
clean_name += "/"
return clean_name

def size(self, name: str) -> int:
name = self._normalize_name(clean_name(name))
return int(self._bucket.list(name)[0]["metadata"]["size"])

def url(self, name: str) -> str:
name = self._normalize_name(clean_name(name))
return self._bucket.get_public_url(name)
Loading

0 comments on commit e1dd15d

Please sign in to comment.