Skip to content

Commit

Permalink
Catch all XML errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronfriedman6 committed Nov 21, 2024
1 parent 2f77f3d commit e09b12d
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 192 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2024-11-21 -- v1.0.5
### Fixed
- Catch non-fatal XML errors and continue requesting. Only throw an error and stop when the API limit has been exceeded, when the request itself has failed, or when the ShopperTrak server cannot be reached even after retrying.

## 2024-11-14 -- v1.0.4
### Fixed
- When site ID is not found (error code "E101"), skip it without throwing an error
Expand Down
1 change: 1 addition & 0 deletions lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .shoppertrak_api_client import (
APIStatus,
ShopperTrakApiClient,
ShopperTrakApiClientError,
ALL_SITES_ENDPOINT,
Expand Down
47 changes: 18 additions & 29 deletions lib/pipeline_controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import pytz
import xml.etree.ElementTree as ET

from datetime import datetime, timedelta
from helpers.query_helper import (
Expand All @@ -12,6 +11,7 @@
REDSHIFT_RECOVERABLE_QUERY,
)
from lib import (
APIStatus,
ShopperTrakApiClient,
ALL_SITES_ENDPOINT,
SINGLE_SITE_ENDPOINT,
Expand Down Expand Up @@ -106,27 +106,26 @@ def get_location_hours_dict(self):
)
)
self.redshift_client.close_connection()
return {(row[0], row[1]): (row[2], row[3]) for row in raw_hours}
return {
(branch_code, weekday): (regular_open, regular_close)
for branch_code, weekday, regular_open, regular_close in raw_hours
}

def process_all_sites_data(self, end_date, batch_num):
"""Gets visits data from all available sites for the given day(s)"""
last_poll_date = self._get_poll_date(batch_num)
poll_date = last_poll_date + timedelta(days=1)
if poll_date <= end_date:
self.logger.info(f"Beginning batch {batch_num+1}: {poll_date.isoformat()}")
all_sites_xml_root = self.shoppertrak_api_client.query(
all_sites_response = self.shoppertrak_api_client.query(
ALL_SITES_ENDPOINT, poll_date
)
if type(all_sites_xml_root) != ET.Element:
self.logger.error(
"Error querying ShopperTrak API for all sites visits data"
)
raise PipelineControllerError(
"Error querying ShopperTrak API for all sites visits data"
) from None
if all_sites_response == APIStatus.ERROR:
self.logger.error("Failed to retrieve all sites visits data")
return

results = self.shoppertrak_api_client.parse_response(
all_sites_xml_root, poll_date
all_sites_response, poll_date
)
encoded_records = self.avro_encoder.encode_batch(results)
if not self.ignore_kinesis:
Expand Down Expand Up @@ -165,8 +164,8 @@ def process_broken_orbits(self, start_date, end_date):
known_data_dict = dict()
if known_data:
known_data_dict = {
(row[0], row[1], row[2]): (row[3], row[4], row[5], row[6])
for row in known_data
(site_id, orbit, inc_start): (redshift_id, is_healthy, enters, exits)
for site_id, orbit, inc_start, redshift_id, is_healthy, enters, exits in known_data
}
self._recover_data(recoverable_site_dates, known_data_dict)
self.redshift_client.close_connection()
Expand All @@ -177,20 +176,15 @@ def _recover_data(self, site_dates, known_data_dict):
unhealthy data. Then check to see if the returned data is actually "recovered"
data, as it may have never been unhealthy to begin with. If so, send to Kinesis.
"""
for row in site_dates:
site_xml_root = self.shoppertrak_api_client.query(
SINGLE_SITE_ENDPOINT + row[0], row[1]
for site_id, visits_date in site_dates:
site_response = self.shoppertrak_api_client.query(
SINGLE_SITE_ENDPOINT + site_id, visits_date
)
# If the site ID can't be found (E101) or multiple sites match the same site
# ID (E104), continue to the next site. If the API limit has been reached
# (E107), stop.
if site_xml_root == "E101" or site_xml_root == "E104":
continue
elif site_xml_root == "E107":
break
if site_response == APIStatus.ERROR:
self.logger.error(f"Failed to retrieve site visits data for {site_id}")
else:
site_results = self.shoppertrak_api_client.parse_response(
site_xml_root, row[1], is_recovery_mode=True
site_response, visits_date, is_recovery_mode=True
)
self._process_recovered_data(site_results, known_data_dict)

Expand Down Expand Up @@ -248,8 +242,3 @@ def _get_poll_date(self, batch_num):
else:
poll_str = self.s3_client.fetch_cache()["last_poll_date"]
return datetime.strptime(poll_str, "%Y-%m-%d").date()


class PipelineControllerError(Exception):
def __init__(self, message=None):
self.message = message
77 changes: 36 additions & 41 deletions lib/shoppertrak_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import xml.etree.ElementTree as ET

from datetime import datetime
from enum import Enum
from nypl_py_utils.functions.log_helper import create_log
from requests.auth import HTTPBasicAuth
from requests.exceptions import RequestException
Expand All @@ -14,6 +15,15 @@
_WEEKDAY_MAP = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"}


class APIStatus(Enum):
SUCCESS = 1 # The API successfully retrieved the data
RETRY = 2 # The API is busy or down and should be retried later

# The API returned a request-specific error. This status indicates that while this
# request failed, others with different parameters may still succeed.
ERROR = 3


class ShopperTrakApiClient:
"""Class for querying the ShopperTrak API for location visits data"""

Expand All @@ -27,9 +37,9 @@ def __init__(self, username, password, location_hours_dict):

def query(self, endpoint, query_date, query_count=1):
"""
Sends query to ShopperTrak API and returns the result as an XML root if
possible. If the API returns that it's busy, this method waits and recursively
calls itself.
Sends query to ShopperTrak API and either a) returns the result as an XML root
if the query was successful, b) returns APIStatus.ERROR if the query failed but
others should be attempted, or c) waits and tries again if the API was busy.
"""
full_url = self.base_url + endpoint
date_str = query_date.strftime("%Y%m%d")
Expand All @@ -48,23 +58,26 @@ def query(self, endpoint, query_date, query_count=1):
f"Failed to retrieve response from {full_url}: {e}"
) from None

response_root = self._check_response(response.text)
if response_root == "E108" or response_root == "E000":
response_status, response_root = self._check_response(response.text)
if response_status == APIStatus.SUCCESS:
return response_root
elif response_status == APIStatus.ERROR:
return response_status
elif response_status == APIStatus.RETRY:
if query_count < self.max_retries:
self.logger.info("Waiting 5 minutes and trying again")
time.sleep(300)
return self.query(endpoint, query_date, query_count + 1)
else:
self.logger.error(
f"Reached max retries: sent {self.max_retries} queries with no "
f"response"
f"Hit retry limit: sent {self.max_retries} queries with no response"
)
raise ShopperTrakApiClientError(
f"Reached max retries: sent {self.max_retries} queries with no "
f"response"
f"Hit retry limit: sent {self.max_retries} queries with no response"
)
else:
return response_root
self.logger.error(f"Unknown API status: {response_status}")
raise ShopperTrakApiClientError(f"Unknown API status: {response_status}")

def parse_response(self, xml_root, input_date, is_recovery_mode=False):
"""
Expand Down Expand Up @@ -191,51 +204,33 @@ def _form_row(

def _check_response(self, response_text):
"""
If API response is XML, does not contain an error, and contains at least one
traffic attribute, returns response as an XML root. Otherwise, throws an error.
Checks response for errors. If none are found, returns the XML root. Otherwise,
either throws an error or returns an APIStatus where appropriate.
"""
try:
root = ET.fromstring(response_text)
error = root.find("error")
except ET.ParseError as e:
self.logger.error(f"Could not parse XML response {response_text}: {e}")
raise ShopperTrakApiClientError(
f"Could not parse XML response {response_text}: {e}"
) from None
return APIStatus.ERROR, None

if error is not None and error.text is not None:
# E000 is used when ShopperTrak is down and they recommend trying again
if error.text == "E000":
self.logger.info("E000: ShopperTrak is down")
return "E000"
# E101 is used when the given site ID is not recognized
elif error.text == "E101":
self.logger.warning("E101: site ID not found")
return "E101"
# E104 is used when the given site ID matches multiple sites
elif error.text == "E104":
self.logger.warning("E104: site ID has multiple matches")
return "E104"
# E107 is used when the daily API limit has been exceeded
elif error.text == "E107":
self.logger.info("E107: API limit exceeded")
return "E107"
# E108 is used when ShopperTrak is busy and they recommend trying again
elif error.text == "E108":
self.logger.info("E108: ShopperTrak is busy")
return "E108"
if error.text == "E107":
self.logger.error("API limit exceeded")
raise ShopperTrakApiClientError(f"API limit exceeded")
# E000 is used when ShopperTrak is down and E108 is used when it's busy
elif error.text == "E000" or error.text == "E108":
self.logger.info("ShopperTrak is unavailable")
return APIStatus.RETRY, None
else:
self.logger.error(f"Error found in XML response: {response_text}")
raise ShopperTrakApiClientError(
f"Error found in XML response: {response_text}"
)
return APIStatus.ERROR, None
elif len(root.findall(".//traffic")) == 0:
self.logger.error(f"No traffic found in XML response: {response_text}")
raise ShopperTrakApiClientError(
f"No traffic found in XML response: {response_text}"
)
return APIStatus.ERROR, None
else:
return root
return APIStatus.SUCCESS, root

def _get_xml_str(self, xml, attribute):
"""
Expand Down
56 changes: 19 additions & 37 deletions tests/test_pipeline_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from datetime import date, datetime, time
from helpers.query_helper import REDSHIFT_DROP_QUERY, REDSHIFT_RECOVERABLE_QUERY
from lib.pipeline_controller import PipelineController, PipelineControllerError
from lib.pipeline_controller import PipelineController
from lib.shoppertrak_api_client import APIStatus
from tests.test_helpers import TestHelpers


Expand Down Expand Up @@ -219,14 +220,15 @@ def test_process_all_sites_data_multi_run(self, test_instance, mock_logger, mock
]
)

def test_process_all_sites_error(self, test_instance, mock_logger, mocker):
def test_process_all_sites_error(self, test_instance, mock_logger, mocker, caplog):
test_instance.s3_client.fetch_cache.return_value = {
"last_poll_date": "2023-12-30"}
test_instance.shoppertrak_api_client.query.return_value = "error"
test_instance.shoppertrak_api_client.query.return_value = APIStatus.ERROR

with pytest.raises(PipelineControllerError):
with caplog.at_level(logging.WARNING):
test_instance.process_all_sites_data(date(2023, 12, 31), 0)

assert "Failed to retrieve all sites visits data" in caplog.text
test_instance.s3_client.fetch_cache.assert_called_once()
test_instance.shoppertrak_api_client.query.assert_called_once_with(
"allsites", date(2023, 12, 31)
Expand Down Expand Up @@ -323,51 +325,31 @@ def test_recover_data(self, test_instance, mock_logger, mocker):
]
)

def test_recover_data_bad_sites(self, test_instance, mock_logger, mocker):
def test_recover_data_error(self, test_instance, mocker, caplog):
test_instance.shoppertrak_api_client.query.side_effect = [
_TEST_XML_ROOT, "E104", "E101", _TEST_XML_ROOT]
_TEST_XML_ROOT, APIStatus.ERROR, _TEST_XML_ROOT]
mocked_process_recovered_data_method = mocker.patch(
"lib.pipeline_controller.PipelineController._process_recovered_data"
)

test_instance._recover_data(
[
("aa", date(2023, 12, 1)),
("bb", date(2023, 12, 1)),
("cc", date(2023, 12, 1)),
("aa", date(2023, 12, 2)),
],
_TEST_KNOWN_DATA_DICT,
)

with caplog.at_level(logging.WARNING):
test_instance._recover_data(
[
("aa", date(2023, 12, 1)),
("bb", date(2023, 12, 1)),
("aa", date(2023, 12, 2)),
],
_TEST_KNOWN_DATA_DICT,
)

assert "Failed to retrieve site visits data for bb" in caplog.text
test_instance.shoppertrak_api_client.parse_response.assert_has_calls(
[
mocker.call(_TEST_XML_ROOT, date(2023, 12, 1), is_recovery_mode=True),
mocker.call(_TEST_XML_ROOT, date(2023, 12, 2), is_recovery_mode=True),
]
)
assert mocked_process_recovered_data_method.call_count == 2

def test_recover_data_api_limit(self, test_instance, mock_logger, mocker):
test_instance.shoppertrak_api_client.query.side_effect = [
_TEST_XML_ROOT, "E107"]
mocked_process_recovered_data_method = mocker.patch(
"lib.pipeline_controller.PipelineController._process_recovered_data"
)

test_instance._recover_data(
[
("aa", date(2023, 12, 1)),
("bb", date(2023, 12, 2)),
("aa", date(2023, 12, 2)),
],
_TEST_KNOWN_DATA_DICT,
)

test_instance.shoppertrak_api_client.parse_response.assert_called_once_with(
_TEST_XML_ROOT, date(2023, 12, 1), is_recovery_mode=True
)
mocked_process_recovered_data_method.assert_called_once()

def test_process_recovered_data(self, test_instance, mocker, caplog):
mocked_update_query = mocker.patch(
Expand Down
Loading

0 comments on commit e09b12d

Please sign in to comment.