Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Feb 16, 2024
1 parent 62f9118 commit 7b7de6e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
from typing import Optional, Sequence
import uuid
import numpy

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
Expand Down Expand Up @@ -131,6 +132,8 @@ def single_vec_to_float(vec):
)

if shared_norm == "none":
# Save the time and efficiency for not casting the type
# when not doing any normalization
scaled_df = imputed_df
elif shared_norm == "min-max":
# Because the scalers expect Vector input, we need to use VectorAssembler on each,
Expand Down Expand Up @@ -163,7 +166,7 @@ def single_vec_to_float(vec):
"normalization. Use an imputer in the transformation."
)
scaled_df = imputed_df.select(
[(F.col(c) / col_sums[f"sum({c})"]).alias(c) for c in cols] + other_cols
[(F.col(c) / col_sums[f"sum({c})"]).cast(DTYPE_MAP[out_dtype]).alias(c) for c in cols] + other_cols
)
elif shared_norm == "rank-gauss":
assert len(cols) == 1, "Rank-Guass numerical transformation only supports single column"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def test_numerical_transformation_without_transformation(input_df: DataFrame, ch
assert row["salary"] == expected_salary


@pytest.mark.parametrize("norm", ["min-max", "standard", "rank-gauss"])
@pytest.mark.parametrize("out_dtype", ["float32", "float64"])
def test_numerical_min_max_transformation_precision(spark: SparkSession, out_dtype):
def test_numerical_min_max_transformation_precision(spark: SparkSession, check_df_schema, out_dtype, norm):
"""Test numerical transformation without any transformation applied"""
# Adjust the number to be an integer
high_precision_integer = 1.2345678901234562
Expand All @@ -127,6 +128,7 @@ def test_numerical_min_max_transformation_precision(spark: SparkSession, out_dty
)

transformed_df = dist_numerical_transformation.apply(input_df)
check_df_schema(transformed_df)
column_data_type = [field.dataType for field in transformed_df.schema.fields if field.name == "age"][0]
if out_dtype == "float32":
assert isinstance(column_data_type, FloatType), f"The column 'age' is not of type FloatType."
Expand Down Expand Up @@ -335,23 +337,17 @@ def rank_gauss(feat, eps):
return erfinv(feat)


@pytest.mark.parametrize("out_dtype", ["float32", "float64"])
@pytest.mark.parametrize("epsilon", [0.0, 1e-6])
def test_rank_gauss(spark: SparkSession, check_df_schema, epsilon, out_dtype):
def test_rank_gauss(spark: SparkSession, check_df_schema, epsilon):
data = [(0.0,), (15.0,), (26.0,), (40.0,)]

input_df = spark.createDataFrame(data, schema=["age"])
rg_transformation = DistNumericalTransformation(
["age"], imputer="none", normalizer="rank-gauss", out_dtype=out_dtype, epsilon=epsilon
["age"], imputer="none", normalizer="rank-gauss", epsilon=epsilon
)

output_df = rg_transformation.apply(input_df)
check_df_schema(output_df)
column_data_type = [field.dataType for field in output_df.schema.fields if field.name == "age"][0]
if out_dtype == "float32":
assert isinstance(column_data_type, FloatType), f"The column 'age' is not of type FloatType."
elif out_dtype == "float64":
assert isinstance(column_data_type, DoubleType), f"The column 'age' is not of type DoubleType."

out_rows = output_df.collect()

Expand Down

0 comments on commit 7b7de6e

Please sign in to comment.