Skip to content

Commit

Permalink
Added pqdm_kwargs to the functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherwin-14 committed Oct 23, 2024
1 parent 42f97ce commit 0e76a23
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 34 deletions.
22 changes: 16 additions & 6 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
import s3fs
from fsspec import AbstractFileSystem
from typing_extensions import Any, Dict, List, Optional, Union, deprecated
from typing_extensions import Any, Dict, List, Optional, Union, deprecated, Mapping

import earthaccess
from earthaccess.services import DataServices
Expand Down Expand Up @@ -205,7 +205,7 @@ def download(
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -226,14 +226,19 @@ def download(
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
if isinstance(granules, DataGranule):
granules = [granules]
elif isinstance(granules, str):
granules = [granules]
try:
results = earthaccess.__store__.get(
granules, local_path, provider, threads, fail_fast=fail_fast
)
granules, local_path, provider, threads, pqdm_kwargs
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
Expand All @@ -245,7 +250,7 @@ def download(
def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[AbstractFileSystem]:
"""Returns a list of file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -259,8 +264,13 @@ def open(
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
results = earthaccess.__store__.open(
granules=granules, provider=provider, fail_fast=fail_fast
granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs
)
return results

Expand Down
75 changes: 47 additions & 28 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,18 @@ def __repr__(self) -> str:
def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: int = 8,
) -> List[fsspec.spec.AbstractBufferedFile]:
threads: Optional[int] = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
url, granule = data
return EarthAccessFile(fs.open(url), granule) # type: ignore
urls, granule = data
return EarthAccessFile(fs.open(urls), granule) # type: ignore

fileset = pqdm(
url_mapping.items(),
multi_thread_open,
n_jobs=threads,
exception_behaviour=exception_behavior,
**pqdm_kwargs
)
return fileset

Expand Down Expand Up @@ -341,8 +342,9 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
"""Returns a list of file-like objects that can be used to access files
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Parameters:
Expand All @@ -354,15 +356,15 @@ def open(
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
if len(granules):
return self._open(granules, provider, fail_fast=fail_fast)
return self._open(granules, provider,**pqdm_kwargs)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
raise NotImplementedError("granules should be a list of DataGranule or URLs")

Expand All @@ -372,6 +374,7 @@ def _open_granules(
granules: List[DataGranule],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand All @@ -397,13 +400,23 @@ def _open_granules(
else:
access = "on_prem"
s3_fs = None

access = "direct"
provider = granules[0]["meta"]["provider-id"]
# if the data has its own S3 credentials endpoint, we will use it
endpoint = self._own_s3_credentials(granules[0]["umm"]["RelatedUrls"])
if endpoint is not None:
logger.info(f"using endpoint: {endpoint}")
s3_fs = self.get_s3_filesystem(endpoint=endpoint)
else:
logger.info(f"using provider: {provider}")
s3_fs = self.get_s3_filesystem(provider=provider)

url_mapping = _get_url_granule_mapping(granules, access)
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping, fs=s3_fs, threads=threads, fail_fast=fail_fast
)
url_mapping, fs=s3_fs, threads=threads, **pqdm_kwargs
)
except Exception as e:
raise RuntimeError(
"An exception occurred while trying to access remote files on S3. "
Expand All @@ -412,13 +425,13 @@ def _open_granules(
) from e
else:
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
url_mapping, threads=threads, **pqdm_kwargs
)
return fileset
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
url_mapping, threads=threads, **pqdm_kwargs
)
return fileset

Expand All @@ -428,6 +441,7 @@ def _open_urls(
granules: List[str],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []

Expand All @@ -452,7 +466,7 @@ def _open_urls(
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping, fs=s3_fs, threads=threads, fail_fast=fail_fast
url_mapping, fs=s3_fs, threads=threads, **pqdm_kwargs
)
except Exception as e:
raise RuntimeError(
Expand All @@ -472,7 +486,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads, fail_fast=fail_fast)
fileset = self._open_urls_https(url_mapping, threads,**pqdm_kwargs)
return fileset

def get(
Expand All @@ -481,7 +495,7 @@ def get(
local_path: Union[Path, str, None] = None,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down Expand Up @@ -509,9 +523,15 @@ def get(
elif isinstance(local_path, str):
local_path = Path(local_path)

pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**pqdm_kwargs,
}

if len(granules):
files = self._get(
granules, local_path, provider, threads, fail_fast=fail_fast
granules, local_path, provider, threads, **pqdm_kwargs
)
return files
else:
Expand All @@ -524,7 +544,7 @@ def _get(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down Expand Up @@ -554,7 +574,7 @@ def _get_urls(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links = granules
downloaded_files: List = []
Expand All @@ -577,8 +597,8 @@ def _get_urls(
else:
# if we are not in AWS
return self._download_onprem_granules(
data_links, local_path, threads, fail_fast=fail_fast
)
data_links, local_path, threads, **pqdm_kwargs
)

@_get.register
def _get_granules(
Expand All @@ -587,7 +607,7 @@ def _get_granules(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links: List = []
downloaded_files: List = []
Expand Down Expand Up @@ -629,7 +649,7 @@ def _get_granules(
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(
data_links, local_path, threads, fail_fast=fail_fast
data_links, local_path, threads, **pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
Expand Down Expand Up @@ -668,7 +688,7 @@ def _download_file(self, url: str, directory: Path) -> str:
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: Path, threads: int = 8, fail_fast: bool = True
self, urls: List[str], directory: Path, threads: int = 8, pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
"""Downloads a list of URLS into the data directory.
Expand All @@ -694,14 +714,13 @@ def _download_onprem_granules(

arguments = [(url, directory) for url in urls]

exception_behavior = "immediate" if fail_fast else "deferred"

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
exception_behaviour=exception_behavior,
exception_behaviour=exception_behavior,
**pqdm_kwargs
)
return results

Expand All @@ -713,7 +732,7 @@ def _open_urls_https(
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads)
return _open_files(url_mapping, https_fs, threads,**pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down

0 comments on commit 0e76a23

Please sign in to comment.