Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed Jan 23, 2024
1 parent 32da0c8 commit 18d61e4
Showing 1 changed file with 38 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import Optional, Sequence
import uuid

from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import MapType, ArrayType, IntegerType, StringType
from pyspark.sql.types import MapType, ArrayType, IntegerType, StringType, StructType, StructField
from pyspark.ml.stat import Summarizer
from pyspark.ml import Pipeline
from pyspark.ml.functions import array_to_vector, vector_to_array
Expand All @@ -45,26 +45,32 @@ def apply_norm(
cols : Sequence[str]
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid value is "tokenize"
max_seq_length : int
The maximal length of the tokenization results.
The type of normalization to use. Valid values is "tokenize"
input_df : DataFrame
The input DataFrame to apply normalization to.
"""

if bert_norm == "tokenize":
scaled_df = input_df

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Define the schema of your return type
schema = StructType([
StructField("input_ids", ArrayType(IntegerType())),
StructField("attention_mask", ArrayType(IntegerType())),
StructField("token_type_ids", ArrayType(IntegerType()))
])

# Define UDF
@udf(returnType=schema, useArrow=True)
def tokenize(text):
# Check if text is a string
if not isinstance(text, str):
raise ValueError("The input of the tokenizer has to be a string.")

# Tokenize the text
# Instead of doing the similar thing as what we do in the GConstruct, it is suggested
# to use numpy here to refactor the data type. So it is not necessary to introduce the
# torch dependency here
t = tokenizer(text, max_length=max_seq_length, truncation=True, padding='max_length', return_tensors='np')
token_type_ids = t.get('token_type_ids', np.zeros_like(t['input_ids'], dtype=np.int8))
result = {
Expand All @@ -74,12 +80,30 @@ def tokenize(text):
}
return result

# Define the UDF with the appropriate return type
tokenize_udf = udf(tokenize, MapType(StringType(), ArrayType(IntegerType())))

# Apply the UDF to the DataFrame
scaled_df = input_df.withColumn(cols[0], tokenize_udf(input_df[cols[0]]))

scaled_df = input_df.withColumn(cols[0], tokenize(input_df[cols[0]]))

# @udf(returnType=schema, useArrow=True)
# def tokenize(text):
# # Check if text is a string
# if not isinstance(text, str):
# raise ValueError("The input of the tokenizer has to be a string.")
#
# # Tokenize the text
# t = tokenizer(text, max_length=max_seq_length, truncation=True, padding='max_length', return_tensors='np')
# token_type_ids = t.get('token_type_ids', np.zeros_like(t['input_ids'], dtype=np.int8))
# result = {
# 'input_ids': t['input_ids'][0].tolist(), # Convert tensor to list
# 'attention_mask': t['attention_mask'][0].astype(np.int8).tolist(),
# 'token_type_ids': token_type_ids[0].astype(np.int8).tolist()
# }
# return result
#
# # Define the UDF with the appropriate return type
# tokenize_udf = udf(tokenize, MapType(StringType(), ArrayType(IntegerType())))
#
# # Apply the UDF to the DataFrame
# scaled_df = input_df.withColumn(cols[0], tokenize_udf(input_df[cols[0]]))
return scaled_df


Expand All @@ -92,10 +116,6 @@ class DistBertTransformation(DistributedTransformation):
List of column names to apply normalization to.
bert_norm : str
The type of normalization to use. Valid values is "tokenize"
bert_model: str
The name of the lm model.
max_seq_length: int
The maximal length of the tokenization results.
"""

def __init__(
Expand All @@ -107,6 +127,7 @@ def __init__(
self.bert_norm = normalizer
self.bert_model = bert_model
self.max_seq_length = max_seq_length
self.spark = spark

def apply(self, input_df: DataFrame) -> DataFrame:
scaled_df = apply_norm(self.cols, self.bert_norm, self.max_seq_length, input_df)
Expand Down

0 comments on commit 18d61e4

Please sign in to comment.