Skip to content

Commit

Permalink
WIP: Unit tests for channel multiplexer
Browse files Browse the repository at this point in the history
  • Loading branch information
jrauh01 committed Nov 29, 2023
1 parent bbed616 commit 90a7bb3
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 4 deletions.
33 changes: 29 additions & 4 deletions pkg/sync/channel-mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sync

import (
"context"
"golang.org/x/sync/errgroup"
"sync/atomic"
)

Expand Down Expand Up @@ -74,10 +75,32 @@ func (mux *channelMux[T]) Run(ctx context.Context) error {

outChannels := append(mux.addedOutChannels, mux.createdOutChannels...)

for {
for _, inChannel := range mux.inChannels {
sink := make(chan T)

g, ctx := errgroup.WithContext(ctx)

for _, ch := range mux.inChannels {
ch := ch

g.Go(func() error {
for {
select {
case spread, more := <-ch:
if !more {
return nil
}
sink <- spread
case <-ctx.Done():
return ctx.Err()
}
}
})
}

g.Go(func() error {
for {
select {
case spread, more := <-inChannel:
case spread, more := <-sink:
if !more {
return nil
}
Expand All @@ -93,5 +116,7 @@ func (mux *channelMux[T]) Run(ctx context.Context) error {
return ctx.Err()
}
}
}
})

return g.Wait()
}
207 changes: 207 additions & 0 deletions pkg/sync/channel-mux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package sync

import (
"context"
"golang.org/x/sync/errgroup"
"testing"
"time"
)

type outputTest struct {
arg1, want int
}

var outputTests = []outputTest{
{0, 0},
{5, 5},
{35253, 35253},
{999999, 999999},
{-7, -7},
}

func TestAddedOutputChannels(t *testing.T) {
for _, test := range outputTests {
multiplexChannel := make(chan int)
multiplexer := NewChannelMux(multiplexChannel)

outputChannel1 := make(chan int)
outputChannel2 := make(chan int)
outputChannel3 := make(chan int)
multiplexer.AddOutChannel(outputChannel1)
multiplexer.AddOutChannel(outputChannel2)
multiplexer.AddOutChannel(outputChannel3)

g, ctx := errgroup.WithContext(context.Background())

g.Go(func() error {
return multiplexer.Run(ctx)
})

multiplexChannel <- test.arg1

if got := <-outputChannel1; got != test.want {
t.Errorf("got '%d' for 1st test channel, wanted '%d'", got, test.want)
}
if got := <-outputChannel2; got != test.want {
t.Errorf("got '%d' for 2nd test channel, wanted '%d'", got, test.want)
}
if got := <-outputChannel3; got != test.want {
t.Errorf("got '%d' for 3rd test channel, wanted '%d'", got, test.want)
}
}
}

func TestCreatedOutputChannels(t *testing.T) {
for _, test := range outputTests {
multiplexChannel := make(chan int)
multiplexer := NewChannelMux(multiplexChannel)

outputChannel1 := multiplexer.NewOutChannel()
outputChannel2 := multiplexer.NewOutChannel()
outputChannel3 := multiplexer.NewOutChannel()

g, ctx := errgroup.WithContext(context.Background())

g.Go(func() error {
return multiplexer.Run(ctx)
})

multiplexChannel <- test.arg1

if got := <-outputChannel1; got != test.want {
t.Errorf("got '%d' for 1st test channel, wanted '%d'", got, test.want)
}
if got := <-outputChannel2; got != test.want {
t.Errorf("got '%d' for 2nd test channel, wanted '%d'", got, test.want)
}
if got := <-outputChannel3; got != test.want {
t.Errorf("got '%d' for 3rd test channel, wanted '%d'", got, test.want)
}
}
}

type inputTest struct {
arg1, arg2, arg3, want int
}

var inputTests = []inputTest{
{0, 0, 0, 0},
{1, 2, 3, 6},
{535, 64, 6432, 7031},
{353632, 636232, 64674, 1054538},
{-1, -2, -3, -6},
}

func TestAddedInputChannels(t *testing.T) {
for _, test := range inputTests {
multiplexChannel1 := make(chan int)
multiplexChannel2 := make(chan int)
multiplexChannel3 := make(chan int)

multiplexer := NewChannelMux(multiplexChannel1, multiplexChannel2, multiplexChannel3)

outputChannel := multiplexer.NewOutChannel()

ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)

g.Go(func() error {
return multiplexer.Run(ctx)
})

g.Go(func() error {
select {
case multiplexChannel1 <- test.arg1:
case <-ctx.Done():
return ctx.Err()
}

select {
case multiplexChannel2 <- test.arg2:
case <-ctx.Done():
return ctx.Err()
}

select {
case multiplexChannel3 <- test.arg3:
case <-ctx.Done():
return ctx.Err()
}

close(multiplexChannel1)
close(multiplexChannel2)
close(multiplexChannel3)

return nil
})

stop := false
got := 0

g.Go(func() error {
for !stop {
select {
case output, more := <-outputChannel:
if !more {
stop = true
break
}

got += output
case <-time.After(time.Second * 1):
stop = true
break
case <-ctx.Done():
return ctx.Err()
}
}

if got != test.want {
t.Errorf("Got %d, wanted %d", got, test.want)
}

cancel()

return nil
})

g.Wait()
}
}

func TestClosedChannels(t *testing.T) {
multiplexChannel := make(chan int)
multiplexer := NewChannelMux(multiplexChannel)

outputChannel1 := multiplexer.NewOutChannel()
outputChannel2 := multiplexer.NewOutChannel()
outputChannel3 := multiplexer.NewOutChannel()

ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)

g.Go(func() error {
return multiplexer.Run(ctx)
})

cancel()

select {
case <-outputChannel1:
case <-time.After(time.Second):
t.Error("1st channel is still open, should be closed")
}

select {
case <-outputChannel2:
case <-time.After(time.Second):
t.Error("2nd channel is still open, should be closed")
}

select {
case <-outputChannel3:
case <-time.After(time.Second):
t.Error("3rd channel is still open, should be closed")
}

}

0 comments on commit 90a7bb3

Please sign in to comment.