Skip to content

Commit

Permalink
Merge pull request #459 from kvenkman/improvement/use-pathlib
Browse files Browse the repository at this point in the history
Improvement/use pathlib
  • Loading branch information
jhkennedy authored Feb 17, 2024
2 parents 6b86f49 + 453e7f9 commit 8b00ea3
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 32 deletions.
6 changes: 2 additions & 4 deletions earthaccess/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ def _netrc(self) -> bool:
try:
my_netrc = Netrc()
except FileNotFoundError as err:
raise FileNotFoundError(
f"No .netrc found in {os.path.expanduser('~')}"
) from err
raise FileNotFoundError(f"No .netrc found in {Path.home()}") from err
except NetrcParseError as err:
raise NetrcParseError("Unable to parse .netrc") from err
if my_netrc["urs.earthdata.nasa.gov"] is not None:
Expand Down Expand Up @@ -365,7 +363,7 @@ def _persist_user_credentials(self, username: str, password: str) -> bool:
try:
netrc_path = Path().home().joinpath(".netrc")
netrc_path.touch(exist_ok=True)
os.chmod(netrc_path.absolute(), 0o600)
netrc_path.chmod(0o600)
except Exception as e:
print(e)
return False
Expand Down
42 changes: 19 additions & 23 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import os
import shutil
import traceback
from functools import lru_cache
Expand Down Expand Up @@ -443,7 +442,7 @@ def _open_urls(
def get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Optional[str] = None,
local_path: Optional[Path] = None,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -466,11 +465,10 @@ def get(
List of downloaded files
"""
if local_path is None:
local_path = os.path.join(
".",
"data",
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
)
today = datetime.datetime.today().strftime("%Y-%m-%d")
uuid = uuid4().hex[:6]
local_path = Path.cwd() / "data" / f"{today}-{uuid}"

if len(granules):
files = self._get(granules, local_path, provider, threads)
return files
Expand All @@ -481,7 +479,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -509,7 +507,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -525,8 +523,8 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
s3_fs.get(file, local_path)
file_name = os.path.join(local_path, os.path.basename(file))
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
Expand All @@ -539,7 +537,7 @@ def _get_urls(
def _get_granules(
self,
granules: List[DataGranule],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -571,8 +569,8 @@ def _get_granules(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = os.path.join(local_path, os.path.basename(file))
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
Expand All @@ -581,7 +579,7 @@ def _get_granules(
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)

def _download_file(self, url: str, directory: str) -> str:
def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Parameters:
Expand All @@ -595,9 +593,8 @@ def _download_file(self, url: str, directory: str) -> str:
if "opendap" in url and url.endswith(".html"):
url = url.replace(".html", "")
local_filename = url.split("/")[-1]
path = Path(directory) / Path(local_filename)
local_path = str(path)
if not os.path.exists(local_path):
path = directory / Path(local_filename)
if not path.exists():
try:
session = self.auth.get_session()
with session.get(
Expand All @@ -606,7 +603,7 @@ def _download_file(self, url: str, directory: str) -> str:
allow_redirects=True,
) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
with open(path, "wb") as f:
# This is to cap memory usage for large files at 1MB per write to disk per thread
# https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content
shutil.copyfileobj(r.raw, f, length=1024 * 1024)
Expand All @@ -616,10 +613,10 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_path
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: str, threads: int = 8
self, urls: List[str], directory: Path, threads: int = 8
) -> List[Any]:
"""Downloads a list of URLS into the data directory.
Expand All @@ -638,8 +635,7 @@ def _download_onprem_granules(
raise ValueError(
"We need to be logged into NASA EDL in order to download data granules"
)
if not os.path.exists(directory):
os.makedirs(directory)
directory.mkdir(parents=True, exist_ok=True)

arguments = [(url, directory) for url in urls]
results = pqdm(
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import unittest
from pathlib import Path

import earthaccess
import pytest
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_download(tmp_path, selection, use_url):
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
assert all(os.path.exists(f) for f in files)
assert all(Path(f).exists() for f in files)


def test_auth_environ():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def activate_netrc():
f.write(
f"machine urs.earthdata.nasa.gov login {username} password {password}\n"
)
os.chmod(NETRC_PATH, 0o600)
NETRC_PATH.chmod(0o600)


def delete_netrc():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_cloud_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@ def test_multi_file_granule(tmp_path):
urls = granules[0].data_links()
assert len(urls) > 1
files = earthaccess.download(granules, str(tmp_path))
assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files))
assert set([Path(f).name for f in urls]) == set([Path(f).name for f in files])
5 changes: 3 additions & 2 deletions tests/integration/test_kerchunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import unittest
from pathlib import Path

import earthaccess
import pytest
Expand Down Expand Up @@ -32,14 +33,14 @@ def granules():
@pytest.mark.parametrize("protocol", ["", "file://"])
def test_consolidate_metadata_outfile(tmp_path, granules, protocol):
outfile = f"{protocol}{tmp_path / 'metadata.json'}"
assert not os.path.exists(outfile)
assert not Path(outfile).exists()
result = earthaccess.consolidate_metadata(
granules,
outfile=outfile,
access="indirect",
kerchunk_options={"concat_dims": "Time"},
)
assert os.path.exists(strip_protocol(outfile))
assert Path(strip_protocol(outfile)).exists()
assert result == outfile


Expand Down

0 comments on commit 8b00ea3

Please sign in to comment.