diff --git a/earthaccess/api.py b/earthaccess/api.py index 6b758aa2..741f06cc 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -85,7 +85,7 @@ def search_datasets(count: int = -1, **kwargs: Any) -> List[DataCollection]: return query.get_all() -def search_data(count: int = -1, **kwargs: Any) -> List[DataGranule]: +def search_data(count: int = -1, **kwargs: Any) -> DataGranules: """Search dataset granules using NASA's CMR. [https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html](https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html) @@ -122,14 +122,13 @@ def search_data(count: int = -1, **kwargs: Any) -> List[DataGranule]: ``` """ if earthaccess.__auth__.authenticated: - query = DataGranules(earthaccess.__auth__).parameters(**kwargs) + results = DataGranules(earthaccess.__auth__).parameters(**kwargs) else: - query = DataGranules().parameters(**kwargs) - granules_found = query.hits() - logger.info(f"Granules found: {granules_found}") - if count > 0: - return query.get(count) - return query.get_all() + results = DataGranules().parameters(**kwargs) + + results.load(count) + logger.info(f"Granules found: {len(results)}") + return results def search_services(count: int = -1, **kwargs: Any) -> List[Any]: diff --git a/earthaccess/search.py b/earthaccess/search.py index 63bb0e8f..6f5992e5 100644 --- a/earthaccess/search.py +++ b/earthaccess/search.py @@ -960,13 +960,13 @@ def __iter__(self): def __len__(self): return len(self.granules) - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> DataGranule: # FIXME: allow slicing # if isinstance(index, slice): # return DataGranules(self.jobs[index]) return self.granules[index] - def __setitem__(self, index: int, granule: DataGranule): + def __setitem__(self, index: int, granule: DataGranule) -> 'DataGranules': self.granules[index] = granule return self @@ -974,11 +974,11 @@ def __setitem__(self, index: int, granule: DataGranule): # def __contains__(self, job: Job): # return job in self.jobs - def __eq__(self, other: 'DataGranules'): + def __eq__(self, other: 'DataGranules') -> bool: # FIXME: compare query parameters too? what does it mean to be equal? - return self.graunles == other.granules + return self.granules == other.granules # TODO: display methods - def __repr__(self): + def __repr__(self) -> str: reprs = ", ".join([granule.__repr__() for granule in self.granules]) return f'DataGranules([{reprs}])'