diff --git a/handler.go b/handler.go index 2f94135..691ddef 100644 --- a/handler.go +++ b/handler.go @@ -129,7 +129,7 @@ func (h *handler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes. ) if data.IsOLAP() { - return h.streamExecute(data, query, emptyBindVars, callback) + return h.streamExecute(c, data, query, emptyBindVars, callback) } resp, err := h.client.Execute(context.Background(), connect.NewRequest(&psdbpb.ExecuteRequest{ @@ -181,7 +181,7 @@ func (h *handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, call data := h.clientData(c) if data.IsOLAP() { - return h.streamExecute(data, prepare.PrepareStmt, castBindVars(prepare.BindVars), callback) + return h.streamExecute(c, data, prepare.PrepareStmt, castBindVars(prepare.BindVars), callback) } resp, err := h.client.Execute(context.Background(), connect.NewRequest(&psdbpb.ExecuteRequest{ @@ -226,7 +226,7 @@ func (h *handler) WarningCount(c *mysql.Conn) uint16 { return uint16(len(session.GetVitessSession().GetWarnings())) } -func (h *handler) streamExecute(data *clientData, query string, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { +func (h *handler) streamExecute(c *mysql.Conn, data *clientData, query string, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error { stream, err := h.client.StreamExecute(context.Background(), connect.NewRequest(&psdbpb.ExecuteRequest{ Session: data.Session, Query: query, @@ -241,18 +241,39 @@ func (h *handler) streamExecute(data *clientData, query string, bindVars map[str for stream.Receive() { resp = stream.Msg() + // NOTE: Some results do not have any Result. This is most likely + // the case when a Session is returned. While Vitess currently (as of v18) + // is implemented such that the last streaming response + // contains a Session, but not Result, I do not want to assume + // this is always the case, so this is implemented to handle + // both existing or none existing, or either existing to cover + // our bases. + + // Some results may contain a Session, if so + // we need to bind it to the mysql.Conn like normal + if resp.Session != nil { + bindSession(c, data, resp.Session) + } + + // If we have ane error, we just return the error if resp.Error != nil { return sqlerror.NewSQLErrorFromError(vterrors.FromVTRPC( castRPCError(resp.Error)), ) } - if fields == nil { - fields = resp.GetResult().GetFields() - } - if err := callback(sqltypes.CustomProto3ToResult( - castFields(fields), castQueryResult(resp.GetResult())), - ); err != nil { - return err + + // Lastly if there are results, we return them to the mysql client. + // messages without results get ignored at this point since they + // likely only contained session data. + if resp.Result != nil { + if fields == nil { + fields = resp.Result.GetFields() + } + if err := callback(sqltypes.CustomProto3ToResult( + castFields(fields), castQueryResult(resp.GetResult())), + ); err != nil { + return err + } } // For each iteration, stream.Msg() is reused to the same struct,