Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add post query cycle script execution hook #32

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions benchmarks/scripts/cache_cleaning_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from presto_utils import execute_cluster_call
import argparse
import sys

def clean_directory_list_cache(hostname, username, password, catalog_name):
query = "CALL " + catalog_name + ".system.invalidate_directory_list_cache()"
return execute_cluster_call(hostname, username, password, catalog_name, query)

def clean_metastore_cache(hostname, username, password, catalog_name):
query = "CALL " + catalog_name + ".system.invalidate_metastore_cache()"
return execute_cluster_call(hostname, username, password, catalog_name, query)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Connect to PrestoDB')
parser.add_argument('--host', required=True, help='Hostname of the Presto coordinator')
parser.add_argument('--username', required=True, help='Username to connect to Presto')
parser.add_argument('--password', required=True, help='Password to connect to Presto')

args = parser.parse_args()

catalog_list = ["hive"]
is_list_cache_cleanup_enabled = True
is_metadata_cache_cleanup_enabled = False

# Directory list cache clean up
if is_list_cache_cleanup_enabled:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agrawalreetika Is it possible to attach the output of these? What would the output look like if output rows are multiple lines?

for catalog_name in catalog_list:
print("Cleaning up directory list cache for ", catalog_name)
rows = clean_directory_list_cache(args.host, args.username, args.password, catalog_name)
print("directory_list_cache_cleanup_query Query Result: ", rows)
if rows[0][0] == True:
print("Directory list cache clean up is successful for ", catalog_name)
else:
print("Directory list cache clean up is failed for ", catalog_name)
sys.exit(1)

# Metadata cache clean up
if is_metadata_cache_cleanup_enabled:
for catalog_name in catalog_list:
print("Cleaning up metadata cache for ", catalog_name)
rows = clean_metastore_cache(args.host, args.username, args.password, catalog_name)
print("metastore_cache_cleanup_query Query Result: ", rows)
if rows[0][0] == True:
print("Metastore cache clean up is successful for ", catalog_name)
else:
print("Metastore cache clean up is failed for ", catalog_name)
sys.exit(1)
74 changes: 74 additions & 0 deletions benchmarks/scripts/cache_cleaning_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from mysql_utils import create_connection
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this being called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from mysql_utils import execute_mysql_query
from system_utils import execute_ssh_command
import json
import argparse

def get_workers_public_ips(data):
data = json.loads(data)

worker_ips = []
for key, value in data["output"].items():
if key.startswith("swarm_presto_worker") and key.endswith("public_ip"):
worker_ips.append(value)
return worker_ips

def cleanup_worker_disk_cache(worker_public_ips, directory_to_cleanup, login_user, ssh_key_path):
cleanup_command = f'sudo rm -rf {directory_to_cleanup}/*'
for worker_ip in worker_public_ips:
execute_ssh_command(worker_ip, login_user, ssh_key_path, cleanup_command)

def cleanup_worker_os_cache(worker_public_ips, login_user, ssh_key_path):
for worker_ip in worker_public_ips:
os_cache_clean_commands = ["sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches", "sudo swapoff -a; sudo swapon -a"]
for command in os_cache_clean_commands :
execute_ssh_command(worker_ip, login_user, ssh_key_path, command)

# Main function to connect and run queries
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Connect to PrestoDB')
parser.add_argument('--mysql', required=True, help='Mysql database details')
parser.add_argument('--clustername', required=True, help='Presto cluster name')
parser.add_argument('--sshkey', required=True, help='SSH key to connect to Presto Vms')

args = parser.parse_args()

with open(args.mysql, 'r') as file:
mysql_details = json.load(file)
username = mysql_details.get("username")
password = mysql_details.get("password")
server = mysql_details.get("server")
database = mysql_details.get("database")

connection = create_connection(
host_name=server,
user_name=username,
user_password=password,
db_name=database
)

cluster_name = args.clustername
select_query = "SELECT detail_json FROM presto_clusters WHERE cluster_name = %s"
results = execute_mysql_query(connection, select_query, cluster_name)

worker_public_ips = []
if results:
worker_public_ips = get_workers_public_ips(results[0][0])

print("worker count = ", len(worker_public_ips))
print("======= worker_public_ips ======")
print(worker_public_ips)

is_worker_disk_cache_cleanup_enabled = True
is_worker_os_cache_cleanup_enabled = True

if is_worker_disk_cache_cleanup_enabled:
native_cache_directory_worker = "/home/centos/presto/async_data_cache"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work on Ubuntu?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currenlty pbench is getting called on Presto clusters, which runs on CentOs, where this is tested.

cleanup_worker_disk_cache(worker_public_ips, native_cache_directory_worker, "centos", args.sshkey)

if is_worker_os_cache_cleanup_enabled:
cleanup_worker_os_cache(worker_public_ips, "centos", args.sshkey)

if connection.is_connected():
connection.close()
print("The connection is closed.")
26 changes: 26 additions & 0 deletions benchmarks/scripts/mysql_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import mysql.connector
from mysql.connector import Error

def create_connection(host_name, user_name, user_password, db_name):
connection = None
try:
connection = mysql.connector.connect(
host=host_name,
user=user_name,
passwd=user_password,
database=db_name
)
print("Connection to Benchmark Database is successful")
except Error as e:
print(f"The error '{e}' occurred")

return connection

def execute_mysql_query(connection, query, cluster_name):
cursor = connection.cursor()
try:
cursor.execute(query, (cluster_name,))
result = cursor.fetchall()
return result
except Error as e:
print(f"The error '{e}' occurred")
24 changes: 24 additions & 0 deletions benchmarks/scripts/presto_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import prestodb

# Establish Presto connection
def create_connection(hostname, username, password, catalog_name):
conn = prestodb.dbapi.connect(
host=hostname,
port=443,
user=username,
catalog=catalog_name,
schema='',
http_scheme='https',
auth=prestodb.auth.BasicAuthentication(username, password)
)
return conn

def execute_presto_query(hostname, username, password, cluster_name, query):
conn = create_connection(hostname, username, password, cluster_name)
try:
cur = conn.cursor()
cur.execute(query)
rows = cur.fetchall()
finally:
conn.close()
return rows
33 changes: 33 additions & 0 deletions benchmarks/scripts/system_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import paramiko
import sys

def execute_ssh_command(worker_ip, login_user, ssh_key_path, command):
ssh = None
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
private_key = paramiko.Ed25519Key(filename=ssh_key_path)

ssh.connect(hostname=worker_ip, username=login_user, pkey=private_key)

stdin, stdout, stderr = ssh.exec_command(command)
stdout_output = stdout.read().decode()
stderr_output = stderr.read().decode()

if stderr_output:
print(f'Error on {worker_ip}: {stderr_output}')
sys.exit(1)
else:
print(f'Successfully finished running command on {worker_ip}')
except paramiko.SSHException as ssh_err:
print(f'SSH error on {worker_ip}: {ssh_err}')
sys.exit(1)
except FileNotFoundError as key_err:
print(f'SSH key file not found: {key_err}')
sys.exit(1)
except Exception as e:
print(f'Failed to connect to {worker_ip}: {str(e)}')
sys.exit(1)
finally:
if ssh:
ssh.close()
11 changes: 11 additions & 0 deletions benchmarks/test/my_post_query_cycle_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is for the two .py file names. Why did you add my_ in front? I think it's better to remove it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this folder is for testing and demoing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "my" stand for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No meaning. it is for pbench test, not for production.

from utils import increment_file_value

# Main function to handle the command-line argument
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Missing <file_path>")
sys.exit(-1)

file_path = sys.argv[1]
increment_file_value(file_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we touched this file in this script, then I guess the unit test assert result should be updated?

11 changes: 11 additions & 0 deletions benchmarks/test/my_pre_query_cycle_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
from utils import increment_file_value

# Main function to handle the command-line argument
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Missing <file_path>")
sys.exit(-1)

file_path = sys.argv[1]
increment_file_value(file_path)
11 changes: 11 additions & 0 deletions benchmarks/test/my_pre_stage_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
from utils import increment_file_value

# Main function to handle the command-line argument
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Missing <file_path>")
sys.exit(-1)

file_path = sys.argv[1]
increment_file_value(file_path)
12 changes: 12 additions & 0 deletions benchmarks/test/stage_4.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@
"query_files": [
"stage_4.sql"
],
"pre_stage_scripts": [
"echo \"run this script before this stage is started\"",
"python3 my_pre_stage_script.py count.txt"
],
"post_stage_scripts": [
"echo \"run this script after this stage is complete\"",
"python3 my_post_stage_script.py count.txt"
],
"pre_query_cycle_scripts": [
"echo \"execute this script before starting all runs of the same query in this stage\"",
"python3 my_pre_query_cycle_script.py count.txt"
],
"post_query_scripts": [
"echo \"run this script after each query in this stage is complete\"",
"python3 my_post_query_script.py count.txt"
],
"post_query_cycle_scripts": [
"echo \"execute this script after all runs of the same query in this stage have completed\"",
"python3 my_post_query_cycle_script.py count.txt"
],
"next": [
"stage_5.json"
],
Expand Down
20 changes: 20 additions & 0 deletions stage/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ type Stage struct {
// If a stage has both Queries and QueryFiles, the queries in the Queries array will be executed first then
// the QueryFiles will be read and executed.
QueryFiles []string `json:"query_files,omitempty"`
// Run shell scripts before starting the execution of queries in a stage.
PreStageShellScripts []string `json:"pre_stage_scripts,omitempty"`
// Run shell scripts after executing all the queries in a stage.
PostStageShellScripts []string `json:"post_stage_scripts,omitempty"`
// Run shell scripts after executing each query.
PostQueryShellScripts []string `json:"post_query_scripts,omitempty"`
// Run shell scripts before starting query cycle runs of each query.
PreQueryCycleShellScripts []string `json:"pre_query_cycle_scripts,omitempty"`
// Run shell scripts after finishing full query cycle runs each query.
PostQueryCycleShellScripts []string `json:"post_query_cycle_scripts,omitempty"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steveburnett for doc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A full cycle here means when we set cold_runs and warm_runs, each query in the benchmark will be run cold_runs + warm_runs times in total.

post_query_scripts will be called after each query execution, post_query_cycle_scripts will be called after all the cold_runs and warm_runs are done for a unique query.

// A map from [catalog.schema] to arrays of integers as expected row counts for all the queries we run
// under different schemas. This includes the queries from both Queries and QueryFiles. Queries first and QueryFiles follows.
// Can use regexp as key to match multiple [catalog.schema] pairs.
Expand Down Expand Up @@ -227,6 +233,10 @@ func (s *Stage) run(ctx context.Context) (returnErr error) {
s.setDefaults()
s.prepareClient()
s.propagateStates()
preStageErr := s.runShellScripts(ctx, s.PreStageShellScripts)
if preStageErr != nil {
return fmt.Errorf("pre-stage script execution failed: %w", preStageErr)
}
if len(s.Queries)+len(s.QueryFiles) > 0 {
if *s.RandomExecution {
returnErr = s.runRandomly(ctx)
Expand Down Expand Up @@ -400,6 +410,11 @@ func (s *Stage) runShellScripts(ctx context.Context, shellScripts []string) erro
func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *string, expectedRowCountStartIndex int) (retErr error) {
batchSize := len(queries)
for i, queryText := range queries {
// run pre query cycle shell scripts
preQueryCycleErr := s.runShellScripts(ctx, s.PreQueryCycleShellScripts)
if preQueryCycleErr != nil {
return fmt.Errorf("pre-query script execution failed: %w", preQueryCycleErr)
}
for j := 0; j < s.ColdRuns+s.WarmRuns; j++ {
query := &Query{
Text: queryText,
Expand Down Expand Up @@ -438,6 +453,11 @@ func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *str
}
log.Info().EmbedObject(result).Msgf("query finished")
}
// run post query cycle shell scripts
postQueryCycleErr := s.runShellScripts(ctx, s.PostQueryCycleShellScripts)
if postQueryCycleErr != nil {
return fmt.Errorf("post-query script execution failed: %w", postQueryCycleErr)
}
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions stage/stage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ func testParseAndExecute(t *testing.T, abortOnError bool, totalQueryCount int, e
func TestParseStageGraph(t *testing.T) {
t.Run("abortOnError = true", func(t *testing.T) {
testParseAndExecute(t, true, 10, 16, []string{
"SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 3)
"SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 9)
})
t.Run("abortOnError = false", func(t *testing.T) {
testParseAndExecute(t, false, 15, 24, []string{
"SYNTAX_ERROR: Table tpch.sf1.foo does not exist",
"SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 4)
"SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 13)
})
}

Expand Down
3 changes: 3 additions & 0 deletions stage/stage_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
s.NextStagePaths = append(s.NextStagePaths, other.NextStagePaths...)
s.BaseDir = other.BaseDir

s.PreStageShellScripts = append(s.PreStageShellScripts, other.PreStageShellScripts...)
s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...)
s.PostStageShellScripts = append(s.PostStageShellScripts, other.PostStageShellScripts...)
s.PreQueryCycleShellScripts = append(s.PreQueryCycleShellScripts, other.PreQueryCycleShellScripts...)
s.PostQueryCycleShellScripts = append(s.PostQueryCycleShellScripts, other.PostQueryCycleShellScripts...)

return s
}
Expand Down