diff --git a/destination_middleware.go b/destination_middleware.go index ab80882..60975bf 100644 --- a/destination_middleware.go +++ b/destination_middleware.go @@ -174,7 +174,7 @@ func (*destinationWithBatch) setBatchConfig(ctx context.Context, cfg Destination if ok { *ctxCfg = cfg } else { - ctx = context.WithValue(ctx, ctxKeyBatchConfig{}, cfg) + ctx = context.WithValue(ctx, ctxKeyBatchConfig{}, &cfg) } return ctx } @@ -425,15 +425,12 @@ type destinationWithSchemaExtraction struct { Destination config *DestinationWithSchemaExtraction - payloadEnabled bool - keyEnabled bool - payloadWarnOnce sync.Once keyWarnOnce sync.Once } func (d *destinationWithSchemaExtraction) Write(ctx context.Context, records []opencdc.Record) (int, error) { - if d.keyEnabled { + if *d.config.KeyEnabled { for i := range records { if err := d.decodeKey(ctx, &records[i]); err != nil { if len(records) > 0 { @@ -443,7 +440,7 @@ func (d *destinationWithSchemaExtraction) Write(ctx context.Context, records []o } } } - if d.payloadEnabled { + if *d.config.PayloadEnabled { for i := range records { if err := d.decodePayload(ctx, &records[i]); err != nil { if len(records) > 0 { diff --git a/destination_middleware_test.go b/destination_middleware_test.go index 2f1131d..f9b2316 100644 --- a/destination_middleware_test.go +++ b/destination_middleware_test.go @@ -36,7 +36,6 @@ import ( func TestDestinationWithBatch_Configure(t *testing.T) { is := is.New(t) - ctx := context.Background() dst := NewMockDestination(gomock.NewController(t)) dst.EXPECT().Open(gomock.Any()).Return(nil) @@ -47,6 +46,7 @@ func TestDestinationWithBatch_Configure(t *testing.T) { } d := mw.Wrap(dst) + ctx := (&destinationWithBatch{}).setBatchConfig(context.Background(), DestinationWithBatch{}) err := d.Open(ctx) is.NoErr(err) @@ -57,9 +57,6 @@ func TestDestinationWithBatch_Configure(t *testing.T) { // -- DestinationWithRateLimit ------------------------------------------------- func TestDestinationWithRateLimit_Configure(t *testing.T) { - ctrl := gomock.NewController(t) - dst := NewMockDestination(ctrl) - testCases := []struct { name string middleware DestinationWithRateLimit @@ -71,7 +68,7 @@ func TestDestinationWithRateLimit_Configure(t *testing.T) { middleware: DestinationWithRateLimit{}, wantLimiter: false, }, { - name: "empty config, custom defaults", + name: "custom defaults", middleware: DestinationWithRateLimit{ RatePerSecond: 1.23, Burst: 4, @@ -95,13 +92,13 @@ func TestDestinationWithRateLimit_Configure(t *testing.T) { Burst: 4, }, wantLimiter: true, - wantLimit: rate.Limit(12.34), - wantBurst: 5, + wantLimit: rate.Limit(1.23), + wantBurst: 4, }, { name: "config with zero burst", middleware: DestinationWithRateLimit{ RatePerSecond: 1.23, - Burst: 4, + Burst: 0, }, wantLimiter: true, wantLimit: rate.Limit(1.23), @@ -111,6 +108,10 @@ func TestDestinationWithRateLimit_Configure(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { is := is.New(t) + + dst := NewMockDestination(gomock.NewController(t)) + dst.EXPECT().Open(gomock.Any()).Return(nil) + d := tt.middleware.Wrap(dst).(*destinationWithRateLimit) err := d.Open(context.Background()) @@ -128,14 +129,17 @@ func TestDestinationWithRateLimit_Configure(t *testing.T) { func TestDestinationWithRateLimit_Write(t *testing.T) { is := is.New(t) - ctrl := gomock.NewController(t) - dst := NewMockDestination(ctrl) ctx := context.Background() + dst := NewMockDestination(gomock.NewController(t)) + dst.EXPECT().Open(gomock.Any()).Return(nil) + d := (&DestinationWithRateLimit{ RatePerSecond: 8, Burst: 2, }).Wrap(dst) + err := d.Open(ctx) + is.NoErr(err) recs := []opencdc.Record{{}, {}, {}, {}} @@ -161,7 +165,7 @@ func TestDestinationWithRateLimit_Write(t *testing.T) { // happens instantly, the second after 250ms (2/8 seconds) expectWriteAfter(0) expectWriteAfter(250 * time.Millisecond) - _, err := d.Write(ctx, recs) + _, err = d.Write(ctx, recs) is.NoErr(err) // third and fourth writes are again delayed by 250ms each @@ -174,26 +178,28 @@ func TestDestinationWithRateLimit_Write(t *testing.T) { func TestDestinationWithRateLimit_Write_CancelledContext(t *testing.T) { is := is.New(t) - ctrl := gomock.NewController(t) - dst := NewMockDestination(ctrl) + ctx := context.Background() - d := (&DestinationWithRateLimit{ + dst := NewMockDestination(gomock.NewController(t)) + dst.EXPECT().Open(gomock.Any()).Return(nil) + + underTest := (&DestinationWithRateLimit{ RatePerSecond: 10, }).Wrap(dst) - ctx, cancel := context.WithCancel(context.Background()) + err := underTest.Open(ctx) + is.NoErr(err) + + ctx, cancel := context.WithCancel(ctx) cancel() - _, err := d.Write(ctx, []opencdc.Record{{}}) + _, err = underTest.Write(ctx, []opencdc.Record{{}}) is.True(errors.Is(err, ctx.Err())) } // -- DestinationWithRecordFormat ---------------------------------------------- func TestDestinationWithRecordFormat_Configure(t *testing.T) { - ctrl := gomock.NewController(t) - dst := NewMockDestination(ctrl) - testCases := []struct { name string middleware DestinationWithRecordFormat @@ -218,7 +224,14 @@ func TestDestinationWithRecordFormat_Configure(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { is := is.New(t) + + dst := NewMockDestination(gomock.NewController(t)) + dst.EXPECT().Open(gomock.Any()).Return(nil) + d := tt.middleware.Wrap(dst).(*destinationWithRecordFormat) + err := d.Open(context.Background()) + is.NoErr(err) + is.Equal(d.serializer, tt.wantSerializer) }) } @@ -227,9 +240,6 @@ func TestDestinationWithRecordFormat_Configure(t *testing.T) { // -- DestinationWithSchemaExtraction ------------------------------------------ func TestDestinationWithSchemaExtraction_Configure(t *testing.T) { - ctrl := gomock.NewController(t) - dst := NewMockDestination(ctrl) - testCases := []struct { name string middleware DestinationWithSchemaExtraction @@ -238,30 +248,40 @@ func TestDestinationWithSchemaExtraction_Configure(t *testing.T) { wantErr error wantPayloadEnabled bool wantKeyEnabled bool - }{{ - name: "empty config", - middleware: DestinationWithSchemaExtraction{}, - have: config.Config{}, - - wantPayloadEnabled: true, - wantKeyEnabled: true, - }, { - name: "disabled by default", - middleware: DestinationWithSchemaExtraction{ - PayloadEnabled: lang.Ptr(false), - KeyEnabled: lang.Ptr(false), + }{ + { + name: "both disabled", + middleware: DestinationWithSchemaExtraction{ + PayloadEnabled: lang.Ptr(false), + KeyEnabled: lang.Ptr(false), + }, + wantPayloadEnabled: false, + wantKeyEnabled: false, }, - wantPayloadEnabled: false, - wantKeyEnabled: false, - }} + { + name: "payload enabled, key disabled", + middleware: DestinationWithSchemaExtraction{ + PayloadEnabled: lang.Ptr(true), + KeyEnabled: lang.Ptr(false), + }, + wantPayloadEnabled: true, + wantKeyEnabled: false, + }, + } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { is := is.New(t) + + dst := NewMockDestination(gomock.NewController(t)) + dst.EXPECT().Open(gomock.Any()).Return(nil) + s := tt.middleware.Wrap(dst).(*destinationWithSchemaExtraction) + err := s.Open(context.Background()) + is.NoErr(err) - is.Equal(s.payloadEnabled, tt.wantPayloadEnabled) - is.Equal(s.keyEnabled, tt.wantKeyEnabled) + is.Equal(*s.config.PayloadEnabled, tt.wantPayloadEnabled) + is.Equal(*s.config.KeyEnabled, tt.wantKeyEnabled) }) } }