diff --git a/fink_broker/hbaseUtils.py b/fink_broker/hbaseUtils.py index e3a20b81..3879910a 100644 --- a/fink_broker/hbaseUtils.py +++ b/fink_broker/hbaseUtils.py @@ -44,7 +44,7 @@ def load_fink_cols(): -------- >>> fink_cols, fink_nested_cols = load_fink_cols() >>> print(len(fink_cols)) - 17 + 19 >>> print(len(fink_nested_cols)) 18 @@ -66,7 +66,9 @@ def load_fink_cols(): 'tracklet': {'type': 'string', 'default': ''}, 'vsx': {'type': 'string', 'default': 'Unknown'}, 'x3hsp': {'type': 'string', 'default': 'Unknown'}, - 'x4lac': {'type': 'string', 'default': 'Unknown'} + 'x4lac': {'type': 'string', 'default': 'Unknown'}, + 'lc_features_g': {'type': 'string', 'default': '[]'}, + 'lc_features_r': {'type': 'string', 'default': '[]'}, } fink_nested_cols = {} @@ -97,7 +99,7 @@ def load_all_cols(): >>> root_level, candidates, images, fink_cols, fink_nested_cols = load_all_cols() >>> out = {**root_level, **candidates, **images, **fink_cols, **fink_nested_cols} >>> print(len(out)) - 146 + 148 """ fink_cols, fink_nested_cols = load_fink_cols() @@ -308,7 +310,7 @@ def load_ztf_index_cols(): -------- >>> out = load_ztf_index_cols() >>> print(len(out)) - 70 + 72 """ common = [ 'objectId', 'candid', 'publisher', 'rcid', 'chipsf', 'distnr', @@ -326,7 +328,7 @@ def load_ztf_index_cols(): 'vsx', 'snn_snia_vs_nonia', 'snn_sn_vs_all', 'rf_snia_vs_nonia', 'classtar', 'drb', 'ndethist', 'rf_kn_vs_nonkn', 'tracklet', - 'anomaly_score', 'x4lac', 'x3hsp' + 'anomaly_score', 'x4lac', 'x3hsp', 'lc_features_g', 'lc_features_r' ] mangrove = [ @@ -801,8 +803,11 @@ def push_full_df_to_hbase(df, row_key_name, table_name, catalog_name): catalog_name: str Name for the JSON catalog (saved locally for inspection) """ + # Cast feature columns + df_casted = cast_features(df) + # Check all columns exist, fill if necessary, and cast data - df_flat, cols_i, cols_d, cols_b = bring_to_current_schema(df) + df_flat, cols_i, cols_d, cols_b = bring_to_current_schema(df_casted) # Assign each column to a specific column family # This is independent from the final structure @@ -832,6 +837,42 @@ def push_full_df_to_hbase(df, row_key_name, table_name, catalog_name): catfolder=catalog_name ) +def cast_features(df): + """ Cast feature columns into string of array + + Parameters + ---------- + df: Spark DataFrame + DataFrame of alerts + + Returns + ---------- + df: Spark DataFrame + + Examples + ---------- + # Read alert from the raw database + >>> df = spark.read.format("parquet").load(ztf_alert_sample_scidatabase) + + >>> df = cast_features(df) + >>> assert 'lc_features_g' in df.columns, df.columns + + >>> a_row = df.select('lc_features_g').limit(1).toPandas().values[0][0] + >>> assert isinstance(a_row, str), a_row + """ + if ('lc_features_g' in df.columns) and ('lc_features_r' in df.columns): + df = df.withColumn( + 'lc_features_g', + F.array('lc_features_g.*').astype('string') + ) + + df = df.withColumn( + 'lc_features_r', + F.array('lc_features_r.*').astype('string') + ) + + return df + if __name__ == "__main__": """ Execute the test suite with SparkSession initialised """