Skip to content

Commit

Permalink
Use caching_sha2_password for proxy auth (#24)
Browse files Browse the repository at this point in the history
Up until now, mysql_native_password was used for auth. This is however
removed in MySQL 9.x and this is the default that Homebrew installs on
MacOS.

While we can also try to deal with installing older versions on MacOS,
alternatively we update the auth for the proxy to
caching_sha2_password.

The one thing that this breaks is very old MySQL 5.7 clients. Anything
older than MySQL 5.7.23 (released 2018-07-27) would break with this. We
don't really support 5.7 for the proxy anyway though.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Oct 9, 2024
1 parent 61feaf3 commit 7fdfa92
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 5 deletions.
74 changes: 74 additions & 0 deletions authserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package psdbproxy

import (
"crypto/rand"
"net"

"vitess.io/vitess/go/mysql"
querypb "vitess.io/vitess/go/vt/proto/query"
)

// cachingSha2AuthServerNone takes all comers.
type cachingSha2AuthServerNone struct{}

type noneGetter struct{}

// AuthMethods returns the list of registered auth methods
// implemented by this auth server.
func (a *cachingSha2AuthServerNone) AuthMethods() []mysql.AuthMethod {
return []mysql.AuthMethod{&mysqlCachingSha2AuthMethod{}}
}

// DefaultAuthMethodDescription returns MysqlNativePassword as the default
// authentication method for the auth server implementation.
func (a *cachingSha2AuthServerNone) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
return mysql.CachingSha2Password
}

// Get returns the empty string
func (ng *noneGetter) Get() *querypb.VTGateCallerID {
return &querypb.VTGateCallerID{Username: "root"}
}

type mysqlCachingSha2AuthMethod struct{}

func (n *mysqlCachingSha2AuthMethod) Name() mysql.AuthMethodDescription {
return mysql.CachingSha2Password
}

func (n *mysqlCachingSha2AuthMethod) HandleUser(conn *mysql.Conn, user string) bool {
return true
}

func (n *mysqlCachingSha2AuthMethod) AuthPluginData() ([]byte, error) {
salt, err := newSalt()
if err != nil {
return nil, err
}
return append(salt, 0), nil
}

func (n *mysqlCachingSha2AuthMethod) AllowClearTextWithoutTLS() bool {
return true
}

func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *mysql.Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (mysql.Getter, error) {
return &noneGetter{}, nil
}

func newSalt() ([]byte, error) {
salt := make([]byte, 20)
if _, err := rand.Read(salt); err != nil {
return nil, err
}

// Salt must be a legal UTF8 string.
for i := range len(salt) {
salt[i] &= 0x7f
if salt[i] == '\x00' || salt[i] == '$' {
salt[i]++
}
}

return salt, nil
}
3 changes: 2 additions & 1 deletion cmd/psdbproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/planetscale/psdb/auth"
"github.com/planetscale/psdbproxy"
"github.com/spf13/pflag"
"vitess.io/vitess/go/mysql"
)

var (
Expand Down Expand Up @@ -40,7 +41,7 @@ func main() {

ch := make(chan error)
go func() {
ch <- s.ListenAndServe()
ch <- s.ListenAndServe(mysql.CachingSha2Password)
}()

logger.Info(
Expand Down
18 changes: 14 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Server struct {
mysql.UnimplementedHandler
}

func (s *Server) Serve(l net.Listener) error {
func (s *Server) Serve(l net.Listener, authMethod mysql.AuthMethodDescription) error {
s.ensureSetup()

handler, err := s.handler()
Expand All @@ -37,9 +37,19 @@ func (s *Server) Serve(l net.Listener) error {
return err
}

var auth mysql.AuthServer
switch authMethod {
case mysql.CachingSha2Password:
auth = &cachingSha2AuthServerNone{}
case mysql.MysqlNativePassword:
auth = mysql.NewAuthServerNone()
default:
return fmt.Errorf("unsupported auth method: %v", authMethod)
}

listener, err := mysql.NewListenerWithConfig(mysql.ListenerConfig{
Listener: l,
AuthServer: mysql.NewAuthServerNone(),
AuthServer: auth,
Handler: handler,
ConnReadTimeout: s.ReadTimeout,
ConnWriteTimeout: 30 * time.Second,
Expand All @@ -55,12 +65,12 @@ func (s *Server) Serve(l net.Listener) error {
return nil
}

func (s *Server) ListenAndServe() error {
func (s *Server) ListenAndServe(authMethod mysql.AuthMethodDescription) error {
l, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
return s.Serve(l)
return s.Serve(l, authMethod)
}

func (s *Server) Shutdown() {
Expand Down

0 comments on commit 7fdfa92

Please sign in to comment.