diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bert_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bert_transformation.py index e778658408..27990b68f8 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bert_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_bert_transformation.py @@ -17,9 +17,9 @@ from typing import Optional, Sequence import uuid -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession from pyspark.sql import functions as F -from pyspark.sql.types import MapType, ArrayType, IntegerType, StringType +from pyspark.sql.types import MapType, ArrayType, IntegerType, StringType, StructType, StructField from pyspark.ml.stat import Summarizer from pyspark.ml import Pipeline from pyspark.ml.functions import array_to_vector, vector_to_array @@ -45,26 +45,32 @@ def apply_norm( cols : Sequence[str] List of column names to apply normalization to. bert_norm : str - The type of normalization to use. Valid value is "tokenize" - max_seq_length : int - The maximal length of the tokenization results. + The type of normalization to use. Valid values is "tokenize" input_df : DataFrame The input DataFrame to apply normalization to. """ if bert_norm == "tokenize": + scaled_df = input_df + # Initialize the tokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + # Define the schema of your return type + schema = StructType([ + StructField("input_ids", ArrayType(IntegerType())), + StructField("attention_mask", ArrayType(IntegerType())), + StructField("token_type_ids", ArrayType(IntegerType())) + ]) + + # Define UDF + @udf(returnType=schema, useArrow=True) def tokenize(text): # Check if text is a string if not isinstance(text, str): raise ValueError("The input of the tokenizer has to be a string.") # Tokenize the text - # Instead of doing the similar thing as what we do in the GConstruct, it is suggested - # to use numpy here to refactor the data type. So it is not necessary to introduce the - # torch dependency here t = tokenizer(text, max_length=max_seq_length, truncation=True, padding='max_length', return_tensors='np') token_type_ids = t.get('token_type_ids', np.zeros_like(t['input_ids'], dtype=np.int8)) result = { @@ -74,12 +80,30 @@ def tokenize(text): } return result - # Define the UDF with the appropriate return type - tokenize_udf = udf(tokenize, MapType(StringType(), ArrayType(IntegerType()))) - # Apply the UDF to the DataFrame - scaled_df = input_df.withColumn(cols[0], tokenize_udf(input_df[cols[0]])) - + scaled_df = input_df.withColumn(cols[0], tokenize(input_df[cols[0]])) + + # @udf(returnType=schema, useArrow=True) + # def tokenize(text): + # # Check if text is a string + # if not isinstance(text, str): + # raise ValueError("The input of the tokenizer has to be a string.") + # + # # Tokenize the text + # t = tokenizer(text, max_length=max_seq_length, truncation=True, padding='max_length', return_tensors='np') + # token_type_ids = t.get('token_type_ids', np.zeros_like(t['input_ids'], dtype=np.int8)) + # result = { + # 'input_ids': t['input_ids'][0].tolist(), # Convert tensor to list + # 'attention_mask': t['attention_mask'][0].astype(np.int8).tolist(), + # 'token_type_ids': token_type_ids[0].astype(np.int8).tolist() + # } + # return result + # + # # Define the UDF with the appropriate return type + # tokenize_udf = udf(tokenize, MapType(StringType(), ArrayType(IntegerType()))) + # + # # Apply the UDF to the DataFrame + # scaled_df = input_df.withColumn(cols[0], tokenize_udf(input_df[cols[0]])) return scaled_df @@ -92,10 +116,6 @@ class DistBertTransformation(DistributedTransformation): List of column names to apply normalization to. bert_norm : str The type of normalization to use. Valid values is "tokenize" - bert_model: str - The name of the lm model. - max_seq_length: int - The maximal length of the tokenization results. """ def __init__( @@ -107,6 +127,7 @@ def __init__( self.bert_norm = normalizer self.bert_model = bert_model self.max_seq_length = max_seq_length + self.spark = spark def apply(self, input_df: DataFrame) -> DataFrame: scaled_df = apply_norm(self.cols, self.bert_norm, self.max_seq_length, input_df)