Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err
return c.handleAuthSwitchResponse()
}

switch authPluginName {
case mysql.AUTH_NATIVE_PASSWORD:
return c.compareNativePasswordAuthData(clientAuthData, c.credential)

case mysql.AUTH_CACHING_SHA2_PASSWORD:
if !c.cachingSha2FullAuth {
if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil {
return err
}
if c.cachingSha2FullAuth {
return c.handleAuthSwitchResponse()
}
return nil
}
// AuthMoreData packet already sent, do full auth
return c.handleCachingSha2PasswordFullAuth(clientAuthData)

case mysql.AUTH_SHA256_PASSWORD:
cont, err := c.handlePublicKeyRetrieval(clientAuthData)
if err != nil {
return err
}
if !cont {
return nil
}
return c.compareSha256PasswordAuthData(clientAuthData, c.credential)

default:
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
}
return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData)
}

func (c *Conn) acquirePassword() error {
Expand Down
50 changes: 50 additions & 0 deletions server/authentication_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package server

import (
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/errors"
)

type AuthenticationProvider interface {
Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error
Validate(authPluginName string) bool
}

type DefaultAuthenticationProvider struct{}

func (d *DefaultAuthenticationProvider) Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error {
switch authPluginName {
case mysql.AUTH_NATIVE_PASSWORD:
return c.compareNativePasswordAuthData(clientAuthData, c.credential)

case mysql.AUTH_CACHING_SHA2_PASSWORD:
if !c.cachingSha2FullAuth {
if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil {
return err
}
if c.cachingSha2FullAuth {
return c.handleAuthSwitchResponse()
}
return nil
}
// AuthMoreData packet already sent, do full auth
return c.handleCachingSha2PasswordFullAuth(clientAuthData)

case mysql.AUTH_SHA256_PASSWORD:
cont, err := c.handlePublicKeyRetrieval(clientAuthData)
if err != nil {
return err
}
if !cont {
return nil
}
return c.compareSha256PasswordAuthData(clientAuthData, c.credential)

default:
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
}
}

func (d *DefaultAuthenticationProvider) Validate(authPluginName string) bool {
return isAuthMethodSupported(authPluginName)
}
1 change: 0 additions & 1 deletion server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func (s *Server) NewConn(conn net.Conn, user string, password string, h Handler)
return s.NewCustomizedConn(conn, p, h)
}

// NewCustomizedConn: create connection with customized server settings
func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handler) (*Conn, error) {
var packetConn *packet.Conn
if s.tlsConfig != nil {
Expand Down
10 changes: 9 additions & 1 deletion server/server_conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Server struct {
pubKey []byte
tlsConfig *tls.Config
cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD))
authProvider AuthenticationProvider
}

// NewDefaultServer: New mysql server with default settings.
Expand All @@ -56,6 +57,7 @@ func NewDefaultServer() *Server {
pubKey: getPublicKeyFromCert(certPem),
tlsConfig: tlsConf,
cacheShaPassword: new(sync.Map),
authProvider: &DefaultAuthenticationProvider{},
}
}

Expand All @@ -69,7 +71,12 @@ func NewDefaultServer() *Server {
// And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide
// a signed or unsigned certificate to provide different level of security.
func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server {
if !isAuthMethodSupported(defaultAuthMethod) {
authProvider := &DefaultAuthenticationProvider{}
return NewServerWithAuth(serverVersion, collationId, defaultAuthMethod, pubKey, tlsConfig, authProvider)
}

func NewServerWithAuth(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config, authProvider AuthenticationProvider) *Server {
if authProvider == nil || !authProvider.Validate(defaultAuthMethod) {
panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod))
}

Expand All @@ -91,6 +98,7 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string
pubKey: pubKey,
tlsConfig: tlsConfig,
cacheShaPassword: new(sync.Map),
authProvider: authProvider,
}
}

Expand Down
Loading