Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 committed Jan 20, 2025
1 parent c794c2f commit 7fc45b7
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 82 deletions.
2 changes: 1 addition & 1 deletion lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (s *localSite) dialAndForwardGit(params reversetunnelclient.DialParams) (_
HostUUID: s.srv.ID,
TargetServer: params.TargetServer,
Clock: s.clock,
KeyManager: s.srv.GitKeyManager,
KeyManager: s.srv.gitKeyManager,
}
remoteServer, err := git.NewForwardServer(serverConfig)
if err != nil {
Expand Down
28 changes: 14 additions & 14 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ type server struct {

// proxySigner is used to sign PROXY headers to securely propagate client IP information
proxySigner multiplexer.PROXYHeaderSigner

// gitKeyManager manages keys for git proxies.
gitKeyManager *git.KeyManager
}

// Config is a reverse tunnel server configuration
Expand Down Expand Up @@ -224,9 +227,6 @@ type Config struct {

// PROXYSigner is used to sign PROXY headers to securely propagate client IP information.
PROXYSigner multiplexer.PROXYHeaderSigner

// GitKeyManager manages keys for git proxies.
GitKeyManager *git.KeyManager
}

// CheckAndSetDefaults checks parameters and sets default values
Expand Down Expand Up @@ -286,17 +286,6 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.CertAuthorityWatcher == nil {
return trace.BadParameter("missing parameter CertAuthorityWatcher")
}
if cfg.GitKeyManager == nil {
var err error
cfg.GitKeyManager, err = git.NewKeyManager(&git.KeyManagerConfig{
ParentContext: cfg.Context,
AuthClient: cfg.LocalAuthClient,
AccessPoint: cfg.LocalAccessPoint,
})
if err != nil {
return trace.Wrap(err)
}
}
return nil
}

Expand Down Expand Up @@ -334,6 +323,16 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) {
return nil, trace.Wrap(err)
}

gitKeyManager, err := git.NewKeyManager(&git.KeyManagerConfig{
ParentContext: ctx,
AuthClient: cfg.LocalAuthClient,
AccessPoint: cfg.LocalAccessPoint,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

srv := &server{
Config: cfg,
localAuthClient: cfg.LocalAuthClient,
Expand All @@ -346,6 +345,7 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) {
logger: cfg.Logger,
offlineThreshold: offlineThreshold,
proxySigner: cfg.PROXYSigner,
gitKeyManager: gitKeyManager,
}

localSite, err := newLocalSite(srv, cfg.ClusterName, cfg.LocalAuthAddresses)
Expand Down
8 changes: 7 additions & 1 deletion lib/srv/git/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,16 @@ func NewForwardServer(cfg *ForwardServerConfig) (*ForwardServer, error) {
return nil, trace.Wrap(err)
}

verifyRemoteHost, err := cfg.KeyManager.HostKeyCallback(cfg.TargetServer)
if err != nil {
return nil, trace.Wrap(err)
}

logger := slog.With(teleport.ComponentKey, teleport.ComponentForwardingGit,
"src_addr", cfg.SrcAddr.String(),
"dst_addr", cfg.DstAddr.String(),
)

s := &ForwardServer{
StreamEmitter: cfg.Emitter,
cfg: cfg,
Expand All @@ -188,7 +194,7 @@ func NewForwardServer(cfg *ForwardServerConfig) (*ForwardServer, error) {
logger: logger,
reply: sshutils.NewReply(logger),
id: uuid.NewString(),
verifyRemoteHost: cfg.KeyManager.HostKeyCallback(cfg.TargetServer),
verifyRemoteHost: verifyRemoteHost,
makeRemoteSigner: makeRemoteSigner,
}
// TODO(greedy52) extract common parts from srv.NewAuthHandlers like
Expand Down
79 changes: 46 additions & 33 deletions lib/srv/git/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package git
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"sync/atomic"
Expand All @@ -35,6 +36,7 @@ import (
"github.com/gravitational/teleport"
integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/sshutils"
Expand All @@ -43,8 +45,7 @@ import (
// githubKeyDownloader downloads SSH keys from the GitHub meta API. The keys
// are used to verify GitHub server when forwarding Git commands to it.
type githubKeyDownloader struct {
keys atomic.Value
etag string
keys atomic.Pointer[[]ssh.PublicKey]

logger *slog.Logger
apiEndpoint string
Expand All @@ -66,20 +67,14 @@ func (d *githubKeyDownloader) Start(ctx context.Context) {
d.logger.InfoContext(ctx, "Starting GitHub key downloader")
defer d.logger.InfoContext(ctx, "GitHub key downloader stopped")

// Fire a refresh immediately.
// Fire a refresh immediately then once a day afterward.
timer := d.clock.NewTimer(0)
defer timer.Stop()
for {
select {
case <-timer.Chan():
// Schedule a refresh in 24 hours upon success and in 5 minutes upon
// failure.
if err := d.refresh(ctx); err != nil {
d.logger.WarnContext(ctx, "Failed to download GitHub server keys", "error", err)
timer.Reset(time.Minute * 5)
} else {
timer.Reset(time.Hour * 24)
}
d.refreshWithRetries(ctx)
timer.Reset(time.Hour * 24)
case <-ctx.Done():
return
}
Expand All @@ -92,47 +87,66 @@ func (d *githubKeyDownloader) GetKnownKeys() ([]ssh.PublicKey, error) {
if keys == nil {
return nil, trace.NotFound("server keys not found for github.com")
}
return keys.([]ssh.PublicKey), nil
return *keys, nil
}

func (d *githubKeyDownloader) refreshWithRetries(ctx context.Context) {
retry, err := retryutils.NewRetryV2(retryutils.RetryV2Config{
Driver: retryutils.NewExponentialDriver(time.Second),
Max: time.Minute * 10,
Jitter: retryutils.HalfJitter,
Clock: d.clock,
})
if err != nil {
d.logger.WarnContext(ctx, "Failed to create retry", "error", err)
return
}

for {
if err := d.refresh(ctx); err != nil {
d.logger.WarnContext(ctx, "Failed to download GitHub server keys", "error", err)
} else {
return
}

select {
case <-ctx.Done():
return
case <-retry.After():
retry.Inc()
}
}
}

func (d *githubKeyDownloader) refresh(ctx context.Context) error {
d.logger.DebugContext(ctx, "Calling GitHub meta API", "endpoint", d.apiEndpoint)
// Meta API reference:
// https://docs.github.com/en/rest/meta/meta#get-github-meta-information
req, err := http.NewRequest("GET", d.apiEndpoint, nil)
req, err := http.NewRequestWithContext(ctx, "GET", d.apiEndpoint, nil)
if err != nil {
return trace.Wrap(err)
}

// Add ETag check.
if d.etag != "" {
req.Header.Set("If-None-Match", d.etag)
}

client := &http.Client{
Timeout: defaults.HTTPRequestTimeout,
client, err := defaults.HTTPClient()
if err != nil {
return trace.Wrap(err, "creating HTTP client")
}
resp, err := client.Do(req)
if err != nil {
return trace.Wrap(err)
}
defer resp.Body.Close()

// Nothing changed. Just update the last check time.
if resp.StatusCode == http.StatusNotModified {
d.logger.DebugContext(ctx, "GitHub metadata is up-to-date")
return nil
body, err := io.ReadAll(resp.Body)
if err != nil {
return trace.Wrap(err, "reading GitHub meta API response body")
}

meta := struct {
SSHKeys []string `json:"ssh_keys"`
}{}
if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil {
return trace.Wrap(err, "decoding meta API response")
}

if len(meta.SSHKeys) == 0 {
return trace.NotFound("no SSH keys found")
if err := json.Unmarshal(body, &meta); err != nil {
return trace.Wrap(err, "decoding GitHub meta API response")
}

var keys []ssh.PublicKey
Expand All @@ -144,9 +158,8 @@ func (d *githubKeyDownloader) refresh(ctx context.Context) error {
keys = append(keys, publicKey)
}

d.etag = resp.Header.Get("ETag")
d.keys.Store(keys)
d.logger.DebugContext(ctx, "Fetched GitHub metadata", "ssh_keys", meta.SSHKeys, "etag", d.etag)
d.keys.Store(&keys)
d.logger.DebugContext(ctx, "Fetched GitHub metadata", "ssh_keys", meta.SSHKeys)
return nil
}

Expand Down
22 changes: 2 additions & 20 deletions lib/srv/git/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -153,7 +152,6 @@ func TestMakeGitHubSigner(t *testing.T) {
type mockGitHubMetaAPIServer struct {
*httptest.Server

etag string
metaResponse []byte
}

Expand All @@ -170,7 +168,6 @@ func newMockGitHubMetaAPIServer(t *testing.T, keys ...ssh.PublicKey) *mockGitHub
require.NoError(t, err)

m := &mockGitHubMetaAPIServer{
etag: uuid.NewString(),
metaResponse: metaResponse,
}
m.Server = httptest.NewServer(m)
Expand All @@ -179,10 +176,6 @@ func newMockGitHubMetaAPIServer(t *testing.T, keys ...ssh.PublicKey) *mockGitHub
}

func (m *mockGitHubMetaAPIServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("If-None-Match") == m.etag {
w.WriteHeader(http.StatusNotModified)
return
}
w.WriteHeader(http.StatusOK)
w.Write(m.metaResponse)
}
Expand Down Expand Up @@ -215,8 +208,7 @@ func Test_githubKeyDownloader(t *testing.T) {
name: "success update",
setup: func(d *githubKeyDownloader) {
d.apiEndpoint = mockSuccessServer.URL
d.etag = "old-etag"
d.keys.Store([]ssh.PublicKey{publicKey, publicKey})
d.keys.Store(&[]ssh.PublicKey{publicKey, publicKey})
},
checkRefreshError: require.NoError,
expectGetCount: 1,
Expand All @@ -225,21 +217,11 @@ func Test_githubKeyDownloader(t *testing.T) {
name: "failure should not override existing keys",
setup: func(d *githubKeyDownloader) {
d.apiEndpoint = mockFailureServer.URL
d.keys.Store([]ssh.PublicKey{publicKey, publicKey})
d.keys.Store(&[]ssh.PublicKey{publicKey, publicKey})
},
checkRefreshError: require.Error,
expectGetCount: 2,
},
{
name: "ETag match",
setup: func(d *githubKeyDownloader) {
d.apiEndpoint = mockSuccessServer.URL
d.etag = mockSuccessServer.etag
d.keys.Store([]ssh.PublicKey{publicKey, publicKey})
},
checkRefreshError: require.NoError,
expectGetCount: 2,
},
}

for _, test := range tests {
Expand Down
21 changes: 10 additions & 11 deletions lib/srv/git/key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,27 @@ func (m *KeyManager) startWatcher(ctx context.Context) error {
}

// HostKeyCallback creates an ssh.HostKeyCallback for verifying the target git-hosting service.
func (m *KeyManager) HostKeyCallback(targetServer types.Server) ssh.HostKeyCallback {
return func(_ string, _ net.Addr, key ssh.PublicKey) error {
switch targetServer.GetSubKind() {
case types.SubKindGitHub:
return trace.Wrap(m.verifyGitHub(key))
default:
return trace.BadParameter("unsupported subkind %q", targetServer.GetSubKind())
}
func (m *KeyManager) HostKeyCallback(targetServer types.Server) (ssh.HostKeyCallback, error) {
switch targetServer.GetSubKind() {
case types.SubKindGitHub:
return m.verifyGitHub, nil
default:
return nil, trace.BadParameter("unsupported subkind %q", targetServer.GetSubKind())
}
}

func (m *KeyManager) verifyGitHub(key ssh.PublicKey) error {
func (m *KeyManager) verifyGitHub(_ string, _ net.Addr, key ssh.PublicKey) error {
knownKeys, err := m.cfg.githubServerKeys.GetKnownKeys()
if err != nil {
return trace.Wrap(err)
}
marshaledKey := key.Marshal()
for _, knownKey := range knownKeys {
if knownKey.Type() == key.Type() {
if bytes.Equal(knownKey.Marshal(), key.Marshal()) {
if bytes.Equal(knownKey.Marshal(), marshaledKey) {
return nil
}
}
}
return trace.BadParameter("cannot verify github.com: unknown server key %q", string(key.Marshal()))
return trace.BadParameter("cannot verify github.com")
}
15 changes: 13 additions & 2 deletions lib/srv/git/key_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ func TestKeyManager_verify_github(t *testing.T) {
require.NoError(t, err)

t.Run("connect and verify", func(t *testing.T) {
hostKeyCallback, err := m.HostKeyCallback(githubServer)
require.NoError(t, err)
require.EventuallyWithT(t, func(collect *assert.CollectT) {
conn, err := ssh.Dial("tcp", targetAddress, &ssh.ClientConfig{
User: "git",
Auth: clientAuth,
HostKeyCallback: m.HostKeyCallback(githubServer),
HostKeyCallback: hostKeyCallback,
})
assert.NoError(collect, err)
if conn != nil {
Expand All @@ -99,8 +101,17 @@ func TestKeyManager_verify_github(t *testing.T) {
})

t.Run("unknown key", func(t *testing.T) {
hostKeyCallback, err := m.HostKeyCallback(githubServer)
require.NoError(t, err)
unknownHostKey, err := apisshutils.MakeRealHostCert(caSigner)
require.NoError(t, err)
require.Error(t, m.HostKeyCallback(githubServer)("github.com", utils.MustParseAddr(targetAddress), unknownHostKey.PublicKey()))
require.Error(t, hostKeyCallback("github.com", utils.MustParseAddr(targetAddress), unknownHostKey.PublicKey()))
})

t.Run("unknown Git server type", func(t *testing.T) {
unsupported := githubServer.DeepCopy()
unsupported.SetSubKind("unsupported")
_, err := m.HostKeyCallback(unsupported)
require.Error(t, err)
})
}

0 comments on commit 7fc45b7

Please sign in to comment.