Skip to content

Commit

Permalink
Finish pbench forward
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanyzhang committed Jun 20, 2024
1 parent 4a2ae07 commit e189d61
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 46 deletions.
6 changes: 6 additions & 0 deletions cmd/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ var forwardCmd = &cobra.Command{
return fmt.Errorf("the forward target server host at position %d is identical to the source server host %s", i, sourceUrl.Host)
}
}
for _, isTrino := range forward.PrestoFlagsArray.IsTrino {
if isTrino {
return fmt.Errorf("forward command does not support Trino yet")
}
}
return nil
},
Short: "Watch incoming query workloads from the first Presto cluster (cluster 0) and forward them to the rest clusters.",
Expand All @@ -39,6 +44,7 @@ var forwardCmd = &cobra.Command{
func init() {
RootCmd.AddCommand(forwardCmd)
forward.PrestoFlagsArray.Install(forwardCmd)
_ = forwardCmd.Flags().MarkHidden("trino")
wd, _ := os.Getwd()
forwardCmd.Flags().StringVarP(&forward.OutputPath, "output-path", "o", wd, "Output directory path")
forwardCmd.Flags().StringVarP(&forward.RunName, "name", "n", fmt.Sprintf("forward_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "forward_<current time>")`)
Expand Down
149 changes: 121 additions & 28 deletions cmd/forward/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package forward

import (
"context"
"fmt"
"github.com/spf13/cobra"
"net/http"
"os"
"os/signal"
"path/filepath"
"pbench/log"
"pbench/presto"
"pbench/utils"
"sync"
"sync/atomic"
"syscall"
"time"
)

Expand All @@ -17,49 +22,137 @@ var (
RunName string
PollInterval time.Duration

runningTasks sync.WaitGroup
runningTasks sync.WaitGroup
failedToForward atomic.Uint32
forwarded atomic.Uint32
)

type QueryHistory struct {
QueryId string `presto:"query_id"`
Query string `presto:"query"`
Created *time.Time `presto:"created"`
}

func Run(_ *cobra.Command, _ []string) {
//OutputPath = filepath.Join(OutputPath, RunName)
//utils.PrepareOutputDirectory(OutputPath)
//
//// also start to write logs to the output directory from this point on.
//logPath := filepath.Join(OutputPath, "forward.log")
//flushLog := utils.InitLogFile(logPath)
//defer flushLog()
OutputPath = filepath.Join(OutputPath, RunName)
utils.PrepareOutputDirectory(OutputPath)

prestoClusters := PrestoFlagsArray.Assemble()
// also start to write logs to the output directory from this point on.
logPath := filepath.Join(OutputPath, "forward.log")
flushLog := utils.InitLogFile(logPath)
defer flushLog()

ctx, cancel := context.WithCancel(context.Background())
timeToExit := make(chan os.Signal, 1)
signal.Notify(timeToExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
// Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back.
go func() {
sig := <-timeToExit
if sig != nil {
log.Info().Msg("abort forwarding")
cancel()
}
}()

prestoClusters := PrestoFlagsArray.Pivot()
// The design here is to forward the traffic from cluster 0 to the rest.
sourceClusterSize := 0
clients := make([]*presto.Client, 0, len(prestoClusters))
for i, cluster := range prestoClusters {
clients = append(clients, cluster.NewPrestoClient())
if stats, _, err := clients[i].GetClusterInfo(context.Background()); err != nil {
log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d", i)
// Check if we can connect to the cluster.
if stats, _, err := clients[i].GetClusterInfo(ctx); err != nil {
log.Fatal().Err(err).Msgf("cannot connect to cluster at position %d: %s", i, cluster.ServerUrl)
} else if i == 0 {
sourceClusterSize = stats.ActiveWorkers
} else if stats.ActiveWorkers != sourceClusterSize {
log.Warn().Msgf("source cluster size does not match target cluster %d size (%d != %d)", i, stats.ActiveWorkers, sourceClusterSize)
log.Warn().Msgf("the source cluster and target cluster %d do not match in size (%d != %d)", i, sourceClusterSize, stats.ActiveWorkers)
}
}

sourceClient := clients[0]
trueValue := true
states, _, err := sourceClient.GetQueryState(context.Background(), &presto.GetQueryStatsOptions{
IncludeAllQueries: &trueValue,
IncludeAllQueryProgressStats: nil,
ExcludeResourceGroupPathInfo: nil,
QueryTextSizeLimit: nil,
})
if err != nil {
log.Fatal().Err(err).Msgf("cannot get query states")
// lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch.
// We only look at queries created later than this timestamp in the following batch.
lastQueryStateCheckCutoffTime := time.Time{}
// Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D.
for ctx.Err() == nil {
states, _, err := sourceClient.GetQueryState(ctx, &presto.GetQueryStatsOptions{IncludeAllQueries: &trueValue})
if err != nil {
log.Error().Err(err).Msgf("failed to get query states")
break
}
newCutoffTime := time.Time{}
for _, state := range states {
if !state.CreateTime.After(lastQueryStateCheckCutoffTime) {
// We looked at this query in the previous batch.
continue
}
if newCutoffTime.Before(state.CreateTime) {
newCutoffTime = state.CreateTime
}
runningTasks.Add(1)
go forwardQuery(ctx, &state, clients)
}
if newCutoffTime.After(lastQueryStateCheckCutoffTime) {
lastQueryStateCheckCutoffTime = newCutoffTime
}
timer := time.NewTimer(PollInterval)
select {
case <-ctx.Done():
case <-timer.C:
}
}
runningTasks.Wait()
// This causes the signal handler to exit.
close(timeToExit)
log.Info().Uint32("forwarded", forwarded.Load()).Uint32("failed_to_forward", failedToForward.Load()).
Msgf("finished forwarding queries")
}

func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, clients []*presto.Client) {
defer runningTasks.Done()
queryInfo, _, queryInfoErr := clients[0].GetQueryInfo(ctx, queryState.QueryId, false, nil)
if queryInfoErr != nil {
log.Error().Str("query_id", queryState.QueryId).Err(queryInfoErr).Msg("failed to get query info for forwarding")
failedToForward.Add(1)
return
}
SessionPropertyHeader := clients[0].GenerateSessionParamsHeaderValue(queryInfo.Session.CollectSessionProperties())
successful, failed := atomic.Uint32{}, atomic.Uint32{}
forwardedQueries := sync.WaitGroup{}
for i := 1; i < len(clients); i++ {
forwardedQueries.Add(1)
go func(client *presto.Client) {
defer forwardedQueries.Done()
clientResult, _, queryErr := client.Query(ctx, queryInfo.Query, func(req *http.Request) {
if queryInfo.Session.Catalog != nil {
req.Header.Set(presto.CatalogHeader, *queryInfo.Session.Catalog)
}
if queryInfo.Session.Schema != nil {
req.Header.Set(presto.SchemaHeader, *queryInfo.Session.Schema)
}
req.Header.Set(presto.SessionHeader, SessionPropertyHeader)
req.Header.Set(presto.SourceHeader, queryInfo.QueryId)
})
if queryErr != nil {
log.Error().Str("source_query_id", queryInfo.QueryId).
Str("target_host", client.GetHost()).Err(queryErr).Msg("failed to execute query")
failed.Add(1)
return
}
rowCount := 0
drainErr := clientResult.Drain(ctx, func(qr *presto.QueryResults) error {
rowCount += len(qr.Data)
return nil
})
if drainErr != nil {
log.Error().Str("source_query_id", queryInfo.QueryId).
Str("target_host", client.GetHost()).Err(drainErr).Msg("failed to fetch query result")
failed.Add(1)
return
}
successful.Add(1)
log.Info().Str("source_query_id", queryInfo.QueryId).
Str("target_host", client.GetHost()).Int("row_count", rowCount).Msg("query executed successfully")
}(clients[i])
}
fmt.Printf("%#v", states)
forwardedQueries.Wait()
log.Info().Str("source_query_id", queryInfo.QueryId).Uint32("successful", successful.Load()).
Uint32("failed", failed.Load()).Msg("query forwarding finished")
forwarded.Add(1)
}
4 changes: 4 additions & 0 deletions cmd/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ We also expect the queries in this CSV file are sorted by "create_time" in ascen
}
utils.ExpandHomeDirectory(&replay.OutputPath)
utils.ExpandHomeDirectory(&args[0])
if replay.PrestoFlags.IsTrino {
return fmt.Errorf("replay command does not support Trino yet")
}
return nil
},
Run: replay.Run,
Expand All @@ -33,6 +36,7 @@ func init() {
RootCmd.AddCommand(replayCmd)
wd, _ := os.Getwd()
replay.PrestoFlags.Install(replayCmd)
_ = replayCmd.Flags().MarkHidden("trino")
replayCmd.Flags().StringVarP(&replay.OutputPath, "output-path", "o", wd, "Output directory path")
replayCmd.Flags().StringVarP(&replay.RunName, "name", "n", fmt.Sprintf("replay_%s", time.Now().Format(utils.DirectoryNameTimeFormat)), `Assign a name to this run. (default: "replay_<current time>")`)
}
4 changes: 4 additions & 0 deletions presto/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func GenerateHttpQueryParameter(v any) string {
return queryBuilder.String()
}

func (c *Client) GetHost() string {
return c.serverUrl.Host
}

func (c *Client) setHeader(key, value string) {
if c.isTrino {
key = strings.Replace(key, "X-Presto", "X-Trino", 1)
Expand Down
8 changes: 7 additions & 1 deletion presto/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"github.com/stretchr/testify/assert"
"pbench/presto"
"pbench/presto/query_json"
"strings"
"syscall"
"testing"
Expand Down Expand Up @@ -42,9 +43,14 @@ func TestQuery(t *testing.T) {
assert.Equal(t, 150000, rowCount)

buf := &strings.Builder{}
_, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf)
var queryInfo *query_json.QueryInfo
queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, false, buf)
assert.Nil(t, err)
assert.Nil(t, queryInfo)
assert.Greater(t, buf.Len(), 0)
queryInfo, _, err = client.GetQueryInfo(context.Background(), qr.Id, true, nil)
assert.Nil(t, err)
assert.Equal(t, qr.Id, queryInfo.QueryId)
}

func TestGenerateQueryParameter(t *testing.T) {
Expand Down
22 changes: 17 additions & 5 deletions presto/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"io"
"net/http"
"pbench/presto/query_json"
)

func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) {
Expand Down Expand Up @@ -49,19 +50,30 @@ func (c *Client) CancelQuery(ctx context.Context, nextUri string, opts ...Reques
return c.requestQueryResults(ctx, req)
}

func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*http.Response, error) {
// GetQueryInfo retrieves the query JSON for the given query ID.
// If writer is nil, we return deserialized QueryInfo. Otherwise, we just return the raw buffer.
func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, writer io.Writer, opts ...RequestOption) (*query_json.QueryInfo, *http.Response, error) {
urlStr := "v1/query/" + queryId
if pretty {
urlStr += "?pretty"
}
req, err := c.NewRequest("GET",
urlStr, nil, opts...)
if err != nil {
return nil, err
return nil, nil, err
}
var (
resp *http.Response
queryInfo *query_json.QueryInfo
)
if writer != nil {
resp, err = c.Do(ctx, req, writer)
} else {
queryInfo = new(query_json.QueryInfo)
resp, err = c.Do(ctx, req, queryInfo)
}
resp, err := c.Do(ctx, req, writer)
if err != nil {
return resp, err
return nil, resp, err
}
return resp, nil
return queryInfo, resp, nil
}
16 changes: 16 additions & 0 deletions presto/query_json/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ func (s *Session) PrepareForInsert() {
s.SessionPropertiesJson = string(jsonBytes[:len(jsonBytes)-1])
}
}

func (s *Session) CollectSessionProperties() map[string]any {
sessionParams := make(map[string]any)
if s == nil {
return sessionParams
}
for k, v := range s.SystemProperties {
sessionParams[k] = v
}
for catalog, catalogProps := range s.CatalogProperties {
for k, v := range catalogProps {
sessionParams[catalog+"."+k] = v
}
}
return sessionParams
}
20 changes: 10 additions & 10 deletions presto/query_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ import (
// https://github.com/prestodb/presto/blob/master/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfo.java
// Unused fields are commented out for now.
type QueryStateInfo struct {
QueryId string `json:"queryId"`
QueryState string `json:"queryState"`
QueryId string `json:"queryId"`
//QueryState string `json:"queryState"`
//ResourceGroupId []string `json:"resourceGroupId"`
Query string `json:"query"`
QueryTruncated bool `json:"queryTruncated"`
CreateTime time.Time `json:"createTime"`
User string `json:"user"`
Authenticated bool `json:"authenticated"`
Source string `json:"source"`
Catalog string `json:"catalog"`
//Query string `json:"query"`
//QueryTruncated bool `json:"queryTruncated"`
CreateTime time.Time `json:"createTime"`
//User string `json:"user"`
//Authenticated bool `json:"authenticated"`
//Source string `json:"source,omitempty"`
//Catalog string `json:"catalog"`
//Progress struct {
// ElapsedTimeMillis int `json:"elapsedTimeMillis"`
// QueuedTimeMillis int `json:"queuedTimeMillis"`
Expand Down Expand Up @@ -61,7 +61,7 @@ func (c *Client) GetQueryState(ctx context.Context, reqOpt *GetQueryStatsOptions
}

infoArray := make([]QueryStateInfo, 0, 16)
resp, err := c.Do(ctx, req, infoArray)
resp, err := c.Do(ctx, req, &infoArray)
if err != nil {
return nil, resp, err
}
Expand Down
2 changes: 1 addition & 1 deletion stage/stage_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func (s *Stage) saveQueryJsonFile(result *QueryResult) {
checkErr(err)
if err == nil {
// We need to save the query json file even if the stage context is canceled.
_, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile)
_, _, err = s.Client.GetQueryInfo(utils.GetCtxWithTimeout(time.Second*5), result.QueryId, false, queryJsonFile)
checkErr(err)
checkErr(queryJsonFile.Close())
}
Expand Down
3 changes: 2 additions & 1 deletion utils/presto_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func (a *PrestoFlagsArray) Install(cmd *cobra.Command) {
cmd.Flags().StringArrayVarP(&a.Password, "password", "p", []string{""}, "Presto user password (optional)")
}

func (a *PrestoFlagsArray) Assemble() []PrestoFlags {
// Pivot generates PrestoFlags array that is suitable for creating Presto clients conveniently.
func (a *PrestoFlagsArray) Pivot() []PrestoFlags {
ret := make([]PrestoFlags, 0, len(a.ServerUrl))
for _, url := range a.ServerUrl {
ret = append(ret, PrestoFlags{
Expand Down

0 comments on commit e189d61

Please sign in to comment.