diff --git a/earthaccess/auth.py b/earthaccess/auth.py index a8335f73..2a820402 100644 --- a/earthaccess/auth.py +++ b/earthaccess/auth.py @@ -257,9 +257,7 @@ def _netrc(self) -> bool: try: my_netrc = Netrc() except FileNotFoundError as err: - raise FileNotFoundError( - f"No .netrc found in {os.path.expanduser('~')}" - ) from err + raise FileNotFoundError(f"No .netrc found in {Path.home()}") from err except NetrcParseError as err: raise NetrcParseError("Unable to parse .netrc") from err if my_netrc["urs.earthdata.nasa.gov"] is not None: @@ -365,7 +363,7 @@ def _persist_user_credentials(self, username: str, password: str) -> bool: try: netrc_path = Path().home().joinpath(".netrc") netrc_path.touch(exist_ok=True) - os.chmod(netrc_path.absolute(), 0o600) + netrc_path.chmod(0o600) except Exception as e: print(e) return False diff --git a/earthaccess/store.py b/earthaccess/store.py index 4981fe50..15ae4ef2 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -1,5 +1,4 @@ import datetime -import os import shutil import traceback from functools import lru_cache @@ -443,7 +442,7 @@ def _open_urls( def get( self, granules: Union[List[DataGranule], List[str]], - local_path: Optional[str] = None, + local_path: Optional[Path] = None, provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -466,11 +465,10 @@ def get( List of downloaded files """ if local_path is None: - local_path = os.path.join( - ".", - "data", - f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}", - ) + today = datetime.datetime.today().strftime("%Y-%m-%d") + uuid = uuid4().hex[:6] + local_path = Path.cwd() / "data" / f"{today}-{uuid}" + if len(granules): files = self._get(granules, local_path, provider, threads) return files @@ -481,7 +479,7 @@ def get( def _get( self, granules: Union[List[DataGranule], List[str]], - local_path: str, + local_path: Path, provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -509,7 +507,7 @@ def _get( def _get_urls( self, granules: List[str], - local_path: str, + local_path: Path, provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -525,8 +523,8 @@ def _get_urls( s3_fs = self.get_s3fs_session(provider=provider) # TODO: make this parallel or concurrent for file in data_links: - s3_fs.get(file, local_path) - file_name = os.path.join(local_path, os.path.basename(file)) + s3_fs.get(file, str(local_path)) + file_name = local_path / Path(file).name print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files @@ -539,7 +537,7 @@ def _get_urls( def _get_granules( self, granules: List[DataGranule], - local_path: str, + local_path: Path, provider: Optional[str] = None, threads: int = 8, ) -> List[str]: @@ -571,8 +569,8 @@ def _get_granules( s3_fs = self.get_s3fs_session(provider=provider) # TODO: make this async for file in data_links: - s3_fs.get(file, local_path) - file_name = os.path.join(local_path, os.path.basename(file)) + s3_fs.get(file, str(local_path)) + file_name = local_path / Path(file).name print(f"Downloaded: {file_name}") downloaded_files.append(file_name) return downloaded_files @@ -581,7 +579,7 @@ def _get_granules( # it will be downloaded as if it was on prem return self._download_onprem_granules(data_links, local_path, threads) - def _download_file(self, url: str, directory: str) -> str: + def _download_file(self, url: str, directory: Path) -> str: """Download a single file from an on-prem location, a DAAC data center. Parameters: @@ -595,9 +593,8 @@ def _download_file(self, url: str, directory: str) -> str: if "opendap" in url and url.endswith(".html"): url = url.replace(".html", "") local_filename = url.split("/")[-1] - path = Path(directory) / Path(local_filename) - local_path = str(path) - if not os.path.exists(local_path): + path = directory / Path(local_filename) + if not path.exists(): try: session = self.auth.get_session() with session.get( @@ -606,7 +603,7 @@ def _download_file(self, url: str, directory: str) -> str: allow_redirects=True, ) as r: r.raise_for_status() - with open(local_path, "wb") as f: + with open(path, "wb") as f: # This is to cap memory usage for large files at 1MB per write to disk per thread # https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content shutil.copyfileobj(r.raw, f, length=1024 * 1024) @@ -616,10 +613,10 @@ def _download_file(self, url: str, directory: str) -> str: raise Exception else: print(f"File {local_filename} already downloaded") - return local_path + return str(path) def _download_onprem_granules( - self, urls: List[str], directory: str, threads: int = 8 + self, urls: List[str], directory: Path, threads: int = 8 ) -> List[Any]: """Downloads a list of URLS into the data directory. @@ -638,8 +635,7 @@ def _download_onprem_granules( raise ValueError( "We need to be logged into NASA EDL in order to download data granules" ) - if not os.path.exists(directory): - os.makedirs(directory) + directory.mkdir(parents=True, exist_ok=True) arguments = [(url, directory) for url in urls] results = pqdm( diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 6fa1ccea..8fd45489 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -2,6 +2,7 @@ import logging import os import unittest +from pathlib import Path import earthaccess import pytest @@ -84,7 +85,7 @@ def test_download(tmp_path, selection, use_url): result = results[selection] files = earthaccess.download(result, str(tmp_path)) assertions.assertIsInstance(files, list) - assert all(os.path.exists(f) for f in files) + assert all(Path(f).exists() for f in files) def test_auth_environ(): diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index d1bfae1e..a4879d12 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -30,7 +30,7 @@ def activate_netrc(): f.write( f"machine urs.earthdata.nasa.gov login {username} password {password}\n" ) - os.chmod(NETRC_PATH, 0o600) + NETRC_PATH.chmod(0o600) def delete_netrc(): diff --git a/tests/integration/test_cloud_download.py b/tests/integration/test_cloud_download.py index 63a05b93..a9b9432c 100644 --- a/tests/integration/test_cloud_download.py +++ b/tests/integration/test_cloud_download.py @@ -166,4 +166,4 @@ def test_multi_file_granule(tmp_path): urls = granules[0].data_links() assert len(urls) > 1 files = earthaccess.download(granules, str(tmp_path)) - assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files)) + assert set([Path(f).name for f in urls]) == set([Path(f).name for f in files]) diff --git a/tests/integration/test_kerchunk.py b/tests/integration/test_kerchunk.py index 58f93077..2e981cce 100644 --- a/tests/integration/test_kerchunk.py +++ b/tests/integration/test_kerchunk.py @@ -1,6 +1,7 @@ import logging import os import unittest +from pathlib import Path import earthaccess import pytest @@ -32,14 +33,14 @@ def granules(): @pytest.mark.parametrize("protocol", ["", "file://"]) def test_consolidate_metadata_outfile(tmp_path, granules, protocol): outfile = f"{protocol}{tmp_path / 'metadata.json'}" - assert not os.path.exists(outfile) + assert not Path(outfile).exists() result = earthaccess.consolidate_metadata( granules, outfile=outfile, access="indirect", kerchunk_options={"concat_dims": "Time"}, ) - assert os.path.exists(strip_protocol(outfile)) + assert Path(strip_protocol(outfile)).exists() assert result == outfile