Skip to content

Commit

Permalink
Adjust pk chunking with a fall back to rest request
Browse files Browse the repository at this point in the history
  • Loading branch information
butkeraites-hotglue committed Oct 3, 2024
1 parent afd0dc6 commit ab6a26b
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 81 deletions.
5 changes: 2 additions & 3 deletions tap_salesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TapSalesforceException, TapSalesforceQuotaExceededException, TapSalesforceBulkAPIDisabledException)

from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial


LOGGER = singer.get_logger()
Expand Down Expand Up @@ -555,10 +554,10 @@ def main_impl():
# Use ThreadPoolExecutor to process the catalog entries in parallel using threads
with ThreadPoolExecutor() as executor:
# Partial function with shared session and config
process_func = partial(process_catalog_entry, sf_data=sf_data, state=args.state, catalog=catalog, config=CONFIG)
state = args.state

# Submit tasks to the executor for each stream
futures = [executor.submit(process_func, stream) for stream in catalog["streams"]]
futures = [executor.submit(process_catalog_entry, stream, sf_data, state, catalog, CONFIG) for stream in catalog["streams"]]

# Optionally wait for all tasks to complete and handle exceptions
for future in futures:
Expand Down
21 changes: 16 additions & 5 deletions tap_salesforce/salesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,18 @@ def _make_request(self, http_method, url, headers=None, body=None, stream=False,
elif http_method == "POST":
LOGGER.info("Making %s request to %s with body %s", http_method, url, body)
resp = self.session.post(url, headers=headers, data=body)
LOGGER.info("Completed %s request to %s with body: %s", http_method, url, body)
else:
raise TapSalesforceException("Unsupported HTTP method")

try:
resp.raise_for_status()
except RequestException as ex:
raise ex
try:
if "is not supported to use PKChunking" not in ex.response.json()['exceptionMessage']:
raise ex
except:
raise ex

if resp.headers.get('Sforce-Limit-Info') is not None:
self.rest_requests_attempted += 1
Expand Down Expand Up @@ -435,10 +440,16 @@ def query(self, catalog_entry, state, query_override=None):
if state["bookmarks"].get("ListView"):
if state["bookmarks"]["ListView"].get("SystemModstamp"):
del state["bookmarks"]["ListView"]["SystemModstamp"]
if self.api_type == BULK_API_TYPE and query_override is None:
bulk = Bulk(self)
return bulk.query(catalog_entry, state)
elif self.api_type == REST_API_TYPE or query_override is not None:
try_rest_call = False
try:
if self.api_type == BULK_API_TYPE and query_override is None:
bulk = Bulk(self)
return bulk.query(catalog_entry, state)
except Exception as e:
LOGGER.warning(f"[FAILURE] BULK API failed for catalog entry {catalog_entry} and state {state}. Trying a rest call.")
LOGGER.info(e)
try_rest_call = True
if try_rest_call or self.api_type == REST_API_TYPE or query_override is not None:
rest = Rest(self)
return rest.query(catalog_entry, state, query_override=query_override)
else:
Expand Down
158 changes: 88 additions & 70 deletions tap_salesforce/salesforce/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, sf):
# Set csv max reading size to the platform's max size available.
csv.field_size_limit(sys.maxsize)
self.sf = sf
self.closed_jobs = []

def has_permissions(self):
try:
Expand Down Expand Up @@ -113,25 +114,47 @@ def _can_pk_chunk_job(self, failure_message): # pylint: disable=no-self-use
"Failed to write query result" in failure_message

def _bulk_query(self, catalog_entry, state):
start_date = self.sf.get_start_date(state, catalog_entry)
batch_status = self._bulk_query_with_pk_chunking(catalog_entry, start_date)
job_id = batch_status['job_id']
self.sf.pk_chunking = True
# Write job ID and batch state for resumption
tap_stream_id = catalog_entry['tap_stream_id']
state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id)
state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batch_status['completed'][:])

# Parallelize the batch result processing
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.process_batch, job_id, completed_batch_id, catalog_entry, state)
for completed_batch_id in batch_status['completed']
]

# Process the results as they complete
for future in futures:
for result in future.result():
try:
start_date = self.sf.get_start_date(state, catalog_entry)
batch_status = self._bulk_query_with_pk_chunking(catalog_entry, start_date)
job_id = batch_status['job_id']
self.sf.pk_chunking = True
# Write job ID and batch state for resumption
tap_stream_id = catalog_entry['tap_stream_id']
state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id)
state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batch_status['completed'][:])

# Parallelize the batch result processing
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.process_batch, job_id, completed_batch_id, catalog_entry, state)
for completed_batch_id in batch_status['completed']
]

# Process the results as they complete
for future in futures:
for result in future.result():
yield result
except Exception as e:
if job_id in self.closed_jobs:
LOGGER.info(f"Another batch failed before. Ignoring this new job...")
pass
LOGGER.info(f"PK Chunking failled on job {job_id}. Trying without it...")
self._close_job(job_id)

job_id = self._create_job(catalog_entry)
start_date = self.sf.get_start_date(state, catalog_entry)
self.sf.pk_chunking = False

batch_id = self._add_batch(catalog_entry, job_id, start_date)

self._close_job(job_id)

batch_status = self._poll_on_batch_status(job_id, batch_id)
if batch_status['state'] == 'Failed':
raise TapSalesforceException(batch_status['stateMessage'])
else:
for result in self.get_batch_results(job_id, batch_id, catalog_entry):
yield result

def process_batch(self, job_id, batch_id, catalog_entry, state):
Expand All @@ -146,7 +169,7 @@ def process_batch(self, job_id, batch_id, catalog_entry, state):
singer.write_state(state)

def _bulk_query_with_pk_chunking(self, catalog_entry, start_date):
LOGGER.info("Retrying Bulk Query with PK Chunking")
LOGGER.info("Trying Bulk Query with PK Chunking")

# Create a new job
job_id = self._create_job(catalog_entry, True)
Expand Down Expand Up @@ -220,6 +243,8 @@ def _poll_on_pk_chunked_batch_status(self, job_id):
if not queued_batches and not in_progress_batches:
completed_batches = [b['id'] for b in batches if b['state'] == "Completed"]
failed_batches = [b['id'] for b in batches if b['state'] == "Failed"]
if len(failed_batches) > 0:
LOGGER.error(f"{[{b['id']:b.get('stateMessage')} for b in batches if b['state'] == 'Failed']}")
return {'completed': completed_batches, 'failed': failed_batches}
else:
time.sleep(PK_CHUNKED_BATCH_STATUS_POLLING_SLEEP)
Expand All @@ -230,6 +255,7 @@ def _poll_on_batch_status(self, job_id, batch_id):
batch_id=batch_id)

while batch_status['state'] not in ['Completed', 'Failed', 'Not Processed']:
LOGGER.info(f'job_id: {job_id}, batch_id: {batch_id} - batch_status["state"]: {batch_status["state"]} - Sleeping for {BATCH_STATUS_POLLING_SLEEP} seconds...')
time.sleep(BATCH_STATUS_POLLING_SLEEP)
batch_status = self._get_batch(job_id=job_id,
batch_id=batch_id)
Expand Down Expand Up @@ -280,73 +306,65 @@ def _get_batch(self, job_id, batch_id):

return batch['batchInfo']

# Function to fetch and process each result in parallel
def process_result(self, job_id, batch_id, result):
endpoint = f"job/{job_id}/batch/{batch_id}/result/{result}"
url = self.bulk_url.format(self.sf.instance_url, endpoint)
headers = {'Content-Type': 'text/csv'}

# Use a context manager for temporary file handling
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf8", delete=False) as csv_file:
# Stream the CSV content from Salesforce Bulk API
try:
resp = self.sf._make_request('GET', url, headers=headers, stream=True)
resp.raise_for_status() # Ensure we handle errors from the request

# Write chunks of CSV data to the temp file
for chunk in resp.iter_content(chunk_size=ITER_CHUNK_SIZE, decode_unicode=True):
if chunk:
csv_file.write(chunk.replace('\0', '')) # Replace NULL bytes

csv_file.seek(0) # Move back to the start of the file after writing

except requests.exceptions.RequestException as e:
# Handle any request errors (timeouts, connection errors, etc.)
raise TapSalesforceException(f"Error fetching results: {str(e)}")

# Now process the CSV file
with open(csv_file.name, mode='r', encoding='utf8') as f:
csv_reader = csv.reader(f, delimiter=',', quotechar='"')

try:
# Read column names from the first line
column_name_list = next(csv_reader)
except StopIteration:
# Handle case where no data is returned (empty CSV)
raise TapSalesforceException(f"No data found in batch {batch_id} result.")

# Process each row in the CSV file
for line in csv_reader:
record = dict(zip(column_name_list, line))
yield record

def get_batch_results(self, job_id, batch_id, catalog_entry):
"""Given a job_id and batch_id, queries the batch results and reads
CSV lines, yielding each line as a record."""
headers = self._get_bulk_headers()
endpoint = f"job/{job_id}/batch/{batch_id}/result"
url = self.bulk_url.format(self.sf.instance_url, endpoint)
batch_url = self.bulk_url.format(self.sf.instance_url, endpoint)

# Timing the request
with metrics.http_request_timer("batch_result_list") as timer:
timer.tags['sobject'] = catalog_entry['stream']
batch_result_resp = self.sf._make_request('GET', url, headers=headers)
batch_result_resp = self.sf._make_request('GET', batch_url, headers=headers)

# Parse the result list from the XML response
batch_result_list = xmltodict.parse(batch_result_resp.text, xml_attribs=False, force_list={'result'})['result-list']

# Use ThreadPoolExecutor to parallelize the processing of results
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor for parallel execution
futures = [executor.submit(self.process_result, job_id, batch_id, result) for result in batch_result_list['result']]

# Yield the results as they complete
for future in futures:
# `future.result()` is a generator, so yield each record from it
for record in future.result():
for result in batch_result_list['result']:
url = batch_url + f"/{result}"
headers['Content-Type'] = 'text/csv'

# Use a context manager for temporary file handling
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf8", delete=False) as csv_file:
# Stream the CSV content from Salesforce Bulk API
try:
resp = self.sf._make_request('GET', url, headers=headers, stream=True)
resp.raise_for_status() # Ensure we handle errors from the request

# Write chunks of CSV data to the temp file
for chunk in resp.iter_content(chunk_size=ITER_CHUNK_SIZE, decode_unicode=True):
if chunk:
csv_file.write(chunk.replace('\0', '')) # Replace NULL bytes

csv_file.seek(0) # Move back to the start of the file after writing

except requests.exceptions.RequestException as e:
# Handle any request errors (timeouts, connection errors, etc.)
raise TapSalesforceException(f"Error fetching results: {str(e)}")

# Now process the CSV file
with open(csv_file.name, mode='r', encoding='utf8') as f:
csv_reader = csv.reader(f, delimiter=',', quotechar='"')

try:
# Read column names from the first line
column_name_list = next(csv_reader)
except StopIteration:
# Handle case where no data is returned (empty CSV)
raise TapSalesforceException(f"No data found in batch {batch_id} result.")

# Process each row in the CSV file
for line in csv_reader:
record = dict(zip(column_name_list, line))
yield record

def _close_job(self, job_id):
if job_id in self.closed_jobs:
LOGGER.info(f"Job {job_id} already closed. Skipping the request")
return
self.closed_jobs.append(job_id)
endpoint = "job/{}".format(job_id)
url = self.bulk_url.format(self.sf.instance_url, endpoint)
body = {"state": "Closed"}
Expand Down
3 changes: 0 additions & 3 deletions tap_salesforce/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def sync_records(sf, catalog_entry, state, input_state, counter, catalog,config=
start_time = singer_utils.now()

LOGGER.info('Syncing Salesforce data for stream %s', stream)
records_post = []

if "/" in state["current_stream"]:
# get current name
Expand Down Expand Up @@ -367,8 +366,6 @@ def sync_list_views_stream(sf, catalog_entry, state, input_state, catalog, repli
if selected_list==f"ListView_{isob['SobjectType']}_{isob['DeveloperName']}":
selected_lists_names.append(isob)

replication_key_value = replication_key and singer_utils.strptime_with_tz(rec[replication_key])

for list_info in selected_lists_names:
sobject = list_info['SobjectType']
lv_name = list_info['DeveloperName']
Expand Down

0 comments on commit ab6a26b

Please sign in to comment.