From 84a23de384c4bba8bc28ae3d11682035601b733d Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Wed, 28 Jun 2023 09:59:34 -0400 Subject: [PATCH] fix merge command (#171) --- flow/connectors/snowflake/qrep_avro_sync.go | 89 +++++++++++++-------- flow/connectors/utils/identifiers.go | 7 ++ flow/e2e/qrep_flow_test.go | 42 ++++++++++ 3 files changed, 105 insertions(+), 33 deletions(-) create mode 100644 flow/connectors/utils/identifiers.go diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index b987d6d466..6015872db7 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -234,6 +235,56 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode() error { return nil } +func GenerateMergeCommand( + allCols []string, + upsertKeyCols []string, + watermarkCol string, + tempTableName string, + dstTable string, +) (string, error) { + upsertKeys := []string{} + partitionKeyCols := []string{} + for _, key := range upsertKeyCols { + quotedKey := utils.QuoteIdentifier(key) + upsertKeys = append(upsertKeys, fmt.Sprintf("dst.%s = src.%s", quotedKey, quotedKey)) + partitionKeyCols = append(partitionKeyCols, quotedKey) + } + upsertKeyClause := strings.Join(upsertKeys, " AND ") + + updateSetClauses := []string{} + insertColumnsClauses := []string{} + insertValuesClauses := []string{} + for _, column := range allCols { + quotedColumn := utils.QuoteIdentifier(column) + updateSetClauses = append(updateSetClauses, fmt.Sprintf("%s = src.%s", quotedColumn, quotedColumn)) + insertColumnsClauses = append(insertColumnsClauses, quotedColumn) + insertValuesClauses = append(insertValuesClauses, fmt.Sprintf("src.%s", quotedColumn)) + } + updateSetClause := strings.Join(updateSetClauses, ", ") + insertColumnsClause := strings.Join(insertColumnsClauses, ", ") + insertValuesClause := strings.Join(insertValuesClauses, ", ") + + quotedWMC := utils.QuoteIdentifier(watermarkCol) + + selectCmd := fmt.Sprintf(` + SELECT * + FROM %s + QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s DESC) = 1 + `, tempTableName, strings.Join(partitionKeyCols, ","), quotedWMC) + + mergeCmd := fmt.Sprintf(` + MERGE INTO %s dst + USING (%s) src + ON %s + WHEN MATCHED AND src.%s > dst.%s THEN UPDATE SET %s + WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s) + `, dstTable, selectCmd, upsertKeyClause, quotedWMC, quotedWMC, + updateSetClause, insertColumnsClause, insertValuesClause) + + return mergeCmd, nil +} + +// HandleUpsertMode handles the upsert mode func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( allCols []string, upsertKeyCols []string, @@ -262,43 +313,15 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( } log.Infof("copied file from stage %s to temp table %s", s.stage, tempTableName) - upsertKey := []string{} - // upsert key should be like "dst.key1 = src.key1 AND dst.key2 = src.key2" - for _, key := range upsertKeyCols { - upsertKey = append(upsertKey, fmt.Sprintf("dst.%s = src.%s", key, key)) - } - upsertKeyClause := strings.Join(upsertKey, " AND ") - - updateSetClauses := []string{} - insertColumnsClauses := []string{} - insertValuesClauses := []string{} - for _, column := range allCols { - updateSetClauses = append(updateSetClauses, fmt.Sprintf("%s = src.%s", column, column)) - insertColumnsClauses = append(insertColumnsClauses, column) - insertValuesClauses = append(insertValuesClauses, fmt.Sprintf("src.%s", column)) + mergeCmd, err := GenerateMergeCommand(allCols, upsertKeyCols, watermarkCol, tempTableName, s.dstTableName) + if err != nil { + return fmt.Errorf("failed to generate merge command: %w", err) } - updateSetClause := strings.Join(updateSetClauses, ", ") - insertColumnsClause := strings.Join(insertColumnsClauses, ", ") - insertValuesClause := strings.Join(insertValuesClauses, ", ") - selectCmd := fmt.Sprintf(` - SELECT * - FROM %s - QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s DESC) = 1 - `, tempTableName, strings.Join(upsertKeyCols, ","), watermarkCol) - - //nolint:gosec - mergeCmd := fmt.Sprintf(` - MERGE INTO %s dst - USING (%s) src - ON %s - WHEN MATCHED AND src.%s > dst.%s THEN UPDATE SET %s - WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s) - `, s.dstTableName, selectCmd, upsertKeyClause, watermarkCol, watermarkCol, - updateSetClause, insertColumnsClause, insertValuesClause) if _, err := s.db.Exec(mergeCmd); err != nil { - return fmt.Errorf("failed to merge data into destination table: %w", err) + return fmt.Errorf("failed to merge data into destination table '%s': %w", mergeCmd, err) } + log.Infof("merged data from temp table %s into destination table %s", tempTableName, s.dstTableName) return nil diff --git a/flow/connectors/utils/identifiers.go b/flow/connectors/utils/identifiers.go new file mode 100644 index 0000000000..0b91b9e4f3 --- /dev/null +++ b/flow/connectors/utils/identifiers.go @@ -0,0 +1,7 @@ +package utils + +import "fmt" + +func QuoteIdentifier(identifier string) string { + return fmt.Sprintf(`"%s"`, identifier) +} diff --git a/flow/e2e/qrep_flow_test.go b/flow/e2e/qrep_flow_test.go index 34c43d69ae..3e250344ba 100644 --- a/flow/e2e/qrep_flow_test.go +++ b/flow/e2e/qrep_flow_test.go @@ -364,6 +364,48 @@ func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro_SF() { env.AssertExpectations(s.T()) } +func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Avro_SF_Upsert_Simple() { + env := s.NewTestWorkflowEnvironment() + registerWorkflowsAndActivities(env) + + numRows := 10 + + tblName := "test_qrep_flow_avro_sf_ups" + s.setupSourceTable(tblName, numRows) + s.setupSFDestinationTable(tblName) + + dstSchemaQualified := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tblName) + + query := fmt.Sprintf("SELECT * FROM e2e_test.%s WHERE updated_at >= {{.start}} AND updated_at < {{.end}}", tblName) + + qrepConfig := s.createQRepWorkflowConfig( + "test_qrep_flow_avro_Sf", + "e2e_test."+tblName, + dstSchemaQualified, + query, + protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO, + s.sfHelper.Peer, + ) + qrepConfig.WriteMode = &protos.QRepWriteMode{ + WriteType: protos.QRepWriteType_QREP_WRITE_MODE_UPSERT, + UpsertKeyColumns: []string{"id"}, + } + + runQrepFlowWorkflow(env, qrepConfig) + + // Verify workflow completes without error + s.True(env.IsWorkflowCompleted()) + + // assert that error contains "invalid connection configs" + err := env.GetWorkflowError() + s.NoError(err) + + sel := getOwnersSelectorString() + s.compareTableContentsSF(tblName, sel) + + env.AssertExpectations(s.T()) +} + func (s *E2EPeerFlowTestSuite) Test_Complete_QRep_Flow_Multi_Insert_PG() { env := s.NewTestWorkflowEnvironment() registerWorkflowsAndActivities(env)