From f40526c12daa2f1d619287167eecd365f1e21a6e Mon Sep 17 00:00:00 2001 From: danielfromearth Date: Mon, 25 Mar 2024 10:28:15 -0400 Subject: [PATCH] use uat urls in sessions passed around, and todo for uat test --- earthaccess/api.py | 13 +++++-------- earthaccess/auth.py | 25 ++++++++++++++++++------ earthaccess/search.py | 44 ++++++++++++++++++------------------------ tests/unit/test_uat.py | 20 +++++++++++-------- 4 files changed, 55 insertions(+), 47 deletions(-) diff --git a/earthaccess/api.py b/earthaccess/api.py index ce6c88b6..162028fb 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -84,7 +84,7 @@ def search_datasets( def search_data( - count: int = -1, **kwargs: Any + count: int = -1, session: Optional[requests.Session] = None, **kwargs: Any ) -> List[earthaccess.results.DataGranule]: """Search dataset granules using NASA's CMR. @@ -119,14 +119,11 @@ def search_data( ``` """ if earthaccess.__auth__.authenticated: - query = DataGranules( - earthaccess.__auth__, - earthdata_environment=earthaccess.__auth__.earthdata_environment, - ).parameters(**kwargs) + query = DataGranules(earthaccess.__auth__, existing_session=session).parameters( + **kwargs + ) else: - query = DataGranules( - earthdata_environment=earthaccess.__auth__.earthdata_environment - ).parameters(**kwargs) + query = DataGranules(existing_session=session).parameters(**kwargs) granules_found = query.hits() print(f"Granules found: {granules_found}") if count > 0: diff --git a/earthaccess/auth.py b/earthaccess/auth.py index b635db06..9a4800eb 100644 --- a/earthaccess/auth.py +++ b/earthaccess/auth.py @@ -51,7 +51,10 @@ class SessionWithHeaderRedirection(requests.Session): ] def __init__( - self, username: Optional[str] = None, password: Optional[str] = None + self, + username: Optional[str] = None, + password: Optional[str] = None, + earthdata_environment: Optional[str] = None, ) -> None: super().__init__() self.headers.update({"User-Agent": user_agent}) @@ -59,6 +62,10 @@ def __init__( if username and password: self.auth = (username, password) + if earthdata_environment is not None: + self.AUTH_HOSTS.pop(0) + self.AUTH_HOSTS.insert(0, earthdata_environment.value) + # Overrides from the library to keep headers when redirected to or from # the NASA auth host. def rebuild_auth(self, prepared_request: Any, response: Any) -> None: @@ -110,6 +117,9 @@ def login( Returns: An instance of Auth. """ + if earthdata_environment is not None: + self._set_earthdata_environment(earthdata_environment) + if self.authenticated and (earthdata_environment == self.earthdata_environment): logger.debug("We are already authenticated with NASA EDL") return self @@ -120,9 +130,6 @@ def login( if strategy == "environment": self._environment() - if earthdata_environment is not None: - self._set_earthdata_environment(earthdata_environment) - return self def _set_earthdata_environment(self, earthdata_environment: Env) -> None: @@ -255,7 +262,9 @@ def get_s3_credentials( print("We need to authenticate with EDL first") return {} - def get_session(self, bearer_token: bool = True) -> requests.Session: + def get_session( + self, bearer_token: bool = True, earthdata_environment: Optional[Env] = None + ) -> requests.Session: """Returns a new request session instance. Parameters: @@ -264,7 +273,11 @@ def get_session(self, bearer_token: bool = True) -> requests.Session: Returns: class Session instance with Auth and bearer token headers """ - session = SessionWithHeaderRedirection() + if earthdata_environment is not None: + self._set_earthdata_environment(earthdata_environment) + session = SessionWithHeaderRedirection( + earthdata_environment=earthdata_environment + ) if bearer_token and self.authenticated: # This will avoid the use of the netrc after we are logged in session.trust_env = False diff --git a/earthaccess/search.py b/earthaccess/search.py index 727eea3d..82caffd2 100644 --- a/earthaccess/search.py +++ b/earthaccess/search.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import dateutil.parser as parser # type: ignore +import requests from cmr import CMR_OPS, CMR_SIT, CMR_UAT, CollectionQuery, GranuleQuery # type: ignore from requests import exceptions, session @@ -36,7 +37,7 @@ class DataCollections(CollectionQuery): def __init__( self, auth: Optional[Auth] = None, - earthdata_environment: Optional[Env] = None, + existing_session: Optional[requests.Session] = None, *args: Any, **kwargs: Any, ) -> None: @@ -47,22 +48,19 @@ def __init__( for queries that need authentication, e.g. restricted datasets. """ super().__init__(*args, **kwargs) - self.session = session() - if auth is not None: - earthdata_environment = auth.earthdata_environment - if (earthdata_environment is None) or (earthdata_environment == Env.PROD): + if existing_session is not None: + self.session = existing_session + else: + self.session = session() + + if self.session.AUTH_HOSTS[0] == Env.PROD.value: self.mode(CMR_OPS) - elif earthdata_environment == Env.UAT: + elif self.session.AUTH_HOSTS[0] == Env.UAT.value: self.mode(CMR_UAT) - elif earthdata_environment == Env.SIT: + elif self.session.AUTH_HOSTS[0] == Env.SIT.value: self.mode(CMR_SIT) - print(f"[in DataCollections] Earthdata environment: {earthdata_environment}") - print( - f"[in DataCollections] earthdata_environment == Env.PROD -----> {earthdata_environment == Env.PROD}" - ) - if auth is not None and auth.authenticated: # To search, we need the new bearer tokens from NASA Earthdata self.session = auth.get_session(bearer_token=True) @@ -389,29 +387,25 @@ class DataGranules(GranuleQuery): def __init__( self, auth: Any = None, - earthdata_environment: Optional[Env] = None, + # earthdata_environment: Optional[Env] = None, + existing_session: Optional[requests.Session] = None, *args: Any, **kwargs: Any, ) -> None: """Base class for Granule and Collection CMR queries.""" super().__init__(*args, **kwargs) - self.session = session() - if auth is not None: - earthdata_environment = auth.earthdata_environment + if existing_session is not None: + self.session = existing_session + else: + self.session = session() - # TODO: Move this in to a data structure, e.g. the existing Enum? - if (earthdata_environment is None) or (earthdata_environment == Env.PROD): + if self.session.AUTH_HOSTS[0] == Env.PROD.value: self.mode(CMR_OPS) - elif earthdata_environment == Env.UAT: + elif self.session.AUTH_HOSTS[0] == Env.UAT.value: self.mode(CMR_UAT) - elif earthdata_environment == Env.SIT: + elif self.session.AUTH_HOSTS[0] == Env.SIT.value: self.mode(CMR_SIT) - print(f"[in DataGranules] Earthdata environment: {earthdata_environment}") - print( - f"[in DataGranules] earthdata_environment == Env.PROD -----> {earthdata_environment == Env.PROD}" - ) - if auth is not None and auth.authenticated: # To search, we need the new bearer tokens from NASA Earthdata self.session = auth.get_session(bearer_token=True) diff --git a/tests/unit/test_uat.py b/tests/unit/test_uat.py index f92263c5..20dac45f 100644 --- a/tests/unit/test_uat.py +++ b/tests/unit/test_uat.py @@ -13,7 +13,7 @@ class TestUatEnvironmentArgument: "builtins.input", new=mock.Mock(return_value="user"), ) - def test_uat_is_requested_when_uat_selected(self) -> bool: + def test_uat_login_when_uat_selected(self) -> bool: """Test the correct env is queried based on what's selected at login-time.""" json_response = [ {"access_token": "EDL-token-1", "expiration_date": "12/15/2021"}, @@ -38,16 +38,20 @@ def test_uat_is_requested_when_uat_selected(self) -> bool: status=200, ) - # Test - # Login - # TODO: Can we use the top-level API? Why do other tests manually create - # an Auth instance instead of: - # earthaccess.login(strategy=..., earthdata_environment=Env.UAT) + # TODO: Add a mock for CMR query? Does it need a "CMR-HITS" field in the response? + # responses.add( + # responses.GET, + # "https://cmr.uat.earthdata.nasa.gov/search/granules.umm_json?page_size=0", + # json=json_response, + # status=200, + # ) + # Use Auth instance instead of the top-level (`earthaccess.`) API since this is a unit, + # not an integration, test auth = Auth() # Check that we're not already authenticated. - session = auth.get_session() + session = auth.get_session(earthdata_environment=Env.UAT) headers = session.headers assert not auth.authenticated @@ -56,7 +60,7 @@ def test_uat_is_requested_when_uat_selected(self) -> bool: assert auth.authenticated assert auth.token in json_response - # test that we are creating a session with the proper headers + # Test that we are creating a session with the proper headers assert "User-Agent" in headers assert "earthaccess" in headers["User-Agent"]