From 5684eacf52b47bddccd39e4c0df38cdd4da1a6a0 Mon Sep 17 00:00:00 2001 From: Kim Gert Nielsen Date: Mon, 20 Oct 2025 09:47:58 +0200 Subject: [PATCH] Make authentication an interface for easy customization --- server/auth.go | 31 +------------------ server/authentication_provider.go | 50 +++++++++++++++++++++++++++++++ server/conn.go | 1 - server/server_conf.go | 10 ++++++- 4 files changed, 60 insertions(+), 32 deletions(-) create mode 100644 server/authentication_provider.go diff --git a/server/auth.go b/server/auth.go index 13402dd45..8d6830727 100644 --- a/server/auth.go +++ b/server/auth.go @@ -27,36 +27,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err return c.handleAuthSwitchResponse() } - switch authPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - return c.compareNativePasswordAuthData(clientAuthData, c.credential) - - case mysql.AUTH_CACHING_SHA2_PASSWORD: - if !c.cachingSha2FullAuth { - if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { - return err - } - if c.cachingSha2FullAuth { - return c.handleAuthSwitchResponse() - } - return nil - } - // AuthMoreData packet already sent, do full auth - return c.handleCachingSha2PasswordFullAuth(clientAuthData) - - case mysql.AUTH_SHA256_PASSWORD: - cont, err := c.handlePublicKeyRetrieval(clientAuthData) - if err != nil { - return err - } - if !cont { - return nil - } - return c.compareSha256PasswordAuthData(clientAuthData, c.credential) - - default: - return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) - } + return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData) } func (c *Conn) acquirePassword() error { diff --git a/server/authentication_provider.go b/server/authentication_provider.go new file mode 100644 index 000000000..0d91753c4 --- /dev/null +++ b/server/authentication_provider.go @@ -0,0 +1,50 @@ +package server + +import ( + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" +) + +type AuthenticationProvider interface { + Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error + Validate(authPluginName string) bool +} + +type DefaultAuthenticationProvider struct{} + +func (d *DefaultAuthenticationProvider) Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error { + switch authPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + return c.compareNativePasswordAuthData(clientAuthData, c.credential) + + case mysql.AUTH_CACHING_SHA2_PASSWORD: + if !c.cachingSha2FullAuth { + if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { + return err + } + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil + } + // AuthMoreData packet already sent, do full auth + return c.handleCachingSha2PasswordFullAuth(clientAuthData) + + case mysql.AUTH_SHA256_PASSWORD: + cont, err := c.handlePublicKeyRetrieval(clientAuthData) + if err != nil { + return err + } + if !cont { + return nil + } + return c.compareSha256PasswordAuthData(clientAuthData, c.credential) + + default: + return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) + } +} + +func (d *DefaultAuthenticationProvider) Validate(authPluginName string) bool { + return isAuthMethodSupported(authPluginName) +} diff --git a/server/conn.go b/server/conn.go index d68e1efc3..12e160bcb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -68,7 +68,6 @@ func (s *Server) NewConn(conn net.Conn, user string, password string, h Handler) return s.NewCustomizedConn(conn, p, h) } -// NewCustomizedConn: create connection with customized server settings func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handler) (*Conn, error) { var packetConn *packet.Conn if s.tlsConfig != nil { diff --git a/server/server_conf.go b/server/server_conf.go index de049f000..1777bfdc7 100644 --- a/server/server_conf.go +++ b/server/server_conf.go @@ -32,6 +32,7 @@ type Server struct { pubKey []byte tlsConfig *tls.Config cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD)) + authProvider AuthenticationProvider } // NewDefaultServer: New mysql server with default settings. @@ -56,6 +57,7 @@ func NewDefaultServer() *Server { pubKey: getPublicKeyFromCert(certPem), tlsConfig: tlsConf, cacheShaPassword: new(sync.Map), + authProvider: &DefaultAuthenticationProvider{}, } } @@ -69,7 +71,12 @@ func NewDefaultServer() *Server { // And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide // a signed or unsigned certificate to provide different level of security. func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server { - if !isAuthMethodSupported(defaultAuthMethod) { + authProvider := &DefaultAuthenticationProvider{} + return NewServerWithAuth(serverVersion, collationId, defaultAuthMethod, pubKey, tlsConfig, authProvider) +} + +func NewServerWithAuth(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config, authProvider AuthenticationProvider) *Server { + if authProvider == nil || !authProvider.Validate(defaultAuthMethod) { panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod)) } @@ -91,6 +98,7 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string pubKey: pubKey, tlsConfig: tlsConfig, cacheShaPassword: new(sync.Map), + authProvider: authProvider, } }