From bb3b863564941dbd7cdfbd8f6fb6f2ec7eff47c1 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 13 Oct 2023 13:00:13 -0500 Subject: [PATCH] Minor earthaccess.download updates --- earthaccess/api.py | 7 +++++-- earthaccess/store.py | 12 +++++------- tests/integration/test_api.py | 11 +++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/earthaccess/api.py b/earthaccess/api.py index f0264454..19391264 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -9,6 +9,7 @@ from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery from .store import Store from .utils import _validation as validate +from .results import DataGranule def search_datasets( @@ -150,7 +151,7 @@ def login(strategy: str = "all", persist: bool = False) -> Auth: def download( - granules: Union[List[earthaccess.results.DataGranule], List[str]], + granules: Union[DataGranule, List[DataGranule], List[str]], local_path: Optional[str], provider: Optional[str] = None, threads: int = 8, @@ -161,7 +162,7 @@ def download( * If we run it outside AWS (us-west-2 region) and the dataset is cloud hostes we'll use HTTP links Parameters: - granules: a list of granules(DataGranule) instances or a list of granule links (HTTP) + granules: a granule, list of granules, or a list of granule links (HTTP) local_path: local directory to store the remote data granules provider: if we download a list of URLs we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 @@ -169,6 +170,8 @@ def download( Returns: List of downloaded files """ + if isinstance(granules, DataGranule): + granules = [granules] try: results = earthaccess.__store__.get(granules, local_path, provider, threads) except AttributeError as err: diff --git a/earthaccess/store.py b/earthaccess/store.py index 62421e9f..d4e7f181 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -509,16 +509,15 @@ def _get_urls( s3_fs = self.get_s3fs_session(provider=provider) # TODO: make this parallel or concurrent for file in data_links: - file_name = file.split("/")[-1] s3_fs.get(file, local_path) - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if we are not in AWS return self._download_onprem_granules(data_links, local_path, threads) - return None @_get.register def _get_granules( @@ -557,14 +556,13 @@ def _get_granules( # TODO: make this async for file in data_links: s3_fs.get(file, local_path) - file_name = file.split("/")[-1] - print(f"Retrieved: {file} to {local_path}") + file_name = os.path.join(local_path, os.path.basename(file)) + print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files else: # if the data is cloud based bu we are not in AWS it will be downloaded as if it was on prem return self._download_onprem_granules(data_links, local_path, threads) - return None def _download_file(self, url: str, directory: str) -> str: """ @@ -598,7 +596,7 @@ def _download_file(self, url: str, directory: str) -> str: raise Exception else: print(f"File {local_filename} already downloaded") - return local_filename + return local_path def _download_onprem_granules( self, urls: List[str], directory: Optional[str] = None, threads: int = 8 diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index d4848d60..5736fa50 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -69,16 +69,15 @@ def test_granules_search_returns_valid_results(kwargs): assertions.assertTrue(len(results) <= 10) -def test_earthaccess_api_can_download_granules(): +@pytest.mark.parametrize("selection", [0, slice(None)]) +def test_earthaccess_api_can_download_granules(tmp_path, selection): results = earthaccess.search_data( count=2, short_name="ATL08", cloud_hosted=True, bounding_box=(-92.86, 16.26, -91.58, 16.97), ) - local_path = "./tests/integration/data/ATL08" - assertions.assertIsInstance(results, list) - assertions.assertTrue(len(results) <= 2) - files = earthaccess.download(results, local_path=local_path) + result = results[selection] + files = earthaccess.download(result, str(tmp_path)) assertions.assertIsInstance(files, list) - shutil.rmtree(local_path) + assert all(os.path.exists(f) for f in files)