Skip to content

Commit

Permalink
Updates to make UPath work with zipfile and tarfiles
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 13, 2024
1 parent 9c6b808 commit 7a2cb5d
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 23 deletions.
3 changes: 2 additions & 1 deletion CI/SCRIPTS/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import pytest
import rasterio
from cloudpathlib import AnyPath
from tempenv import tempenv

from CI.SCRIPTS.script_utils import CI_SERTIT_S3
from sertit import AnyPath, rasters
from sertit import rasters
from sertit.s3 import USE_S3_STORAGE, s3_env, temp_s3


Expand Down
3 changes: 2 additions & 1 deletion CI/SCRIPTS/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import numpy as np
from cloudpathlib import CloudPath
from upath import UPath

from sertit import AnyPath
from sertit.types import AnyPathType, is_iterable, make_iterable


def test_types():
"""Test some type aliases"""
assert AnyPathType == Union[Path, CloudPath]
assert AnyPathType == Union[Path, CloudPath, UPath]


def test_is_iterable():
Expand Down
6 changes: 3 additions & 3 deletions CI/SCRIPTS/test_unistra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# limitations under the License.
""" Script testing the CI """
import pytest
from cloudpathlib import S3Client
from cloudpathlib import AnyPath, S3Client
from tempenv import tempenv

from CI.SCRIPTS.script_utils import CI_SERTIT_S3
from sertit import AnyPath, ci, misc, rasters, s3
from sertit import ci, misc, rasters, s3
from sertit.unistra import (
_get_db_path,
get_db2_path,
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_unistra_s3():
assert with_s3() == 1

# Test get_geodatastore with s3
assert str(get_geodatastore()) == "s3://sertit-geodatastore"
assert str(get_geodatastore()) == "s3://sertit-geodatastore/"

# Test get_geodatastore without s3
with tempenv.TemporaryEnvironment({s3.USE_S3_STORAGE: "0"}):
Expand Down
25 changes: 19 additions & 6 deletions sertit/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,25 @@ def extract_sub_dir(arch, filename_list):
arch.extractall(archive_output)

# Manage archive type

if file_path.suffix == ".zip":
if path.is_cloud_path(file_path):
file_path = s3.read(file_path)
with zipfile.ZipFile(file_path, "r") as zip_file:
extract_sub_dir(zip_file, zip_file.namelist())
elif file_path.suffix == ".tar" or file_path.suffixes == [".tar", ".gz"]:
with tarfile.open(file_path, "r") as tar_file:
if path.is_cloud_path(file_path):
args = {"fileobj": s3.read(file_path), "mode": "r"}
else:
args = {"name": file_path, "mode": "r"}
with tarfile.open(**args) as tar_file:
extract_sub_dir(tar_file, tar_file.getnames())
elif file_path.suffix == ".7z":
try:
import py7zr

if path.is_cloud_path(file_path):
file_path = s3.read(file_path)
with py7zr.SevenZipFile(file_path, "r") as z7_file:
extract_sub_dir(z7_file, z7_file.getnames())
except ModuleNotFoundError:
Expand Down Expand Up @@ -394,14 +403,18 @@ def read_archived_file(
bytes: Archived file in bytes
"""
archive_path = AnyPath(archive_path)

archive_fn = get_filename(archive_path)
# Compile regex
regex = re.compile(regex)

# Open tar and zip XML
try:
if archive_path.suffix == ".tar":
with tarfile.open(archive_path) as tar_ds:
if path.is_cloud_path(archive_path):
args = {"fileobj": s3.read(archive_path), "mode": "r"}
else:
args = {"name": archive_path, "mode": "r"}
with tarfile.open(**args) as tar_ds:
# file_list is not very useful for TAR files...
if file_list is None:
tar_mb = tar_ds.getmembers()
Expand All @@ -410,6 +423,8 @@ def read_archived_file(
tarinfo = tar_ds.getmember(name)
file_str = tar_ds.extractfile(tarinfo).read()
elif archive_path.suffix == ".zip":
if path.is_cloud_path(archive_path):
archive_path = s3.read(archive_path)
with zipfile.ZipFile(archive_path) as zip_ds:
if file_list is None:
file_list = [f.filename for f in zip_ds.filelist]
Expand All @@ -425,9 +440,7 @@ def read_archived_file(
"Only .zip and .tar files can be read from inside its archive."
)
except IndexError:
raise FileNotFoundError(
f"Impossible to find file {regex} in {path.get_filename(archive_path)}"
)
raise FileNotFoundError(f"Impossible to find file {regex} in {archive_fn}")

return file_str

Expand Down
21 changes: 15 additions & 6 deletions sertit/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import zipfile
from typing import Any, Union

from sertit import AnyPath, logs
from sertit import AnyPath, logs, s3
from sertit.logs import SU_NAME
from sertit.types import AnyPathStrType, AnyPathType

Expand Down Expand Up @@ -167,18 +167,27 @@ def get_archived_file_list(archive_path: AnyPathStrType) -> list:
['file_1.txt', 'file_2.tif', 'file_3.xml', 'file_4.geojson']
"""
archive_path = AnyPath(archive_path)
if archive_path.suffix == ".zip":

is_zip = archive_path.suffix == ".zip"
archive_fn = get_filename(archive_path)
if is_zip:

if is_cloud_path(archive_path):
archive_path = s3.read(archive_path)

with zipfile.ZipFile(archive_path) as zip_ds:
file_list = [f.filename for f in zip_ds.filelist]
else:
try:
with tarfile.open(archive_path) as tar_ds:
if is_cloud_path(archive_path):
args = {"fileobj": s3.read(archive_path), "mode": "r"}
else:
args = {"name": archive_path, "mode": "r"}
with tarfile.open(**args) as tar_ds:
tar_mb = tar_ds.getmembers()
file_list = [mb.name for mb in tar_mb]
except tarfile.ReadError as ex:
raise tarfile.ReadError(
f"Impossible to open archive: {archive_path}"
) from ex
raise tarfile.ReadError(f"Impossible to open archive: {archive_fn}") from ex

return file_list

Expand Down
12 changes: 12 additions & 0 deletions sertit/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
from contextlib import contextmanager
from functools import wraps
from io import BytesIO

from cloudpathlib import S3Client

Expand Down Expand Up @@ -305,3 +306,14 @@ def download(src, dst):
downloaded_path = src.download_to(dst)

return downloaded_path


def read(src):
src = AnyPath(src)
try:
b = src.read_bytes()
except Exception:
with src.open("rb") as f:
b = f.read()

return BytesIO(b)
17 changes: 14 additions & 3 deletions sertit/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,11 @@ def get_aoi_wkt(aoi_path: AnyPathStrType, as_str: bool = True) -> Union[str, Pol

if aoi_path.suffix == ".wkt":
try:
with open(aoi_path, "r") as aoi_f:
aoi = wkt.load(aoi_f)
if path.is_cloud_path(aoi_path):
aoi = wkt.load(s3.read(aoi_path))
else:
with open(aoi_path, "r") as aoi_f:
aoi = wkt.load(aoi_f)
except Exception as ex:
raise ValueError("AOI WKT cannot be read") from ex
else:
Expand Down Expand Up @@ -707,11 +710,19 @@ def ogr2geojson(
vector_path = AnyPath(vector_path)

# archived vector_path are extracted in a tmp folder so no need to be downloaded

if vector_path.suffix == ".zip":
if path.is_cloud_path(vector_path):
vector_path = s3.read(vector_path)
with zipfile.ZipFile(vector_path, "r") as zip_ds:
vect_path = zip_ds.extract(arch_vect_path, out_dir)
elif vector_path.suffix == ".tar":
with tarfile.open(vector_path, "r") as tar_ds:
if path.is_cloud_path(vector_path):
args = {"fileobj": s3.read(vector_path), "mode": "r"}
else:
args = {"name": vector_path, "mode": "r"}

with tarfile.open(**args) as tar_ds:
tar_ds.extract(arch_vect_path, out_dir)
vect_path = os.path.join(out_dir, arch_vect_path)
else:
Expand Down
6 changes: 3 additions & 3 deletions sertit/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from lxml.html.builder import E

from sertit import AnyPath, files, path
from sertit import AnyPath, files, path, s3
from sertit.logs import SU_NAME
from sertit.misc import ListEnum
from sertit.types import AnyPathStrType
Expand All @@ -55,12 +55,12 @@ def read(xml_path: AnyPathStrType) -> _Element:
try:
# Try using read_text (faster)
root = fromstring(xml_path.read_text())
except ValueError:
except (ValueError, PermissionError):
# Try using read_bytes
# Slower but works with:
# {ValueError}Unicode strings with encoding declaration are not supported.
# Please use bytes input or XML fragments without declaration.
root = fromstring(xml_path.read_bytes())
root = fromstring(s3.read(xml_path))
else:
# pylint: disable=I1101:
# Module 'lxml.etree' has no 'parse' member, but source is unavailable.
Expand Down

0 comments on commit 7a2cb5d

Please sign in to comment.