From b8d6326865d119a1d026d8fc54155d16bfffa4dd Mon Sep 17 00:00:00 2001 From: Chuck Daniels Date: Mon, 8 Apr 2024 11:15:30 -0400 Subject: [PATCH] Fix CMR-related type hints There were a number of type hints in `search.py` and `api.py` related to CMR queries that were incorrect. These were fixed. In addition, there were a number of other static type errors that were masked because of ignored `cmr` imports. Added type stubs for `python_cmr` library to unmask and address these additional type errors. Limited static type changes as much as possible to only functions and methods dealing with CMR queries and results to keep this PR manageable. Fixes #508 --- .gitignore | 75 ++++++- CHANGELOG.md | 4 + earthaccess/api.py | 23 +-- earthaccess/search.py | 428 ++++++++++++++++++++++++++++++---------- earthaccess/typing_.py | 49 +++++ pyproject.toml | 28 ++- scripts/lint.sh | 5 +- stubs/cmr/__init__.pyi | 10 + stubs/cmr/queries.pyi | 112 +++++++++++ tests/__init__.py | 0 tests/unit/test_auth.py | 8 +- 11 files changed, 611 insertions(+), 131 deletions(-) create mode 100644 earthaccess/typing_.py create mode 100644 stubs/cmr/__init__.pyi create mode 100644 stubs/cmr/queries.pyi create mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore index ea6fc19e..ab1f58ce 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,78 @@ docs/tutorials/data tests/integration/data .ruff_cache -# OS X +notebooks/data/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Created by https://www.toptal.com/developers/gitignore/api/macos +# Edit at https://www.toptal.com/developers/gitignore?templates=macos + +### macOS ### +# General .DS_Store +.AppleDouble +.LSOverride -notebooks/data/ +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +# End of https://www.toptal.com/developers/gitignore/api/macos + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode + +### VisualStudioCode ### +.vscode/ + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode + +# Created by https://www.toptal.com/developers/gitignore/api/direnv +# Edit at https://www.toptal.com/developers/gitignore?templates=direnv + +### direnv ### +.direnv +.envrc + +# End of https://www.toptal.com/developers/gitignore/api/direnv diff --git a/CHANGELOG.md b/CHANGELOG.md index bbe66ba8..0cbf0326 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +* Enhancements: + * Corrected and enhanced static type hints for functions and methods that make + CMR queries or handle CMR query results (#508) + ## [v0.9.0] 2024-02-28 * Bug fixes: diff --git a/earthaccess/api.py b/earthaccess/api.py index a7d35fb0..ab2d7b0a 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Optional, Type, Union - import requests import s3fs from fsspec import AbstractFileSystem @@ -7,9 +5,10 @@ import earthaccess from .auth import Auth -from .results import DataGranule +from .results import DataCollection, DataGranule from .search import CollectionQuery, DataCollections, DataGranules, GranuleQuery from .store import Store +from .typing_ import Any, Dict, List, Optional, Union from .utils import _validation as validate @@ -28,9 +27,7 @@ def _normalize_location(location: Optional[str]) -> Optional[str]: return location -def search_datasets( - count: int = -1, **kwargs: Any -) -> List[earthaccess.results.DataCollection]: +def search_datasets(count: int = -1, **kwargs: Any) -> List[DataCollection]: """Search datasets using NASA's CMR. [https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html](https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html) @@ -78,9 +75,7 @@ def search_datasets( return query.get_all() -def search_data( - count: int = -1, **kwargs: Any -) -> List[earthaccess.results.DataGranule]: +def search_data(count: int = -1, **kwargs: Any) -> List[DataGranule]: """Search dataset granules using NASA's CMR. [https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html](https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html) @@ -194,7 +189,7 @@ def download( def open( - granules: Union[List[str], List[earthaccess.results.DataGranule]], + granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, ) -> List[AbstractFileSystem]: """Returns a list of fsspec file-like objects that can be used to access files @@ -216,7 +211,7 @@ def open( def get_s3_credentials( daac: Optional[str] = None, provider: Optional[str] = None, - results: Optional[List[earthaccess.results.DataGranule]] = None, + results: Optional[List[DataGranule]] = None, ) -> Dict[str, Any]: """Returns temporary (1 hour) credentials for direct access to NASA S3 buckets. We can use the daac name, the provider, or a list of results from earthaccess.search_data(). @@ -239,7 +234,7 @@ def get_s3_credentials( return earthaccess.__auth__.get_s3_credentials(daac=daac, provider=provider) -def collection_query() -> Type[CollectionQuery]: +def collection_query() -> CollectionQuery: """Returns a query builder instance for NASA collections (datasets). Returns: @@ -252,7 +247,7 @@ def collection_query() -> Type[CollectionQuery]: return query_builder -def granule_query() -> Type[GranuleQuery]: +def granule_query() -> GranuleQuery: """Returns a query builder instance for data granules Returns: @@ -311,7 +306,7 @@ def get_requests_https_session() -> requests.Session: def get_s3fs_session( daac: Optional[str] = None, provider: Optional[str] = None, - results: Optional[earthaccess.results.DataGranule] = None, + results: Optional[DataGranule] = None, ) -> s3fs.S3FileSystem: """Returns a fsspec s3fs file session for direct access when we are in us-west-2. diff --git a/earthaccess/search.py b/earthaccess/search.py index 4ac8cb61..00dcbcc7 100644 --- a/earthaccess/search.py +++ b/earthaccess/search.py @@ -1,14 +1,31 @@ import datetime as dt from inspect import getmembers, ismethod -from typing import Any, Dict, List, Optional, Tuple, Type, Union -import dateutil.parser as parser # type: ignore -from cmr import CollectionQuery, GranuleQuery # type: ignore +import dateutil.parser as parser from requests import exceptions, session +from cmr import CollectionQuery, GranuleQuery + from .auth import Auth from .daac import find_provider, find_provider_by_shortname from .results import DataCollection, DataGranule +from .typing_ import ( + Any, + Dict, + List, + Never, + Optional, + Self, + Sequence, + SupportsFloat, + Tuple, + TypeAlias, + Union, + override, +) + +FloatLike: TypeAlias = Union[str, SupportsFloat] +PointLike: TypeAlias = Tuple[FloatLike, FloatLike] class DataCollections(CollectionQuery): @@ -18,7 +35,7 @@ class DataCollections(CollectionQuery): the response has to be in umm_json to use the result classes. """ - _fields = None + _fields: Optional[List[str]] = None _format = "umm_json" _valid_formats_regex = [ "json", @@ -51,11 +68,14 @@ def __init__(self, auth: Optional[Auth] = None, *args: Any, **kwargs: Any) -> No self.params["has_granules"] = True self.params["include_granule_counts"] = True - def hits(self) -> int: + @override + def hits(self) -> Union[int, Never]: """Returns the number of hits the current query will return. This is done by making a lightweight query to CMR and inspecting the returned headers. Restricted datasets will always return zero results even if there are results. + Raises: + RuntimeError: if the CMR query fails Returns: The number of results reported by CMR. @@ -71,7 +91,8 @@ def hits(self) -> int: return int(response.headers["CMR-Hits"]) - def concept_id(self, IDs: List[str]) -> Type[CollectionQuery]: + @override + def concept_id(self, IDs: Sequence[str]) -> Union[Self, Never]: """Filter by concept ID. For example: C1299783579-LPDAAC_ECS or G1327299284-LPDAAC_ECS, S12345678-LPDAAC_ECS @@ -84,22 +105,30 @@ def concept_id(self, IDs: List[str]) -> Type[CollectionQuery]: Parameters: IDs: ID(s) to search by. Can be provided as a string or list of strings. + + Raises: + ValueError: if an ID does not start with a valid prefix + + Returns: + self """ - super().concept_id(IDs) - return self + return super().concept_id(IDs) - def keyword(self, text: str) -> Type[CollectionQuery]: + @override + def keyword(self, text: str) -> Self: """Case-insensitive and wildcard (*) search through over two dozen fields in a CMR collection record. This allows for searching against fields like summary and science keywords. Parameters: text: text to search for + + Returns: + self """ - super().keyword(text) - return self + return super().keyword(text) - def doi(self, doi: str) -> Type[CollectionQuery]: + def doi(self, doi: str) -> Union[Self, Never]: """Search datasets by DOI. ???+ Tip @@ -109,6 +138,12 @@ def doi(self, doi: str) -> Type[CollectionQuery]: Parameters: doi: DOI of a datasets, e.g. 10.5067/AQR50-3Q7CS + + Raises: + TypeError: if `doi` is not of type `str` + + Returns: + self """ if not isinstance(doi, str): raise TypeError("doi must be of type str") @@ -116,7 +151,7 @@ def doi(self, doi: str) -> Type[CollectionQuery]: self.params["doi"] = doi return self - def instrument(self, instrument: str) -> Type[CollectionQuery]: + def instrument(self, instrument: str) -> Union[Self, Never]: """Searh datasets by instrument ???+ Tip @@ -125,6 +160,12 @@ def instrument(self, instrument: str) -> Type[CollectionQuery]: Parameters: instrument (String): instrument of a datasets, e.g. instrument=GEDI + + Raises: + TypeError: if `instrument` is not of type `str` + + Returns: + self """ if not isinstance(instrument, str): raise TypeError("instrument must be of type str") @@ -132,7 +173,7 @@ def instrument(self, instrument: str) -> Type[CollectionQuery]: self.params["instrument"] = instrument return self - def project(self, project: str) -> Type[CollectionQuery]: + def project(self, project: str) -> Union[Self, Never]: """Searh datasets by associated project ???+ Tip @@ -142,6 +183,12 @@ def project(self, project: str) -> Type[CollectionQuery]: Parameters: project (String): associated project of a datasets, e.g. project=EMIT + + Raises: + TypeError: if `project` is not of type `str` + + Returns: + self """ if not isinstance(project, str): raise TypeError("project must be of type str") @@ -149,7 +196,8 @@ def project(self, project: str) -> Type[CollectionQuery]: self.params["project"] = project return self - def parameters(self, **kwargs: Any) -> Type[CollectionQuery]: + @override + def parameters(self, **kwargs: Any) -> Union[Self, Never]: """Provide query parameters as keyword arguments. The keyword needs to match the name of the method, and the value should either be the value or a tuple of values. @@ -159,12 +207,16 @@ def parameters(self, **kwargs: Any) -> Type[CollectionQuery]: temporal=("2015-01","2015-02"), point=(42.5, -101.25)) ``` + + Raises: + ValueError: if the name of a keyword argument is not the name of a method + TypeError: if the value of a keyword argument is not an argument or tuple + of arguments matching the number and type(s) of the method's parameters + Returns: - Query instance + self """ - methods = {} - for name, func in getmembers(self, predicate=ismethod): - methods[name] = func + methods = dict(getmembers(self, predicate=ismethod)) for key, val in kwargs.items(): # verify the key matches one of our methods @@ -185,25 +237,31 @@ def print_help(self, method: str = "fields") -> None: print([method for method in dir(self) if method.startswith("_") is False]) help(getattr(self, method)) - def fields(self, fields: Optional[List[str]] = None) -> Type[CollectionQuery]: + def fields(self, fields: Optional[List[str]] = None) -> Self: """Masks the response by only showing the fields included in this list. Parameters: fields (List): list of fields to show, these fields come from the UMM model e.g. Abstract, Title + + Returns: + self """ self._fields = fields return self - def debug(self, debug: bool = True) -> Type[CollectionQuery]: + def debug(self, debug: bool = True) -> Self: """If True, prints the actual query to CMR, notice that the pagination happens in the headers. Parameters: debug (Boolean): Print CMR query. + + Returns: + self """ - self._debug = True + self._debug = debug return self - def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: + def cloud_hosted(self, cloud_hosted: bool = True) -> Union[Self, Never]: """Only match granules that are hosted in the cloud. This is valid for public collections. ???+ Tip @@ -212,6 +270,12 @@ def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: Parameters: cloud_hosted: True to require granules only be online + + Raises: + TypeError: if `cloud_hosted` is not of type `bool` + + Returns: + self """ if not isinstance(cloud_hosted, bool): raise TypeError("cloud_hosted must be of type bool") @@ -222,7 +286,8 @@ def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: self.params["provider"] = provider return self - def provider(self, provider: str = "") -> Type[CollectionQuery]: + @override + def provider(self, provider: str) -> Self: """Only match collections from a given provider. A NASA datacenter or DAAC can have one or more providers. @@ -231,23 +296,32 @@ def provider(self, provider: str = "") -> Type[CollectionQuery]: Parameters: provider: a provider code for any DAAC, e.g. POCLOUD, NSIDC_CPRD, etc. + + Returns: + self """ self.params["provider"] = provider return self - def data_center(self, data_center_name: str = "") -> Type[CollectionQuery]: - """An alias name for `daac()`. + def data_center(self, data_center_name: str) -> Self: + """An alias name for the `daac` method. Parameters: data_center_name: DAAC shortname, e.g. NSIDC, PODAAC, GESDISC + + Returns: + self """ return self.daac(data_center_name) - def daac(self, daac_short_name: str = "") -> Type[CollectionQuery]: + def daac(self, daac_short_name: str) -> Self: """Only match collections for a given DAAC, by default the on-prem collections for the DAAC. Parameters: daac_short_name: a DAAC shortname, e.g. NSIDC, PODAAC, GESDISC + + Returns: + self """ if "cloud_hosted" in self.params: cloud_hosted = self.params["cloud_hosted"] @@ -257,7 +331,8 @@ def daac(self, daac_short_name: str = "") -> Type[CollectionQuery]: self.params["provider"] = find_provider(daac_short_name, cloud_hosted) return self - def get(self, limit: int = 2000) -> list: + @override + def get(self, limit: int = 2000) -> Union[List[Any], Never]: """Get all the collections (datasets) that match with our current parameters up to some limit, even if spanning multiple pages. @@ -269,14 +344,20 @@ def get(self, limit: int = 2000) -> list: Parameters: limit: The number of results to return + Raises: + RuntimeError: if the CMR query fails + Returns: - query results as a list of `DataCollection` instances. + List (possibly empty) of query results. Elements are `DataCollection` + instances if the the query format is `"umm_json"`, generic `dict`s if the + query format is `"json"`, or an unparsed `str` element per page of results + otherwise. """ page_size = min(limit, 2000) url = self._build_url() - results: List = [] + results: List[Any] = [] page = 1 while len(results) < limit: params = {"page_size": page_size, "page_num": page} @@ -311,12 +392,13 @@ def get(self, limit: int = 2000) -> list: return results + @override def temporal( self, date_from: Optional[Union[str, dt.datetime]] = None, date_to: Optional[Union[str, dt.datetime]] = None, exclude_boundary: bool = False, - ) -> Type[CollectionQuery]: + ) -> Union[Self, Never]: """Filter by an open or closed date range. Dates can be provided as datetime objects or ISO 8601 formatted strings. Multiple ranges can be provided by successive calls to this method before calling execute(). @@ -325,6 +407,15 @@ def temporal( date_from (String or Datetime object): earliest date of temporal range date_to (String or Datetime object): latest date of temporal range exclude_boundary (Boolean): whether or not to exclude the date_from/to in the matched range. + + Raises: + ValueError: if `date_from` or `date_to` is a non-`None` value that is + neither a datetime object nor a string that can be parsed as a datetime + object; or if `date_from` and `date_to` are both datetime objects (or + parsable as such) and `date_from` is greater than `date_to` + + Returns: + self """ DEFAULT = dt.datetime(1979, 1, 1) if date_from is not None and not isinstance(date_from, dt.datetime): @@ -341,8 +432,7 @@ def temporal( print("The provided end date was not recognized") date_to = "" - super().temporal(date_from, date_to, exclude_boundary) - return self + return super().temporal(date_from, date_to, exclude_boundary) class DataGranules(GranuleQuery): @@ -365,20 +455,27 @@ class DataGranules(GranuleQuery): "umm_json", ] - def __init__(self, auth: Any = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, auth: Optional[Auth] = None, *args: Any, **kwargs: Any) -> None: """Base class for Granule and Collection CMR queries.""" super().__init__(*args, **kwargs) - self.session = session() - if auth is not None and auth.authenticated: + + self.session = ( # To search, we need the new bearer tokens from NASA Earthdata - self.session = auth.get_session(bearer_token=True) + auth.get_session(bearer_token=True) + if auth is not None and auth.authenticated + else session() + ) self._debug = False - def hits(self) -> int: + @override + def hits(self) -> Union[int, Never]: """Returns the number of hits the current query will return. This is done by making a lightweight query to CMR and inspecting the returned headers. + Raises: + RuntimeError: if the CMR query fails + Returns: The number of results reported by CMR. """ @@ -397,7 +494,8 @@ def hits(self) -> int: return int(response.headers["CMR-Hits"]) - def parameters(self, **kwargs: Any) -> Type[CollectionQuery]: + @override + def parameters(self, **kwargs: Any) -> Union[Self, Never]: """Provide query parameters as keyword arguments. The keyword needs to match the name of the method, and the value should either be the value or a tuple of values. @@ -408,8 +506,13 @@ def parameters(self, **kwargs: Any) -> Type[CollectionQuery]: point=(42.5, -101.25)) ``` + Raises: + ValueError: if the name of a keyword argument is not the name of a method + TypeError: if the value of a keyword argument is not an argument or tuple + of arguments matching the number and type(s) of the method's parameters + Returns: - Query instance + self """ methods = {} for name, func in getmembers(self, predicate=ismethod): @@ -428,7 +531,8 @@ def parameters(self, **kwargs: Any) -> Type[CollectionQuery]: return self - def provider(self, provider: str = "") -> Type[CollectionQuery]: + @override + def provider(self, provider: str) -> Self: """Only match collections from a given provider. A NASA datacenter or DAAC can have one or more providers. For example, PODAAC is a data center or DAAC, @@ -437,23 +541,32 @@ def provider(self, provider: str = "") -> Type[CollectionQuery]: Parameters: provider: a provider code for any DAAC, e.g. POCLOUD, NSIDC_CPRD, etc. + + Returns: + self """ self.params["provider"] = provider return self - def data_center(self, data_center_name: str = "") -> Type[CollectionQuery]: + def data_center(self, data_center_name: str) -> Self: """An alias name for `daac()`. Parameters: data_center_name (String): DAAC shortname, e.g. NSIDC, PODAAC, GESDISC + + Returns: + self """ return self.daac(data_center_name) - def daac(self, daac_short_name: str = "") -> Type[CollectionQuery]: + def daac(self, daac_short_name: str) -> Self: """Only match collections for a given DAAC. Default to on-prem collections for the DAAC. Parameters: daac_short_name: a DAAC shortname, e.g. NSIDC, PODAAC, GESDISC + + Returns: + self """ if "cloud_hosted" in self.params: cloud_hosted = self.params["cloud_hosted"] @@ -463,18 +576,25 @@ def daac(self, daac_short_name: str = "") -> Type[CollectionQuery]: self.params["provider"] = find_provider(daac_short_name, cloud_hosted) return self - def orbit_number(self, orbit1: int, orbit2: int) -> Type[GranuleQuery]: + @override + def orbit_number( + self, + orbit1: FloatLike, + orbit2: Optional[FloatLike] = None, + ) -> Self: """Filter by the orbit number the granule was acquired during. Either a single orbit can be targeted or a range of orbits. Parameter: orbit1: orbit to target (lower limit of range when orbit2 is provided) orbit2: upper limit of range + + Returns: + self """ - super().orbit_number(orbit1, orbit2) - return self + return super().orbit_number(orbit1, orbit2) - def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: + def cloud_hosted(self, cloud_hosted: bool = True) -> Union[Self, Never]: """Only match granules that are hosted in the cloud. This is valid for public collections and when using the short_name parameter. Concept-Id is unambiguous. @@ -485,6 +605,12 @@ def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: Parameters: cloud_hosted: True to require granules only be online + + Raises: + TypeError: if `cloud_hosted` is not of type `bool` + + Returns: + self """ if not isinstance(cloud_hosted, bool): raise TypeError("cloud_hosted must be of type bool") @@ -497,7 +623,7 @@ def cloud_hosted(self, cloud_hosted: bool = True) -> Type[CollectionQuery]: self.params["provider"] = provider return self - def granule_name(self, granule_name: str) -> Type[CollectionQuery]: + def granule_name(self, granule_name: str) -> Union[Self, Never]: """Find granules matching either granule ur or producer granule id, queries using the readable_granule_name metadata field. @@ -507,6 +633,12 @@ def granule_name(self, granule_name: str) -> Type[CollectionQuery]: Parameters: granule_name: granule name (accepts wildcards) + + Raises: + TypeError: if `granule_name` is not of type `str` + + Returns: + self """ if not isinstance(granule_name, str): raise TypeError("granule_name must be of type string") @@ -515,54 +647,89 @@ def granule_name(self, granule_name: str) -> Type[CollectionQuery]: self.params["options[readable_granule_name][pattern]"] = True return self - def online_only(self, online_only: bool = True) -> Type[GranuleQuery]: + @override + def online_only(self, online_only: bool = True) -> Union[Self, Never]: """Only match granules that are listed online and not available for download. The opposite of this method is downloadable(). Parameters: online_only: True to require granules only be online + + Raises: + TypeError: if `online_only` is not of type `bool` + + Returns: + self """ - super().online_only(online_only) - return self + return super().online_only(online_only) - def day_night_flag(self, day_night_flag: str) -> Type[GranuleQuery]: + @override + def day_night_flag(self, day_night_flag: str) -> Union[Self, Never]: """Filter by period of the day the granule was collected during. Parameters: day_night_flag: "day", "night", or "unspecified" + + Raises: + TypeError: if `day_night_flag` is not of type `str` + ValueError: if `day_night_flag` is not one of `"day"`, `"night"`, or + `"unspecified"` + + Returns: + self """ - super().day_night_flag(day_night_flag) - return self + return super().day_night_flag(day_night_flag) - def instrument(self, instrument: str = "") -> Type[GranuleQuery]: + @override + def instrument(self, instrument: str) -> Union[Self, Never]: """Filter by the instrument associated with the granule. Parameters: instrument: name of the instrument + + Raises: + ValueError: if `instrument` is not a non-empty string + + Returns: + self """ - super().instrument(instrument) - return self + return super().instrument(instrument) - def platform(self, platform: str = "") -> Type[GranuleQuery]: + @override + def platform(self, platform: str) -> Union[Self, Never]: """Filter by the satellite platform the granule came from. Parameters: platform: name of the satellite + + Raises: + ValueError: if `platform` is not a non-empty string + + Returns: + self """ - super().platform(platform) - return self + return super().platform(platform) + @override def cloud_cover( - self, min_cover: int = 0, max_cover: int = 100 - ) -> Type[GranuleQuery]: + self, + min_cover: Optional[FloatLike] = 0, + max_cover: Optional[FloatLike] = 100, + ) -> Union[Self, Never]: """Filter by the percentage of cloud cover present in the granule. Parameters: min_cover: minimum percentage of cloud cover max_cover: maximum percentage of cloud cover + + Raises: + ValueError: if `min_cover` or `max_cover` is not convertible to a float, + or if `min_cover` is greater than `max_cover` + + Returns: + self """ - super().cloud_cover(min_cover, max_cover) - return self + return super().cloud_cover(min_cover, max_cover) def _valid_state(self) -> bool: # spatial params must be paired with a collection limiting parameter @@ -587,19 +754,20 @@ def _is_cloud_hosted(self, granule: Any) -> bool: return True return False - def short_name(self, short_name: str = "") -> Type[GranuleQuery]: + @override + def short_name(self, short_name: str) -> Self: """Filter by short name (aka product or collection name). Parameters: short_name: name of a collection Returns: - Query instance + self """ - super().short_name(short_name) - return self + return super().short_name(short_name) - def get(self, limit: int = 2000) -> list: + @override + def get(self, limit: int = 2000) -> Union[List[Any], Never]: """Get all the collections (datasets) that match with our current parameters up to some limit, even if spanning multiple pages. @@ -611,14 +779,20 @@ def get(self, limit: int = 2000) -> list: Parameters: limit: The number of results to return + Raises: + RuntimeError: if the CMR query fails + Returns: - query results as a list of `DataCollection` instances. + List (possibly empty) of query results. Elements are `DataGranule` + instances if the the query format is `"umm_json"`, generic `dict`s if the + query format is `"json"`, or an unparsed `str` element per page of results + otherwise. """ # TODO: implement items() iterator page_size = min(limit, 2000) url = self._build_url() - results: List = [] + results: List[Any] = [] page = 1 headers: Dict[str, str] = {} while len(results) < limit: @@ -669,21 +843,25 @@ def get(self, limit: int = 2000) -> list: return results - def debug(self, debug: bool = True) -> Type[GranuleQuery]: + def debug(self, debug: bool = True) -> Self: """If True, prints the actual query to CMR, notice that the pagination happens in the headers. Parameters: debug: Print CMR query. + + Returns: + self """ - self._debug = True + self._debug = debug return self + @override def temporal( self, date_from: Optional[Union[str, dt.datetime]] = None, date_to: Optional[Union[str, dt.datetime]] = None, exclude_boundary: bool = False, - ) -> Type[GranuleQuery]: + ) -> Union[Self, Never]: """Filter by an open or closed date range. Dates can be provided as a datetime objects or ISO 8601 formatted strings. Multiple ranges can be provided by successive calls to this method before calling execute(). @@ -692,6 +870,15 @@ def temporal( date_from: earliest date of temporal range date_to: latest date of temporal range exclude_boundary: whether to exclude the date_from/to in the matched range + + Raises: + ValueError: if `date_from` or `date_to` is a non-`None` value that is + neither a datetime object nor a string that can be parsed as a datetime + object; or if `date_from` and `date_to` are both datetime objects (or + parsable as such) and `date_from` is greater than `date_to` + + Returns: + self """ DEFAULT = dt.datetime(1979, 1, 1) if date_from is not None and not isinstance(date_from, dt.datetime): @@ -708,46 +895,63 @@ def temporal( print("The provided end date was not recognized") date_to = "" - super().temporal(date_from, date_to, exclude_boundary) - return self + return super().temporal(date_from, date_to, exclude_boundary) - def version(self, version: str = "") -> Type[GranuleQuery]: + @override + def version(self, version: str) -> Self: """Filter by version. Note that CMR defines this as a string. For example, MODIS version 6 products must be searched for with "006". Parameters: version: version string + + Returns: + self """ - super().version(version) - return self + return super().version(version) - def point(self, lon: str, lat: str) -> Type[GranuleQuery]: + @override + def point(self, lon: FloatLike, lat: FloatLike) -> Union[Self, Never]: """Filter by granules that include a geographic point. Parameters: lon (String): longitude of geographic point lat (String): latitude of geographic point + + Raises: + ValueError: if `lon` or `lat` cannot be converted to a float + + Returns: + self """ - super().point(lon, lat) - return self + return super().point(lon, lat) - def polygon(self, coordinates: List[Tuple[str, str]]) -> Type[GranuleQuery]: + @override + def polygon(self, coordinates: Sequence[PointLike]) -> Union[Self, Never]: """Filter by granules that overlap a polygonal area. Must be used in combination with a collection filtering parameter such as short_name or entry_title. Parameters: coordinates: list of (lon, lat) tuples + + Raises: + ValueError: if `coordinates` is not a sequence of at least 4 coordinate + pairs, any of the coordinates cannot be converted to a float, or the first + and last coordinate pairs are not equal + + Returns: + self """ - super().polygon(coordinates) - return self + return super().polygon(coordinates) + @override def bounding_box( self, - lower_left_lon: str, - lower_left_lat: str, - upper_right_lon: str, - upper_right_lat: str, - ) -> Type[GranuleQuery]: + lower_left_lon: FloatLike, + lower_left_lat: FloatLike, + upper_right_lon: FloatLike, + upper_right_lat: FloatLike, + ) -> Union[Self, Never]: """Filter by granules that overlap a bounding box. Must be used in combination with a collection filtering parameter such as short_name or entry_title. @@ -756,33 +960,51 @@ def bounding_box( lower_left_lat: lower left latitude of the box upper_right_lon: upper right longitude of the box upper_right_lat: upper right latitude of the box + + Raises: + ValueError: if any of the coordinates cannot be converted to a float + + Returns: + self """ - super().bounding_box( + return super().bounding_box( lower_left_lon, lower_left_lat, upper_right_lon, upper_right_lat ) - return self - def line(self, coordinates: List[Tuple[str, str]]) -> Type[GranuleQuery]: + @override + def line(self, coordinates: Sequence[PointLike]) -> Union[Self, Never]: """Filter by granules that overlap a series of connected points. Must be used in combination with a collection filtering parameter such as short_name or entry_title. Parameters: coordinates: a list of (lon, lat) tuples + + Raises: + ValueError: if `coordinates` is not a sequence of at least 2 coordinate + pairs, or any of the coordinates cannot be converted to a float + + Returns: + self """ - super().line(coordinates) - return self + return super().line(coordinates) - def downloadable(self, downloadable: bool = True) -> Type[GranuleQuery]: + @override + def downloadable(self, downloadable: bool = True) -> Union[Self, Never]: """Only match granules that are available for download. The opposite of this method is online_only(). Parameters: downloadable: True to require granules be downloadable + + Raises: + TypeError: if `downloadable` is not of type `bool` + + Returns: + self """ - super().downloadable(downloadable) - return self + return super().downloadable(downloadable) - def doi(self, doi: str) -> Type[GranuleQuery]: + def doi(self, doi: str) -> Union[Self, Never]: """Search data granules by DOI ???+ Tip @@ -790,7 +1012,13 @@ def doi(self, doi: str) -> Type[GranuleQuery]: earthaccess will grab the concept_id for the query to CMR. Parameters: - doi: DOI of a datasets, e.g. 10.5067/AQR50-3Q7CS + doi: DOI of a dataset, e.g. 10.5067/AQR50-3Q7CS + + Raises: + RuntimeError: if the CMR query to get the collection for the DOI fails + + Returns: + self """ collection = DataCollections().doi(doi).get() if len(collection) > 0: diff --git a/earthaccess/typing_.py b/earthaccess/typing_.py new file mode 100644 index 00000000..c8a3507f --- /dev/null +++ b/earthaccess/typing_.py @@ -0,0 +1,49 @@ +""" +Convenience module for importing types from the typing module, abstracting away +the differences between Python versions. +""" + +import sys +from typing import Any, Callable, Optional, SupportsFloat, Type, Union, cast + +if sys.version_info < (3, 9): + from typing import Dict, List, Mapping, Sequence, Tuple +else: + from builtins import dict as Dict, list as List, tuple as Tuple + from collections.abc import Mapping, Sequence + +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias + +if sys.version_info < (3, 11): + from typing import NoReturn as Never + + from typing_extensions import Self +else: + from typing import Never, Self + +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override + +__all__ = [ + "Any", + "Callable", + "Dict", + "List", + "Mapping", + "Never", + "Optional", + "Self", + "Sequence", + "SupportsFloat", + "Tuple", + "Type", + "TypeAlias", + "Union", + "cast", + "override", +] diff --git a/pyproject.toml b/pyproject.toml index 889e33c4..777ea1a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,22 +86,37 @@ build-backend = "poetry.core.masonry.api" [tool.pytest] filterwarnings = ["error::UserWarning"] - [tool.mypy] -disallow_untyped_defs = false -ignore_missing_imports = true +mypy_path = ["earthaccess", "tests", "stubs"] +disallow_untyped_defs = true +# TODO: incrementally work towards strict mode (currently too many errors) +# strict = true +pretty = true # Show additional context in error messages +enable_error_code = "redundant-self" [[tool.mypy.overrides]] module = [ "tests.*", ] -ignore_errors = true +disallow_untyped_defs = false + +[[tool.mypy.overrides]] +module = [ + "fsspec.*", + "kerchunk.*", + "pqdm.*", + "s3fs", + "tinynetrc.*", # TODO: generate stubs for tinynetrc and remove this line +] +ignore_missing_imports = true +[tool.pyright] +include = ["earthaccess"] +stubPath = "./stubs" [tool.ruff] line-length = 88 -src = ["earthaccess", "tests"] -exclude = ["mypy-stubs", "stubs", "typeshed"] +src = ["earthaccess", "stubs", "tests"] [tool.ruff.lint] extend-select = ["I"] @@ -109,7 +124,6 @@ extend-select = ["I"] [tool.ruff.lint.isort] combine-as-imports = true - [tool.bumpversion] current_version = "0.9.0" commit = false diff --git a/scripts/lint.sh b/scripts/lint.sh index 3a528811..02f9c70a 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -set -e -set -x +set -ex -mypy earthaccess --disallow-untyped-defs +mypy earthaccess stubs tests ruff check . diff --git a/stubs/cmr/__init__.pyi b/stubs/cmr/__init__.pyi new file mode 100644 index 00000000..3ea9733e --- /dev/null +++ b/stubs/cmr/__init__.pyi @@ -0,0 +1,10 @@ +from .queries import ( + CMR_OPS as CMR_OPS, + CMR_SIT as CMR_SIT, + CMR_UAT as CMR_UAT, + CollectionQuery as CollectionQuery, + GranuleQuery as GranuleQuery, + ServiceQuery as ServiceQuery, + ToolQuery as ToolQuery, + VariableQuery as VariableQuery, +) diff --git a/stubs/cmr/queries.pyi b/stubs/cmr/queries.pyi new file mode 100644 index 00000000..41d18b53 --- /dev/null +++ b/stubs/cmr/queries.pyi @@ -0,0 +1,112 @@ +import sys +from datetime import datetime +from typing import Any, Optional, SupportsFloat, Union + +if sys.version_info < (3, 9): + from typing import List, MutableMapping, Sequence, Tuple +else: + from builtins import list as List, tuple as Tuple + from collections.abc import MutableMapping, Sequence + +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias + +if sys.version_info < (3, 11): + from typing import NoReturn as Never + + from typing_extensions import Self +else: + from typing import Never, Self + +CMR_OPS: str +CMR_UAT: str +CMR_SIT: str + +FloatLike: TypeAlias = Union[str, SupportsFloat] +PointLike: TypeAlias = Tuple[FloatLike, FloatLike] + +class Query: + params: MutableMapping[str, Any] + options: MutableMapping[str, Any] + concept_id_chars: Sequence[str] + headers: MutableMapping[str, str] + + def __init__(self, route: str, mode: str = ...) -> None: ... + def _build_url(self) -> Union[str, Never]: ... + def get(self, limit: int = ...) -> Union[List[Any], Never]: ... + def hits(self) -> Union[int, Never]: ... + def get_all(self) -> Union[List[Any], Never]: ... + def parameters(self, **kwargs: Any) -> Self: ... + def format(self, output_format: str = "json") -> Union[Self, Never]: ... + def concept_id(self, ids: Sequence[str]) -> Union[Self, Never]: ... + def provider(self, provider: str) -> Self: ... + def mode(self, mode: str = ...) -> Union[None, Never]: ... + def token(self, token: str) -> Self: ... + def bearer_token(self, bearer_token: str) -> Self: ... + +class GranuleCollectionBaseQuery(Query): + def online_only(self, online_only: bool = True) -> Self: ... + def temporal( + self, + date_from: Optional[Union[str, datetime]], + date_to: Optional[Union[str, datetime]], + exclude_boundary: bool = False, + ) -> Union[Self, Never]: ... + def short_name(self, short_name: str) -> Self: ... + def version(self, version: str) -> Self: ... + def point(self, lon: FloatLike, lat: FloatLike) -> Self: ... + def circle( + self, lon: FloatLike, lat: FloatLike, dist: FloatLike + ) -> Union[Self, Never]: ... + def polygon(self, coordinates: Sequence[PointLike]) -> Union[Self, Never]: ... + def bounding_box( + self, + lower_left_lon: FloatLike, + lower_left_lat: FloatLike, + upper_right_lon: FloatLike, + upper_right_lat: FloatLike, + ) -> Self: ... + def line(self, coordinates: Sequence[PointLike]) -> Self: ... + def downloadable(self, downloadable: bool = True) -> Self: ... + def entry_title(self, entry_title: str) -> Self: ... + +class GranuleQuery(GranuleCollectionBaseQuery): + def __init__(self, mode: str = ...) -> None: ... + def orbit_number( + self, + orbit1: FloatLike, + orbit2: Optional[FloatLike] = ..., + ) -> Self: ... + def day_night_flag(self, day_night_flag: str) -> Union[Self, Never]: ... + def cloud_cover( + self, + min_cover: Optional[FloatLike] = ..., + max_cover: Optional[FloatLike] = ..., + ) -> Self: ... + def instrument(self, instrument: str) -> Union[Self, Never]: ... + def platform(self, platform: str) -> Union[Self, Never]: ... + def sort_key(self, sort_key: str) -> Union[Self, Never]: ... + def granule_ur(self, granule_ur: str) -> Union[Self, Never]: ... + +class CollectionQuery(GranuleCollectionBaseQuery): + def __init__(self, mode: str = ...) -> None: ... + def archive_center(self, center: str) -> Self: ... + def keyword(self, text: str) -> Self: ... + def native_id(self, native_ids: Sequence[str]) -> Self: ... + def tool_concept_id(self, ids: Sequence[str]) -> Union[Self, Never]: ... + def service_concept_id(self, ids: Sequence[str]) -> Union[Self, Never]: ... + +class ToolServiceVariableBaseQuery(Query): + def native_id(self, native_ids: Sequence[str]) -> Self: ... + def name(self, name: str) -> Self: ... + +class ToolQuery(ToolServiceVariableBaseQuery): + def __init__(self, mode: str = ...) -> None: ... + +class ServiceQuery(ToolServiceVariableBaseQuery): + def __init__(self, mode: str = ...) -> None: ... + +class VariableQuery(ToolServiceVariableBaseQuery): + def __init__(self, mode: str = ...) -> None: ... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 0c59fc86..b2b0a048 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -11,7 +11,7 @@ class TestCreateAuth(unittest.TestCase): @responses.activate @mock.patch("getpass.getpass") @mock.patch("builtins.input") - def test_auth_gets_proper_credentials(self, user_input, user_password) -> bool: + def test_auth_gets_proper_credentials(self, user_input, user_password): user_input.return_value = "user" user_password.return_value = "password" json_response = [ @@ -53,9 +53,7 @@ def test_auth_gets_proper_credentials(self, user_input, user_password) -> bool: @responses.activate @mock.patch("getpass.getpass") @mock.patch("builtins.input") - def test_auth_can_create_proper_credentials( - self, user_input, user_password - ) -> bool: + def test_auth_can_create_proper_credentials(self, user_input, user_password): user_input.return_value = "user" user_password.return_value = "password" json_response = {"access_token": "EDL-token-1", "expiration_date": "12/15/2021"} @@ -94,7 +92,7 @@ def test_auth_can_create_proper_credentials( @responses.activate @mock.patch("getpass.getpass") @mock.patch("builtins.input") - def test_auth_fails_for_wrong_credentials(self, user_input, user_password) -> bool: + def test_auth_fails_for_wrong_credentials(self, user_input, user_password): user_input.return_value = "bad_user" user_password.return_value = "bad_password" json_response = {"error": "wrong credentials"}