Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
micah-prime committed Aug 20, 2024
1 parent a199e3d commit d8c0475
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 52 deletions.
82 changes: 31 additions & 51 deletions snowexsql/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class LargeQueryCheckException(RuntimeError):
pass


class NoColumnException(RuntimeError):
"""
The object does not have that column
"""


@contextmanager
def db_session(db_name):
# use default_name
Expand All @@ -35,7 +41,7 @@ def db_session(db_name):


def get_points():
# Lets grab a single row from the points table
# Let's grab a single row from the points table
with db_session(DB_NAME) as session:
qry = session.query(PointData).limit(1)
# Execute that query!
Expand Down Expand Up @@ -164,65 +170,57 @@ def from_unique_entries(cls, columns_to_search, **kwargs):

return results

def _all_from_attribute(self, attribute_name):
if not hasattr(self.MODEL, attribute_name):
raise NoColumnException(
f"{self.MODEL} does not have {attribute_name}"
)
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(getattr(self.MODEL, attribute_name)).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)

@property
def all_site_names(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.site_name).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("site_name")

@property
def all_types(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.type).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("type")

@property
def all_dates(self):
"""
Return all distinct dates in the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.date).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("date")

@property
def all_observers(self):
"""
Return all distinct observers in the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.observers).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("observers")

@property
def all_units(self):
"""
Return all distinct units in the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.units).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("units")

@property
def all_instruments(self):
"""
Return all distinct instruments in the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.instrument).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("instrument")


class PointMeasurements(BaseDataset):
Expand All @@ -244,7 +242,7 @@ def from_filter(cls, **kwargs):
df = query_to_geopandas(qry, engine)
except Exception as e:
session.close()
LOG.error("Failed query for PointData")
LOG.error(f"Failed query for {cls.MODEL}")
raise e

return df
Expand Down Expand Up @@ -306,7 +304,7 @@ class SiteMeasurements(PointMeasurements):
ALLOWED_QRY_KWARGS = [
"site_name", "site_id", "date", "pit_id", "utm_zone",
"aspect", "sky_cover", "ground_roughness",
"ground_vegetation", "tree_canopy", "weather_description"
"ground_vegetation", "tree_canopy", "weather_description",
"date_greater_equal", "date_less_equal",
]
MODEL = SiteData
Expand All @@ -316,60 +314,42 @@ def all_weather_description(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.weather_description).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("weather_description")

@property
def all_ground_vegetation(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.ground_vegetation).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("ground_vegetation")

@property
def all_tree_canopy(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.tree_canopy).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("tree_canopy")

@property
def all_ground_roughness(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.ground_roughness).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("ground_roughness")

@property
def all_sky_cover(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.sky_cover).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("sky_cover")

@property
def all_aspect(self):
"""
Return all types of the data
"""
with db_session(self.DB_NAME) as (session, engine):
qry = session.query(self.MODEL.aspect).distinct()
result = qry.all()
return self.retrieve_single_value_result(result)
return self._all_from_attribute("aspect")


class TooManyRastersException(Exception):
Expand Down
87 changes: 86 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import date

from snowexsql.api import (
PointMeasurements, LargeQueryCheckException, LayerMeasurements
PointMeasurements, LargeQueryCheckException, LayerMeasurements,
SiteMeasurements, NoColumnException
)
from snowexsql.db import get_db, initialize

Expand Down Expand Up @@ -248,3 +249,87 @@ def test_from_area_point(self, clz):
type="density",
)
assert len(result) == 0


class TestSiteMeasurements(DBConnection):
"""
Test the Layer Measurement class
"""
CLZ = SiteMeasurements

def test_all_types_fails(self, clz):
with pytest.raises(NoColumnException):
clz().all_types

def test_all_site_names(self, clz):
result = clz().all_site_names
assert result == []

def test_all_dates(self, clz):
result = clz().all_dates
assert len(result) == 0

def test_all_weather_description(self, clz):
result = clz().all_weather_description
assert unsorted_list_tuple_compare(result, [])

def test_all_instruments(self, clz):
result = clz().all_instruments
assert unsorted_list_tuple_compare(result, [])

@pytest.mark.parametrize(
"kwargs, expected_length, mean_value", [
({
"date": date(2020, 3, 12), "ground_roughness": "Smooth",
"pit_id": "COERIB_20200312_0938"
}, 0, np.nan), # filter to 1 pit
({"ground_roughness": "Smooth", "limit": 10}, 0, np.nan), # limit works
({
"date": date(2020, 5, 28),
}, 0, np.nan), # nothing returned
({
"date_less_equal": date(2019, 12, 15),
}, 0, np.nan),
({
"date_greater_equal": date(2020, 5, 13),
}, 0, np.nan),
]
)
def test_from_filter(self, clz, kwargs, expected_length, mean_value):
result = clz.from_filter(**kwargs)
assert len(result) == expected_length

@pytest.mark.parametrize(
"kwargs, expected_error", [
({"notakey": "value"}, ValueError),
# ({"date": date(2020, 3, 12)}, LargeQueryCheckException),
({"date": [date(2020, 5, 28), date(2019, 10, 3)]}, ValueError),
]
)
def test_from_filter_fails(self, clz, kwargs, expected_error):
"""
Test failure on not-allowed key and too many returns
"""
with pytest.raises(expected_error):
clz.from_filter(**kwargs)

def test_from_area(self, clz):
df = gpd.GeoDataFrame(
geometry=gpd.points_from_xy(
[743766.4794971556], [4321444.154620216], crs="epsg:26912"
).buffer(1000.0)
).set_crs("epsg:26912")
result = clz.from_area(
ground_roughness="Smooth",
shp=df.iloc[0].geometry,
)
assert len(result) == 0

def test_from_area_point(self, clz):
pts = gpd.points_from_xy([743766.4794971556], [4321444.154620216])
crs = "26912"
result = clz.from_area(
pt=pts[0], buffer=1000, crs=crs,
ground_roughness="Smooth",
)
assert len(result) == 0

0 comments on commit d8c0475

Please sign in to comment.