diff --git a/benchmarks/scripts/cache_cleaning_coordinator.py b/benchmarks/scripts/cache_cleaning_coordinator.py new file mode 100644 index 0000000..87f3cdc --- /dev/null +++ b/benchmarks/scripts/cache_cleaning_coordinator.py @@ -0,0 +1,47 @@ +from presto_utils import execute_cluster_call +import argparse +import sys + +def clean_directory_list_cache(hostname, username, password, catalog_name): + query = "CALL " + catalog_name + ".system.invalidate_directory_list_cache()" + return execute_cluster_call(hostname, username, password, catalog_name, query) + +def clean_metastore_cache(hostname, username, password, catalog_name): + query = "CALL " + catalog_name + ".system.invalidate_metastore_cache()" + return execute_cluster_call(hostname, username, password, catalog_name, query) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Connect to PrestoDB') + parser.add_argument('--host', required=True, help='Hostname of the Presto coordinator') + parser.add_argument('--username', required=True, help='Username to connect to Presto') + parser.add_argument('--password', required=True, help='Password to connect to Presto') + + args = parser.parse_args() + + catalog_list = ["hive"] + is_list_cache_cleanup_enabled = True + is_metadata_cache_cleanup_enabled = False + + # Directory list cache clean up + if is_list_cache_cleanup_enabled: + for catalog_name in catalog_list: + print("Cleaning up directory list cache for ", catalog_name) + rows = clean_directory_list_cache(args.host, args.username, args.password, catalog_name) + print("directory_list_cache_cleanup_query Query Result: ", rows) + if rows[0][0] == True: + print("Directory list cache clean up is successful for ", catalog_name) + else: + print("Directory list cache clean up is failed for ", catalog_name) + sys.exit(1) + + # Metadata cache clean up + if is_metadata_cache_cleanup_enabled: + for catalog_name in catalog_list: + print("Cleaning up metadata cache for ", catalog_name) + rows = clean_metastore_cache(args.host, args.username, args.password, catalog_name) + print("metastore_cache_cleanup_query Query Result: ", rows) + if rows[0][0] == True: + print("Metastore cache clean up is successful for ", catalog_name) + else: + print("Metastore cache clean up is failed for ", catalog_name) + sys.exit(1) diff --git a/benchmarks/scripts/cache_cleaning_workers.py b/benchmarks/scripts/cache_cleaning_workers.py new file mode 100644 index 0000000..0cca2f2 --- /dev/null +++ b/benchmarks/scripts/cache_cleaning_workers.py @@ -0,0 +1,74 @@ +from mysql_utils import create_connection +from mysql_utils import execute_mysql_query +from system_utils import execute_ssh_command +import json +import argparse + +def get_workers_public_ips(data): + data = json.loads(data) + + worker_ips = [] + for key, value in data["output"].items(): + if key.startswith("swarm_presto_worker") and key.endswith("public_ip"): + worker_ips.append(value) + return worker_ips + +def cleanup_worker_disk_cache(worker_public_ips, directory_to_cleanup, login_user, ssh_key_path): + cleanup_command = f'sudo rm -rf {directory_to_cleanup}/*' + for worker_ip in worker_public_ips: + execute_ssh_command(worker_ip, login_user, ssh_key_path, cleanup_command) + +def cleanup_worker_os_cache(worker_public_ips, login_user, ssh_key_path): + for worker_ip in worker_public_ips: + os_cache_clean_commands = ["sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches", "sudo swapoff -a; sudo swapon -a"] + for command in os_cache_clean_commands : + execute_ssh_command(worker_ip, login_user, ssh_key_path, command) + +# Main function to connect and run queries +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Connect to PrestoDB') + parser.add_argument('--mysql', required=True, help='Mysql database details') + parser.add_argument('--clustername', required=True, help='Presto cluster name') + parser.add_argument('--sshkey', required=True, help='SSH key to connect to Presto Vms') + + args = parser.parse_args() + + with open(args.mysql, 'r') as file: + mysql_details = json.load(file) + username = mysql_details.get("username") + password = mysql_details.get("password") + server = mysql_details.get("server") + database = mysql_details.get("database") + + connection = create_connection( + host_name=server, + user_name=username, + user_password=password, + db_name=database + ) + + cluster_name = args.clustername + select_query = "SELECT detail_json FROM presto_clusters WHERE cluster_name = %s" + results = execute_mysql_query(connection, select_query, cluster_name) + + worker_public_ips = [] + if results: + worker_public_ips = get_workers_public_ips(results[0][0]) + + print("worker count = ", len(worker_public_ips)) + print("======= worker_public_ips ======") + print(worker_public_ips) + + is_worker_disk_cache_cleanup_enabled = True + is_worker_os_cache_cleanup_enabled = True + + if is_worker_disk_cache_cleanup_enabled: + native_cache_directory_worker = "/home/centos/presto/async_data_cache" + cleanup_worker_disk_cache(worker_public_ips, native_cache_directory_worker, "centos", args.sshkey) + + if is_worker_os_cache_cleanup_enabled: + cleanup_worker_os_cache(worker_public_ips, "centos", args.sshkey) + + if connection.is_connected(): + connection.close() + print("The connection is closed.") diff --git a/benchmarks/scripts/mysql_utils.py b/benchmarks/scripts/mysql_utils.py new file mode 100644 index 0000000..fd03f8a --- /dev/null +++ b/benchmarks/scripts/mysql_utils.py @@ -0,0 +1,26 @@ +import mysql.connector +from mysql.connector import Error + +def create_connection(host_name, user_name, user_password, db_name): + connection = None + try: + connection = mysql.connector.connect( + host=host_name, + user=user_name, + passwd=user_password, + database=db_name + ) + print("Connection to Benchmark Database is successful") + except Error as e: + print(f"The error '{e}' occurred") + + return connection + +def execute_mysql_query(connection, query, cluster_name): + cursor = connection.cursor() + try: + cursor.execute(query, (cluster_name,)) + result = cursor.fetchall() + return result + except Error as e: + print(f"The error '{e}' occurred") diff --git a/benchmarks/scripts/presto_utils.py b/benchmarks/scripts/presto_utils.py new file mode 100644 index 0000000..64efcd0 --- /dev/null +++ b/benchmarks/scripts/presto_utils.py @@ -0,0 +1,24 @@ +import prestodb + +# Establish Presto connection +def create_connection(hostname, username, password, catalog_name): + conn = prestodb.dbapi.connect( + host=hostname, + port=443, + user=username, + catalog=catalog_name, + schema='', + http_scheme='https', + auth=prestodb.auth.BasicAuthentication(username, password) + ) + return conn + +def execute_presto_query(hostname, username, password, cluster_name, query): + conn = create_connection(hostname, username, password, cluster_name) + try: + cur = conn.cursor() + cur.execute(query) + rows = cur.fetchall() + finally: + conn.close() + return rows \ No newline at end of file diff --git a/benchmarks/scripts/system_utils.py b/benchmarks/scripts/system_utils.py new file mode 100644 index 0000000..0f8afdb --- /dev/null +++ b/benchmarks/scripts/system_utils.py @@ -0,0 +1,33 @@ +import paramiko +import sys + +def execute_ssh_command(worker_ip, login_user, ssh_key_path, command): + ssh = None + try: + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + private_key = paramiko.Ed25519Key(filename=ssh_key_path) + + ssh.connect(hostname=worker_ip, username=login_user, pkey=private_key) + + stdin, stdout, stderr = ssh.exec_command(command) + stdout_output = stdout.read().decode() + stderr_output = stderr.read().decode() + + if stderr_output: + print(f'Error on {worker_ip}: {stderr_output}') + sys.exit(1) + else: + print(f'Successfully finished running command on {worker_ip}') + except paramiko.SSHException as ssh_err: + print(f'SSH error on {worker_ip}: {ssh_err}') + sys.exit(1) + except FileNotFoundError as key_err: + print(f'SSH key file not found: {key_err}') + sys.exit(1) + except Exception as e: + print(f'Failed to connect to {worker_ip}: {str(e)}') + sys.exit(1) + finally: + if ssh: + ssh.close() \ No newline at end of file diff --git a/benchmarks/test/my_post_query_cycle_script.py b/benchmarks/test/my_post_query_cycle_script.py new file mode 100644 index 0000000..2e6dd56 --- /dev/null +++ b/benchmarks/test/my_post_query_cycle_script.py @@ -0,0 +1,11 @@ +import sys +from utils import increment_file_value + +# Main function to handle the command-line argument +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Missing ") + sys.exit(-1) + + file_path = sys.argv[1] + increment_file_value(file_path) diff --git a/benchmarks/test/my_pre_query_cycle_script.py b/benchmarks/test/my_pre_query_cycle_script.py new file mode 100644 index 0000000..2e6dd56 --- /dev/null +++ b/benchmarks/test/my_pre_query_cycle_script.py @@ -0,0 +1,11 @@ +import sys +from utils import increment_file_value + +# Main function to handle the command-line argument +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Missing ") + sys.exit(-1) + + file_path = sys.argv[1] + increment_file_value(file_path) diff --git a/benchmarks/test/my_pre_stage_script.py b/benchmarks/test/my_pre_stage_script.py new file mode 100644 index 0000000..2e6dd56 --- /dev/null +++ b/benchmarks/test/my_pre_stage_script.py @@ -0,0 +1,11 @@ +import sys +from utils import increment_file_value + +# Main function to handle the command-line argument +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Missing ") + sys.exit(-1) + + file_path = sys.argv[1] + increment_file_value(file_path) diff --git a/benchmarks/test/stage_4.json b/benchmarks/test/stage_4.json index 9f1479f..232a466 100644 --- a/benchmarks/test/stage_4.json +++ b/benchmarks/test/stage_4.json @@ -6,14 +6,26 @@ "query_files": [ "stage_4.sql" ], + "pre_stage_scripts": [ + "echo \"run this script before this stage is started\"", + "python3 my_pre_stage_script.py count.txt" + ], "post_stage_scripts": [ "echo \"run this script after this stage is complete\"", "python3 my_post_stage_script.py count.txt" ], + "pre_query_cycle_scripts": [ + "echo \"execute this script before starting all runs of the same query in this stage\"", + "python3 my_pre_query_cycle_script.py count.txt" + ], "post_query_scripts": [ "echo \"run this script after each query in this stage is complete\"", "python3 my_post_query_script.py count.txt" ], + "post_query_cycle_scripts": [ + "echo \"execute this script after all runs of the same query in this stage have completed\"", + "python3 my_post_query_cycle_script.py count.txt" + ], "next": [ "stage_5.json" ], diff --git a/stage/stage.go b/stage/stage.go index 47cc774..2918309 100644 --- a/stage/stage.go +++ b/stage/stage.go @@ -40,10 +40,16 @@ type Stage struct { // If a stage has both Queries and QueryFiles, the queries in the Queries array will be executed first then // the QueryFiles will be read and executed. QueryFiles []string `json:"query_files,omitempty"` + // Run shell scripts before starting the execution of queries in a stage. + PreStageShellScripts []string `json:"pre_stage_scripts,omitempty"` // Run shell scripts after executing all the queries in a stage. PostStageShellScripts []string `json:"post_stage_scripts,omitempty"` // Run shell scripts after executing each query. PostQueryShellScripts []string `json:"post_query_scripts,omitempty"` + // Run shell scripts before starting query cycle runs of each query. + PreQueryCycleShellScripts []string `json:"pre_query_cycle_scripts,omitempty"` + // Run shell scripts after finishing full query cycle runs each query. + PostQueryCycleShellScripts []string `json:"post_query_cycle_scripts,omitempty"` // A map from [catalog.schema] to arrays of integers as expected row counts for all the queries we run // under different schemas. This includes the queries from both Queries and QueryFiles. Queries first and QueryFiles follows. // Can use regexp as key to match multiple [catalog.schema] pairs. @@ -227,6 +233,10 @@ func (s *Stage) run(ctx context.Context) (returnErr error) { s.setDefaults() s.prepareClient() s.propagateStates() + preStageErr := s.runShellScripts(ctx, s.PreStageShellScripts) + if preStageErr != nil { + return fmt.Errorf("pre-stage script execution failed: %w", preStageErr) + } if len(s.Queries)+len(s.QueryFiles) > 0 { if *s.RandomExecution { returnErr = s.runRandomly(ctx) @@ -400,6 +410,11 @@ func (s *Stage) runShellScripts(ctx context.Context, shellScripts []string) erro func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *string, expectedRowCountStartIndex int) (retErr error) { batchSize := len(queries) for i, queryText := range queries { + // run pre query cycle shell scripts + preQueryCycleErr := s.runShellScripts(ctx, s.PreQueryCycleShellScripts) + if preQueryCycleErr != nil { + return fmt.Errorf("pre-query script execution failed: %w", preQueryCycleErr) + } for j := 0; j < s.ColdRuns+s.WarmRuns; j++ { query := &Query{ Text: queryText, @@ -438,6 +453,11 @@ func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *str } log.Info().EmbedObject(result).Msgf("query finished") } + // run post query cycle shell scripts + postQueryCycleErr := s.runShellScripts(ctx, s.PostQueryCycleShellScripts) + if postQueryCycleErr != nil { + return fmt.Errorf("post-query script execution failed: %w", postQueryCycleErr) + } } return nil } diff --git a/stage/stage_test.go b/stage/stage_test.go index 0329337..0ed9aee 100644 --- a/stage/stage_test.go +++ b/stage/stage_test.go @@ -83,12 +83,12 @@ func testParseAndExecute(t *testing.T, abortOnError bool, totalQueryCount int, e func TestParseStageGraph(t *testing.T) { t.Run("abortOnError = true", func(t *testing.T) { testParseAndExecute(t, true, 10, 16, []string{ - "SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 3) + "SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 9) }) t.Run("abortOnError = false", func(t *testing.T) { testParseAndExecute(t, false, 15, 24, []string{ "SYNTAX_ERROR: Table tpch.sf1.foo does not exist", - "SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 4) + "SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 13) }) } diff --git a/stage/stage_utils.go b/stage/stage_utils.go index 37a7caa..5b3174a 100644 --- a/stage/stage_utils.go +++ b/stage/stage_utils.go @@ -88,8 +88,11 @@ func (s *Stage) MergeWith(other *Stage) *Stage { s.NextStagePaths = append(s.NextStagePaths, other.NextStagePaths...) s.BaseDir = other.BaseDir + s.PreStageShellScripts = append(s.PreStageShellScripts, other.PreStageShellScripts...) s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...) s.PostStageShellScripts = append(s.PostStageShellScripts, other.PostStageShellScripts...) + s.PreQueryCycleShellScripts = append(s.PreQueryCycleShellScripts, other.PreQueryCycleShellScripts...) + s.PostQueryCycleShellScripts = append(s.PostQueryCycleShellScripts, other.PostQueryCycleShellScripts...) return s }