Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] BERT Tokenizer #700

Merged
merged 40 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0964886
add gconstruct converter
jalencato Jan 10, 2024
0a67217
first commit about code structure on tokenize feature transformation
jalencato Jan 10, 2024
7c20436
add first version with udf implementation
jalencato Jan 11, 2024
626565d
remove torch related
jalencato Jan 17, 2024
cfebe1d
remove torch
jalencato Jan 17, 2024
32da0c8
add doc
jalencato Jan 18, 2024
18d61e4
add
Jan 23, 2024
178da9f
add
jalencato Jan 23, 2024
288604e
Merge branch 'main' into bert_tokenzier
jalencato Jan 25, 2024
8a0f872
add fix
jalencato Jan 25, 2024
69edfd9
rename
jalencato Jan 26, 2024
c26aa68
rename
jalencato Jan 26, 2024
ad74f46
add test for huggingface
jalencato Jan 26, 2024
d7bccff
black reformat
jalencato Jan 26, 2024
bd8c075
apply lint
jalencato Jan 26, 2024
ae967da
add dependency
jalencato Jan 26, 2024
a81f4b5
add
jalencato Jan 26, 2024
9181826
test fix
jalencato Jan 27, 2024
93a9685
add fix
jalencato Jan 29, 2024
e091bf9
add test
jalencato Jan 29, 2024
69bee37
change config
jalencato Jan 29, 2024
dfd63e4
apply comments
jalencato Jan 31, 2024
250938a
apply comment
jalencato Jan 31, 2024
efb1bfa
apply lint
jalencato Jan 31, 2024
6d13571
add final line
jalencato Jan 31, 2024
3d85407
Update docs/source/gs-processing/developer/input-configuration.rst
jalencato Jan 31, 2024
89924c6
name change
jalencato Jan 31, 2024
4ec8563
add build docker'
jalencato Jan 31, 2024
966e6fd
add doc
jalencato Jan 31, 2024
aa43a83
add doc
jalencato Jan 31, 2024
a2d36d5
add doc
jalencato Jan 31, 2024
9cbf74e
change dockerfile
jalencato Feb 1, 2024
b702b30
add docker packing
jalencato Feb 1, 2024
a67c66c
doc
jalencato Feb 1, 2024
fe49292
Apply suggestions from code review
jalencato Feb 1, 2024
5c6fd8c
final version
jalencato Feb 1, 2024
b09d058
apply black
jalencato Feb 1, 2024
9d1a234
doc
jalencato Feb 1, 2024
6247297
convert
jalencato Feb 1, 2024
3c7bef1
Merge branch 'main' into bert_tokenzier
jalencato Feb 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/gs-processing/developer/input-configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,15 @@ arguments.
will be considered as an array. For Parquet files, if the input type is ArrayType(StringType()), then the
separator is ignored; if it is StringType(), it will apply same logic as in CSV.

- ``huggingface``

- Transforms a text feature column to tokens or embeddings with different Hugging Face models, enabling nuanced understanding and processing of natural language data.
- ``kwargs``:

- ``normalizer`` (String, required): It should be "tokenize_hf".
jalencato marked this conversation as resolved.
Show resolved Hide resolved
- ``bert_model`` (String, required): It should be the identifier of a pre-trained model available in the Hugging Face Model Hub.
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
- ``max_seq_length`` (Integer, required): It specifies the maximum number of tokens of the input.
thvasilo marked this conversation as resolved.
Show resolved Hide resolved

--------------

Creating a graph for inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any

from .converter_base import ConfigConverter
Expand Down Expand Up @@ -134,6 +135,13 @@ def _convert_feature(feats: list[dict]) -> list[dict]:
else:
gsp_transformation_dict["name"] = "categorical"
gsp_transformation_dict["kwargs"] = {}
elif gconstruct_transform_dict["name"] == "tokenize_hf":
gsp_transformation_dict["name"] = "huggingface"
gsp_transformation_dict["kwargs"] = {
"normalizer": "tokenize_hf",
"bert_model": gconstruct_transform_dict["bert_model"],
"max_seq_length": gconstruct_transform_dict["max_seq_length"],
}
# TODO: Add support for other common transformations here
else:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NumericalFeatureConfig,
)
from .categorical_configs import MultiCategoricalFeatureConfig
from .hf_configs import HFConfig
from .data_config_base import DataStorageConfig


Expand Down Expand Up @@ -67,6 +68,8 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig:
return FeatureConfig(feature_dict)
elif transformation_name == "multi-categorical":
return MultiCategoricalFeatureConfig(feature_dict)
elif transformation_name == "huggingface":
return HFConfig(feature_dict)
else:
raise RuntimeError(f"Unknown transformation name: '{transformation_name}'")

Expand Down
50 changes: 50 additions & 0 deletions graphstorm-processing/graphstorm_processing/config/hf_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Mapping

from graphstorm_processing.constants import HUGGINGFACE_TOKENIZE
from .feature_config_base import FeatureConfig


class HFConfig(FeatureConfig):
"""Feature configuration for huggingface text features.

Supported kwargs
----------------
bert_model: str, required
The name of the lm model.
max_seq_length: int, required
The maximal length of the tokenization results.
"""

def __init__(self, config: Mapping):
super().__init__(config)
self.bert_norm = self._transformation_kwargs.get("normalizer")
self.bert_model = self._transformation_kwargs.get("bert_model")
self.max_seq_length = self._transformation_kwargs.get("max_seq_length")

self._sanity_check()

def _sanity_check(self) -> None:
super()._sanity_check()
assert self.bert_norm in [HUGGINGFACE_TOKENIZE], "bert normalizer needs to be tokenize_hf"
assert isinstance(
self.bert_model, str
), f"Expect bert_model to be a string, but got {self.bert_model}"
assert (
isinstance(self.max_seq_length, int) and self.max_seq_length > 0
), f"Expect max_seq_length {self.max_seq_length} be an integer and larger than zero."
5 changes: 5 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

################### Categorical Limits #######################
MAX_CATEGORIES_PER_FEATURE = 100
RARE_CATEGORY = "GSP_CONSTANT_OTHER"
Expand Down Expand Up @@ -43,3 +44,7 @@
################# Numerical transformations ################
VALID_IMPUTERS = ["none", "mean", "median", "most_frequent"]
VALID_NORMALIZERS = ["none", "min-max", "standard", "rank-gauss"]

################# Bert transformations ################
HUGGINGFACE_TRANFORM = "huggingface"
HUGGINGFACE_TOKENIZE = "tokenize_hf"
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging

from pyspark.sql import DataFrame
Expand All @@ -26,6 +27,7 @@
DistBucketNumericalTransformation,
DistCategoryTransformation,
DistMultiCategoryTransformation,
DistHFTransformation,
)


Expand Down Expand Up @@ -57,6 +59,8 @@ def __init__(self, feature_config: FeatureConfig):
self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "huggingface":
self.transformation = DistHFTransformation(**default_kwargs, **args_dict)
else:
raise NotImplementedError(
f"Feature {feat_name} has type: {feat_type} that is not supported"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
DistNumericalTransformation,
)
from .dist_bucket_numerical_transformation import DistBucketNumericalTransformation
from .dist_hf_transformation import DistHFTransformation
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField
from pyspark.sql.functions import udf
from transformers import AutoTokenizer

from graphstorm_processing.constants import HUGGINGFACE_TOKENIZE
from .base_dist_transformation import DistributedTransformation


def apply_transform(
cols: Sequence[str], bert_norm: str, bert_model: str, max_seq_length: int, input_df: DataFrame
) -> DataFrame:
"""Applies a single normalizer to the imputed dataframe, individually to each of the columns
provided in the cols argument.

Parameters
----------
cols : Sequence[str]
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid values is "tokenize"
jalencato marked this conversation as resolved.
Show resolved Hide resolved
bert_model : str
The name of huggingface model.
max_seq_length: int
The maximal length of the tokenization results.
input_df : DataFrame
The input DataFrame to apply normalization to.
"""

if bert_norm == HUGGINGFACE_TOKENIZE:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(bert_model)

# Define the schema of your return type
schema = StructType(
[
StructField("input_ids", ArrayType(IntegerType())),
StructField("attention_mask", ArrayType(IntegerType())),
StructField("token_type_ids", ArrayType(IntegerType())),
]
)

# Define UDF
@udf(returnType=schema)
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
def tokenize(text):
# Check if text is a string
if not isinstance(text, str):
raise ValueError("The input of the tokenizer has to be a string.")

# Tokenize the text
t = tokenizer(
text,
max_length=max_seq_length,
truncation=True,
padding="max_length",
return_tensors="np",
)
token_type_ids = t.get("token_type_ids", np.zeros_like(t["input_ids"], dtype=np.int8))
result = (
t["input_ids"][0].tolist(), # Convert tensor to list
t["attention_mask"][0].astype(np.int8).tolist(),
token_type_ids[0].astype(np.int8).tolist(),
)

return result

# Apply the UDF to the DataFrame
transformed_df = input_df.withColumn(cols[0], tokenize(input_df[cols[0]]))
transformed_df = transformed_df.select(
transformed_df[cols[0]].getItem("input_ids").alias("input_ids"),
transformed_df[cols[0]].getItem("attention_mask").alias("attention_mask"),
transformed_df[cols[0]].getItem("token_type_ids").alias("token_type_ids"),
)
return transformed_df


class DistHFTransformation(DistributedTransformation):
"""Transformation to apply various forms of bert normalization to a text input.

Parameters
----------
cols : Sequence[str]
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid values is "tokenize"
"""

def __init__(
self, cols: Sequence[str], normalizer: str, bert_model: str, max_seq_length: int
) -> None:
super().__init__(cols)
self.cols = cols
assert len(self.cols) == 1, "Huggingface transformation only supports single column"
self.bert_norm = normalizer
self.bert_model = bert_model
self.max_seq_length = max_seq_length

def apply(self, input_df: DataFrame) -> DataFrame:
transformed_df = apply_transform(
self.cols, self.bert_norm, self.bert_model, self.max_seq_length, input_df
)

return transformed_df

@staticmethod
def get_transformation_name() -> str:
return "DistHFTransformation"
Loading
Loading