Skip to content

Commit

Permalink
Allow waiting for N established listeners before returning from listen.
Browse files Browse the repository at this point in the history
Fixes #465
  • Loading branch information
plorenz committed Dec 12, 2023
1 parent 74c8fd3 commit 6e17377
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 96 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Release 0.21.0

- New `ListenOptions` field: `WaitForNEstablishedListeners`. Allows specifying that you want at least N listeners to be established before the `Listen` method returns. Defaults to 0.

# Release 0.20.145

- New `Context` API method: `RefreshService`, which allows refreshing a single service, when that's all that's needed.
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.20
0.21
22 changes: 22 additions & 0 deletions ziti/edge/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type RouterConn interface {
IsClosed() bool
Key() string
GetRouterName() string
GetBoolHeader(key int32) bool
}

type Identifiable interface {
Expand Down Expand Up @@ -211,6 +212,12 @@ func (d DialOptions) GetConnectTimeout() time.Duration {
return d.ConnectTimeout
}

func NewListenOptions() *ListenOptions {
return &ListenOptions{
eventC: make(chan *ListenerEvent, 3),
}
}

type ListenOptions struct {
Cost uint16
Precedence Precedence
Expand All @@ -222,6 +229,11 @@ type ListenOptions struct {
ManualStart bool
ListenerId string
KeyPair *kx.KeyPair
eventC chan *ListenerEvent
}

func (options *ListenOptions) GetEventChannel() chan *ListenerEvent {
return options.eventC
}

func (options *ListenOptions) GetConnectTimeout() time.Duration {
Expand All @@ -231,3 +243,13 @@ func (options *ListenOptions) GetConnectTimeout() time.Duration {
func (options *ListenOptions) String() string {
return fmt.Sprintf("[ListenOptions cost=%v, max-connections=%v]", options.Cost, options.MaxConnections)
}

type ListenerEventType int

const (
ListenerEstablished ListenerEventType = 1
)

type ListenerEvent struct {
EventType ListenerEventType
}
3 changes: 3 additions & 0 deletions ziti/edge/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
ContentTypeTraceRouteResponse = 60797
ContentTypeConnInspectRequest = 60798
ContentTypeConnInspectResponse = 60799
ContentTypeBindSuccess = 60800

ConnIdHeader = 1000
SeqHeader = 1001
Expand All @@ -67,6 +68,7 @@ const (
ListenerId = 1021
ConnTypeHeader = 1022
SupportsInspectHeader = 1023
SupportsBindSuccessHeader = 1024

ErrorCodeInternal = 1
ErrorCodeInvalidApiSession = 2
Expand Down Expand Up @@ -208,6 +210,7 @@ func NewDialMsg(connId uint32, token string, callerId string) *channel.Message {
func NewBindMsg(connId uint32, token string, pubKey []byte, options *ListenOptions) *channel.Message {
msg := newMsg(ContentTypeBind, connId, 0, []byte(token))
msg.PutBoolHeader(SupportsInspectHeader, true)
msg.PutBoolHeader(SupportsBindSuccessHeader, true)

if pubKey != nil {
msg.Headers[PublicKeyHeader] = pubKey
Expand Down
15 changes: 14 additions & 1 deletion ziti/edge/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ func (conn *edgeConn) Accept(msg *channel.Message) {
go conn.newChildConnection(msg)
} else if msg.ContentType == edge.ContentTypeStateClosed {
conn.close(true)
} else if msg.ContentType == edge.ContentTypeBindSuccess {
for entry := range conn.hosting.IterBuffered() {
entry.Val.established.Store(true)
event := &edge.ListenerEvent{
EventType: edge.ListenerEstablished,
}
select {
case entry.Val.eventC <- event:
default:
logrus.WithFields(edge.GetLoggerFields(msg)).Warn("unable to send listener established event")
}
}
}
default:
logrus.WithFields(edge.GetLoggerFields(msg)).Errorf("invalid connection type: %v", conn.connType)
Expand Down Expand Up @@ -341,7 +353,7 @@ func (conn *edgeConn) establishServerCrypto(keypair *kx.KeyPair, peerKey []byte,
return txHeader, nil
}

func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_model.ServiceDetail, options *edge.ListenOptions) (edge.Listener, error) {
func (conn *edgeConn) listen(session *rest_model.SessionDetail, service *rest_model.ServiceDetail, options *edge.ListenOptions) (*edgeListener, error) {
logger := pfxlog.ContextLogger(conn.Channel.Label()).
WithField("connId", conn.Id()).
WithField("serviceName", *service.Name).
Expand All @@ -356,6 +368,7 @@ func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_mo
token: *session.Token,
edgeChan: conn,
manualStart: options.ManualStart,
eventC: options.GetEventChannel(),
}
logger.Debug("adding listener for session")
conn.hosting.Set(*session.Token, listener)
Expand Down
11 changes: 10 additions & 1 deletion ziti/edge/network/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ type routerConn struct {
owner RouterConnOwner
}

func (conn *routerConn) GetBoolHeader(key int32) bool {
val, _ := conn.ch.Underlay().Headers()[key]

Check failure on line 41 in ziti/edge/network/factory.go

View workflow job for this annotation

GitHub Actions / lint

S1005: unnecessary assignment to the blank identifier (gosimple)
return len(val) == 1 && val[0] == 1
}

func (conn *routerConn) Key() string {
return conn.key
}
Expand Down Expand Up @@ -69,6 +74,7 @@ func (conn *routerConn) BindChannel(binding channel.Binding) error {
binding.AddReceiveHandlerF(edge.ContentTypeStateClosed, conn.msgMux.HandleReceive)
binding.AddReceiveHandlerF(edge.ContentTypeTraceRoute, conn.msgMux.HandleReceive)
binding.AddReceiveHandlerF(edge.ContentTypeConnInspectRequest, conn.msgMux.HandleReceive)
binding.AddReceiveHandlerF(edge.ContentTypeBindSuccess, conn.msgMux.HandleReceive)

// Since data is the common message type, it gets to be dispatched directly
binding.AddTypedReceiveHandler(conn.msgMux)
Expand Down Expand Up @@ -151,7 +157,7 @@ func (conn *routerConn) Listen(service *rest_model.ServiceDetail, session *rest_
WithField("serviceId", *service.ID).
WithField("serviceName", *service.Name)

listener, err := ec.Listen(session, service, options)
listener, err := ec.listen(session, service, options)
if err != nil {
log.WithError(err).Error("failed to establish listener")

Expand All @@ -160,6 +166,9 @@ func (conn *routerConn) Listen(service *rest_model.ServiceDetail, session *rest_
Error("failed to cleanup listener for service after failed bind")
}
} else {
if !conn.GetBoolHeader(edge.SupportsBindSuccessHeader) {
listener.established.Store(true)
}
log.Debug("established listener")
}
return listener, err
Expand Down
Loading

0 comments on commit 6e17377

Please sign in to comment.