From 205720cc70ee63be89a45af51ce5906c3d5ae6ed Mon Sep 17 00:00:00 2001 From: JalenCato Date: Thu, 15 Feb 2024 19:44:26 +0000 Subject: [PATCH] add output dtype for numerical --- .../graphstorm_processing/constants.py | 2 ++ .../dist_numerical_transformation.py | 18 ++++++------------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index 4881384683..c4dddde1e6 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from pyspark.sql.types import FloatType, DoubleType ################### Categorical Limits ####################### MAX_CATEGORIES_PER_FEATURE = 100 @@ -45,6 +46,7 @@ VALID_IMPUTERS = ["none", "mean", "median", "most_frequent"] VALID_NORMALIZERS = ["none", "min-max", "standard", "rank-gauss"] VALID_OUTDTYPE = ["float32", "float64"] +DTYPE_MAP = {"float32": FloatType(), "float64": DoubleType()} ################# Bert transformations ################ HUGGINGFACE_TRANFORM = "huggingface" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py index 4056b2d931..b37c85ef56 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_numerical_transformation.py @@ -32,7 +32,7 @@ # pylint: disable = no-name-in-module from scipy.special import erfinv -from graphstorm_processing.constants import SPECIAL_CHARACTERS, VALID_IMPUTERS, VALID_NORMALIZERS +from graphstorm_processing.constants import SPECIAL_CHARACTERS, VALID_IMPUTERS, VALID_NORMALIZERS, DTYPE_MAP from .base_dist_transformation import DistributedTransformation from ..spark_utils import rename_multiple_cols @@ -119,15 +119,9 @@ def apply_norm( def single_vec_to_float(vec): return float(vec[0]) - # Define a mapping from dtype strings to Spark SQL data types - dtype_map = { - "float32": FloatType(), - "float64": DoubleType(), - } - # Use the map to get the corresponding data type object, or raise an error if not found - if out_dtype in dtype_map: - vec_udf = F.udf(single_vec_to_float, dtype_map[out_dtype]) + if out_dtype in DTYPE_MAP: + vec_udf = F.udf(single_vec_to_float, DTYPE_MAP[out_dtype]) else: raise ValueError("Unsupported feature output dtype") @@ -196,7 +190,7 @@ def gauss_transform(rank: pd.Series) -> pd.Series: return pd.Series(erfinv(clipped_rank)) num_rows = value_rank_df.count() - gauss_udf = F.pandas_udf(gauss_transform, dtype_map[out_dtype]) + gauss_udf = F.pandas_udf(gauss_transform, DTYPE_MAP[out_dtype]) normalized_df = value_rank_df.withColumn(column_name, gauss_udf(value_rank_col)) scaled_df = normalized_df.orderBy(original_order_col).drop( value_rank_col, original_order_col @@ -360,7 +354,7 @@ def convert_multistring_to_sequence_df( ), # Split along the separator replace_empty_with_nan, ) - .cast(ArrayType(FloatType(), True)) + .cast(ArrayType(DTYPE_MAP[self.out_dtype], True)) .alias(self.multi_column) ) @@ -422,7 +416,7 @@ def vector_df_has_nan(vector_df: DataFrame, vector_col: str) -> bool: else: split_array_df = input_df.select( F.col(self.multi_column) - .cast(ArrayType(FloatType(), True)) + .cast(ArrayType(DTYPE_MAP[self.out_dtype], True)) .alias(self.multi_column) )