Skip to content

Commit

Permalink
Fix pqdm_kwargs integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chuckwondo committed Nov 4, 2024
1 parent 84be54e commit 7b7b132
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 89 deletions.
40 changes: 20 additions & 20 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path

import requests
import s3fs
Expand Down Expand Up @@ -202,9 +203,10 @@ def login(strategy: str = "all", persist: bool = False, system: System = PROD) -

def download(
granules: Union[DataGranule, List[DataGranule], str, List[str]],
local_path: Optional[str],
local_path: Optional[Union[Path, str]] = None,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -215,7 +217,11 @@ def download(
Parameters:
granules: a granule, list of granules, a granule link (HTTP), or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
local_path: Local directory to store the remote data granules. If not
supplied, defaults to a subdirectory of the current working directory
of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year,
month, and day of the current date, and `UUID` is the last 6 digits
of a UUID4 value.
provider: if we download a list of URLs, we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
Expand All @@ -228,31 +234,29 @@ def download(
Raises:
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
provider = _normalize_location(str(provider))

if isinstance(granules, DataGranule):
granules = [granules]
elif isinstance(granules, str):
granules = [granules]

try:
results = earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs
return earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs=pqdm_kwargs
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
)
return []
return results

return []


def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[AbstractFileSystem]:
"""Returns a list of file-like objects that can be used to access files
Expand All @@ -269,15 +273,11 @@ def open(
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}
results = earthaccess.__store__.open(
granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs
return earthaccess.__store__.open(
granules=granules,
provider=_normalize_location(provider),
pqdm_kwargs=pqdm_kwargs,
)
return results


def get_s3_credentials(
Expand Down
94 changes: 51 additions & 43 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,20 @@ def __repr__(self) -> str:
def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
url, granule = data
return EarthAccessFile(fs.open(url), granule) # type: ignore

pqdm_kwargs = {
"exception_behavior": "immediate",
"exception_behaviour": "immediate",
"n_jobs": 8,
**(pqdm_kwargs or {}),
}

fileset = pqdm(
url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs
)
return fileset
return pqdm(url_mapping.items(), multi_thread_open, **pqdm_kwargs)


def make_instance(
Expand Down Expand Up @@ -344,6 +342,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
"""Returns a list of file-like objects that can be used to access files
Expand All @@ -361,14 +360,15 @@ def open(
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
if len(granules):
return self._open(granules, provider, pqdm_kwargs)
return self._open(granules, provider, pqdm_kwargs=pqdm_kwargs)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
raise NotImplementedError("granules should be a list of DataGranule or URLs")
Expand All @@ -378,7 +378,8 @@ def _open_granules(
self,
granules: List[DataGranule],
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand Down Expand Up @@ -411,7 +412,7 @@ def _open_granules(
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
raise RuntimeError(
Expand All @@ -420,19 +421,19 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(url_mapping, threads=threads)
return fileset
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)

return fileset

@_open.register
def _open_urls(
self,
granules: List[str],
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
Expand Down Expand Up @@ -460,7 +461,6 @@ def _open_urls(
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
Expand All @@ -481,15 +481,16 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs)
fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs)
return fileset

def get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Union[Path, str, None] = None,
local_path: Optional[Union[Path, str]] = None,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -503,7 +504,11 @@ def get(
Parameters:
granules: A list of granules(DataGranule) instances or a list of granule links (HTTP).
local_path: Local directory to store the remote data granules.
local_path: Local directory to store the remote data granules. If not
supplied, defaults to a subdirectory of the current working directory
of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year,
month, and day of the current date, and `UUID` is the last 6 digits
of a UUID4 value.
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
Expand All @@ -514,26 +519,28 @@ def get(
Returns:
List of downloaded files
"""
if not granules:
raise ValueError("List of URLs or DataGranule instances expected")

if local_path is None:
today = datetime.datetime.today().strftime("%Y-%m-%d")
today = datetime.datetime.now().strftime("%Y-%m-%d")
uuid = uuid4().hex[:6]
local_path = Path.cwd() / "data" / f"{today}-{uuid}"
elif isinstance(local_path, str):
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads, pqdm_kwargs)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
pqdm_kwargs = {
"n_jobs": threads,
**(pqdm_kwargs or {}),
}

return self._get(granules, Path(local_path), provider, pqdm_kwargs=pqdm_kwargs)

@singledispatchmethod
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down Expand Up @@ -566,7 +573,7 @@ def _get_urls(
granules: List[str],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links = granules
Expand All @@ -590,7 +597,7 @@ def _get_urls(
else:
# if we are not in AWS
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

@_get.register
Expand All @@ -599,7 +606,7 @@ def _get_granules(
granules: List[DataGranule],
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links: List = []
Expand All @@ -615,7 +622,7 @@ def _get_granules(
for granule in granules
)
)
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
total_size = round(sum(granule.size() for granule in granules) / 1024, 2)
logger.info(
f" Getting {len(granules)} granules, approx download size: {total_size} GB"
)
Expand All @@ -642,7 +649,7 @@ def _get_granules(
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
Expand Down Expand Up @@ -684,7 +691,7 @@ def _download_onprem_granules(
self,
urls: List[str],
directory: Path,
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
"""Downloads a list of URLS into the data directory.
Expand All @@ -711,25 +718,26 @@ def _download_onprem_granules(

arguments = [(url, directory) for url in urls]

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
**pqdm_kwargs,
)
return results
pqdm_kwargs = {
"exception_behaviour": "immediate",
**(pqdm_kwargs or {}),
# We don't want a user to be able to override the following kwargs,
# which is why they appear *after* spreading pqdm_kwargs above.
"argument_type": "args",
}

return pqdm(arguments, self._download_file, **pqdm_kwargs)

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: int = 8,
*,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads, pqdm_kwargs)
return _open_files(url_mapping, https_fs, pqdm_kwargs=pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
Loading

0 comments on commit 7b7b132

Please sign in to comment.