Skip to content

Commit

Permalink
add output dtype for numerical
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Feb 15, 2024
1 parent 61a3a15 commit 205720c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
2 changes: 2 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from pyspark.sql.types import FloatType, DoubleType

################### Categorical Limits #######################
MAX_CATEGORIES_PER_FEATURE = 100
Expand Down Expand Up @@ -45,6 +46,7 @@
VALID_IMPUTERS = ["none", "mean", "median", "most_frequent"]
VALID_NORMALIZERS = ["none", "min-max", "standard", "rank-gauss"]
VALID_OUTDTYPE = ["float32", "float64"]
DTYPE_MAP = {"float32": FloatType(), "float64": DoubleType()}

################# Bert transformations ################
HUGGINGFACE_TRANFORM = "huggingface"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# pylint: disable = no-name-in-module
from scipy.special import erfinv

from graphstorm_processing.constants import SPECIAL_CHARACTERS, VALID_IMPUTERS, VALID_NORMALIZERS
from graphstorm_processing.constants import SPECIAL_CHARACTERS, VALID_IMPUTERS, VALID_NORMALIZERS, DTYPE_MAP
from .base_dist_transformation import DistributedTransformation
from ..spark_utils import rename_multiple_cols

Expand Down Expand Up @@ -119,15 +119,9 @@ def apply_norm(
def single_vec_to_float(vec):
return float(vec[0])

# Define a mapping from dtype strings to Spark SQL data types
dtype_map = {
"float32": FloatType(),
"float64": DoubleType(),
}

# Use the map to get the corresponding data type object, or raise an error if not found
if out_dtype in dtype_map:
vec_udf = F.udf(single_vec_to_float, dtype_map[out_dtype])
if out_dtype in DTYPE_MAP:
vec_udf = F.udf(single_vec_to_float, DTYPE_MAP[out_dtype])
else:
raise ValueError("Unsupported feature output dtype")

Expand Down Expand Up @@ -196,7 +190,7 @@ def gauss_transform(rank: pd.Series) -> pd.Series:
return pd.Series(erfinv(clipped_rank))

num_rows = value_rank_df.count()
gauss_udf = F.pandas_udf(gauss_transform, dtype_map[out_dtype])
gauss_udf = F.pandas_udf(gauss_transform, DTYPE_MAP[out_dtype])
normalized_df = value_rank_df.withColumn(column_name, gauss_udf(value_rank_col))
scaled_df = normalized_df.orderBy(original_order_col).drop(
value_rank_col, original_order_col
Expand Down Expand Up @@ -360,7 +354,7 @@ def convert_multistring_to_sequence_df(
), # Split along the separator
replace_empty_with_nan,
)
.cast(ArrayType(FloatType(), True))
.cast(ArrayType(DTYPE_MAP[self.out_dtype], True))
.alias(self.multi_column)
)

Expand Down Expand Up @@ -422,7 +416,7 @@ def vector_df_has_nan(vector_df: DataFrame, vector_col: str) -> bool:
else:
split_array_df = input_df.select(
F.col(self.multi_column)
.cast(ArrayType(FloatType(), True))
.cast(ArrayType(DTYPE_MAP[self.out_dtype], True))
.alias(self.multi_column)
)

Expand Down

0 comments on commit 205720c

Please sign in to comment.