diff --git a/authserver.go b/authserver.go new file mode 100644 index 0000000..5fedcb4 --- /dev/null +++ b/authserver.go @@ -0,0 +1,43 @@ +package psdbproxy + +import ( + "net" + + "vitess.io/vitess/go/mysql" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// cachingSha2AuthServerNone takes all comers. +type cachingSha2AuthServerNone struct{} + +type noneGetter struct{} + +func (a *cachingSha2AuthServerNone) UserEntryWithPassword(conn *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) { + return &noneGetter{}, nil +} + +func (a *cachingSha2AuthServerNone) UserEntryWithCacheHash(conn *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, mysql.CacheState, error) { + return &noneGetter{}, mysql.AuthAccepted, nil +} + +// AuthMethods returns the list of registered auth methods +// implemented by this auth server. +func (a *cachingSha2AuthServerNone) AuthMethods() []mysql.AuthMethod { + return []mysql.AuthMethod{mysql.NewSha2CachingAuthMethod(a, a, a)} +} + +// DefaultAuthMethodDescription returns MysqlNativePassword as the default +// authentication method for the auth server implementation. +func (a *cachingSha2AuthServerNone) DefaultAuthMethodDescription() mysql.AuthMethodDescription { + return mysql.CachingSha2Password +} + +// HandleUser validates if this user can use this auth method +func (a *cachingSha2AuthServerNone) HandleUser(user string) bool { + return true +} + +// Get returns the empty string +func (ng *noneGetter) Get() *querypb.VTGateCallerID { + return &querypb.VTGateCallerID{Username: "userData1"} +} diff --git a/cmd/psdbproxy/main.go b/cmd/psdbproxy/main.go index 7fcccea..f3b4b03 100644 --- a/cmd/psdbproxy/main.go +++ b/cmd/psdbproxy/main.go @@ -7,6 +7,7 @@ import ( "github.com/planetscale/psdb/auth" "github.com/planetscale/psdbproxy" "github.com/spf13/pflag" + "vitess.io/vitess/go/mysql" ) var ( @@ -40,7 +41,7 @@ func main() { ch := make(chan error) go func() { - ch <- s.ListenAndServe() + ch <- s.ListenAndServe(mysql.CachingSha2Password) }() logger.Info( diff --git a/server.go b/server.go index d9e2f5a..3a67759 100644 --- a/server.go +++ b/server.go @@ -26,7 +26,7 @@ type Server struct { mysql.UnimplementedHandler } -func (s *Server) Serve(l net.Listener) error { +func (s *Server) Serve(l net.Listener, authMethod mysql.AuthMethodDescription) error { s.ensureSetup() handler, err := s.handler() @@ -37,9 +37,19 @@ func (s *Server) Serve(l net.Listener) error { return err } + var auth mysql.AuthServer + switch authMethod { + case mysql.CachingSha2Password: + auth = &cachingSha2AuthServerNone{} + case mysql.MysqlNativePassword: + auth = mysql.NewAuthServerNone() + default: + return fmt.Errorf("unsupported auth method: %v", authMethod) + } + listener, err := mysql.NewListenerWithConfig(mysql.ListenerConfig{ Listener: l, - AuthServer: mysql.NewAuthServerNone(), + AuthServer: auth, Handler: handler, ConnReadTimeout: s.ReadTimeout, ConnWriteTimeout: 30 * time.Second, @@ -55,12 +65,12 @@ func (s *Server) Serve(l net.Listener) error { return nil } -func (s *Server) ListenAndServe() error { +func (s *Server) ListenAndServe(authMethod mysql.AuthMethodDescription) error { l, err := net.Listen("tcp", s.Addr) if err != nil { return err } - return s.Serve(l) + return s.Serve(l, authMethod) } func (s *Server) Shutdown() {