Skip to content

Commit

Permalink
update typehints and code for _get methods to use Path consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
kvenkman committed Feb 14, 2024
1 parent 95ce27f commit c68e634
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def get(
def _get(
self,
granules: Union[List[DataGranule], List[str]],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -508,7 +508,7 @@ def _get(
def _get_urls(
self,
granules: List[str],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand All @@ -524,8 +524,8 @@ def _get_urls(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this parallel or concurrent
for file in data_links:
s3_fs.get(file, local_path)
file_name = Path(local_path) / Path(file).name
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
Expand All @@ -538,7 +538,7 @@ def _get_urls(
def _get_granules(
self,
granules: List[DataGranule],
local_path: str,
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
) -> List[str]:
Expand Down Expand Up @@ -570,15 +570,15 @@ def _get_granules(
s3_fs = self.get_s3fs_session(provider=provider)
# TODO: make this async
for file in data_links:
s3_fs.get(file, local_path)
file_name = Path(local_path) / Path(file).name
s3_fs.get(file, str(local_path))
file_name = local_path / Path(file).name
print(f"Downloaded: {file_name}")
downloaded_files.append(file_name)
return downloaded_files
else:
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(data_links, str(local_path), threads)

def _download_file(self, url: str, directory: str) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand All @@ -595,8 +595,7 @@ def _download_file(self, url: str, directory: str) -> str:
url = url.replace(".html", "")
local_filename = url.split("/")[-1]
path = Path(directory) / Path(local_filename)
local_path = str(path)
if not Path(local_path).exists():
if not path.exists():
try:
session = self.auth.get_session()
with session.get(
Expand All @@ -605,7 +604,7 @@ def _download_file(self, url: str, directory: str) -> str:
allow_redirects=True,
) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
with open(path, "wb") as f:
# This is to cap memory usage for large files at 1MB per write to disk per thread
# https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content
shutil.copyfileobj(r.raw, f, length=1024 * 1024)
Expand All @@ -615,7 +614,7 @@ def _download_file(self, url: str, directory: str) -> str:
raise Exception
else:
print(f"File {local_filename} already downloaded")
return local_path
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: str, threads: int = 8
Expand Down

0 comments on commit c68e634

Please sign in to comment.