Skip to content

Commit

Permalink
Test with UPath
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 13, 2024
1 parent a546d02 commit 9c6b808
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 122 deletions.
9 changes: 7 additions & 2 deletions CI/SCRIPTS/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ class Polarization(ListEnum):

def get_s3_ci_path():
"""Get S3 CI path"""
unistra.define_s3_client()
return AnyPath("s3://sertit-sertit-utils-ci")
# unistra.define_s3_client()

from sertit.unistra import UNISTRA_S3_ENDPOINT

return AnyPath(
"s3://sertit-sertit-utils-ci", endpoint_url=f"https://{UNISTRA_S3_ENDPOINT}"
)


def get_proj_path():
Expand Down
27 changes: 13 additions & 14 deletions CI/SCRIPTS/test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lxml import etree

from CI.SCRIPTS.script_utils import files_path, rasters_path, s3_env, vectors_path
from sertit import ci, path, rasters, rasters_rio, vectors
from sertit import ci, path, rasters, rasters_rio, s3, vectors

ci.reduce_verbosity()

Expand Down Expand Up @@ -169,15 +169,15 @@ def test_assert_raster():


@s3_env
def test_assert_xml():
def test_assert_xml(tmp_path):
# XML
xml_folder = files_path().joinpath("LM05_L1TP_200030_20121230_20200820_02_T2_CI")
xml_path = xml_folder.joinpath("LM05_L1TP_200030_20121230_20200820_02_T2_MTL.xml")
xml_bad_path = xml_folder.joinpath("false_xml.xml")

if path.is_cloud_path(files_path()):
xml_path = xml_path.fspath
xml_bad_path = xml_bad_path.fspath
xml_path = s3.download(xml_path, tmp_path)
xml_bad_path = s3.download(xml_bad_path, tmp_path)

xml_ok = etree.parse(str(xml_path)).getroot()
xml_nok = etree.parse(str(xml_bad_path)).getroot()
Expand All @@ -188,19 +188,18 @@ def test_assert_xml():


@s3_env
def test_assert_html():
def test_assert_html(tmp_path):
# HTML
html_path = files_path().joinpath("productPreview.html")
html_bad_path = files_path().joinpath("false.html")

with tempfile.TemporaryDirectory() as tmp_dir:
if path.is_cloud_path(files_path()):
html_path = html_path.download_to(tmp_dir)
html_bad_path = html_bad_path.download_to(tmp_dir)
if path.is_cloud_path(files_path()):
html_path = s3.download(html_path, tmp_path)
html_bad_path = s3.download(html_bad_path, tmp_path)

html_ok = etree.parse(str(html_path)).getroot()
html_nok = etree.parse(str(html_bad_path)).getroot()
html_ok = etree.parse(str(html_path)).getroot()
html_nok = etree.parse(str(html_bad_path)).getroot()

ci.assert_xml_equal(html_ok, html_ok)
with pytest.raises(AssertionError):
ci.assert_xml_equal(html_ok, html_nok)
ci.assert_xml_equal(html_ok, html_ok)
with pytest.raises(AssertionError):
ci.assert_xml_equal(html_ok, html_nok)
125 changes: 59 additions & 66 deletions CI/SCRIPTS/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lxml import etree, html

from CI.SCRIPTS.script_utils import Polarization, files_path, s3_env
from sertit import AnyPath, ci, files, path, vectors
from sertit import AnyPath, ci, files, path, s3, vectors

ci.reduce_verbosity()

Expand Down Expand Up @@ -94,7 +94,7 @@ def test_archive():


@s3_env
def test_archived_files():
def test_archived_files(tmp_path):
landsat_name = "LM05_L1TP_200030_20121230_20200820_02_T2_CI"
ok_folder = files_path().joinpath(landsat_name)
zip_file = files_path().joinpath(f"{landsat_name}.zip")
Expand All @@ -114,70 +114,63 @@ def test_archived_files():
ci.assert_geom_equal(vect_ok, vect_zip)
ci.assert_geom_equal(vect_ok, vect_tar)

with tempfile.TemporaryDirectory() as tmp_dir:
# XML
xml_name = "LM05_L1TP_200030_20121230_20200820_02_T2_MTL.xml"
xml_ok_path = ok_folder.joinpath(xml_name)
if path.is_cloud_path(files_path()):
xml_ok_path = str(xml_ok_path.download_to(tmp_dir))
else:
xml_ok_path = str(xml_ok_path)

xml_regex = f".*{xml_name}"
xml_zip = files.read_archived_xml(zip_file, xml_regex)
xml_tar = files.read_archived_xml(tar_file, r".*_MTL\.xml")
xml_ok = etree.parse(xml_ok_path).getroot()
ci.assert_xml_equal(xml_ok, xml_zip)
ci.assert_xml_equal(xml_ok, xml_tar)

# FILE + HTML
html_zip_file = files_path().joinpath("productPreview.zip")
html_tar_file = files_path().joinpath("productPreview.tar")
html_name = "productPreview.html"
html_ok_path = files_path().joinpath(html_name)
if path.is_cloud_path(files_path()):
html_ok_path = str(html_ok_path.download_to(tmp_dir))
else:
html_ok_path = str(html_ok_path)

html_regex = f".*{html_name}"

# FILE
file_zip = files.read_archived_file(html_zip_file, html_regex)
file_tar = files.read_archived_file(html_tar_file, html_regex)
html_ok = html.parse(html_ok_path).getroot()
ci.assert_html_equal(html_ok, html.fromstring(file_zip))
ci.assert_html_equal(html_ok, html.fromstring(file_tar))

file_list = path.get_archived_file_list(html_zip_file)
ci.assert_html_equal(
html_ok,
html.fromstring(
files.read_archived_file(html_zip_file, html_regex, file_list=file_list)
),
)

# HTML
html_zip = files.read_archived_html(html_zip_file, html_regex)
html_tar = files.read_archived_html(html_tar_file, html_regex)
ci.assert_html_equal(html_ok, html_zip)
ci.assert_html_equal(html_ok, html_tar)
ci.assert_html_equal(
html_ok,
files.read_archived_html(
html_tar_file,
html_regex,
file_list=path.get_archived_file_list(html_tar_file),
),
)

# ERRORS
with pytest.raises(TypeError):
files.read_archived_file(targz_file, xml_regex)
with pytest.raises(TypeError):
files.read_archived_file(sz_file, xml_regex)
with pytest.raises(FileNotFoundError):
files.read_archived_file(zip_file, "cdzeferf")
# XML
xml_name = "LM05_L1TP_200030_20121230_20200820_02_T2_MTL.xml"
xml_ok_path = ok_folder.joinpath(xml_name)
xml_ok_path = str(s3.download(xml_ok_path, tmp_path))

xml_regex = f".*{xml_name}"
xml_zip = files.read_archived_xml(zip_file, xml_regex)
xml_tar = files.read_archived_xml(tar_file, r".*_MTL\.xml")
xml_ok = etree.parse(xml_ok_path).getroot()
ci.assert_xml_equal(xml_ok, xml_zip)
ci.assert_xml_equal(xml_ok, xml_tar)

# FILE + HTML
html_zip_file = files_path().joinpath("productPreview.zip")
html_tar_file = files_path().joinpath("productPreview.tar")
html_name = "productPreview.html"
html_ok_path = files_path().joinpath(html_name)
html_ok_path = str(s3.download(html_ok_path, tmp_path))

html_regex = f".*{html_name}"

# FILE
file_zip = files.read_archived_file(html_zip_file, html_regex)
file_tar = files.read_archived_file(html_tar_file, html_regex)
html_ok = html.parse(html_ok_path).getroot()
ci.assert_html_equal(html_ok, html.fromstring(file_zip))
ci.assert_html_equal(html_ok, html.fromstring(file_tar))

file_list = path.get_archived_file_list(html_zip_file)
ci.assert_html_equal(
html_ok,
html.fromstring(
files.read_archived_file(html_zip_file, html_regex, file_list=file_list)
),
)

# HTML
html_zip = files.read_archived_html(html_zip_file, html_regex)
html_tar = files.read_archived_html(html_tar_file, html_regex)
ci.assert_html_equal(html_ok, html_zip)
ci.assert_html_equal(html_ok, html_tar)
ci.assert_html_equal(
html_ok,
files.read_archived_html(
html_tar_file,
html_regex,
file_list=path.get_archived_file_list(html_tar_file),
),
)

# ERRORS
with pytest.raises(TypeError):
files.read_archived_file(targz_file, xml_regex)
with pytest.raises(TypeError):
files.read_archived_file(sz_file, xml_regex)
with pytest.raises(FileNotFoundError):
files.read_archived_file(zip_file, "cdzeferf")


def test_cp_rm():
Expand Down
5 changes: 3 additions & 2 deletions CI/SCRIPTS/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

import pytest
import rasterio
from cloudpathlib import AnyPath, S3Client
from tempenv import tempenv

from CI.SCRIPTS.script_utils import CI_SERTIT_S3
from sertit import rasters
from sertit import AnyPath, rasters
from sertit.s3 import USE_S3_STORAGE, s3_env, temp_s3


Expand All @@ -43,6 +42,8 @@ def with_s3(variable_1, variable_2):


def without_s3():
from cloudpathlib import S3Client

S3Client().set_as_default_client()
return base_fct(None)

Expand Down
4 changes: 2 additions & 2 deletions CI/SCRIPTS/test_unistra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# limitations under the License.
""" Script testing the CI """
import pytest
from cloudpathlib import AnyPath, S3Client
from cloudpathlib import S3Client
from tempenv import tempenv

from CI.SCRIPTS.script_utils import CI_SERTIT_S3
from sertit import ci, misc, rasters, s3
from sertit import AnyPath, ci, misc, rasters, s3
from sertit.unistra import (
_get_db_path,
get_db2_path,
Expand Down
12 changes: 9 additions & 3 deletions sertit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
.. include:: ../README.md
"""
try:
from cloudpathlib import AnyPath
from upath import UPath

AnyPath = UPath

AnyPath = AnyPath
except ImportError:
pass
try:
from cloudpathlib import AnyPath

AnyPath = AnyPath
except ImportError:
pass

# flake8: noqa
from .__meta__ import (
Expand Down
36 changes: 15 additions & 21 deletions sertit/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from lxml import etree, html
from tqdm import tqdm

from sertit import AnyPath, logs, path
from sertit import AnyPath, logs, path, s3
from sertit.logs import SU_NAME
from sertit.strings import DATE_FORMAT
from sertit.types import AnyPathStrType, AnyPathType
Expand Down Expand Up @@ -515,26 +515,20 @@ def archive(
'D:/path/to/output/folder_to_archive.tar.gz'
"""
archive_path = AnyPath(archive_path)
folder_path = AnyPath(folder_path)

tmp_dir = None
if path.is_cloud_path(folder_path):
tmp_dir = tempfile.TemporaryDirectory()
folder_path = folder_path.download_to(tmp_dir.name)

# Shutil make_archive needs a path without extension
archive_base = os.path.splitext(archive_path)[0]

# Archive the folder
archive_fn = shutil.make_archive(
archive_base,
format=fmt,
root_dir=folder_path.parent,
base_dir=folder_path.name,
)

if tmp_dir is not None:
tmp_dir.cleanup()
with tempfile.TemporaryDirectory() as tmp_path:
folder_path = s3.download(AnyPath(folder_path), tmp_path)

# Shutil make_archive needs a path without extension
archive_base = os.path.splitext(archive_path)[0]

# Archive the folder
archive_fn = shutil.make_archive(
archive_base,
format=fmt,
root_dir=folder_path.parent,
base_dir=folder_path.name,
)

return AnyPath(archive_fn)

Expand Down Expand Up @@ -755,7 +749,7 @@ def copy(src: AnyPathStrType, dst: AnyPathStrType) -> AnyPathType:
src = AnyPath(src)

if path.is_cloud_path(src):
out = src.download_to(dst)
out = s3.download(src, dst)
else:
out = None
try:
Expand Down
22 changes: 17 additions & 5 deletions sertit/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,22 @@ def is_cloud_path(path: AnyPathStrType):
bool: True if the file is store on the cloud.
"""
try:
from cloudpathlib import CloudPath
return AnyPath(path).protocol in [
"s3",
"az",
"adl",
"abfs",
"abfss",
"gs",
"gcs",
]
except ImportError:
try:
from cloudpathlib import CloudPath

return isinstance(AnyPath(path), CloudPath)
except Exception:
return False
return isinstance(AnyPath(path), CloudPath)
except Exception:
return False


def is_path(path: Any) -> bool:
Expand All @@ -613,5 +624,6 @@ def is_path(path: Any) -> bool:
from pathlib import Path

from cloudpathlib import CloudPath
from upath import UPath

return isinstance(path, (str, Path, CloudPath))
return isinstance(path, (str, Path, CloudPath, UPath))
4 changes: 2 additions & 2 deletions sertit/rasters_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"Please install 'rasterio' to use the 'rasters_rio' package."
) from ex

from sertit import AnyPath, geometry, logs, misc, path, strings, vectors, xml
from sertit import AnyPath, geometry, logs, misc, path, s3, strings, vectors, xml
from sertit.logs import SU_NAME
from sertit.types import AnyNumpyArray, AnyPathStrType, AnyPathType, AnyRasterType

Expand Down Expand Up @@ -1435,7 +1435,7 @@ def merge_vrt(
crs_path = AnyPath(crs_path)
# Download file if VRT is needed
if path.is_cloud_path(crs_path):
crs_path = crs_path.download_to(merged_path.parent)
crs_path = s3.download(crs_path, merged_path.parent)

with rasterio.open(str(crs_path)) as src:
if first_crs is None:
Expand Down
Loading

0 comments on commit 9c6b808

Please sign in to comment.