diff --git a/solardatatools/load_redshift_data.py b/solardatatools/load_redshift_data.py index 9aa1b956..f8d73127 100644 --- a/solardatatools/load_redshift_data.py +++ b/solardatatools/load_redshift_data.py @@ -1,8 +1,7 @@ -import sshtunnel import os import numpy as np -import redshift_connector from functools import wraps +from datetime import datetime import pandas as pd from typing import Callable, TypedDict from dotenv import load_dotenv @@ -43,86 +42,166 @@ class DBConnectionParams(TypedDict): db_local_hostname: str = "127.0.0.1" -def create_tunnel_and_connect(ssh_params: SSHParams): - def decorator(func: Callable): - @wraps(func) - def inner_wrapper(db_connection_params: DBConnectionParams, *args, **kwargs): - with sshtunnel.SSHTunnelForwarder( - ssh_address_or_host=ssh_params["ssh_address_or_host"], - ssh_username=ssh_params["ssh_username"], - ssh_pkey=os.path.abspath(ssh_params["ssh_private_key"]), - remote_bind_address=ssh_params["remote_bind_address"], - host_pkey_directories=[ - os.path.dirname(os.path.abspath(ssh_params["ssh_private_key"])) - ], - ) as tunnel: - if tunnel is None: - raise Exception("Tunnel is None") - - tunnel.start() - - if tunnel.is_active is False: - raise Exception("Tunnel is not active") - - local_port = tunnel.local_bind_port - db_connection_params["port"] = local_port - - return func(db_connection_params, *args, **kwargs) - - return inner_wrapper - - return decorator - - def load_redshift_data( ssh_params: SSHParams, redshift_params: DBConnectionParams, siteid: str, column: str = "ac_power", - sensor: str | None = None, - tmin: str | None = None, - tmax: str | None = None, + sensor: int | None = None, + tmin: datetime | None = None, + tmax: datetime | None = None, limit: int | None = None, - verbose: bool = True, -): - sql_query = """ - SELECT site, meas_name, ts, sensor, meas_val_f FROM measurements - WHERE site = '{}' - AND meas_name = '{}' - """.format( - siteid, column - ) + verbose: bool = False, +) -> pd.DataFrame: + """Loads data based on a site id from a Redshift database into a Pandas DataFrame using an SSH tunnel + + Parameters + ---------- + ssh_params : SSHParams + SSH connection parameters + redshift_params : DBConnectionParams + Redshift connection parameters + siteid : str + site id to query + column : str + meas_name to query (default ac_power) + sensor : int, optional + sensor index to query based on number of sensors at the site id (default None) + tmin : timestamp, optional + minimum timestamp to query (default None) + tmax : timestamp, optional + maximum timestamp to query (default None) + limit : int, optional + maximum number of rows to query (default None) + verbose : bool, optional + whether to print out timing information (default False) + + Returns + ------ + df : pd.DataFrame + Pandas DataFrame containing the queried data + """ - ts_constraint = np.logical_or(tmin is not None, tmax is not None) - if tmin is not None: - sql_query += "and ts > '{}'\n".format(tmin) - if tmax is not None: - sql_query += "and ts < '{}'\n".format(tmax) - if sensor is not None and ts_constraint: - sql_query += "and sensor = '{}'\n".format(sensor) - elif sensor is not None and not ts_constraint: - sql_query += "and ts > '2000-01-01'\n" - sql_query += "and sensor = '{}'\n".format(sensor) - if limit is not None: - sql_query += "LIMIT {}".format(limit) - sql_query += ";" + try: + import sshtunnel + except ImportError: + raise Exception( + "Please install sshtunnel into your Python environment to use this function" + ) + + try: + import redshift_connector + except ImportError: + raise Exception( + "Please install redshift_connector into your Python environment to use this function" + ) + def timing(verbose: bool = True): + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time() + result = func(*args, **kwargs) + end_time = time() + execution_time = end_time - start_time + if verbose: + print(f"{func.__name__} took {execution_time:.2f} seconds to run") + return result + + return wrapper + + return decorator + + def create_tunnel_and_connect(ssh_params: SSHParams): + def decorator(func: Callable): + @wraps(func) + def inner_wrapper( + db_connection_params: DBConnectionParams, *args, **kwargs + ): + with sshtunnel.SSHTunnelForwarder( + ssh_address_or_host=ssh_params["ssh_address_or_host"], + ssh_username=ssh_params["ssh_username"], + ssh_pkey=os.path.abspath(ssh_params["ssh_private_key"]), + remote_bind_address=ssh_params["remote_bind_address"], + host_pkey_directories=[ + os.path.dirname(os.path.abspath(ssh_params["ssh_private_key"])) + ], + ) as tunnel: + if tunnel is None: + raise Exception("Tunnel is None") + + tunnel.start() + + if tunnel.is_active is False: + raise Exception("Tunnel is not active") + + local_port = tunnel.local_bind_port + db_connection_params["port"] = local_port + + return func(db_connection_params, *args, **kwargs) + + return inner_wrapper + + return decorator + + @timing(verbose) @create_tunnel_and_connect(ssh_params) - def create_df_from_query(redshift_params, sql_query): + def create_df_from_query( + redshift_params: DBConnectionParams, sql_query: str + ) -> pd.DataFrame: with redshift_connector.connect(**redshift_params) as conn: with conn.cursor() as cursor: cursor.execute(sql_query) df = cursor.fetch_dataframe() return df - ti = time() + sensor_found: bool = False + sensor_dict: dict = {} + if sensor is not None: + sensor = sensor - 1 + + site_sensor_map_query = f""" + SELECT sensor FROM measurements + WHERE site = '{siteid}' + GROUP BY sensor + ORDER BY sensor ASC + """ + + site_sensor_df = create_df_from_query(redshift_params, site_sensor_map_query) + + if site_sensor_df is None: + raise Exception("No data returned from query when getting sensor map") + sensor_dict = site_sensor_df.to_dict()["sensor"] + if sensor not in sensor_dict: + raise Exception( + f"The index of {sensor + 1} for a sensor at site {siteid} is out of bounds. For site {siteid} please choose a sensor index ranging from 1 to {len(sensor_dict)}" + ) + sensor_found = True + + sql_query = f""" + SELECT site, meas_name, ts, sensor, meas_val_f FROM measurements + WHERE site = '{siteid}' + AND meas_name = '{column}' + """ + + # ts_constraint = np.logical_or(tmin is not None, tmax is not None) + if sensor is not None and sensor_found: + sql_query += f"AND sensor = '{sensor_dict.get(sensor)}'\n" + if tmin is not None: + sql_query += f"AND ts > '{tmin}'\n" + if tmax is not None: + sql_query += f"AND ts < '{tmax}'\n" + # if sensor is not None and ts_constraint: + # sql_query += f"AND sensor = '{sensor}'\n" + # elif sensor is not None and not ts_constraint: + # sql_query += f"AND ts > '2000-01-01'\n" + # sql_query += f"AND sensor = '{sensor}'\n" + if limit is not None: + sql_query += f"LIMIT {limit}\n" + df = create_df_from_query(redshift_params, sql_query) if df is None: raise Exception("No data returned from query") - # df.replace(-999999.0, np.NaN, inplace=True) - tf = time() - if verbose: - print("Query of {} rows complete in {:.2f} seconds".format(len(df), tf - ti)) return df @@ -142,11 +221,20 @@ def create_df_from_query(redshift_params, sql_query): "port": 0, } try: + start_time = datetime(2017, 8, 2, 19) + df = load_redshift_data( - ssh_params, redshift_params, siteid="ZT163485000441C1369", limit=10 + ssh_params, + redshift_params, + siteid="ZT163485000441C1369", + sensor=3, + tmin=start_time, + limit=100, ) if df is None: raise Exception("No data returned from query") + if df.empty: + raise Exception("Empty dataframe returned from query") print(df.head(100)) except Exception as e: print(e)