From 81ec25a3092ffb238b9ae37e69ba6f9388e1b1a9 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Fri, 8 Nov 2024 19:57:54 +0000 Subject: [PATCH] [GSProcessing] Add transformation saving and re-applying for numerical transforms. --- .../distributed/example.rst | 6 +- .../dist_bucket_numerical_transformation.py | 2 +- .../dist_numerical_transformation.py | 446 +++++++++++++++--- .../tests/test_dist_executor.py | 11 +- .../test_dist_numerical_transformation.py | 129 ++++- 5 files changed, 511 insertions(+), 83 deletions(-) diff --git a/docs/source/cli/graph-construction/distributed/example.rst b/docs/source/cli/graph-construction/distributed/example.rst index 14d2d0e984..b215fd8370 100644 --- a/docs/source/cli/graph-construction/distributed/example.rst +++ b/docs/source/cli/graph-construction/distributed/example.rst @@ -259,7 +259,9 @@ the graph structure, features, and labels. In more detail: GSProcessing will use the transformation values listed here instead of creating new ones, ensuring that models trained with the original data can still be used in the newly transformed data. Currently only - categorical transformations can be re-applied. + categorical and numerical transformations can be re-applied. Note that + the Rank-Gauss transformation cannot support re-application, it can + only work for transductive tasks. * ``updated_row_counts_metadata.json``: This file is meant to be used as the input configuration for the distributed partitioning pipeline. ``gs-repartition`` produces @@ -313,7 +315,7 @@ you can use the following command to run the partition job locally: --num-parts 2 \ --dgl-tool-path ./dgl/tools \ --partition-algorithm random \ - --ip-config ip_list.txt + --ip-config ip_list.txt The command above will first do graph partitioning to determine the ownership for each partition and save the results. Then it will do data dispatching to physically assign the partitions to graph data and dispatch them to each machine. diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bucket_numerical_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bucket_numerical_transformation.py index 8848062437..f53ad78786 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bucket_numerical_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bucket_numerical_transformation.py @@ -67,7 +67,7 @@ def get_transformation_name() -> str: return "DistBucketNumericalTransformation" def apply(self, input_df: DataFrame) -> DataFrame: - imputed_df = apply_imputation(self.cols, self.shared_imputation, input_df) + imputed_df = apply_imputation(self.cols, self.shared_imputation, input_df).imputed_df # TODO: Make range optional by getting min/max from data. min_val, max_val = self.range diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py index 29f3a9519e..bf52b6caa5 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py @@ -15,13 +15,20 @@ """ import logging -from typing import Optional, Sequence import uuid +from dataclasses import dataclass +from typing import Any, Optional, Sequence from pyspark.sql import DataFrame from pyspark.sql import functions as F from pyspark.sql.types import ArrayType -from pyspark.ml.feature import MinMaxScaler, Imputer, VectorAssembler, ElementwiseProduct +from pyspark.ml.feature import ( + MinMaxScaler, + MinMaxScalerModel, + Imputer, + VectorAssembler, + ElementwiseProduct, +) from pyspark.ml.linalg import DenseVector from pyspark.ml.stat import Summarizer from pyspark.ml import Pipeline @@ -44,8 +51,42 @@ from ..spark_utils import rename_multiple_cols -def apply_imputation(cols: Sequence[str], shared_imputation: str, input_df: DataFrame) -> DataFrame: - """Applies a single imputation to input dataframe, individually to each of the columns +@dataclass +class ImputationResult: + """Container class to store the results of imputation. + + Parameters + ---------- + imputed_df: DataFrame + The imputed DataFrame. + impute_representation: dict[str, dict] + A representation of the imputation. + """ + + imputed_df: DataFrame + impute_representation: dict[str, Any] + + +@dataclass +class NormalizationResult: + """Container class to store the results of normalization. + + Parameters + ---------- + scaled_df: DataFrame + The normalized DataFrame. + normalization_representation: dict[str, dict] + A representation of the normalization. + """ + + scaled_df: DataFrame + normalization_representation: dict[str, Any] + + +def apply_imputation( + cols: Sequence[str], shared_imputation: str, input_df: DataFrame +) -> ImputationResult: + """Applies a single imputation to input DataFrame, individually to each of the columns provided in the cols argument. Parameters @@ -60,8 +101,9 @@ def apply_imputation(cols: Sequence[str], shared_imputation: str, input_df: Data Returns ------- - DataFrame - The imputed DataFrame. + ImputationResult + A dataclass containing the imputed DataFrame in the ``imputed_df`` element + and a dict representation of the imputation in the ``impute_representation`` element. """ # "mode" is another way to say most frequent, used by SparkML valid_inner_imputers = VALID_IMPUTERS + ["mode"] @@ -70,20 +112,34 @@ def apply_imputation(cols: Sequence[str], shared_imputation: str, input_df: Data f"Unsupported imputation strategy requested: {shared_imputation}, the supported " f"strategies are : {valid_inner_imputers}" ) + imputer_model = None if shared_imputation == "most_frequent": shared_imputation = "mode" + if shared_imputation == "none": imputed_df = input_df else: imputed_col_names = [col_name + "_imputed" for col_name in cols] imputer = Imputer(strategy=shared_imputation, inputCols=cols, outputCols=imputed_col_names) - model = imputer.fit(input_df) + imputer_model = imputer.fit(input_df) # Create transformed columns and drop originals, then rename transformed cols to original - input_df = model.transform(input_df).drop(*cols) + input_df = imputer_model.transform(input_df).drop(*cols) imputed_df, _ = rename_multiple_cols(input_df, imputed_col_names, cols) - return imputed_df + imputed_val_dict = {} + if imputer_model: + # Structure: {col_name[str]: imputed_val[float]} + imputed_val_dict = imputer_model.surrogateDF.collect()[0].asDict() + + impute_representation = { + "imputed_val_dict": imputed_val_dict, + "imputer_name": shared_imputation, + } + + imputed_df = imputed_df.select(*cols) + + return ImputationResult(imputed_df, impute_representation) def apply_norm( @@ -92,7 +148,7 @@ def apply_norm( imputed_df: DataFrame, out_dtype: str = TYPE_FLOAT32, epsilon: float = 1e-6, -) -> DataFrame: +) -> NormalizationResult: """Applies a single normalizer to the imputed dataframe, individually to each of the columns provided in the cols argument. @@ -114,8 +170,10 @@ def apply_norm( Returns ------- - DataFrame - The normalized DataFrame with only the columns listed in `cols` retained. + NormalizationResult + A dataclass containing the normalized DataFrame with only the + columns listed in ``cols`` retained in the ``scaled_df`` element, + and a dict representation of the transformation in the ``normalization_representation`` Raises ------ @@ -125,62 +183,32 @@ def apply_norm( ValueError If unsupported feature output dtype is provided. """ - other_cols = list(set(imputed_df.columns).difference(cols)) - - def single_vec_to_float(vec): - return float(vec[0]) - - # Use the map to get the corresponding data type object, or raise an error if not found - if out_dtype in DTYPE_MAP: - vec_udf = F.udf(single_vec_to_float, DTYPE_MAP[out_dtype]) - else: - raise ValueError("Unsupported feature output dtype") - assert shared_norm in VALID_NORMALIZERS, ( f"Unsupported normalization requested: {shared_norm}, the supported " f"strategies are : {VALID_NORMALIZERS}" ) + norm_representation: dict[str, Any] = { + "norm_name": shared_norm, + } if shared_norm == "none": # Save the time and efficiency for not casting the type # when not doing any normalization scaled_df = imputed_df + norm_representation["norm_reconstruction"] = {} elif shared_norm == "min-max": - # Because the scalers expect Vector input, we need to use VectorAssembler on each, - # creating one (scaled) vector per normalizer type - # TODO: See if it's possible to have all features under one assembler and scaler, - # speeding up the process. Then do the "disentaglement" on the caller side. - assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in cols] - scalers = [MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in cols] - - vector_cols = [col + "_vec" for col in cols] - scaled_cols = [col + "_scaled" for col in cols] - pipeline = Pipeline(stages=assemblers + scalers) - scaler_model = pipeline.fit(imputed_df) - scaled_df = scaler_model.transform(imputed_df).drop(*vector_cols).drop(*cols) - - scaled_df = scaled_df.select( - *[ - (vec_udf(F.col(scaled_col_name))).alias(orig_col) - for scaled_col_name, orig_col in zip(scaled_cols, cols) - ] - + other_cols + scaled_df, norm_reconstruction = _apply_min_max_transform( + imputed_df, + cols, + out_dtype, ) + norm_representation["norm_reconstruction"] = norm_reconstruction elif shared_norm == "standard": - col_sums = imputed_df.agg({col: "sum" for col in cols}).collect()[0].asDict() - # TODO: See if it's possible to exclude NaN values from the sum - for _, val in col_sums.items(): - if np.isinf(val) or np.isnan(val): - raise RuntimeError( - "Missing values found in the data, cannot apply " - "normalization. Use an imputer in the transformation." - ) - scaled_df = imputed_df.select( - [(F.col(c) / col_sums[f"sum({c})"]).cast(DTYPE_MAP[out_dtype]).alias(c) for c in cols] - + other_cols - ) + scaled_df, norm_reconstruction = _apply_standard_transform(imputed_df, cols, out_dtype) + norm_representation["norm_reconstruction"] = norm_reconstruction elif shared_norm == "rank-gauss": - assert len(cols) == 1, "Rank-Guass numerical transformation only supports single column" + assert len(cols) == 1, "Rank-Gauss numerical transformation only supports single column" + norm_representation["norm_reconstruction"] = {} column_name = cols[0] select_df = imputed_df.select(column_name) # original id is the original order for the input data frame, @@ -209,8 +237,166 @@ def gauss_transform(rank: pd.Series) -> pd.Series: scaled_df = normalized_df.orderBy(original_order_col).drop( value_rank_col, original_order_col ) + else: + raise ValueError(f"Unsupported normalization requested: {shared_norm}") + + return NormalizationResult(scaled_df, norm_representation) - return scaled_df + +def _apply_standard_transform( + input_df: DataFrame, + cols: list[str], + out_dtype: str, + col_sums: Optional[dict[str, float]] = None, +) -> tuple[DataFrame, dict]: + """Applies standard scaling to the input DataFrame, individually to each of the columns. + + Parameters + ---------- + input_df : DataFrame + Input data to transform + cols : list[str] + List of column names to apply standard normalization to. + out_dtype : str + Type of output data. + col_sums : Optional[dict[str, float]], optional + Pre-calculated sums per column, by default None + + Returns + ------- + tuple[DataFrame, dict] + The transformed dataframe and the representation of the standard transform as dict. + + Raises + ------ + RuntimeError + When there's missing values in the input DF. + """ + if col_sums is None: + col_sums = input_df.agg({col: "sum" for col in cols}).collect()[0].asDict() + # TODO: See if it's possible to exclude NaN values from the sum + for _, val in col_sums.items(): + if np.isinf(val) or np.isnan(val): + raise RuntimeError( + "Missing values found in the data, cannot apply " + "normalization. Use an imputer in the transformation." + ) + scaled_df = input_df.select( + [(F.col(c) / col_sums[f"sum({c})"]).cast(DTYPE_MAP[out_dtype]).alias(c) for c in cols] + ) + + norm_reconstruction = {"col_sums": col_sums} + + return scaled_df, norm_reconstruction + + +def _apply_min_max_transform( + input_df: DataFrame, + cols: list[str], + out_dtype: str, + original_min_vals: Optional[list[float]] = None, + original_max_vals: Optional[list[float]] = None, +) -> tuple[DataFrame, dict]: + """Applies min max normalization to the input. + + Parameters + ---------- + input_df : DataFrame + The input DF to be transformed + cols : list[str] + List of column names to apply min-max normalization to. + other_cols : list[str] + Other cols that we want to retain + out_dtype : str + Numerical type of output data. + original_min_vals : Optional[list[float]], optional + Pre-calculated minimum values for each column, by default None + original_max_vals : Optional[list[float]], optional + Pre-calculated maximum values for each column, by default None + + Returns + ------- + tuple[DataFrame, dict] + The transformed DataFrame and the representation of the min-max transform as dict. + """ + + # Use the map to get the corresponding data type object, or raise an error if not found + if out_dtype not in DTYPE_MAP: + raise ValueError("Unsupported feature output dtype") + + # Because the scalers expect Vector input, we need to use VectorAssembler on each, + # creating one (scaled) vector per normalizer type + # TODO: See if it's possible to have all features under one assembler and scaler, + # speeding up the process. Then do the "disentaglement" on the caller side. + assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in cols] + scalers = [MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in cols] + + vector_cols = [col + "_vec" for col in cols] + scaled_cols = [col + "_scaled" for col in cols] + + pipeline = Pipeline(stages=assemblers + scalers) + # If transformation representation exists, use that to fit the pipeline, + # otherwise we just use the input DF + if original_max_vals and original_min_vals: + # Create a DF with just the min and the max value per column, + # and we use that to fit each scaler used to fit the pipeline. + # The dummy DF will have two rows and len(cols) columns. + # We use the first row to get the min values and the second row to get the max values + min_exprs = [F.lit(val).alias(col) for val, col in zip(original_min_vals, cols)] + max_exprs = [F.lit(val).alias(col) for val, col in zip(original_max_vals, cols)] + + # We add a zipWithIndex column to distinguish the first and second rows + # We use a list comprehension with F.when() to set the values for each column. + # For the first row (where row number is 0), we use the values from original_min_vals, + # and for the second row, we use the values from original_max_vals. + # Example: minvals=[0, 4], maxvals=[100, 256], cols=["col1", "col2"] becomes + # |"col1"|"col2"| + # |0 |4 | + # |100 |256 | + dummy_df = ( + input_df.limit(2) + .rdd.zipWithIndex() + .toDF() + .select( + *[ + F.when(F.col("_2") == 0, min_expr).otherwise(max_expr).alias(col_name) + for min_expr, max_expr, col_name in zip(min_exprs, max_exprs, cols) + ] + ) + ) + + # Fit a pipeline on just the dummy DF + scaler_pipeline = pipeline.fit(dummy_df) + else: + # Fit a pipeline on the entire input DF + scaler_pipeline = pipeline.fit(input_df) + + # Transform the input DF + scaled_df = scaler_pipeline.transform(input_df).drop(*vector_cols).drop(*cols) + + # F.col(scaled_col_name).getField('values'[0] get the first element of a SparseVector + # https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.linalg.SparseVector.html + scaled_df = scaled_df.select( + *[ + (vector_to_array(F.col(scaled_col_name), dtype=out_dtype)[0].alias(orig_col)) + for scaled_col_name, orig_col in zip(scaled_cols, cols) + ] + ) + # Spark pipelines arrange transformations in a list, ordered by stage/ + # So here we have first all the VectorAssemblers for each feature, then all the + # MinMaxScalerModel for each feature. So we skip the first num_cols to + # get just the MinMaxScalerModels + min_max_models: list[MinMaxScalerModel] = scaler_pipeline.stages[len(cols) :] + norm_reconstruction = { + "originalMinValues": [ + min_max_model.originalMin.toArray()[0] for min_max_model in min_max_models + ], + "originalMaxValues": [ + min_max_model.originalMax.toArray()[0] for min_max_model in min_max_models + ], + } + + return scaled_df, norm_reconstruction class DistNumericalTransformation(DistributedTransformation): @@ -240,8 +426,11 @@ def __init__( imputer: str = "none", out_dtype: str = TYPE_FLOAT32, epsilon: float = 1e-6, + json_representation: Optional[dict] = None, ) -> None: - super().__init__(cols) + if not json_representation: + json_representation = {} + super().__init__(cols, json_representation=json_representation) self.cols = cols self.shared_norm = normalizer self.epsilon = epsilon @@ -254,14 +443,151 @@ def apply(self, input_df: DataFrame) -> DataFrame: "Applying normalizer: %s, imputation: %s", self.shared_norm, self.shared_imputation ) - imputed_df = apply_imputation(self.cols, self.shared_imputation, input_df) - scaled_df = apply_norm( + imputation_result = apply_imputation(self.cols, self.shared_imputation, input_df) + imputed_df, impute_representation = ( + imputation_result.imputed_df, + imputation_result.impute_representation, + ) + + norm_result = apply_norm( self.cols, self.shared_norm, imputed_df, self.out_dtype, self.epsilon ) + scaled_df, norm_representation = ( + norm_result.scaled_df, + norm_result.normalization_representation, + ) + + # see get_json_representation() docstring for structure + self.json_representation = { + "cols": self.cols, + "imputer_model": impute_representation, + "normalizer_model": norm_representation, + "out_dtype": self.out_dtype, + "transformation_name": self.get_transformation_name(), + } # TODO: Figure out why the transformation is producing Double values, and switch to float return scaled_df + def get_json_representation(self) -> dict: + """Representation of numerical transformation for one or more columns. + + Returns + ------- + dict + Structure: + cols: list[str] + The list of columns the transformation is applied to. Order matters. + + imputer_model: dict[str, Any] + A dict representation of the imputation applied. + + Structure: + imputed_val_dict: dict[str, float] + The imputed values for each column, {col_name: imputation_val}. + Empty if no imputation was applied. + imputer_name: str + The name of imputer used. + + normalizer_model: dict[str, Any] + A dict representation of the normalization applied. + + Structure: + norm_name: str + The name of normalizer used. + norm_reconstruction: dict[str, Any] + The reconstruction information for the normalizer. Empty if no normalization + was applied. Inner structure depends on normalizer. + + Structure for MinMaxScaler: + originalMinValues: list[float] + The original minimum values for each column, in the order of the cols key. + originalMaxValues: list[float] + The original maximum values for each column, in the order of the cols key. + + Structure for StandardScaler: + col_sums: dict[str, float] + The sum of each column. + + out_dtype: str + The output feature dtype, can take the values 'float32' and 'float64'. + + transformation_name: str + Will be DistNumericalTransformation. + """ + return self.json_representation + + def apply_precomputed_transformation(self, input_df: DataFrame) -> DataFrame: + """Applies a numerical transformation using pre-computed representation. + + Parameters + ---------- + input_df : DataFrame + Input DataFrame to apply the transformation to. + + Returns + ------- + DataFrame + The input DataFrame, modified according to the pre-computed transformation values. + """ + assert self.json_representation, ( + "No precomputed transformation found. Please run `apply()` " + "first or set self.json_representation." + ) + + cols = self.json_representation["cols"] + impute_representation = self.json_representation["imputer_model"] + norm_representation = self.json_representation["normalizer_model"] + out_dtype = self.json_representation.get("out_dtype", TYPE_FLOAT32) + + # First reapply pre-computed imputation if needed + if impute_representation["imputer_name"] == "none": + imputed_df = input_df + else: + imputed_vals = impute_representation["imputed_val_dict"] + shared_imputation = impute_representation["imputer_name"] + imputed_col_names = [col_name + "_imputed" for col_name in cols] + + # Create a DF with a single value per column name, used to fit an imputer + single_val_df = input_df.limit(1).select( + [F.lit(imputed_vals[col_name]).alias(col_name) for col_name in cols] + ) + + imputer = Imputer( + strategy=shared_imputation, inputCols=cols, outputCols=imputed_col_names + ) + imputer_model = imputer.fit(single_val_df) + + # Create transformed columns and drop originals, + # then rename transformed cols to original + input_df = imputer_model.transform(input_df).drop(*cols) + imputed_df, _ = rename_multiple_cols(input_df, imputed_col_names, cols) + imputed_df = imputed_df.select(*cols) + + # Second, re-apply normalization if needed + norm_name = norm_representation["norm_name"] + norm_reconstruction = norm_representation["norm_reconstruction"] + if norm_name == "none": + scaled_df = imputed_df + elif norm_name == "min-max": + scaled_df, _ = _apply_min_max_transform( + imputed_df, + cols, + out_dtype, + norm_reconstruction["originalMinValues"], + norm_reconstruction["originalMaxValues"], + ) + elif norm_name == "standard": + scaled_df, _ = _apply_standard_transform( + imputed_df, cols, out_dtype, norm_reconstruction["col_sums"] + ) + elif norm_name == "rank-gauss": + raise ValueError("Rank-Gauss transformation does not support re-applying.") + else: + raise ValueError(f"Unknown normalizer: {norm_name=}") + + return scaled_df + @staticmethod def get_transformation_name() -> str: return "DistNumericalTransformation" @@ -458,7 +784,7 @@ def vector_df_has_nan(vector_df: DataFrame, vector_col: str) -> bool: # and call the base numerical transformer imputed_df = apply_imputation( split_col_df.columns, self.shared_imputation, split_col_df - ) + ).imputed_df # Assemble the separate columns back into a single vector column assembler = VectorAssembler( diff --git a/graphstorm-processing/tests/test_dist_executor.py b/graphstorm-processing/tests/test_dist_executor.py index 3290126bb0..75faf34c0d 100644 --- a/graphstorm-processing/tests/test_dist_executor.py +++ b/graphstorm-processing/tests/test_dist_executor.py @@ -110,8 +110,15 @@ def test_dist_executor_run_with_precomputed( with open(os.path.join(tempdir, TRANSFORMATIONS_FILENAME), "r", encoding="utf-8") as f: reapplied_transformations = json.load(f) - # There should be no difference between original and re-applied transformation dicts - assert reapplied_transformations == original_transformations + # There should be no difference between original and + # pre-existing, pre-applied transformation dicts + node_feature_transforms = original_transformations["node_features"] + for node_type, node_type_transforms in node_feature_transforms.items(): + for feature_name, feature_transforms in node_type_transforms.items(): + assert ( + feature_transforms + == reapplied_transformations["node_features"][node_type][feature_name] + ) # TODO: Verify other metadata files that verify_integ_test_output doesn't check for diff --git a/graphstorm-processing/tests/test_dist_numerical_transformation.py b/graphstorm-processing/tests/test_dist_numerical_transformation.py index aaae7c41d6..74324f395e 100644 --- a/graphstorm-processing/tests/test_dist_numerical_transformation.py +++ b/graphstorm-processing/tests/test_dist_numerical_transformation.py @@ -31,6 +31,7 @@ LongType, ) from scipy.special import erfinv # pylint: disable=no-name-in-module +from pandas.testing import assert_frame_equal from graphstorm_processing.data_transformations.dist_transformations import ( DistNumericalTransformation, @@ -81,36 +82,40 @@ def test_numerical_transformation_with_mode_imputer(input_df: DataFrame): transformed_df = dist_numerical_transformation.apply(input_df) - transformed_rows = transformed_df.collect() + transformed_pd_df = transformed_df.toPandas() - for row in transformed_rows: - if row["name"] == "mark": - assert row["salary"] == 10000 - elif row["name"] == "john": - assert row["age"] == 40 - else: - assert row["salary"] in {10000, 20000, 40000} - assert row["age"] in {20, 40, 60} + expected_pd_df = pd.DataFrame( + { + "salary": [10000, 10000, 20000, 10000, 40000], + "age": [40, 40, 20, 60, 40], + } + ) + + assert_frame_equal(transformed_pd_df, expected_pd_df) def test_numerical_transformation_with_minmax_scaler(input_df: DataFrame): """Test numerical min-max normalizer""" no_na_df = input_df.na.fill(0) dist_numerical_transformation = DistNumericalTransformation( - ["age", "salary"], imputer="none", normalizer="min-max" + ["age", "salary"], + imputer="none", + normalizer="min-max", + out_dtype="float64", ) transformed_df = dist_numerical_transformation.apply(no_na_df) - transformed_rows = transformed_df.collect() + transformed_pd_df = transformed_df.toPandas() - for row in transformed_rows: - if row["name"] == "kate": - assert row["salary"] == 1.0 - elif row["name"] == "mark": - assert row["salary"] == 0.0 - else: - assert row["salary"] < 1.0 and row["salary"] > 0.0 + expected_pd_df = pd.DataFrame( + { + "age": [0.666667, 0.0, 0.333333, 1.0, 0.666667], + "salary": [0.0, 0.25, 0.5, 0.25, 1.0], + } + ) + + assert_frame_equal(transformed_pd_df, expected_pd_df, rtol=0.001) def test_numerical_transformation_without_transformation(input_df: DataFrame, check_df_schema): @@ -453,3 +458,91 @@ def test_rank_gauss_reshuffling(spark: SparkSession, check_df_schema, epsilon): assert_almost_equal( [row["rand"]], expected_vals[i, :], decimal=4, err_msg=f"Row {i} is not equal" ) + + +def test_json_representation(input_df: DataFrame, check_df_schema): + """Test that the generated representation is correct""" + dist_numerical_transformation = DistNumericalTransformation( + ["salary", "age"], imputer="mean", normalizer="min-max" + ) + transformed_df = dist_numerical_transformation.apply(input_df) + json_rep = dist_numerical_transformation.get_json_representation() + + assert "cols" in json_rep + assert "imputer_model" in json_rep + assert "normalizer_model" in json_rep + assert "out_dtype" in json_rep + assert json_rep["cols"] == ["salary", "age"] + assert json_rep["imputer_model"]["imputer_name"] == "mean" + assert json_rep["normalizer_model"]["norm_name"] == "min-max" + + check_df_schema(transformed_df) + + +def test_precomputed_transformation(input_df: DataFrame, check_df_schema): + """Test that the precomputed transformation produced works as intended""" + # First, apply the transformation and get the JSON representation + dist_numerical_transformation = DistNumericalTransformation( + ["salary", "age"], imputer="mean", normalizer="min-max" + ) + original_df = dist_numerical_transformation.apply(input_df) + json_rep = dist_numerical_transformation.get_json_representation() + + # Now, create a new transformation with the precomputed values + precomputed_transformation = DistNumericalTransformation( + ["salary", "age"], json_representation=json_rep + ) + precomputed_df = precomputed_transformation.apply_precomputed_transformation(input_df) + + check_df_schema(precomputed_df) + + # Compare the results + assert_frame_equal(original_df.toPandas(), precomputed_df.toPandas()) + + +@pytest.mark.parametrize("imputer", ["mean", "median", "most_frequent"]) +def test_precomputed_imputation(imputer, input_df: DataFrame, check_df_schema): + """Test various precomputed imputations""" + dist_numerical_transformation = DistNumericalTransformation( + ["salary", "age"], imputer=imputer, normalizer="none" + ) + original_df = dist_numerical_transformation.apply(input_df) + json_rep = dist_numerical_transformation.get_json_representation() + + precomputed_transformation = DistNumericalTransformation( + ["salary", "age"], json_representation=json_rep + ) + precomputed_df = precomputed_transformation.apply_precomputed_transformation(input_df) + + check_df_schema(precomputed_df) + + assert_frame_equal(original_df.toPandas(), precomputed_df.toPandas()) + + +@pytest.mark.parametrize("normalizer", ["min-max", "standard"]) +def test_precomputed_normalization(normalizer, input_df: DataFrame, check_df_schema): + """Test various precomputed norms""" + input_df = input_df.na.fill(0) + dist_numerical_transformation = DistNumericalTransformation( + ["salary", "age"], imputer="none", normalizer=normalizer + ) + original_df = dist_numerical_transformation.apply(input_df) + json_rep = dist_numerical_transformation.get_json_representation() + + precomputed_transformation = DistNumericalTransformation( + ["salary", "age"], json_representation=json_rep + ) + precomputed_df = precomputed_transformation.apply_precomputed_transformation(input_df) + + check_df_schema(precomputed_df) + + assert_frame_equal(original_df.toPandas(), precomputed_df.toPandas()) + + +def test_precomputed_transformation_without_json(input_df: DataFrame): + """Test trying to re-apply transformation without a representation""" + dist_numerical_transformation = DistNumericalTransformation( + ["salary", "age"], imputer="mean", normalizer="min-max" + ) + with pytest.raises(AssertionError): + dist_numerical_transformation.apply_precomputed_transformation(input_df)