diff --git a/benchmarks/test/my_post_query_cycle_script.py b/benchmarks/test/my_post_query_cycle_script.py new file mode 100644 index 0000000..2e6dd56 --- /dev/null +++ b/benchmarks/test/my_post_query_cycle_script.py @@ -0,0 +1,11 @@ +import sys +from utils import increment_file_value + +# Main function to handle the command-line argument +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Missing ") + sys.exit(-1) + + file_path = sys.argv[1] + increment_file_value(file_path) diff --git a/benchmarks/test/stage_4.json b/benchmarks/test/stage_4.json index 9f1479f..f252007 100644 --- a/benchmarks/test/stage_4.json +++ b/benchmarks/test/stage_4.json @@ -14,6 +14,10 @@ "echo \"run this script after each query in this stage is complete\"", "python3 my_post_query_script.py count.txt" ], + "post_query_cycle_scripts": [ + "echo \"execute this script after all runs of the same query in this stage have completed\"", + "python3 my_post_query_cycle_script.py count.txt" + ], "next": [ "stage_5.json" ], diff --git a/stage/stage.go b/stage/stage.go index 47cc774..3edfc2e 100644 --- a/stage/stage.go +++ b/stage/stage.go @@ -44,6 +44,8 @@ type Stage struct { PostStageShellScripts []string `json:"post_stage_scripts,omitempty"` // Run shell scripts after executing each query. PostQueryShellScripts []string `json:"post_query_scripts,omitempty"` + // Run shell scripts after finishing full query cycle runs each query. + PostQueryCycleShellScripts []string `json:"post_query_cycle_scripts,omitempty"` // A map from [catalog.schema] to arrays of integers as expected row counts for all the queries we run // under different schemas. This includes the queries from both Queries and QueryFiles. Queries first and QueryFiles follows. // Can use regexp as key to match multiple [catalog.schema] pairs. @@ -438,6 +440,11 @@ func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *str } log.Info().EmbedObject(result).Msgf("query finished") } + // run post query cycle shell scripts + postQueryCycleErr := s.runShellScripts(ctx, s.PostQueryCycleShellScripts) + if retErr == nil { + retErr = postQueryCycleErr + } } return nil } diff --git a/stage/stage_test.go b/stage/stage_test.go index 0329337..d99fbcc 100644 --- a/stage/stage_test.go +++ b/stage/stage_test.go @@ -83,12 +83,12 @@ func testParseAndExecute(t *testing.T, abortOnError bool, totalQueryCount int, e func TestParseStageGraph(t *testing.T) { t.Run("abortOnError = true", func(t *testing.T) { testParseAndExecute(t, true, 10, 16, []string{ - "SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 3) + "SYNTAX_ERROR: Table tpch.sf1.foo does not exist"}, 5) }) t.Run("abortOnError = false", func(t *testing.T) { testParseAndExecute(t, false, 15, 24, []string{ "SYNTAX_ERROR: Table tpch.sf1.foo does not exist", - "SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 4) + "SYNTAX_ERROR: line 1:11: Function sum1 not registered"}, 8) }) } diff --git a/stage/stage_utils.go b/stage/stage_utils.go index 37a7caa..178dcea 100644 --- a/stage/stage_utils.go +++ b/stage/stage_utils.go @@ -90,6 +90,7 @@ func (s *Stage) MergeWith(other *Stage) *Stage { s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...) s.PostStageShellScripts = append(s.PostStageShellScripts, other.PostStageShellScripts...) + s.PostQueryCycleShellScripts = append(s.PostQueryCycleShellScripts, other.PostQueryCycleShellScripts...) return s }