Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate the temporal module #604

Merged
merged 17 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions icepyx/core/APIformatting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Generate and format information for submitting to API (CMR and NSIDC)."""

import datetime as dt
from typing import Any, Generic, Literal, TypeVar, Union, overload
from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload

from icepyx.core.types import (
CMRParams,
Expand Down Expand Up @@ -36,10 +36,6 @@ def _fmt_temporal(start, end, key):

assert isinstance(start, dt.datetime)
assert isinstance(end, dt.datetime)
assert key in [
"time",
"temporal",
], "An invalid time key was submitted for formatting."

if key == "temporal":
fmt_timerange = (
Expand All @@ -53,6 +49,8 @@ def _fmt_temporal(start, end, key):
+ ","
+ end.strftime("%Y-%m-%dT%H:%M:%S")
)
else:
raise RuntimeError("An invalid time key was submitted for formatting.")

return {key: fmt_timerange}

Expand Down Expand Up @@ -231,7 +229,7 @@ def __get__(
Returns the dictionary of formatted keys associated with the
parameter object.
"""
return instance._fmted_keys
return instance._fmted_keys # pyright: ignore[reportReturnType]


# ----------------------------------------------------------------------
Expand All @@ -257,13 +255,16 @@ class Parameters(Generic[T]):
on the type of query. Must be one of ['search','download']
"""

partype: T
_reqtype: Optional[Literal["search", "download"]]
fmted_keys = _FmtedKeysDescriptor()
# _fmted_keys: Union[CMRParams, EGISpecificRequiredParams, EGIParamsSubset]

def __init__(
self,
partype: T,
values=None,
reqtype=None,
values: Optional[dict] = None,
reqtype: Optional[Literal["search", "download"]] = None,
):
assert partype in [
"CMR",
Expand All @@ -282,31 +283,14 @@ def __init__(
self._fmted_keys = values if values is not None else {}

@property
def poss_keys(self):
def poss_keys(self) -> dict[str, list[str]]:
"""
Returns a list of possible input keys for the given parameter object.
Possible input keys depend on the parameter type (partype).
"""

if not hasattr(self, "_poss_keys"):
self._get_possible_keys()

return self._poss_keys

# @property
# def wanted_keys(self):
# if not hasattr(_wanted):
# self._wanted = []

# return self._wanted

def _get_possible_keys(self) -> dict[str, list[str]]:
"""
Use the parameter type to get a list of possible parameter keys.
"""

if self.partype == "CMR":
self._poss_keys = {
return {
"spatial": ["bounding_box", "polygon"],
"optional": [
"temporal",
Expand All @@ -316,7 +300,7 @@ def _get_possible_keys(self) -> dict[str, list[str]]:
],
}
elif self.partype == "required":
self._poss_keys = {
return {
"search": ["short_name", "version", "page_size"],
"download": [
"short_name",
Expand All @@ -331,7 +315,7 @@ def _get_possible_keys(self) -> dict[str, list[str]]:
],
}
elif self.partype == "subset":
self._poss_keys = {
return {
"spatial": ["bbox", "Boundingshape"],
"optional": [
"time",
Expand All @@ -341,8 +325,17 @@ def _get_possible_keys(self) -> dict[str, list[str]]:
"Coverage",
],
}
else:
raise RuntimeError("Programmer error!")

# @property
# def wanted_keys(self):
# if not hasattr(_wanted):
# self._wanted = []

def _check_valid_keys(self):
# return self._wanted

def _check_valid_keys(self) -> None:
"""
Checks that any keys passed in with values are valid keys.
"""
Expand All @@ -352,13 +345,13 @@ def _check_valid_keys(self):

val_list = list({val for lis in self.poss_keys.values() for val in lis})

for key in self.fmted_keys:
for key in self.fmted_keys: # pyright: ignore[reportAttributeAccessIssue]
assert key in val_list, (
"An invalid key (" + key + ") was passed. Please remove it using `del`"
)

# DevNote: can check_req_values and check_values be combined?
def check_req_values(self):
def check_req_values(self) -> bool:
"""
Check that all of the required keys have values, if the key was passed in with
the values parameter.
Expand All @@ -367,17 +360,22 @@ def check_req_values(self):
assert (
self.partype == "required"
), "You cannot call this function for your parameter type"

if not self._reqtype:
raise RuntimeError("Programmer error!")

reqkeys = self.poss_keys[self._reqtype]

if all(keys in self.fmted_keys for keys in reqkeys):
if all(keys in self.fmted_keys for keys in reqkeys): # pyright: ignore[reportAttributeAccessIssue]
assert all(
self.fmted_keys.get(key, -9999) != -9999 for key in reqkeys
self.fmted_keys.get(key, -9999) != -9999 # pyright: ignore[reportAttributeAccessIssue]
for key in reqkeys
), "One of your formatted parameters is missing a value"
return True
else:
return False

def check_values(self):
def check_values(self) -> bool:
"""
Check that the non-required keys have values, if the key was
passed in with the values parameter.
Expand All @@ -391,7 +389,8 @@ def check_values(self):
# not the most robust check, but better than nothing...
if any(keys in self._fmted_keys for keys in spatial_keys):
assert any(
self.fmted_keys.get(key, -9999) != -9999 for key in spatial_keys
self.fmted_keys.get(key, -9999) != -9999 # pyright: ignore[reportAttributeAccessIssue]
for key in spatial_keys
), "One of your formatted parameters is missing a value"
return True
else:
Expand Down Expand Up @@ -427,6 +426,9 @@ def build_params(self, **kwargs) -> None:
self._check_valid_keys()

if self.partype == "required":
if not self._reqtype:
raise RuntimeError("Programmer error!")

if self.check_req_values() and kwargs == {}:
pass
else:
Expand Down Expand Up @@ -484,6 +486,7 @@ def build_params(self, **kwargs) -> None:
if any(keys in self._fmted_keys for keys in spatial_keys):
pass
else:
k = None
if self.partype == "CMR":
k = kwargs["extent_type"]
elif self.partype == "subset":
Expand All @@ -492,6 +495,9 @@ def build_params(self, **kwargs) -> None:
elif kwargs["extent_type"] == "polygon":
k = "Boundingshape"

if not k:
raise RuntimeError("Programmer error!")

self._fmted_keys.update({k: kwargs["spatial_extent"]})


Expand Down
11 changes: 7 additions & 4 deletions icepyx/core/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime as dt
import pprint
from typing import Optional, Union, cast

Expand Down Expand Up @@ -126,6 +127,8 @@ class GenQuery:
Quest
"""

_temporal: tp.Temporal

def __init__(
self,
spatial_extent=None,
Expand Down Expand Up @@ -157,7 +160,7 @@ def __str__(self):
# Properties

@property
def temporal(self):
def temporal(self) -> Union[tp.Temporal, list[str]]:
"""
Return the Temporal object containing date/time range information for the query object.

Expand Down Expand Up @@ -254,7 +257,7 @@ def spatial_extent(self):
return (self._spatial._ext_type, self._spatial._spatial_ext)

@property
def dates(self):
def dates(self) -> Union[list[str], list[dt.time]]:
mfisher87 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return an array showing the date range of the query object.
Dates are returned as an array containing the start and end datetime
Expand All @@ -279,7 +282,7 @@ def dates(self):
] # could also use self._start.date()

@property
def start_time(self):
def start_time(self) -> Union[list[str], str]:
JessicaS11 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the start time specified for the start date.

Expand All @@ -303,7 +306,7 @@ def start_time(self):
return self._temporal._start.strftime("%H:%M:%S")

@property
def end_time(self):
def end_time(self) -> Union[list[str], str]:
"""
Return the end time specified for the end date.

Expand Down
Loading