Skip to content

Commit

Permalink
fix: performance improvements (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
ralphrass authored Sep 16, 2024
1 parent f6c5db6 commit 11cc5d5
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 33 deletions.
14 changes: 12 additions & 2 deletions butterfree/_cli/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import pkgutil
import sys
from typing import Set
from typing import Set, Type

import boto3
import setuptools
Expand Down Expand Up @@ -90,8 +90,18 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]:

instances.add(value)

def create_instance(cls: Type[FeatureSetPipeline]) -> FeatureSetPipeline:
sig = inspect.signature(cls.__init__)
parameters = sig.parameters

if "run_date" in parameters:
run_date = datetime.datetime.today().strftime("%y-%m-%d")
return cls(run_date)

return cls()

logger.info("Creating instances...")
return set(value() for value in instances) # type: ignore
return set(create_instance(value) for value in instances) # type: ignore


PATH = typer.Argument(
Expand Down
14 changes: 10 additions & 4 deletions butterfree/extract/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

from pyspark.sql import DataFrame
from pyspark.storagelevel import StorageLevel

from butterfree.clients import SparkClient
from butterfree.extract.readers.reader import Reader
Expand Down Expand Up @@ -95,16 +96,21 @@ def construct(
DataFrame with the query result against all readers.
"""
# Step 1: Build temporary views for each reader
for reader in self.readers:
reader.build(
client=client, start_date=start_date, end_date=end_date
) # create temporary views for each reader
reader.build(client=client, start_date=start_date, end_date=end_date)

# Step 2: Execute SQL query on the combined readers
dataframe = client.sql(self.query)

# Step 3: Cache the dataframe if necessary, using memory and disk storage
if not dataframe.isStreaming and self.eager_evaluation:
dataframe.cache().count()
# Persist to ensure the DataFrame is stored in mem and disk (if necessary)
dataframe.persist(StorageLevel.MEMORY_AND_DISK)
# Trigger the cache/persist operation by performing an action
dataframe.count()

# Step 4: Run post-processing hooks on the dataframe
post_hook_df = self.run_post_hooks(dataframe)

return post_hook_df
29 changes: 21 additions & 8 deletions butterfree/pipelines/feature_set_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List, Optional

from pyspark.storagelevel import StorageLevel

from butterfree.clients import SparkClient
from butterfree.dataframe_service import repartition_sort_df
from butterfree.extract import Source
Expand Down Expand Up @@ -209,35 +211,46 @@ def run(
soon. Use only if strictly necessary.
"""
# Step 1: Construct input dataframe from the source.
dataframe = self.source.construct(
client=self.spark_client,
start_date=self.feature_set.define_start_date(start_date),
end_date=end_date,
)

# Step 2: Repartition and sort if required, avoid if not necessary.
if partition_by:
order_by = order_by or partition_by
dataframe = repartition_sort_df(
dataframe, partition_by, order_by, num_processors
)

dataframe = self.feature_set.construct(
current_partitions = dataframe.rdd.getNumPartitions()
optimal_partitions = num_processors or current_partitions
if current_partitions != optimal_partitions:
dataframe = repartition_sort_df(
dataframe, partition_by, order_by, num_processors
)

# Step 3: Construct the feature set dataframe using defined transformations.
transformed_dataframe = self.feature_set.construct(
dataframe=dataframe,
client=self.spark_client,
start_date=start_date,
end_date=end_date,
num_processors=num_processors,
)

if dataframe.storageLevel != StorageLevel.NONE:
dataframe.unpersist() # Clear the data from the cache (disk and memory)

# Step 4: Load the data into the configured sink.
self.sink.flush(
dataframe=dataframe,
dataframe=transformed_dataframe,
feature_set=self.feature_set,
spark_client=self.spark_client,
)

if not dataframe.isStreaming:
# Step 5: Validate the output if not streaming and data volume is reasonable.
if not transformed_dataframe.isStreaming:
self.sink.validate(
dataframe=dataframe,
dataframe=transformed_dataframe,
feature_set=self.feature_set,
spark_client=self.spark_client,
)
Expand Down
31 changes: 18 additions & 13 deletions butterfree/transform/aggregated_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def _aggregate(
]

groupby = self.keys_columns.copy()

if window is not None:
dataframe = dataframe.withColumn("window", window.get())
groupby.append("window")
Expand All @@ -410,19 +411,23 @@ def _aggregate(
"keep_rn", functions.row_number().over(partition_window)
).filter("keep_rn = 1")

# repartition to have all rows for each group at the same partition
# by doing that, we won't have to shuffle data on grouping by id
dataframe = repartition_df(
dataframe,
partition_by=groupby,
num_processors=num_processors,
)
current_partitions = dataframe.rdd.getNumPartitions()
optimal_partitions = num_processors or current_partitions

if current_partitions != optimal_partitions:
dataframe = repartition_df(
dataframe,
partition_by=groupby,
num_processors=optimal_partitions,
)

grouped_data = dataframe.groupby(*groupby)

if self._pivot_column:
if self._pivot_column and self._pivot_values:
grouped_data = grouped_data.pivot(self._pivot_column, self._pivot_values)

aggregated = grouped_data.agg(*aggregations)

return self._with_renamed_columns(aggregated, features, window)

def _with_renamed_columns(
Expand Down Expand Up @@ -637,12 +642,12 @@ def construct(
output_df = output_df.select(*self.columns).replace( # type: ignore
float("nan"), None
)
if not output_df.isStreaming:
if self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if self.eager_evaluation:
output_df.cache().count()
if not output_df.isStreaming and self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)

post_hook_df = self.run_post_hooks(output_df)

if not output_df.isStreaming and self.eager_evaluation:
post_hook_df.cache().count()

return post_hook_df
7 changes: 2 additions & 5 deletions butterfree/transform/feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,8 @@ def construct(
pre_hook_df,
).select(*self.columns)

if not output_df.isStreaming:
if self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)
if self.eager_evaluation:
output_df.cache().count()
if not output_df.isStreaming and self.deduplicate_rows:
output_df = self._filter_duplicated_rows(output_df)

output_df = self.incremental_strategy.filter_with_incremental_strategy(
dataframe=output_df, start_date=start_date, end_date=end_date
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/butterfree/transform/test_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_construct(
+ feature_divide.get_output_columns()
)
assert_dataframe_equality(result_df, feature_set_dataframe)
assert result_df.is_cached
assert not result_df.is_cached

def test_construct_invalid_df(
self, key_id, timestamp_c, feature_add, feature_divide
Expand Down

0 comments on commit 11cc5d5

Please sign in to comment.