Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
iskakaushik committed Jun 28, 2023
2 parents 794d846 + 84a23de commit 41ebb0b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 33 deletions.
89 changes: 56 additions & 33 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
Expand Down Expand Up @@ -234,6 +235,56 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode() error {
return nil
}

func GenerateMergeCommand(
allCols []string,
upsertKeyCols []string,
watermarkCol string,
tempTableName string,
dstTable string,
) (string, error) {
upsertKeys := []string{}
partitionKeyCols := []string{}
for _, key := range upsertKeyCols {
quotedKey := utils.QuoteIdentifier(key)
upsertKeys = append(upsertKeys, fmt.Sprintf("dst.%s = src.%s", quotedKey, quotedKey))
partitionKeyCols = append(partitionKeyCols, quotedKey)
}
upsertKeyClause := strings.Join(upsertKeys, " AND ")

updateSetClauses := []string{}
insertColumnsClauses := []string{}
insertValuesClauses := []string{}
for _, column := range allCols {
quotedColumn := utils.QuoteIdentifier(column)
updateSetClauses = append(updateSetClauses, fmt.Sprintf("%s = src.%s", quotedColumn, quotedColumn))
insertColumnsClauses = append(insertColumnsClauses, quotedColumn)
insertValuesClauses = append(insertValuesClauses, fmt.Sprintf("src.%s", quotedColumn))
}
updateSetClause := strings.Join(updateSetClauses, ", ")
insertColumnsClause := strings.Join(insertColumnsClauses, ", ")
insertValuesClause := strings.Join(insertValuesClauses, ", ")

quotedWMC := utils.QuoteIdentifier(watermarkCol)

selectCmd := fmt.Sprintf(`
SELECT *
FROM %s
QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s DESC) = 1
`, tempTableName, strings.Join(partitionKeyCols, ","), quotedWMC)

mergeCmd := fmt.Sprintf(`
MERGE INTO %s dst
USING (%s) src
ON %s
WHEN MATCHED AND src.%s > dst.%s THEN UPDATE SET %s
WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)
`, dstTable, selectCmd, upsertKeyClause, quotedWMC, quotedWMC,
updateSetClause, insertColumnsClause, insertValuesClause)

return mergeCmd, nil
}

// HandleUpsertMode handles the upsert mode
func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
allCols []string,
upsertKeyCols []string,
Expand Down Expand Up @@ -262,43 +313,15 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
}
log.Infof("copied file from stage %s to temp table %s", s.stage, tempTableName)

upsertKey := []string{}
// upsert key should be like "dst.key1 = src.key1 AND dst.key2 = src.key2"
for _, key := range upsertKeyCols {
upsertKey = append(upsertKey, fmt.Sprintf("dst.%s = src.%s", key, key))
}
upsertKeyClause := strings.Join(upsertKey, " AND ")

updateSetClauses := []string{}
insertColumnsClauses := []string{}
insertValuesClauses := []string{}
for _, column := range allCols {
updateSetClauses = append(updateSetClauses, fmt.Sprintf("%s = src.%s", column, column))
insertColumnsClauses = append(insertColumnsClauses, column)
insertValuesClauses = append(insertValuesClauses, fmt.Sprintf("src.%s", column))
mergeCmd, err := GenerateMergeCommand(allCols, upsertKeyCols, watermarkCol, tempTableName, s.dstTableName)
if err != nil {
return fmt.Errorf("failed to generate merge command: %w", err)
}
updateSetClause := strings.Join(updateSetClauses, ", ")
insertColumnsClause := strings.Join(insertColumnsClauses, ", ")
insertValuesClause := strings.Join(insertValuesClauses, ", ")

selectCmd := fmt.Sprintf(`
SELECT *
FROM %s
QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s DESC) = 1
`, tempTableName, strings.Join(upsertKeyCols, ","), watermarkCol)

//nolint:gosec
mergeCmd := fmt.Sprintf(`
MERGE INTO %s dst
USING (%s) src
ON %s
WHEN MATCHED AND src.%s > dst.%s THEN UPDATE SET %s
WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)
`, s.dstTableName, selectCmd, upsertKeyClause, watermarkCol, watermarkCol,
updateSetClause, insertColumnsClause, insertValuesClause)
if _, err := s.db.Exec(mergeCmd); err != nil {
return fmt.Errorf("failed to merge data into destination table: %w", err)
return fmt.Errorf("failed to merge data into destination table '%s': %w", mergeCmd, err)
}

log.Infof("merged data from temp table %s into destination table %s",
tempTableName, s.dstTableName)
return nil
Expand Down
7 changes: 7 additions & 0 deletions flow/connectors/utils/identifiers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package utils

import "fmt"

func QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}
42 changes: 42 additions & 0 deletions flow/e2e/qrep_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,48 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro_SF() {
env.AssertExpectations(s.T())
}

func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro_SF_Upsert_Simple() {
env := s.NewTestWorkflowEnvironment()
registerWorkflowsAndActivities(env)

numRows := 10

tblName := "test_qrep_flow_avro_sf_ups"
s.setupSourceTable(tblName, numRows)
s.setupSFDestinationTable(tblName)

dstSchemaQualified := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tblName)

query := fmt.Sprintf("SELECT * FROM e2e_test.%s WHERE updated_at >= {{.start}} AND updated_at < {{.end}}", tblName)

qrepConfig := s.createQRepWorkflowConfig(
"test_qrep_flow_avro_Sf",
"e2e_test."+tblName,
dstSchemaQualified,
query,
protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO,
s.sfHelper.Peer,
)
qrepConfig.WriteMode = &protos.QRepWriteMode{
WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT,
UpsertKeyColumns: []string{"id"},
}

runQrepFlowWorkflow(env, qrepConfig)

// Verify workflow completes without error
s.True(env.IsWorkflowCompleted())

// assert that error contains "invalid connection configs"
err := env.GetWorkflowError()
s.NoError(err)

sel := getOwnersSelectorString()
s.compareTableContentsSF(tblName, sel)

env.AssertExpectations(s.T())
}

func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Multi_Insert_PG() {
env := s.NewTestWorkflowEnvironment()
registerWorkflowsAndActivities(env)
Expand Down

0 comments on commit 41ebb0b

Please sign in to comment.