diff --git a/pkg/discovery/game.go b/pkg/discovery/game.go index 625d3d4d..b4823119 100644 --- a/pkg/discovery/game.go +++ b/pkg/discovery/game.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -79,6 +79,7 @@ func NewGame(ctx context.Context, id string, bus mb.MessageBus, stateTimeout tim fsm.WhenIn(Playing).GotEvent(GameFinishedWithError).GoTo(GameError), fsm.WhenIn(Playing).GotEvent(GameSuccess).GoTo(GameDone), fsm.WhenIn(Playing).GotEvent(GameError).GoTo(GameError), + fsm.WhenInAnyState().GotEvent(GameFinishedWithError).GoTo(GameError), fsm.WhenInAnyState().GotEvent(StateTimeoutError).GoTo(GameError), fsm.WhenInAnyState().GotEvent(GameDone).GoTo(GameDone), } diff --git a/pkg/ephemeral/network/proxy.go b/pkg/ephemeral/network/proxy.go index 8d8fb5b6..5d5118cf 100644 --- a/pkg/ephemeral/network/proxy.go +++ b/pkg/ephemeral/network/proxy.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -110,7 +110,7 @@ func (p *Proxy) checkConnectionToPeers() error { proxyEntry := proxyEntry waitGroup.Add(1) go func() { - err := p.checkTCPConnectionToPeer(proxyEntry) + err := p.checkTCPConnectionToPeer(p.ctx.Context, proxyEntry) defer waitGroup.Done() if err != nil { errorsCheckingConnection = append(errorsCheckingConnection, err) @@ -142,9 +142,9 @@ func (p *Proxy) addProxyEntry(config *ProxyConfig) *PingAwareTarget { return pat } -func (p *Proxy) checkTCPConnectionToPeer(config *ProxyConfig) error { +func (p *Proxy) checkTCPConnectionToPeer(ctx context.Context, config *ProxyConfig) error { p.logger.Info(fmt.Sprintf("Checking if connection to peer works for config: %s", config)) - err := p.tcpChecker.Verify(config.Host, config.Port) + err := p.tcpChecker.Verify(ctx, config.Host, config.Port) if err != nil { return fmt.Errorf("error checking connection to the peer '%s:%s': %s", config.Host, config.Port, err) } diff --git a/pkg/ephemeral/network/tcpchecker.go b/pkg/ephemeral/network/tcpchecker.go index 0a30e29e..6cb012c6 100644 --- a/pkg/ephemeral/network/tcpchecker.go +++ b/pkg/ephemeral/network/tcpchecker.go @@ -1,10 +1,11 @@ -// Copyright (c) 2021 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 package network import ( + "context" "fmt" "io" "net" @@ -15,7 +16,7 @@ import ( // NetworkChecker verifies the network connectivity between the players before starting the computation. type NetworkChecker interface { - Verify(string, string) error + Verify(context.Context, string, string) error } // NoopChecker verifies the network for all MPC players is in place. @@ -23,7 +24,7 @@ type NoopChecker struct { } // Verify checks network connectivity between the players and communicates its results to discovery and players FSM. -func (t *NoopChecker) Verify(host, port string) error { +func (t *NoopChecker) Verify(context.Context, string, string) error { return nil } @@ -48,10 +49,12 @@ type TCPChecker struct { } // Verify checks network connectivity between the players and communicates its results to discovery and players FSM. -func (t *TCPChecker) Verify(host, port string) error { +func (t *TCPChecker) Verify(ctx context.Context, host, port string) error { done := time.After(t.conf.RetryTimeout) for { select { + case <-ctx.Done(): + return fmt.Errorf("TCPCheck for '%s:%s' aborted after %d attempts", host, port, t.retries) case <-done: return fmt.Errorf("TCPCheck for '%s:%s' failed after %s and %d attempts", host, port, t.conf.RetryTimeout.String(), t.retries) default: diff --git a/pkg/ephemeral/network/tcpchecker_test.go b/pkg/ephemeral/network/tcpchecker_test.go index 60743695..40eb8641 100644 --- a/pkg/ephemeral/network/tcpchecker_test.go +++ b/pkg/ephemeral/network/tcpchecker_test.go @@ -1,10 +1,11 @@ -// Copyright (c) 2021 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 package network import ( + "context" "io" "net" "sync" @@ -48,7 +49,7 @@ var _ = Describe("TcpChecker", func() { Logger: zap.NewNop().Sugar(), } checker := NewTCPChecker(conf) - err := checker.Verify(host, port) + err := checker.Verify(context.TODO(), host, port) Expect(err).NotTo(HaveOccurred()) wg.Wait() }) @@ -59,7 +60,7 @@ var _ = Describe("TcpChecker", func() { Logger: zap.NewNop().Sugar(), } checker := NewTCPChecker(conf) - err := checker.Verify(host, port) + err := checker.Verify(context.TODO(), host, port) Expect(err).To(HaveOccurred()) }) It("returns an error if dialing succeeds but the connection is closed down shortly", func() { @@ -87,7 +88,7 @@ var _ = Describe("TcpChecker", func() { Logger: zap.NewNop().Sugar(), } checker := NewTCPChecker(conf) - err := checker.Verify(host, port) + err := checker.Verify(context.TODO(), host, port) Expect(err).To(HaveOccurred()) Expect(checker.retries > 1).To(BeTrue()) wg.Wait() @@ -100,8 +101,22 @@ var _ = Describe("TcpChecker", func() { Logger: zap.NewNop().Sugar(), } checker := NewTCPChecker(conf) - err := checker.Verify(host, port) + err := checker.Verify(context.TODO(), host, port) Expect(err).To(HaveOccurred()) Expect(checker.retries > 1).To(BeTrue()) }) + It("aborts if context is closed", func() { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + conf := &TCPCheckerConf{ + DialTimeout: 50 * time.Millisecond, + RetryTimeout: 100 * time.Millisecond, + Logger: zap.NewNop().Sugar(), + } + checker := NewTCPChecker(conf) + err := checker.Verify(ctx, host, port) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("TCPCheck for 'localhost:9999' aborted after 0 attempts")) + Expect(checker.retries == 0).To(BeTrue()) + }) }) diff --git a/pkg/ephemeral/player.go b/pkg/ephemeral/player.go index de7156b9..ddc14004 100644 --- a/pkg/ephemeral/player.go +++ b/pkg/ephemeral/player.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 - for information on the respective copyright owner +// Copyright (c) 2021-2024 - for information on the respective copyright owner // see the NOTICE file and/or the repository https://github.com/carbynestack/ephemeral. // // SPDX-License-Identifier: Apache-2.0 @@ -49,8 +49,8 @@ func NewPlayer(ctx context.Context, bus mb.MessageBus, stateTimeout time.Duratio fsm.WhenIn(Init).GotEvent(Register).GoTo(Registering), fsm.WhenIn(Registering).GotEvent(PlayersReady).GoTo(Playing).WithTimeout(computationTimeout), fsm.WhenIn(Playing).GotEvent(PlayerFinishedWithSuccess).GoTo(PlayerFinishedWithSuccess), - fsm.WhenIn(Playing).GotEvent(PlayingError).GoTo(PlayerFinishedWithError), fsm.WhenInAnyState().GotEvent(GameError).GoTo(PlayerFinishedWithError), + fsm.WhenInAnyState().GotEvent(PlayingError).GoTo(PlayerFinishedWithError), fsm.WhenInAnyState().GotEvent(PlayerDone).GoTo(PlayerDone), fsm.WhenInAnyState().GotEvent(StateTimeoutError).GoTo(PlayerFinishedWithError), }