-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first commit about code structure on tokenize feature transformation
- Loading branch information
Showing
7 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
45 changes: 45 additions & 0 deletions
45
graphstorm-processing/graphstorm_processing/config/bert_configs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"). | ||
You may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
from typing import Mapping | ||
import numbers | ||
|
||
from graphstorm_processing.constants import VALID_BERT_MODEL | ||
from .feature_config_base import FeatureConfig | ||
|
||
|
||
class BertConfig(FeatureConfig): | ||
"""Feature configuration for single-column numerical features. | ||
Supported kwargs | ||
---------------- | ||
""" | ||
|
||
def __init__(self, config: Mapping): | ||
super().__init__(config) | ||
self.bert_model = self._transformation_kwargs.get("bert_model", "none") | ||
self.max_seq_length = self._transformation_kwargs.get("max_seq_length", "none") | ||
|
||
self._sanity_check() | ||
|
||
def _sanity_check(self) -> None: | ||
super()._sanity_check() | ||
assert ( | ||
self.bert_model in VALID_BERT_MODEL | ||
), f"Unknown imputer requested, expected one of {VALID_BERT_MODEL}, got {self.bert_model}" | ||
assert isinstance(self.max_seq_length, int) and self.max_seq_length > 0, \ | ||
f"Expect bucket_cnt {self.bucket_cnt} be an integer and larger than zero." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
...aphstorm_processing/data_transformations/dist_transformations/dist_bert_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"). | ||
You may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import logging | ||
from typing import Optional, Sequence | ||
import uuid | ||
|
||
from pyspark.sql import DataFrame | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.types import ArrayType, FloatType | ||
from pyspark.ml.feature import MinMaxScaler, Imputer, VectorAssembler, ElementwiseProduct | ||
from pyspark.ml.linalg import DenseVector | ||
from pyspark.ml.stat import Summarizer | ||
from pyspark.ml import Pipeline | ||
from pyspark.ml.functions import array_to_vector, vector_to_array | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from .base_dist_transformation import DistributedTransformation | ||
from ..spark_utils import rename_multiple_cols | ||
|
||
|
||
def apply_norm( | ||
cols: Sequence[str], bert_norm: str, input_df: DataFrame | ||
) -> DataFrame: | ||
"""Applies a single normalizer to the imputed dataframe, individually to each of the columns | ||
provided in the cols argument. | ||
Parameters | ||
---------- | ||
cols : Sequence[str] | ||
List of column names to apply normalization to. | ||
bert_norm : str | ||
The type of normalization to use. Valid values is "tokenize" | ||
input_df : DataFrame | ||
The input DataFrame to apply normalization to. | ||
""" | ||
|
||
if bert_norm == "tokenize": | ||
scaled_df = input_df | ||
|
||
return scaled_df | ||
|
||
|
||
class DistBertTransformation(DistributedTransformation): | ||
"""Transformation to apply various forms of bert normalization to a text input. | ||
Parameters | ||
---------- | ||
cols : Sequence[str] | ||
List of column names to apply normalization to. | ||
bert_norm : str | ||
The type of normalization to use. Valid values is "tokenize" | ||
""" | ||
|
||
def __init__( | ||
self, cols: Sequence[str], normalizer: str, bert_model: str, max_seq_length: int | ||
) -> None: | ||
super().__init__(cols) | ||
self.cols = cols | ||
self.bert_norm = normalizer | ||
self.bert_model = bert_model | ||
self.max_seq_length = max_seq_length | ||
|
||
def apply(self, input_df: DataFrame) -> DataFrame: | ||
scaled_df = apply_norm(self.cols, self.bert_norm, input_df) | ||
|
||
return scaled_df | ||
|
||
@staticmethod | ||
def get_transformation_name() -> str: | ||
return "DistBertTransformation" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters