diff --git a/authserver.go b/authserver.go new file mode 100644 index 0000000..7678d25 --- /dev/null +++ b/authserver.go @@ -0,0 +1,74 @@ +package psdbproxy + +import ( + "crypto/rand" + "net" + + "vitess.io/vitess/go/mysql" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// cachingSha2AuthServerNone takes all comers. +type cachingSha2AuthServerNone struct{} + +type noneGetter struct{} + +// AuthMethods returns the list of registered auth methods +// implemented by this auth server. +func (a *cachingSha2AuthServerNone) AuthMethods() []mysql.AuthMethod { + return []mysql.AuthMethod{&mysqlCachingSha2AuthMethod{}} +} + +// DefaultAuthMethodDescription returns MysqlNativePassword as the default +// authentication method for the auth server implementation. +func (a *cachingSha2AuthServerNone) DefaultAuthMethodDescription() mysql.AuthMethodDescription { + return mysql.CachingSha2Password +} + +// Get returns the empty string +func (ng *noneGetter) Get() *querypb.VTGateCallerID { + return &querypb.VTGateCallerID{Username: "root"} +} + +type mysqlCachingSha2AuthMethod struct{} + +func (n *mysqlCachingSha2AuthMethod) Name() mysql.AuthMethodDescription { + return mysql.CachingSha2Password +} + +func (n *mysqlCachingSha2AuthMethod) HandleUser(conn *mysql.Conn, user string) bool { + return true +} + +func (n *mysqlCachingSha2AuthMethod) AuthPluginData() ([]byte, error) { + salt, err := newSalt() + if err != nil { + return nil, err + } + return append(salt, 0), nil +} + +func (n *mysqlCachingSha2AuthMethod) AllowClearTextWithoutTLS() bool { + return true +} + +func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *mysql.Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (mysql.Getter, error) { + return &noneGetter{}, nil +} + +func newSalt() ([]byte, error) { + salt := make([]byte, 20) + if _, err := rand.Read(salt); err != nil { + return nil, err + } + + // Salt must be a legal UTF8 string. + for i := range len(salt) { + salt[i] &= 0x7f + if salt[i] == '\x00' || salt[i] == '$' { + salt[i]++ + } + } + + return salt, nil +} diff --git a/cmd/psdbproxy/main.go b/cmd/psdbproxy/main.go index 7fcccea..f3b4b03 100644 --- a/cmd/psdbproxy/main.go +++ b/cmd/psdbproxy/main.go @@ -7,6 +7,7 @@ import ( "github.com/planetscale/psdb/auth" "github.com/planetscale/psdbproxy" "github.com/spf13/pflag" + "vitess.io/vitess/go/mysql" ) var ( @@ -40,7 +41,7 @@ func main() { ch := make(chan error) go func() { - ch <- s.ListenAndServe() + ch <- s.ListenAndServe(mysql.CachingSha2Password) }() logger.Info( diff --git a/server.go b/server.go index d9e2f5a..3a67759 100644 --- a/server.go +++ b/server.go @@ -26,7 +26,7 @@ type Server struct { mysql.UnimplementedHandler } -func (s *Server) Serve(l net.Listener) error { +func (s *Server) Serve(l net.Listener, authMethod mysql.AuthMethodDescription) error { s.ensureSetup() handler, err := s.handler() @@ -37,9 +37,19 @@ func (s *Server) Serve(l net.Listener) error { return err } + var auth mysql.AuthServer + switch authMethod { + case mysql.CachingSha2Password: + auth = &cachingSha2AuthServerNone{} + case mysql.MysqlNativePassword: + auth = mysql.NewAuthServerNone() + default: + return fmt.Errorf("unsupported auth method: %v", authMethod) + } + listener, err := mysql.NewListenerWithConfig(mysql.ListenerConfig{ Listener: l, - AuthServer: mysql.NewAuthServerNone(), + AuthServer: auth, Handler: handler, ConnReadTimeout: s.ReadTimeout, ConnWriteTimeout: 30 * time.Second, @@ -55,12 +65,12 @@ func (s *Server) Serve(l net.Listener) error { return nil } -func (s *Server) ListenAndServe() error { +func (s *Server) ListenAndServe(authMethod mysql.AuthMethodDescription) error { l, err := net.Listen("tcp", s.Addr) if err != nil { return err } - return s.Serve(l) + return s.Serve(l, authMethod) } func (s *Server) Shutdown() {