Skip to content

Commit

Permalink
Refactor data loading and querying functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellAV committed Dec 7, 2023
1 parent ea89593 commit 56bdf2c
Showing 1 changed file with 66 additions and 17 deletions.
83 changes: 66 additions & 17 deletions solardatatools/dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def load_redshift_data(
Pandas DataFrame containing the queried data
"""

class QueryParams(TypedDict):
api_key: str
siteid: str
column: str
sensor: int | list[int] | None
tmin: datetime | None
tmax: datetime | None
limit: int | None

def decompress_data_to_dataframe(encoded_data):
# Decompress gzip data
decompressed_buffer = BytesIO(encoded_data)
Expand Down Expand Up @@ -295,19 +304,21 @@ def wrapper(*args, **kwargs):

return decorator

@timing(verbose)
def query_redshift_w_api(page: int, is_batch: bool = False) -> requests.Response:
# @timing(verbose)
def query_redshift_w_api(
params: QueryParams, page: int, is_batch: bool = False
) -> requests.Response:
url = "https://lmojfukey3rylrbqughzlfu6ca0ujdby.lambda-url.us-west-1.on.aws/"
payload = {
"api_key": api_key,
"siteid": siteid,
"column": column,
"sensor": sensor,
"tmin": str(tmin),
"tmax": str(tmax),
"limit": str(limit),
"api_key": params.get("api_key"),
"siteid": params.get("siteid"),
"column": params.get("column"),
"sensor": params.get("sensor"),
"tmin": str(params.get("tmin")),
"tmax": str(params.get("tmax")),
"limit": str(params.get("limit")),
"page": str(page),
"batch_num": str(is_batch),
"is_batch": str(is_batch),
}

if sensor is None:
Expand All @@ -320,17 +331,23 @@ def query_redshift_w_api(page: int, is_batch: bool = False) -> requests.Response
payload.pop("limit")

response = requests.post(url, json=payload, timeout=60 * 5)
if response.status_code != 200:
raise Exception(f"Error {response.status_code} returned from API")

if response.status_code != 200:
error = response.json()
error_msg = error["error"]
raise Exception(
f"Query failed with status code {response.status_code}: {error_msg}"
)
if verbose:
print(f"Content size: {len(response.content)}")

return response

def fetch_data(df_list: list[pd.DataFrame], index: int, page: int):
def fetch_data(
query_params: QueryParams, df_list: list[pd.DataFrame], index: int, page: int
):
try:
response = query_redshift_w_api(page)
response = query_redshift_w_api(query_params, page)
new_df = decompress_data_to_dataframe(response.content)

if new_df.empty:
Expand All @@ -345,8 +362,25 @@ def fetch_data(df_list: list[pd.DataFrame], index: int, page: int):

import threading

batch_df: requests.Response = query_redshift_w_api(0, is_batch=True)
data = batch_df.json()
data: Dict[str, Any] = {}

query_params: QueryParams = {
"api_key": api_key,
"siteid": siteid,
"column": column,
"sensor": sensor,
"tmin": tmin,
"tmax": tmax,
"limit": limit,
}

try:
batch_df: requests.Response = query_redshift_w_api(
query_params, 0, is_batch=True
)
data = batch_df.json()
except Exception as e:
raise e
max_limit = int(data["max_limit"])
total_count = int(data["total_count"])
batches = int(data["batches"])
Expand All @@ -358,9 +392,19 @@ def fetch_data(df_list: list[pd.DataFrame], index: int, page: int):
batch_size = 2 # Max number of threads to run at once (limited by redshift)

loops = math.ceil(batches / batch_size)

if batches <= batch_size:
loops = 1
batch_size = batches
# if limit is not None:
# if limit <= max_limit:
# loops = 1
# batch_size = 1
# else:
# loops = math.ceil(limit / max_limit)
# batch_size = math.ceil(batches / loops)

running_count = total_count
page = 0
df = pd.DataFrame()
list_of_dfs: list[pd.DataFrame] = []
Expand All @@ -371,12 +415,17 @@ def fetch_data(df_list: list[pd.DataFrame], index: int, page: int):

# Create threads for each batch of pages
for i in range(len(page_batch)):
query_params_copy = query_params.copy()
if running_count < max_limit:
query_params_copy["limit"] = running_count
thread = threading.Thread(
target=fetch_data, args=(df_list, i, page_batch[i])
target=fetch_data, args=(query_params_copy, df_list, i, page_batch[i])
)
threads.append(thread)
thread.start()

running_count -= max_limit

# Wait for all threads to complete
for thread in threads:
thread.join()
Expand Down

0 comments on commit 56bdf2c

Please sign in to comment.