Skip to content

Commit

Permalink
Merge pull request #33 from TOMToolkit/feature/max_alerts
Browse files Browse the repository at this point in the history
Feature/max alerts
  • Loading branch information
jchate6 authored Mar 24, 2023
2 parents 5afc892 + 5e37907 commit ac82aff
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
13 changes: 12 additions & 1 deletion tom_antares/antares.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ class ANTARESBrokerForm(GenericQueryForm):
label='Elastic Search query in JSON format',
widget=forms.TextInput(attrs={'placeholder': '{"query":{}}'}),
)
max_alerts = forms.FloatField(
label='Maximum number of alerts to fetch',
widget=forms.TextInput(attrs={'placeholder': 'Max Alerts'}),
min_value=1,
initial=20
)

# cone_search = ConeSearchField()
# api_search_tags = forms.MultipleChoiceField(choices=get_tag_choices)
Expand Down Expand Up @@ -221,6 +227,10 @@ def __init__(self, *args, **kwargs):
'View Tags',
'tag'
),
Fieldset(
'Max Alerts',
'max_alerts'
),
HTML('<hr/>'),
HTML('<p style="color:blue;font-size:30px">Advanced query</p>'),
Fieldset(
Expand Down Expand Up @@ -318,6 +328,7 @@ def fetch_alerts(self, parameters: dict) -> iter:
mag_max = parameters.get('mag__max')
elsquery = parameters.get('esquery')
ztfid = parameters.get('ztfid')
max_alerts = parameters.get('max_alerts', 20)
if ztfid:
query = {
"query": {
Expand Down Expand Up @@ -384,7 +395,7 @@ def fetch_alerts(self, parameters: dict) -> iter:
# if ztfid:
# loci = get_by_ztf_object_id(ztfid)
alerts = []
while len(alerts) < 20:
while len(alerts) < max_alerts:
try:
locus = next(loci)
except (marshmallow.exceptions.ValidationError, StopIteration):
Expand Down
9 changes: 8 additions & 1 deletion tom_antares/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ def test_fetch_alerts(self, mock_client):
# NOTE: if .side_effect is going to return a list, it needs a function that returns a list
mock_client.search.search.side_effect = lambda loci: iter(self.loci)
expected_alert = ANTARESBroker.alert_to_dict(self.locus)
alerts = ANTARESBroker().fetch_alerts({'tag': [self.tag]})
alerts = ANTARESBroker().fetch_alerts({'tag': [self.tag], 'max_alerts': 3})

# TODO: compare iterator length with len(self.loci)
self.assertEqual(next(alerts), expected_alert)

@mock.patch('tom_antares.antares.antares_client')
def test_fetch_alerts_max_alerts(self, mock_client):
"""Tests that the max_alerts parameter actually affects the length of the alert stream"""
mock_client.search.search.side_effect = lambda loci: iter(self.loci)
alerts = ANTARESBroker().fetch_alerts({'max_alerts': 4})
self.assertEqual(len(list(alerts)), 4)

def test_to_target_with_horizons_targetname(self):
"""
Test that the expected names are created.
Expand Down

0 comments on commit ac82aff

Please sign in to comment.