Skip to content

Commit

Permalink
ssh: add ServerConfig.PreAuthConnCallback, ServerPreAuthConn (banner)…
Browse files Browse the repository at this point in the history
… interface

Fixes golang/go#68688

Change-Id: Id5f72b32c61c9383a26ec182339486a432c7cdf5
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/613856
LUCI-TryBot-Result: Go LUCI <[email protected]>
Auto-Submit: Nicola Murino <[email protected]>
Reviewed-by: Jonathan Amsterdam <[email protected]>
Reviewed-by: Nicola Murino <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
  • Loading branch information
bradfitz authored and gopherbot committed Jan 18, 2025
1 parent 71d3a4c commit a8ea4be
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 15 deletions.
14 changes: 12 additions & 2 deletions ssh/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type handshakeTransport struct {
pendingPackets [][]byte // Used when a key exchange is in progress.
writePacketsLeft uint32
writeBytesLeft int64
userAuthComplete bool // whether the user authentication phase is complete

// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
Expand Down Expand Up @@ -552,16 +553,25 @@ func (t *handshakeTransport) sendKexInit() error {
return nil
}

var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")

func (t *handshakeTransport) writePacket(p []byte) error {
t.mu.Lock()
defer t.mu.Unlock()

switch p[0] {
case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys")
case msgUserAuthBanner:
if t.userAuthComplete {
return errSendBannerPhase
}
case msgUserAuthSuccess:
t.userAuthComplete = true
}

t.mu.Lock()
defer t.mu.Unlock()
if t.writeError != nil {
return t.writeError
}
Expand Down
50 changes: 37 additions & 13 deletions ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
Server GSSAPIServer
}

// SendAuthBanner implements [ServerPreAuthConn].
func (s *connection) SendAuthBanner(msg string) error {
return s.transport.writePacket(Marshal(&userAuthBannerMsg{
Message: msg,
}))
}

func (*connection) unexportedMethodForFutureProofing() {}

// ServerPreAuthConn is the interface available on an incoming server
// connection before authentication has completed.
type ServerPreAuthConn interface {
unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB

ConnMetadata

// SendAuthBanner sends a banner message to the client.
// It returns an error once the authentication phase has ended.
SendAuthBanner(string) error
}

// ServerConfig holds server specific configuration data.
type ServerConfig struct {
// Config contains configuration shared between client and server.
Expand Down Expand Up @@ -118,6 +139,12 @@ type ServerConfig struct {
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)

// PreAuthConnCallback, if non-nil, is called upon receiving a new connection
// before any authentication has started. The provided ServerPreAuthConn
// can be used at any time before authentication is complete, including
// after this callback has returned.
PreAuthConnCallback func(ServerPreAuthConn)

// ServerVersion is the version identification string to announce in
// the public handshake.
// If empty, a reasonable default is used.
Expand Down Expand Up @@ -488,14 +515,18 @@ func (b *BannerError) Error() string {
}

func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
if config.PreAuthConnCallback != nil {
config.PreAuthConnCallback(s)
}

sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions

authFailures := 0
noneAuthCount := 0
var authErrs []error
var displayedBanner bool
var calledBannerCallback bool
partialSuccessReturned := false
// Set the initial authentication callbacks from the config. They can be
// changed if a PartialSuccessError is returned.
Expand Down Expand Up @@ -542,14 +573,10 @@ userAuthLoop:

s.user = userAuthReq.User

if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
if !calledBannerCallback && config.BannerCallback != nil {
calledBannerCallback = true
if msg := config.BannerCallback(s); msg != "" {
if err := s.SendAuthBanner(msg); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -762,10 +789,7 @@ userAuthLoop:
var bannerErr *BannerError
if errors.As(authErr, &bannerErr) {
if bannerErr.Message != "" {
bannerMsg := &userAuthBannerMsg{
Message: bannerErr.Message,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
if err := s.SendAuthBanner(bannerErr.Message); err != nil {
return nil, err
}
}
Expand Down
86 changes: 86 additions & 0 deletions ssh/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,92 @@ func TestPublicKeyCallbackLastSeen(t *testing.T) {
}
}

func TestPreAuthConnAndBanners(t *testing.T) {
testDone := make(chan struct{})
defer close(testDone)

authConnc := make(chan ServerPreAuthConn, 1)
serverConfig := &ServerConfig{
PreAuthConnCallback: func(c ServerPreAuthConn) {
t.Logf("got ServerPreAuthConn: %v", c)
authConnc <- c // for use later in the test
for _, s := range []string{"hello1", "hello2"} {
if err := c.SendAuthBanner(s); err != nil {
t.Errorf("failed to send banner %q: %v", s, err)
}
}
// Now start a goroutine to spam SendAuthBanner in hopes
// of hitting a race.
go func() {
for {
select {
case <-testDone:
return
default:
if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
t.Errorf("unexpected error from SendAuthBanner: %v", err)
}
time.Sleep(5 * time.Millisecond)
}
}
}()
},
NoClientAuth: true,
NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
t.Logf("got NoClientAuthCallback")
return &Permissions{}, nil
},
}
serverConfig.AddHostKey(testSigners["rsa"])

var banners []string
clientConfig := &ClientConfig{
User: "test",
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(msg string) error {
if msg != "attempted-race" {
banners = append(banners, msg)
}
return nil
},
}

c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
c, _, _, err := NewClientConn(c2, "", clientConfig)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
defer c.Close()

wantBanners := []string{
"hello1",
"hello2",
}
if !reflect.DeepEqual(banners, wantBanners) {
t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
}

// Now that we're authenticated, verify that use of SendBanner
// is an error.
var bc ServerPreAuthConn
select {
case bc = <-authConnc:
default:
t.Fatal("expected ServerPreAuthConn")
}
if err := bc.SendAuthBanner("wrong-phase"); err == nil {
t.Error("unexpected success of SendAuthBanner after authentication")
} else if err != errSendBannerPhase {
t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
}
}

type markerConn struct {
closed uint32
used uint32
Expand Down

0 comments on commit a8ea4be

Please sign in to comment.