From d8c0475d182224a8b9ddefe17b9e7b042ffbd2c8 Mon Sep 17 00:00:00 2001 From: Micah Sandusky Date: Tue, 20 Aug 2024 12:29:16 -0600 Subject: [PATCH] add some tests --- snowexsql/api.py | 82 +++++++++++++++++--------------------------- tests/test_api.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 52 deletions(-) diff --git a/snowexsql/api.py b/snowexsql/api.py index 8cd6b6f..fe0aa8b 100644 --- a/snowexsql/api.py +++ b/snowexsql/api.py @@ -25,6 +25,12 @@ class LargeQueryCheckException(RuntimeError): pass +class NoColumnException(RuntimeError): + """ + The object does not have that column + """ + + @contextmanager def db_session(db_name): # use default_name @@ -35,7 +41,7 @@ def db_session(db_name): def get_points(): - # Lets grab a single row from the points table + # Let's grab a single row from the points table with db_session(DB_NAME) as session: qry = session.query(PointData).limit(1) # Execute that query! @@ -164,65 +170,57 @@ def from_unique_entries(cls, columns_to_search, **kwargs): return results + def _all_from_attribute(self, attribute_name): + if not hasattr(self.MODEL, attribute_name): + raise NoColumnException( + f"{self.MODEL} does not have {attribute_name}" + ) + with db_session(self.DB_NAME) as (session, engine): + qry = session.query(getattr(self.MODEL, attribute_name)).distinct() + result = qry.all() + return self.retrieve_single_value_result(result) + @property def all_site_names(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.site_name).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("site_name") @property def all_types(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.type).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("type") @property def all_dates(self): """ Return all distinct dates in the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.date).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("date") @property def all_observers(self): """ Return all distinct observers in the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.observers).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("observers") @property def all_units(self): """ Return all distinct units in the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.units).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("units") @property def all_instruments(self): """ Return all distinct instruments in the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.instrument).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("instrument") class PointMeasurements(BaseDataset): @@ -244,7 +242,7 @@ def from_filter(cls, **kwargs): df = query_to_geopandas(qry, engine) except Exception as e: session.close() - LOG.error("Failed query for PointData") + LOG.error(f"Failed query for {cls.MODEL}") raise e return df @@ -306,7 +304,7 @@ class SiteMeasurements(PointMeasurements): ALLOWED_QRY_KWARGS = [ "site_name", "site_id", "date", "pit_id", "utm_zone", "aspect", "sky_cover", "ground_roughness", - "ground_vegetation", "tree_canopy", "weather_description" + "ground_vegetation", "tree_canopy", "weather_description", "date_greater_equal", "date_less_equal", ] MODEL = SiteData @@ -316,60 +314,42 @@ def all_weather_description(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.weather_description).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("weather_description") @property def all_ground_vegetation(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.ground_vegetation).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("ground_vegetation") @property def all_tree_canopy(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.tree_canopy).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("tree_canopy") @property def all_ground_roughness(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.ground_roughness).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("ground_roughness") @property def all_sky_cover(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.sky_cover).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("sky_cover") @property def all_aspect(self): """ Return all types of the data """ - with db_session(self.DB_NAME) as (session, engine): - qry = session.query(self.MODEL.aspect).distinct() - result = qry.all() - return self.retrieve_single_value_result(result) + return self._all_from_attribute("aspect") class TooManyRastersException(Exception): diff --git a/tests/test_api.py b/tests/test_api.py index 46fa128..ae5bd97 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,7 +5,8 @@ from datetime import date from snowexsql.api import ( - PointMeasurements, LargeQueryCheckException, LayerMeasurements + PointMeasurements, LargeQueryCheckException, LayerMeasurements, + SiteMeasurements, NoColumnException ) from snowexsql.db import get_db, initialize @@ -248,3 +249,87 @@ def test_from_area_point(self, clz): type="density", ) assert len(result) == 0 + + +class TestSiteMeasurements(DBConnection): + """ + Test the Layer Measurement class + """ + CLZ = SiteMeasurements + + def test_all_types_fails(self, clz): + with pytest.raises(NoColumnException): + clz().all_types + + def test_all_site_names(self, clz): + result = clz().all_site_names + assert result == [] + + def test_all_dates(self, clz): + result = clz().all_dates + assert len(result) == 0 + + def test_all_weather_description(self, clz): + result = clz().all_weather_description + assert unsorted_list_tuple_compare(result, []) + + def test_all_instruments(self, clz): + result = clz().all_instruments + assert unsorted_list_tuple_compare(result, []) + + @pytest.mark.parametrize( + "kwargs, expected_length, mean_value", [ + ({ + "date": date(2020, 3, 12), "ground_roughness": "Smooth", + "pit_id": "COERIB_20200312_0938" + }, 0, np.nan), # filter to 1 pit + ({"ground_roughness": "Smooth", "limit": 10}, 0, np.nan), # limit works + ({ + "date": date(2020, 5, 28), + }, 0, np.nan), # nothing returned + ({ + "date_less_equal": date(2019, 12, 15), + }, 0, np.nan), + ({ + "date_greater_equal": date(2020, 5, 13), + }, 0, np.nan), + ] + ) + def test_from_filter(self, clz, kwargs, expected_length, mean_value): + result = clz.from_filter(**kwargs) + assert len(result) == expected_length + + @pytest.mark.parametrize( + "kwargs, expected_error", [ + ({"notakey": "value"}, ValueError), + # ({"date": date(2020, 3, 12)}, LargeQueryCheckException), + ({"date": [date(2020, 5, 28), date(2019, 10, 3)]}, ValueError), + ] + ) + def test_from_filter_fails(self, clz, kwargs, expected_error): + """ + Test failure on not-allowed key and too many returns + """ + with pytest.raises(expected_error): + clz.from_filter(**kwargs) + + def test_from_area(self, clz): + df = gpd.GeoDataFrame( + geometry=gpd.points_from_xy( + [743766.4794971556], [4321444.154620216], crs="epsg:26912" + ).buffer(1000.0) + ).set_crs("epsg:26912") + result = clz.from_area( + ground_roughness="Smooth", + shp=df.iloc[0].geometry, + ) + assert len(result) == 0 + + def test_from_area_point(self, clz): + pts = gpd.points_from_xy([743766.4794971556], [4321444.154620216]) + crs = "26912" + result = clz.from_area( + pt=pts[0], buffer=1000, crs=crs, + ground_roughness="Smooth", + ) + assert len(result) == 0