Skip to content

Commit

Permalink
tcptunnel: rename to tunnel (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey authored Nov 5, 2024
1 parent 213419f commit 0085418
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 53 deletions.
6 changes: 3 additions & 3 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/golang/groupcache/lru"

pb "github.com/pomerium/cli/proto"
"github.com/pomerium/cli/tcptunnel"
"github.com/pomerium/cli/tunnel"
)

// ConfigProvider provides interface to the configuration persistence
Expand All @@ -35,9 +35,9 @@ type ListenerStatus interface {
SetListenerError(id string, err error) error
}

// Tunnel is abstraction over tcptunnel.Tunnel to allow mocking
// Tunnel is abstraction over tunnel.Tunnel to allow mocking
type Tunnel interface {
Run(context.Context, io.ReadWriter, tcptunnel.TunnelEvents) error
Run(context.Context, io.ReadWriter, tunnel.EventSink) error
}

// Server implements both config and listener interfaces
Expand Down
18 changes: 9 additions & 9 deletions api/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

"github.com/pomerium/cli/certstore"
pb "github.com/pomerium/cli/proto"
"github.com/pomerium/cli/tcptunnel"
"github.com/pomerium/cli/tunnel"
)

func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFile string) (Tunnel, string, error) {
Expand All @@ -27,7 +27,7 @@ func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFi
listenAddr = *conn.ListenAddr
}

destinationAddr, proxyURL, err := tcptunnel.ParseURLs(conn.GetRemoteAddr(), conn.GetPomeriumUrl())
destinationAddr, proxyURL, err := tunnel.ParseURLs(conn.GetRemoteAddr(), conn.GetPomeriumUrl())
if err != nil {
return nil, "", err
}
Expand All @@ -40,13 +40,13 @@ func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFi
}
}

return tcptunnel.New(
tcptunnel.WithDestinationHost(destinationAddr),
tcptunnel.WithProxyHost(proxyURL.Host),
tcptunnel.WithServiceAccount(serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountFile),
tcptunnel.WithTLSConfig(tlsCfg),
tcptunnel.WithBrowserCommand(browserCmd),
return tunnel.New(
tunnel.WithDestinationHost(destinationAddr),
tunnel.WithProxyHost(proxyURL.Host),
tunnel.WithServiceAccount(serviceAccount),
tunnel.WithServiceAccountFile(serviceAccountFile),
tunnel.WithTLSConfig(tlsCfg),
tunnel.WithBrowserCommand(browserCmd),
), listenAddr, nil
}

Expand Down
18 changes: 9 additions & 9 deletions cmd/pomerium-cli/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"

"github.com/pomerium/cli/tcptunnel"
"github.com/pomerium/cli/tunnel"
)

var proxyCmdOptions struct {
Expand Down Expand Up @@ -100,7 +100,7 @@ func makeDomainRegexes() ([]*regexp.Regexp, error) {
return domainRegexes, nil
}

func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tcptunnel.Tunnel, error) {
func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tunnel.Tunnel, error) {
dstHostname, dstPort, err := net.SplitHostPort(dstHost)
if err != nil {
return nil, fmt.Errorf("invalid destination: %w", err)
Expand Down Expand Up @@ -138,12 +138,12 @@ func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tcptunnel.Tunnel
}
}

return tcptunnel.New(
tcptunnel.WithDestinationHost(net.JoinHostPort(dstHostname, dstPort)),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tcptunnel.WithTLSConfig(tlsConfig),
return tunnel.New(
tunnel.WithDestinationHost(net.JoinHostPort(dstHostname, dstPort)),
tunnel.WithProxyHost(pomeriumURL.Host),
tunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tunnel.WithTLSConfig(tlsConfig),
), nil
}

Expand All @@ -165,7 +165,7 @@ func hijackProxyConnect(req *http.Request, client net.Conn, ctx *goproxy.ProxyCt
log.Error().Err(err).Msg("Failed to send response to client")
return
}
if err := tun.Run(req.Context(), client, tcptunnel.DiscardEvents()); err != nil {
if err := tun.Run(req.Context(), client, tunnel.DiscardEvents()); err != nil {
log.Error().Err(err).Msg("Failed to run TCP tunnel")
}
}
20 changes: 10 additions & 10 deletions cmd/pomerium-cli/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

"github.com/spf13/cobra"

"github.com/pomerium/cli/tcptunnel"
"github.com/pomerium/cli/tunnel"
)

var tcpCmdOptions struct {
Expand All @@ -36,7 +36,7 @@ var tcpCmd = &cobra.Command{
Short: "creates a TCP tunnel through Pomerium",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
destinationAddr, proxyURL, err := tcptunnel.ParseURLs(args[0], tcpCmdOptions.pomeriumURL)
destinationAddr, proxyURL, err := tunnel.ParseURLs(args[0], tcpCmdOptions.pomeriumURL)
if err != nil {
return err
}
Expand All @@ -57,17 +57,17 @@ var tcpCmd = &cobra.Command{
cancel()
}()

tun := tcptunnel.New(
tcptunnel.WithBrowserCommand(browserOptions.command),
tcptunnel.WithDestinationHost(destinationAddr),
tcptunnel.WithProxyHost(proxyURL.Host),
tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tcptunnel.WithTLSConfig(tlsConfig),
tun := tunnel.New(
tunnel.WithBrowserCommand(browserOptions.command),
tunnel.WithDestinationHost(destinationAddr),
tunnel.WithProxyHost(proxyURL.Host),
tunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tunnel.WithTLSConfig(tlsConfig),
)

if tcpCmdOptions.listen == "-" {
err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout}, tcptunnel.DiscardEvents())
err = tun.Run(ctx, readWriter{Reader: os.Stdin, Writer: os.Stdout}, tunnel.DiscardEvents())
} else {
err = tun.RunListener(ctx, tcpCmdOptions.listen)
}
Expand Down
2 changes: 1 addition & 1 deletion tcptunnel/config.go → tunnel/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcptunnel
package tunnel

import (
"crypto/tls"
Expand Down
16 changes: 8 additions & 8 deletions tcptunnel/events.go → tunnel/events.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package tcptunnel
package tunnel

import (
"context"
)

// TunnelEvents is used to notify on the tunnel state transition
type TunnelEvents interface {
// EventSink is used to notify on the tunnel state transition
type EventSink interface {
// OnConnecting is called when listener is accepting a new connection from client
OnConnecting(context.Context)
// OnConnected is called when a connection is successfully
Expand All @@ -19,22 +19,22 @@ type TunnelEvents interface {
}

// DiscardEvents returns a broadcaster that discards all events
func DiscardEvents() TunnelEvents {
func DiscardEvents() EventSink {
return discardEvents{}
}

type discardEvents struct{}

// OnConnecting is called when listener is accepting a new connection from client
func (d discardEvents) OnConnecting(_ context.Context) {}
func (discardEvents) OnConnecting(_ context.Context) {}

// OnConnected is called when a connection is successfully
// established to the remote destination via pomerium proxy
func (d discardEvents) OnConnected(_ context.Context) {}
func (discardEvents) OnConnected(_ context.Context) {}

// OnAuthRequired is called after listener accepted a new connection from client,
// but has to perform user authentication first
func (d discardEvents) OnAuthRequired(_ context.Context, _ string) {}
func (discardEvents) OnAuthRequired(_ context.Context, _ string) {}

// OnDisconnected is called when connection to client was closed
func (d discardEvents) OnDisconnected(_ context.Context, _ error) {}
func (discardEvents) OnDisconnected(_ context.Context, _ error) {}
20 changes: 10 additions & 10 deletions tcptunnel/tcptunnel.go → tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Package tcptunnel contains an implementation of a TCP tunnel via HTTP Connect.
package tcptunnel
// Package tunnel contains an implementation of a TCP tunnel via HTTP Connect.
package tunnel

import (
"bufio"
Expand Down Expand Up @@ -90,7 +90,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
}

// Run establishes a TCP tunnel via HTTP Connect and forwards all traffic from/to local.
func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, evt TunnelEvents) error {
func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, eventSink EventSink) error {
rawJWT, err := tun.cfg.jwtCache.LoadJWT(tun.jwtCacheKey())
switch {
// if there is no error, or it is one of the pre-defined cliutil errors,
Expand All @@ -100,13 +100,13 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter, evt TunnelEvent
errors.Is(err, jwt.ErrInvalid),
errors.Is(err, jwt.ErrNotFound):
default:
return fmt.Errorf("tcptunnel: failed to load JWT: %w", err)
return fmt.Errorf("tunnel: failed to load JWT: %w", err)
}
return tun.run(ctx, evt, local, rawJWT, 0)
return tun.run(ctx, eventSink, local, rawJWT, 0)
}

func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWriter, rawJWT string, retryCount int) error {
evt.OnConnecting(ctx)
func (tun *Tunnel) run(ctx context.Context, eventSink EventSink, local io.ReadWriter, rawJWT string, retryCount int) error {
eventSink.OnConnecting(ctx)

hdr := http.Header{}
if rawJWT != "" {
Expand Down Expand Up @@ -174,7 +174,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite
serverURL.Scheme = "https"
}

rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { evt.OnAuthRequired(ctx, authURL) })
rawJWT, err = tun.auth.GetJWT(ctx, serverURL, func(authURL string) { eventSink.OnAuthRequired(ctx, authURL) })
if err != nil {
return fmt.Errorf("failed to get authentication JWT: %w", err)
}
Expand All @@ -184,7 +184,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite
return fmt.Errorf("failed to store JWT: %w", err)
}

return tun.run(ctx, evt, local, rawJWT, retryCount+1)
return tun.run(ctx, eventSink, local, rawJWT, retryCount+1)
}
fallthrough
default:
Expand All @@ -193,7 +193,7 @@ func (tun *Tunnel) run(ctx context.Context, evt TunnelEvents, local io.ReadWrite
}

log.Println("connection established")
evt.OnConnected(ctx)
eventSink.OnConnected(ctx)

errc := make(chan error, 2)
go func() {
Expand Down
2 changes: 1 addition & 1 deletion tcptunnel/tcptunnel_test.go → tunnel/tunnel_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcptunnel
package tunnel

import (
"bufio"
Expand Down
2 changes: 1 addition & 1 deletion tcptunnel/urls.go → tunnel/urls.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcptunnel
package tunnel

import (
"fmt"
Expand Down
2 changes: 1 addition & 1 deletion tcptunnel/urls_test.go → tunnel/urls_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcptunnel
package tunnel

import (
"errors"
Expand Down

0 comments on commit 0085418

Please sign in to comment.