Skip to content

Commit

Permalink
Update sdt dask
Browse files Browse the repository at this point in the history
Style: add sphinx style document
Separate execute() to set_up() and get_result()
  • Loading branch information
zhanghaoc committed Mar 25, 2024
1 parent d136935 commit e841c4e
Show file tree
Hide file tree
Showing 2 changed files with 429 additions and 94 deletions.
72 changes: 56 additions & 16 deletions sdt_dask/dask_tool/sdt_dask.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
"""
This module provides a class to run the SolarDataTools pipeline on a Dask cluster.
It takes a data plug and a dask client as input and runs the pipeline on the data plug
See the README and tool_demo_SDTDask.ipynb for more information
"""
import os
import pandas as pd
from dask import delayed
from dask.distributed import performance_report
from solardatatools import DataHandler

# Define the pipeline run for as single dataset
def run_pipeline(datahandler, **kwargs):
# Need to call this separately to have it run correctly in task graph
# since it doesn't return anything
"""function to run the pipeline on a datahandler object
user can pass any keyword arguments to the pipeline
Args:
datahandler (:obj:`DataHandler`): The datahandler object.
**kwargs: Optional parameters.
Returns:
DataHandler: The datahandler object after running the pipeline.
"""
# TODO: add loss analysis
# TODO: if dataset failed to run, throw python error
datahandler.run_pipeline(**kwargs)
return datahandler

class SDTDask:
"""A class to run the SolarDataTools pipeline on a Dask cluster.
Will handle invalid data keys and failed datasets.
Attributes:
data_plug (:obj:`DataPlug`): The data plug object.
client (:obj:`Client`): The Dask client object.
output_path (str): The path to save the results.
def __init__(self, data_plug, client):
"""

def __init__(self, data_plug, client, output_path="../results/"):
self.data_plug = data_plug
self.client = client
self.output_path = output_path

def execute(self, KEYS, **kwargs):
# Call above functions in a for loop over the keys
# and collect results in a DataFrame
def set_up(self, KEYS, **kwargs):
"""function to set up the pipeline on the data plug
Call run_pipeline functions in a for loop over the keys
and collect results in a DataFrame
Args:
KEYS (list): List of tuples
**kwargs: Optional parameters.
"""

reports = []
runtimes = []

# KEYS example: [(34, 2011), (35, 2015), (51,2012)] # site ID and year pairs
for key in KEYS:
# TODO: to check if a key is valid explicitly

# TODO: if dataset failed to run, throw python error

df = delayed(self.data_plug.get_data)(key)
dh = delayed(DataHandler)(df)
dh_run = delayed(run_pipeline)(dh, **kwargs)
Expand All @@ -45,17 +80,22 @@ def execute(self, KEYS, **kwargs):
self.df_reports = delayed(pd.DataFrame)(reports)
self.df_reports = delayed(self.df_reports.assign)(runtime=runtimes, keys=KEYS)

# Visualize task graph
# self.df_reports.visualize(filename='sdt_graph_or.png')
def visualize(self, filename="sdt_graph.png"):
# visualize the pipeline, user should have graphviz installed
self.df_reports.visualize(filename)

def get_result(self):
self.get_report()

def get_report(self):
# test if the filepath exist, if not create it
if not os.path.exists(self.output_path):
print("output path does not exist, creating it...")
os.makedirs(self.output_path)
# Compute tasks on cluster and save results
with performance_report(filename="../results/dask-report.html"):
with performance_report(self.output_path + "/dask-report.html"):
summary_table = self.client.compute(self.df_reports)
df = summary_table.result()
df.to_csv('../results/summary_report.csv')

self.client.shutdown()
df.to_csv(self.output_path + "/summary_report.csv")

self.client.shutdown()
Loading

0 comments on commit e841c4e

Please sign in to comment.