Skip to content

Commit

Permalink
add config for rank-gauss
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 2, 2023
1 parent 80d4762 commit 43dca3f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BucketNumericalFeatureConfig,
MultiNumericalFeatureConfig,
NumericalFeatureConfig,
RankGaussNumericalFeatureConfig,
)
from .data_config_base import DataStorageConfig

Expand Down Expand Up @@ -62,6 +63,8 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig:
return MultiNumericalFeatureConfig(feature_dict)
elif transformation_name == "bucket-numerical":
return BucketNumericalFeatureConfig(feature_dict)
elif transformation_name == "rank-gauss":
return RankGaussNumericalFeatureConfig(feature_dict)
else:
raise RuntimeError(f"Unknown transformation name: '{transformation_name}'")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,49 @@ def _sanity_check(self) -> None:
assert (
isinstance(self.slide_window_size, numbers.Number) or self.slide_window_size == "none"
), f"Expect no slide window size or expect {self.slide_window_size} is a number"


class RankGaussNumericalFeatureConfig(FeatureConfig):
"""Feature configuration for rank gauss numerical features.
Supported kwargs
----------------
imputer: str
A method to fill in missing values in the data. Valid values are:
"none" (Default), "mean", "median", and "most_frequent". Missing values will be replaced
with the respective value computed from the data.
normalizer: str
A normalization to apply to each column. Valid values are
"none", "min-max", and "standard".
The transformation applied will be:
* "none": (Default) Don't normalize the numerical values during encoding.
* "min-max": Normalize each value by subtracting the minimum value from it,
and then dividing it by the difference between the maximum value and the minimum.
* "standard": Normalize each value by dividing it by the sum of all the values.
epsilon: float
Epsilon for normalization used to avoid INF float during computation.
"""

def __init__(self, config: Mapping):
super().__init__(config)
self.imputer = self._transformation_kwargs.get("imputer", "none")
self.norm = self._transformation_kwargs.get("normalizer", "none")
self.epsilon = self._transformation_kwargs.get("epsilon", "none")

self._sanity_check()

def _sanity_check(self) -> None:
super()._sanity_check()
assert (
self.imputer in VALID_IMPUTERS
), f"Unknown imputer requested, expected one of {VALID_IMPUTERS}, got {self.imputer}"
assert (
self.norm in VALID_NORMALIZERS
), f"Unknown normalizer requested, expected one of {VALID_NORMALIZERS}, got {self.norm}"
assert (
self.epsilon == "none" or (isinstance(self.epsilon, numbers.Number) and 0.0 <= self.epsilon <= 1.0)
), f"Expect epsilon {self.epsilon} be in range [0, 1], or not provided"

0 comments on commit 43dca3f

Please sign in to comment.