Skip to content

Commit

Permalink
Minor earthaccess.download updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Oct 13, 2023
1 parent 25a38ea commit bb3b863
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
7 changes: 5 additions & 2 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery
from .store import Store
from .utils import _validation as validate
from .results import DataGranule


def search_datasets(
Expand Down Expand Up @@ -150,7 +151,7 @@ def login(strategy: str = "all", persist: bool = False) -> Auth:


def download(
granules: Union[List[earthaccess.results.DataGranule], List[str]],
granules: Union[DataGranule, List[DataGranule], List[str]],
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
Expand All @@ -161,14 +162,16 @@ def download(
* If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links
Parameters:
granules: a list of granules(DataGranule) instances or a list of granule links (HTTP)
granules: a granule, list of granules, or a list of granule links (HTTP)
local_path: local directory to store the remote data granules
provider: if we download a list of URLs we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
Returns:
List of downloaded files
"""
if isinstance(granules, DataGranule):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
except AttributeError as err:
Expand Down
12 changes: 5 additions & 7 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,16 +509,15 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
file_name = file.split("/")[-1]
s3_fs.get(file, local_path)
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return None

@_get.register
def _get_granules(
Expand Down Expand Up @@ -557,14 +556,13 @@ def _get_granules(
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = file.split("/")[-1]
print(f"Retrieved: {file} to {local_path}")
file_name = os.path.join(local_path, os.path.basename(file))
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
else:
# if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return None

def _download_file(self, url: str, directory: str) -> str:
"""
Expand Down Expand Up @@ -598,7 +596,7 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_filename
return local_path

def _download_onprem_granules(
self, urls: List[str], directory: Optional[str] = None, threads: int = 8
Expand Down
11 changes: 5 additions & 6 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,15 @@ def test_granules_search_returns_valid_results(kwargs):
assertions.assertTrue(len(results) <= 10)


def test_earthaccess_api_can_download_granules():
@pytest.mark.parametrize("selection", [0, slice(None)])
def test_earthaccess_api_can_download_granules(tmp_path, selection):
results = earthaccess.search_data(
count=2,
short_name="ATL08",
cloud_hosted=True,
bounding_box=(-92.86, 16.26, -91.58, 16.97),
)
local_path = "./tests/integration/data/ATL08"
assertions.assertIsInstance(results, list)
assertions.assertTrue(len(results) <= 2)
files = earthaccess.download(results, local_path=local_path)
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
shutil.rmtree(local_path)
assert all(os.path.exists(f) for f in files)

0 comments on commit bb3b863

Please sign in to comment.