Skip to content

Commit

Permalink
first commit about code structure on tokenize feature transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Jan 10, 2024
1 parent 0964886 commit 0a67217
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 0 deletions.
45 changes: 45 additions & 0 deletions graphstorm-processing/graphstorm_processing/config/bert_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Mapping
import numbers

from graphstorm_processing.constants import VALID_BERT_MODEL
from .feature_config_base import FeatureConfig


class BertConfig(FeatureConfig):
"""Feature configuration for single-column numerical features.
Supported kwargs
----------------
"""

def __init__(self, config: Mapping):
super().__init__(config)
self.bert_model = self._transformation_kwargs.get("bert_model", "none")
self.max_seq_length = self._transformation_kwargs.get("max_seq_length", "none")

self._sanity_check()

def _sanity_check(self) -> None:
super()._sanity_check()
assert (
self.bert_model in VALID_BERT_MODEL
), f"Unknown imputer requested, expected one of {VALID_BERT_MODEL}, got {self.bert_model}"
assert isinstance(self.max_seq_length, int) and self.max_seq_length > 0, \
f"Expect bucket_cnt {self.bucket_cnt} be an integer and larger than zero."

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NumericalFeatureConfig,
)
from .categorical_configs import MultiCategoricalFeatureConfig
from .bert_configs import BertConfig
from .data_config_base import DataStorageConfig


Expand Down Expand Up @@ -67,6 +68,8 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig:
return FeatureConfig(feature_dict)
elif transformation_name == "multi-categorical":
return MultiCategoricalFeatureConfig(feature_dict)
elif transformation_name == "bert":
return BertConfig(feature_dict)
else:
raise RuntimeError(f"Unknown transformation name: '{transformation_name}'")

Expand Down
2 changes: 2 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@
################# Numerical transformations ################
VALID_IMPUTERS = ["none", "mean", "median", "most_frequent"]
VALID_NORMALIZERS = ["none", "min-max", "standard", "rank-gauss"]
VALID_BERT_MODEL = ["bert-base-uncased", "bert", "roberta", "albert", "camembert", "ernie", "ibert",
"luke", "mega", "mpnet", "nezha", "qdqbert","roc_bert"]
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DistBucketNumericalTransformation,
DistCategoryTransformation,
DistMultiCategoryTransformation,
DistBertTransformation,
)


Expand Down Expand Up @@ -57,6 +58,8 @@ def __init__(self, feature_config: FeatureConfig):
self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "bert":
self.transformation = DistBertTransformation(**default_kwargs, **args_dict)
else:
raise NotImplementedError(
f"Feature {feat_name} has type: {feat_type} that is not supported"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
DistNumericalTransformation,
)
from .dist_bucket_numerical_transformation import DistBucketNumericalTransformation
from .dist_bert_transformation import DistBertTransformation
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from typing import Optional, Sequence
import uuid

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
from pyspark.ml.feature import MinMaxScaler, Imputer, VectorAssembler, ElementwiseProduct
from pyspark.ml.linalg import DenseVector
from pyspark.ml.stat import Summarizer
from pyspark.ml import Pipeline
from pyspark.ml.functions import array_to_vector, vector_to_array

import numpy as np
import pandas as pd

from .base_dist_transformation import DistributedTransformation
from ..spark_utils import rename_multiple_cols


def apply_norm(
cols: Sequence[str], bert_norm: str, input_df: DataFrame
) -> DataFrame:
"""Applies a single normalizer to the imputed dataframe, individually to each of the columns
provided in the cols argument.
Parameters
----------
cols : Sequence[str]
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid values is "tokenize"
input_df : DataFrame
The input DataFrame to apply normalization to.
"""

if bert_norm == "tokenize":
scaled_df = input_df

return scaled_df


class DistBertTransformation(DistributedTransformation):
"""Transformation to apply various forms of bert normalization to a text input.
Parameters
----------
cols : Sequence[str]
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid values is "tokenize"
"""

def __init__(
self, cols: Sequence[str], normalizer: str, bert_model: str, max_seq_length: int
) -> None:
super().__init__(cols)
self.cols = cols
self.bert_norm = normalizer
self.bert_model = bert_model
self.max_seq_length = max_seq_length

def apply(self, input_df: DataFrame) -> DataFrame:
scaled_df = apply_norm(self.cols, self.bert_norm, input_df)

return scaled_df

@staticmethod
def get_transformation_name() -> str:
return "DistBertTransformation"
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def determine_spark_feature_type(feature_type: str) -> Type[DataType]:
return StringType
if feature_type in ["numerical", "bucket-numerical", "none"]:
return FloatType
if feature_type in ["bert"]:
return StringType
else:
raise NotImplementedError(f"Unknown feature type: {feature_type}")

Expand Down

0 comments on commit 0a67217

Please sign in to comment.