diff --git a/api/server.go b/api/server.go index fea1150..f5d06b4 100644 --- a/api/server.go +++ b/api/server.go @@ -9,7 +9,7 @@ import ( "github.com/golang/groupcache/lru" pb "github.com/pomerium/cli/proto" - "github.com/pomerium/cli/tcptunnel" + "github.com/pomerium/cli/tunnel" ) // ConfigProvider provides interface to the configuration persistence @@ -35,9 +35,9 @@ type ListenerStatus interface { SetListenerError(id string, err error) error } -// Tunnel is abstraction over tcptunnel.Tunnel to allow mocking +// Tunnel is abstraction over tunnel.Tunnel to allow mocking type Tunnel interface { - Run(context.Context, io.ReadWriter, tcptunnel.TunnelEvents) error + Run(context.Context, io.ReadWriter, tunnel.EventSink) error } // Server implements both config and listener interfaces diff --git a/api/tunnel.go b/api/tunnel.go index 0b17ba9..96b8560 100644 --- a/api/tunnel.go +++ b/api/tunnel.go @@ -18,7 +18,7 @@ import ( "github.com/pomerium/cli/certstore" pb "github.com/pomerium/cli/proto" - "github.com/pomerium/cli/tcptunnel" + "github.com/pomerium/cli/tunnel" ) func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFile string) (Tunnel, string, error) { @@ -27,7 +27,7 @@ func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFi listenAddr = *conn.ListenAddr } - destinationAddr, proxyURL, err := tcptunnel.ParseURLs(conn.GetRemoteAddr(), conn.GetPomeriumUrl()) + destinationAddr, proxyURL, err := tunnel.ParseURLs(conn.GetRemoteAddr(), conn.GetPomeriumUrl()) if err != nil { return nil, "", err } @@ -40,13 +40,13 @@ func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFi } } - return tcptunnel.New( - tcptunnel.WithDestinationHost(destinationAddr), - tcptunnel.WithProxyHost(proxyURL.Host), - tcptunnel.WithServiceAccount(serviceAccount), - tcptunnel.WithServiceAccountFile(serviceAccountFile), - tcptunnel.WithTLSConfig(tlsCfg), - tcptunnel.WithBrowserCommand(browserCmd), + return tunnel.New( + tunnel.WithDestinationHost(destinationAddr), + tunnel.WithProxyHost(proxyURL.Host), + tunnel.WithServiceAccount(serviceAccount), + tunnel.WithServiceAccountFile(serviceAccountFile), + tunnel.WithTLSConfig(tlsCfg), + tunnel.WithBrowserCommand(browserCmd), ), listenAddr, nil } diff --git a/cmd/pomerium-cli/proxy.go b/cmd/pomerium-cli/proxy.go index 698f538..6f69eb0 100644 --- a/cmd/pomerium-cli/proxy.go +++ b/cmd/pomerium-cli/proxy.go @@ -17,7 +17,7 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/cobra" - "github.com/pomerium/cli/tcptunnel" + "github.com/pomerium/cli/tunnel" ) var proxyCmdOptions struct { @@ -100,7 +100,7 @@ func makeDomainRegexes() ([]*regexp.Regexp, error) { return domainRegexes, nil } -func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tcptunnel.Tunnel, error) { +func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tunnel.Tunnel, error) { dstHostname, dstPort, err := net.SplitHostPort(dstHost) if err != nil { return nil, fmt.Errorf("invalid destination: %w", err) @@ -138,12 +138,12 @@ func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tcptunnel.Tunnel } } - return tcptunnel.New( - tcptunnel.WithDestinationHost(net.JoinHostPort(dstHostname, dstPort)), - tcptunnel.WithProxyHost(pomeriumURL.Host), - tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount), - tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile), - tcptunnel.WithTLSConfig(tlsConfig), + return tunnel.New( + tunnel.WithDestinationHost(net.JoinHostPort(dstHostname, dstPort)), + tunnel.WithProxyHost(pomeriumURL.Host), + tunnel.WithServiceAccount(serviceAccountOptions.serviceAccount), + tunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile), + tunnel.WithTLSConfig(tlsConfig), ), nil } @@ -165,7 +165,7 @@ func hijackProxyConnect(req *http.Request, client net.Conn, ctx *goproxy.ProxyCt log.Error().Err(err).Msg("Failed to send response to client") return } - if err := tun.Run(req.Context(), client, tcptunnel.DiscardEvents()); err != nil { + if err := tun.Run(req.Context(), client, tunnel.DiscardEvents()); err != nil { log.Error().Err(err).Msg("Failed to run TCP tunnel") } } diff --git a/cmd/pomerium-cli/tcp.go b/cmd/pomerium-cli/tcp.go index 0e70ebe..7d28df2 100644 --- a/cmd/pomerium-cli/tcp.go +++ b/cmd/pomerium-cli/tcp.go @@ -11,7 +11,7 @@ import ( "github.com/spf13/cobra" - "github.com/pomerium/cli/tcptunnel" + "github.com/pomerium/cli/tunnel" ) var tcpCmdOptions struct { @@ -36,7 +36,7 @@ var tcpCmd = &cobra.Command{ Short: "creates a TCP tunnel through Pomerium", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - destinationAddr, proxyURL, err := tcptunnel.ParseURLs(args[0], tcpCmdOptions.pomeriumURL) + destinationAddr, proxyURL, err := tunnel.ParseURLs(args[0], tcpCmdOptions.pomeriumURL) if err != nil { return err } @@ -57,17 +57,17 @@ var tcpCmd = &cobra.Command{ cancel() }() - tun := tcptunnel.New( - tcptunnel.WithBrowserCommand(browserOptions.command), - tcptunnel.WithDestinationHost(destinationAddr), - tcptunnel.WithProxyHost(proxyURL.Host), - tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount), - tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile), - tcptunnel.WithTLSConfig(tlsConfig), + tun := tunnel.New( + tunnel.WithBrowserCommand(browserOptions.command), + tunnel.WithDestinationHost(destinationAddr), + tunnel.WithProxyHost(proxyURL.Host), + tunnel.WithServiceAccount(serviceAccountOptions.serviceAccount), + tunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile), + tunnel.WithTLSConfig(tlsConfig), ) if tcpCmdOptions.listen == "-" { - err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout}, tcptunnel.DiscardEvents()) + err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout}, tunnel.DiscardEvents()) } else { err = tun.RunListener(ctx, tcpCmdOptions.listen) } diff --git a/tcptunnel/config.go b/tunnel/config.go similarity index 99% rename from tcptunnel/config.go rename to tunnel/config.go index dc0f773..f9ce489 100644 --- a/tcptunnel/config.go +++ b/tunnel/config.go @@ -1,4 +1,4 @@ -package tcptunnel +package tunnel import ( "crypto/tls" diff --git a/tcptunnel/events.go b/tunnel/events.go similarity index 73% rename from tcptunnel/events.go rename to tunnel/events.go index 093a37e..e824d96 100644 --- a/tcptunnel/events.go +++ b/tunnel/events.go @@ -1,11 +1,11 @@ -package tcptunnel +package tunnel import ( "context" ) -// TunnelEvents is used to notify on the tunnel state transition -type TunnelEvents interface { +// EventSink is used to notify on the tunnel state transition +type EventSink interface { // OnConnecting is called when listener is accepting a new connection from client OnConnecting(context.Context) // OnConnected is called when a connection is successfully @@ -19,22 +19,22 @@ type TunnelEvents interface { } // DiscardEvents returns a broadcaster that discards all events -func DiscardEvents() TunnelEvents { +func DiscardEvents() EventSink { return discardEvents{} } type discardEvents struct{} // OnConnecting is called when listener is accepting a new connection from client -func (d discardEvents) OnConnecting(_ context.Context) {} +func (discardEvents) OnConnecting(_ context.Context) {} // OnConnected is called when a connection is successfully // established to the remote destination via pomerium proxy -func (d discardEvents) OnConnected(_ context.Context) {} +func (discardEvents) OnConnected(_ context.Context) {} // OnAuthRequired is called after listener accepted a new connection from client, // but has to perform user authentication first -func (d discardEvents) OnAuthRequired(_ context.Context, _ string) {} +func (discardEvents) OnAuthRequired(_ context.Context, _ string) {} // OnDisconnected is called when connection to client was closed -func (d discardEvents) OnDisconnected(_ context.Context, _ error) {} +func (discardEvents) OnDisconnected(_ context.Context, _ error) {} diff --git a/tcptunnel/tcptunnel.go b/tunnel/tunnel.go similarity index 90% rename from tcptunnel/tcptunnel.go rename to tunnel/tunnel.go index d758db3..9f47a96 100644 --- a/tcptunnel/tcptunnel.go +++ b/tunnel/tunnel.go @@ -1,5 +1,5 @@ -// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect. -package tcptunnel +// Package tunnel contains an implementation of a TCP tunnel via HTTP Connect. +package tunnel import ( "bufio" @@ -90,7 +90,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro } // Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local. -func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, evt TunnelEvents) error { +func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, eventSink EventSink) error { rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey()) switch { // if there is no error, or it is one of the pre-defined cliutil errors, @@ -100,13 +100,13 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, evt TunnelEvent errors.Is(err, jwt.ErrInvalid), errors.Is(err, jwt.ErrNotFound): default: - return fmt.Errorf("tcptunnel: failed to load JWT: %w", err) + return fmt.Errorf("tunnel: failed to load JWT: %w", err) } - return tun.run(ctx, evt, local, rawJWT, 0) + return tun.run(ctx, eventSink, local, rawJWT, 0) } -func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWriter, rawJWT string, retryCount int) error { - evt.OnConnecting(ctx) +func (tun *Tunnel) run(ctx context.Context, eventSink EventSink, local io.ReadWriter, rawJWT string, retryCount int) error { + eventSink.OnConnecting(ctx) hdr := http.Header{} if rawJWT != "" { @@ -174,7 +174,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite serverURL.Scheme = "https" } - rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { evt.OnAuthRequired(ctx, authURL) }) + rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { eventSink.OnAuthRequired(ctx, authURL) }) if err != nil { return fmt.Errorf("failed to get authentication JWT: %w", err) } @@ -184,7 +184,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite return fmt.Errorf("failed to store JWT: %w", err) } - return tun.run(ctx, evt, local, rawJWT, retryCount+1) + return tun.run(ctx, eventSink, local, rawJWT, retryCount+1) } fallthrough default: @@ -193,7 +193,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite } log.Println("connection established") - evt.OnConnected(ctx) + eventSink.OnConnected(ctx) errc := make(chan error, 2) go func() { diff --git a/tcptunnel/tcptunnel_test.go b/tunnel/tunnel_test.go similarity index 99% rename from tcptunnel/tcptunnel_test.go rename to tunnel/tunnel_test.go index 0a285f7..b84408e 100644 --- a/tcptunnel/tcptunnel_test.go +++ b/tunnel/tunnel_test.go @@ -1,4 +1,4 @@ -package tcptunnel +package tunnel import ( "bufio" diff --git a/tcptunnel/urls.go b/tunnel/urls.go similarity index 98% rename from tcptunnel/urls.go rename to tunnel/urls.go index 8c99f46..3878abf 100644 --- a/tcptunnel/urls.go +++ b/tunnel/urls.go @@ -1,4 +1,4 @@ -package tcptunnel +package tunnel import ( "fmt" diff --git a/tcptunnel/urls_test.go b/tunnel/urls_test.go similarity index 99% rename from tcptunnel/urls_test.go rename to tunnel/urls_test.go index 0e3b1f4..6519e0b 100644 --- a/tcptunnel/urls_test.go +++ b/tunnel/urls_test.go @@ -1,4 +1,4 @@ -package tcptunnel +package tunnel import ( "errors"