Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement/use pathlib #459

Merged
merged 13 commits into from
Feb 17, 2024
6 changes: 2 additions & 4 deletions earthaccess/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ def _netrc(self) -> bool:
try:
my_netrc = Netrc()
except FileNotFoundError as err:
raise FileNotFoundError(
f"No .netrc found in {os.path.expanduser('~')}"
) from err
raise FileNotFoundError(f"No .netrc found in {Path.home()}") from err
except NetrcParseError as err:
raise NetrcParseError("Unable to parse .netrc") from err
if my_netrc["urs.earthdata.nasa.gov"] is not None:
Expand Down Expand Up @@ -365,7 +363,7 @@ def _persist_user_credentials(self, username: str, password: str) -> bool:
try:
netrc_path = Path().home().joinpath(".netrc")
netrc_path.touch(exist_ok=True)
os.chmod(netrc_path.absolute(), 0o600)
netrc_path.chmod(0o600)
except Exception as e:
print(e)
return False
Expand Down
42 changes: 19 additions & 23 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import os
import shutil
import traceback
from functools import lru_cache
Expand Down Expand Up @@ -443,7 +442,7 @@ def _open_urls(
def get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: Optional[str] = None,
local_path: Optional[Path] = None,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -466,11 +465,10 @@ def get(
List of downloaded files
"""
if local_path is None:
local_path = os.path.join(
".",
"data",
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
)
today = datetime.datetime.today().strftime("%Y-%m-%d")
uuid = uuid4().hex[:6]
local_path = Path.cwd() / "data" / f"{today}-{uuid}"

if len(granules):
files = self._get(granules, local_path, provider, threads)
return files
Expand All @@ -481,7 +479,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -509,7 +507,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -525,8 +523,8 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
s3_fs.get(file, local_path)
file_name = os.path.join(local_path, os.path.basename(file))
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
Expand All @@ -539,7 +537,7 @@ def _get_urls(
def _get_granules(
self,
granules: List[DataGranule],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -571,8 +569,8 @@ def _get_granules(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = os.path.join(local_path, os.path.basename(file))
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
Expand All @@ -581,7 +579,7 @@ def _get_granules(
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)

def _download_file(self, url: str, directory: str) -> str:
def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.

Parameters:
Expand All @@ -595,9 +593,8 @@ def _download_file(self, url: str, directory: str) -> str:
if "opendap" in url and url.endswith(".html"):
url = url.replace(".html", "")
local_filename = url.split("/")[-1]
path = Path(directory) / Path(local_filename)
local_path = str(path)
if not os.path.exists(local_path):
path = directory / Path(local_filename)
if not path.exists():
try:
session = self.auth.get_session()
with session.get(
Expand All @@ -606,7 +603,7 @@ def _download_file(self, url: str, directory: str) -> str:
allow_redirects=True,
) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
with open(path, "wb") as f:
# This is to cap memory usage for large files at 1MB per write to disk per thread
# https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content
shutil.copyfileobj(r.raw, f, length=1024 * 1024)
Expand All @@ -616,10 +613,10 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_path
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: str, threads: int = 8
self, urls: List[str], directory: Path, threads: int = 8
) -> List[Any]:
"""Downloads a list of URLS into the data directory.

Expand All @@ -638,8 +635,7 @@ def _download_onprem_granules(
raise ValueError(
"We need to be logged into NASA EDL in order to download data granules"
)
if not os.path.exists(directory):
os.makedirs(directory)
directory.mkdir(parents=True, exist_ok=True)

arguments = [(url, directory) for url in urls]
results = pqdm(
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import unittest
from pathlib import Path

import earthaccess
import pytest
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_download(tmp_path, selection, use_url):
result = results[selection]
files = earthaccess.download(result, str(tmp_path))
assertions.assertIsInstance(files, list)
assert all(os.path.exists(f) for f in files)
assert all(Path(f).exists() for f in files)


def test_auth_environ():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def activate_netrc():
f.write(
f"machine urs.earthdata.nasa.gov login {username} password {password}\n"
)
os.chmod(NETRC_PATH, 0o600)
NETRC_PATH.chmod(0o600)


def delete_netrc():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_cloud_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@ def test_multi_file_granule(tmp_path):
urls = granules[0].data_links()
assert len(urls) > 1
files = earthaccess.download(granules, str(tmp_path))
assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files))
assert set([Path(f).name for f in urls]) == set([Path(f).name for f in files])
5 changes: 3 additions & 2 deletions tests/integration/test_kerchunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import unittest
from pathlib import Path

import earthaccess
import pytest
Expand Down Expand Up @@ -32,14 +33,14 @@ def granules():
@pytest.mark.parametrize("protocol", ["", "file://"])
def test_consolidate_metadata_outfile(tmp_path, granules, protocol):
outfile = f"{protocol}{tmp_path / 'metadata.json'}"
assert not os.path.exists(outfile)
assert not Path(outfile).exists()
result = earthaccess.consolidate_metadata(
granules,
outfile=outfile,
access="indirect",
kerchunk_options={"concat_dims": "Time"},
)
assert os.path.exists(strip_protocol(outfile))
assert Path(strip_protocol(outfile)).exists()
assert result == outfile


Expand Down
Loading