diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2e07fea9..d745b401 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,6 +39,7 @@ jobs: '8.4', # LTS '8.0', '5.7', + 'mariadb-11.7', # in order to test parsec 'mariadb-11.4', # LTS 'mariadb-11.2', 'mariadb-11.1', diff --git a/AUTHORS b/AUTHORS index 510b869b..a261819f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,6 +37,7 @@ Daniel Montoya Daniel Nichter DaniĆ«l van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov diff --git a/README.md b/README.md index da4593cc..3851601b 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ db.SetMaxIdleConns(10) The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): ``` -[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] +[[username][:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] ``` A DSN in its fullest form: @@ -172,6 +172,16 @@ Default: false `allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. +##### `AllowDialogPasswords` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`AllowDialogPasswords=true` allows using the [PAM client side plugin](https://mariadb.com/kb/en/authentication-plugin-pam/) if required by an account, such as one defined with the PAM authentication plugin. Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + ##### `allowFallbackToPlaintext` @@ -453,6 +463,16 @@ Default: none [Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time. + +##### `OtherPasswd` + +``` +Type: comma-delimited string of password for MariaDB PAM authentication, if requiring more than one password +Valid Values: (,,...) +Default: none +``` + + ##### System Variables Any other parameters are interpreted as system variables: @@ -534,6 +554,19 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d > The `QueryContext`, `ExecContext`, etc. variants provided by `database/sql` will cause the connection to be closed if the provided context is cancelled or timed out before the result is received by the driver. +### Authentication Plugin System + +The driver implements a pluggable authentication system that supports various authentication methods used by MySQL and MariaDB servers. The built-in authentication plugins include: + +- `mysql_native_password` - The default MySQL authentication method +- `caching_sha2_password` - Default authentication method in MySQL 8.0+ +- `mysql_clear_password` - Cleartext authentication (requires `allowCleartextPasswords=true`) +- `mysql_old_password` - Old MySQL authentication (requires `allowOldPasswords=true`) +- `sha256_password` - SHA256 authentication +- `parsec` - MariaDB 11.6+ PARSEC authentication +- `client_ed25519` - MariaDB Ed25519 authentication +- `dialog` - MariaDB PAM authentication (requires `AllowDialogPasswords=true`) + ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): ```go diff --git a/auth.go b/auth.go index 74e1bd03..0470bb9b 100644 --- a/auth.go +++ b/auth.go @@ -9,17 +9,10 @@ package mysql import ( - "crypto/rand" + "bytes" "crypto/rsa" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "crypto/x509" - "encoding/pem" "fmt" "sync" - - "filippo.io/edwards25519" ) // server pub keys registry @@ -137,348 +130,122 @@ func pwHash(password []byte) (result [2]uint32) { return } -// Hash password using insecure pre 4.1 method -func scrambleOldPassword(scramble []byte, password string) []byte { - scramble = scramble[:8] - - hashPw := pwHash([]byte(password)) - hashSc := pwHash(scramble) - - r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) - - var out [8]byte - for i := range out { - out[i] = r.NextByte() + 64 - } - - mask := r.NextByte() - for i := range out { - out[i] ^= mask - } - - return out[:] -} - -// Hash password using 4.1+ method (SHA1) -func scramblePassword(scramble []byte, password string) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write([]byte(password)) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - -// Hash password using MySQL 8+ method (SHA256) -func scrambleSHA256Password(scramble []byte, password string) []byte { - if len(password) == 0 { - return nil +// handleAuthResult processes the initial authentication packet and manages subsequent +// authentication flow. It reads the first authentication packet and hands off processing +// to the appropriate auth plugin. +// +// Parameters: +// - initialSeed: The initial random seed sent from server to client +// - authPlugin: The authentication plugin to use for this connection +// +// Returns an error if authentication fails or if there's a network/protocol error. +func (mc *mysqlConn) handleAuthResult(initialSeed []byte, authPlugin AuthPlugin) error { + data, err := mc.readPacket() + if err != nil { + return err } - // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) - - crypt := sha256.New() - crypt.Write([]byte(password)) - message1 := crypt.Sum(nil) - - crypt.Reset() - crypt.Write(message1) - message1Hash := crypt.Sum(nil) - - crypt.Reset() - crypt.Write(message1Hash) - crypt.Write(scramble) - message2 := crypt.Sum(nil) - - for i := range message1 { - message1[i] ^= message2[i] + data, err = authPlugin.ProcessAuthResponse(data, initialSeed, mc) + if err != nil { + return err } - return message1 + return mc.processAuthResponse(data, initialSeed) } -func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { - plain := make([]byte, len(password)+1) - copy(plain, password) - for i := range plain { - j := i % len(seed) - plain[i] ^= seed[j] +// processAuthResponse handles the different types of server responses during +// the authentication phase, routing each response type to the appropriate handler. +// +// Parameters: +// - data: The packet data received from the server +// - initialSeed: The initial random seed sent from server to client +// +// Returns an error if authentication fails or if there's a protocol error. +func (mc *mysqlConn) processAuthResponse(data []byte, initialSeed []byte) error { + switch data[0] { + case iOK: + return mc.resultUnchanged().handleOkPacket(data) + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + return mc.handleAuthSwitch(data, initialSeed) + default: + return ErrMalformPkt } - sha1 := sha1.New() - return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } -// authEd25519 does ed25519 authentication used by MariaDB. -func authEd25519(scramble []byte, password string) ([]byte, error) { - // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c - // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 - h := sha512.Sum512([]byte(password)) - - s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) - if err != nil { - return nil, err - } - A := (&edwards25519.Point{}).ScalarBaseMult(s) +// handleAuthSwitch processes an authentication plugin switch request from the server. +// This happens when the server wants to use a different authentication method than +// what was initially negotiated. +// +// Parameters: +// - data: The packet data received from the server containing switch request information +// - initialSeed: The initial random seed from the server +// +// Returns an error if the requested plugin is not supported, or if there's an error +// during the authentication process. +func (mc *mysqlConn) handleAuthSwitch(data []byte, initialSeed []byte) error { + plugin, authData := mc.parseAuthSwitchData(data, initialSeed) - mh := sha512.New() - mh.Write(h[32:]) - mh.Write(scramble) - messageDigest := mh.Sum(nil) - r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) - if err != nil { - return nil, err + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + return fmt.Errorf("this authentication plugin '%s' is not supported", plugin) } - R := (&edwards25519.Point{}).ScalarBaseMult(r) - - kh := sha512.New() - kh.Write(R.Bytes()) - kh.Write(A.Bytes()) - kh.Write(scramble) - hramDigest := kh.Sum(nil) - k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) - if err != nil { - return nil, err - } - - S := k.MultiplyAdd(k, s, r) - - return append(R.Bytes(), S.Bytes()...), nil -} - -func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { - enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + cachedEncryptPassword, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { return err } - return mc.writeAuthSwitchPacket(enc) -} - -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { - switch plugin { - case "caching_sha2_password": - authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, nil - - case "mysql_old_password": - if !mc.cfg.AllowOldPasswords { - return nil, ErrOldPassword - } - if len(mc.cfg.Passwd) == 0 { - return nil, nil - } - // Note: there are edge cases where this should work but doesn't; - // this is currently "wontfix": - // https://github.com/go-sql-driver/mysql/issues/184 - authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) - return authResp, nil - - case "mysql_clear_password": - if !mc.cfg.AllowCleartextPasswords { - return nil, ErrCleartextPassword - } - // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html - // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return append([]byte(mc.cfg.Passwd), 0), nil - - case "mysql_native_password": - if !mc.cfg.AllowNativePasswords { - return nil, ErrNativePassword - } - // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html - // Native password authentication only need and will need 20-byte challenge. - authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, nil - - case "sha256_password": - if len(mc.cfg.Passwd) == 0 { - return []byte{0}, nil - } - // unlike caching_sha2_password, sha256_password does not accept - // cleartext password on unix transport. - if mc.cfg.TLS != nil { - // write cleartext auth packet - return append([]byte(mc.cfg.Passwd), 0), nil - } - - pubKey := mc.cfg.pubKey - if pubKey == nil { - // request public key from server - return []byte{1}, nil - } - - // encrypted password - enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, err - - case "client_ed25519": - if len(authData) != 32 { - return nil, ErrMalformPkt - } - return authEd25519(authData, mc.cfg.Passwd) - default: - mc.log("unknown auth plugin:", plugin) - return nil, ErrUnknownPlugin + if err := mc.writeAuthSwitchPacket(cachedEncryptPassword); err != nil { + return err } -} -func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { - // Read Result Packet - authData, newPlugin, err := mc.readAuthResult() + data, err = mc.readPacket() if err != nil { return err } - // handle auth plugin switch, if requested - if newPlugin != "" { - // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is - // sent and we have to keep using the cipher sent in the init packet. - if authData == nil { - authData = oldAuthData - } else { - // copy data from read buffer to owned slice - copy(oldAuthData, authData) - } - - plugin = newPlugin - - authResp, err := mc.auth(authData, plugin) - if err != nil { - return err - } - if err = mc.writeAuthSwitchPacket(authResp); err != nil { - return err - } - - // Read Result Packet - authData, newPlugin, err = mc.readAuthResult() + switch data[0] { + case iERR, iOK, iEOF: + return mc.processAuthResponse(data, initialSeed) + default: + data, err = authPlugin.ProcessAuthResponse(data, authData, mc) if err != nil { return err } - - // Do not allow to change the auth plugin more than once - if newPlugin != "" { - return ErrMalformPkt - } + return mc.processAuthResponse(data, initialSeed) } +} - switch plugin { - - // https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/ - case "caching_sha2_password": - switch len(authData) { - case 0: - return nil // auth successful - case 1: - switch authData[0] { - case cachingSha2PasswordFastAuthSuccess: - if err = mc.resultUnchanged().readResultOK(); err == nil { - return nil // auth successful - } - - case cachingSha2PasswordPerformFullAuthentication: - if mc.cfg.TLS != nil || mc.cfg.Net == "unix" { - // write cleartext auth packet - err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) - if err != nil { - return err - } - } else { - pubKey := mc.cfg.pubKey - if pubKey == nil { - // request public key from server - data, err := mc.buf.takeSmallBuffer(4 + 1) - if err != nil { - return err - } - data[4] = cachingSha2PasswordRequestPublicKey - err = mc.writePacket(data) - if err != nil { - return err - } - - if data, err = mc.readPacket(); err != nil { - return err - } - - if data[0] != iAuthMoreData { - return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication") - } - - // parse public key - block, rest := pem.Decode(data[1:]) - if block == nil { - return fmt.Errorf("no pem data found, data: %s", rest) - } - pkix, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return err - } - pubKey = pkix.(*rsa.PublicKey) - } - - // send encrypted password - err = mc.sendEncryptedPassword(oldAuthData, pubKey) - if err != nil { - return err - } - } - return mc.resultUnchanged().readResultOK() - - default: - return ErrMalformPkt - } - default: - return ErrMalformPkt - } +// parseAuthSwitchData extracts the authentication plugin name and associated data +// from an authentication switch request packet. +// +// Parameters: +// - data: The packet data from an authentication switch request +// - initialSeed: The initial seed, used as fallback for old authentication method +// +// Returns: +// - string: The name of the requested authentication plugin +// - []byte: The authentication data to be used with the plugin +func (mc *mysqlConn) parseAuthSwitchData(data []byte, initialSeed []byte) (string, []byte) { + if len(data) == 1 { + // Special case for the old authentication protocol + return "mysql_old_password", initialSeed + } - case "sha256_password": - switch len(authData) { - case 0: - return nil // auth successful - default: - block, _ := pem.Decode(authData) - if block == nil { - return fmt.Errorf("no Pem data found, data: %s", authData) - } - - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return err - } - - // send encrypted password - err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) - if err != nil { - return err - } - return mc.resultUnchanged().readResultOK() - } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return "", nil + } - default: - return nil // auth successful + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + if len(authData) > 0 && authData[len(authData)-1] == 0 { + authData = authData[:len(authData)-1] } - return err + savedAuthData := make([]byte, len(authData)) + copy(savedAuthData, authData) + return plugin, savedAuthData } diff --git a/auth_caching_sha2.go b/auth_caching_sha2.go new file mode 100644 index 00000000..37479125 --- /dev/null +++ b/auth_caching_sha2.go @@ -0,0 +1,177 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "fmt" +) + +// CachingSha2PasswordPlugin implements the caching_sha2_password authentication +// This plugin provides secure password-based authentication using SHA256 and RSA encryption, +// with server-side caching of password verifiers for improved performance. +type CachingSha2PasswordPlugin struct { + AuthPlugin +} + +func init() { + RegisterAuthPlugin(&CachingSha2PasswordPlugin{}) +} + +func (p *CachingSha2PasswordPlugin) GetPluginName() string { + return "caching_sha2_password" +} + +// InitAuth initializes the authentication process by scrambling the password. +// +// The scrambling process uses a three-step SHA256 hash: +// 1. SHA256(password) +// 2. SHA256(SHA256(password)) +// 3. XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) +func (p *CachingSha2PasswordPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + return scrambleSHA256Password(authData, cfg.Passwd), nil +} + +// ProcessAuthResponse processes the server's response to our authentication attempt. +// +// The authentication flow can take several paths: +// 1. Fast auth success (password found in cache) +// 2. Full authentication needed: +// a. With TLS: send cleartext password +// b. Without TLS: +// - Request server's public key if not cached +// - Encrypt password with RSA public key +// - Send encrypted password +func (p *CachingSha2PasswordPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + + if len(packet) == 0 { + return nil, fmt.Errorf("%w: empty auth response packet", ErrMalformPkt) + } + + switch packet[0] { + case iOK, iERR, iEOF: + return packet, nil + case iAuthMoreData: + switch len(packet) { + case 1: + return mc.readPacket() // Auth successful + + case 2: + switch packet[1] { + case 3: + // the password was found in the server's cache + return mc.readPacket() + + case 4: + // indicates full authentication is needed + // For TLS connections or Unix socket, send cleartext password + if mc.cfg.TLS != nil || mc.cfg.Net == "unix" { + err := mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + if err != nil { + return nil, fmt.Errorf("failed to send cleartext password: %w", err) + } + } else { + // For non-TLS connections, use RSA encryption + pubKey := mc.cfg.pubKey + if pubKey == nil { + // Request public key from server + packet, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return nil, fmt.Errorf("failed to allocate buffer: %w", err) + } + packet[4] = 2 + if err = mc.writePacket(packet); err != nil { + return nil, fmt.Errorf("failed to request public key: %w", err) + } + + // Read public key packet + if packet, err = mc.readPacket(); err != nil { + return nil, fmt.Errorf("failed to read public key: %w", err) + } + + if packet[0] != iAuthMoreData { + return nil, fmt.Errorf("unexpected packet type %d when requesting public key", packet[0]) + } + + // Parse public key from PEM format + block, rest := pem.Decode(packet[1:]) + if block == nil { + return nil, fmt.Errorf("invalid PEM data in auth response: %q", rest) + } + + // Parse the public key + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + pubKey = pkix.(*rsa.PublicKey) + } + + // Encrypt and send password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt password: %w", err) + } + if err = mc.writeAuthSwitchPacket(enc); err != nil { + return nil, fmt.Errorf("failed to send encrypted password: %w", err) + } + } + return mc.readPacket() + + default: + return nil, fmt.Errorf("%w: unknown auth state %d", ErrMalformPkt, packet[1]) + } + + default: + return nil, fmt.Errorf("%w: unexpected packet length %d", ErrMalformPkt, len(packet)) + } + default: + return nil, fmt.Errorf("%w: expected auth more data packet", ErrMalformPkt) + } +} + +// scrambleSHA256Password implements MySQL 8+ password scrambling. +// +// The algorithm is: +// 1. SHA256(password) +// 2. SHA256(SHA256(SHA256(password))) +// 3. XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) +// +// This provides a way to verify the password without storing it in cleartext. +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return []byte{} + } + + // First hash: SHA256(password) + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + // Second hash: SHA256(SHA256(password)) + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + // Third hash: SHA256(SHA256(SHA256(password)), scramble) + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + // XOR the first hash with the third hash + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} diff --git a/auth_cleartext.go b/auth_cleartext.go new file mode 100644 index 00000000..d5e7a615 --- /dev/null +++ b/auth_cleartext.go @@ -0,0 +1,53 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +// ClearPasswordPlugin implements the mysql_clear_password authentication. +// +// This plugin sends passwords in cleartext and should only be used: +// 1. Over TLS/SSL connections +// 2. Over Unix domain sockets +// 3. When required by authentication methods like PAM +// +// See: http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html +// +// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html +type ClearPasswordPlugin struct { + AuthPlugin +} + +func init() { + RegisterAuthPlugin(&ClearPasswordPlugin{}) +} + +func (p *ClearPasswordPlugin) GetPluginName() string { + return "mysql_clear_password" +} + +// InitAuth implements the cleartext password authentication. +// It will return an error if AllowCleartextPasswords is false. +// +// The cleartext password is sent as a null-terminated string. +// This is required by the server to support external authentication +// systems that need access to the original password. +func (p *ClearPasswordPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + if !cfg.AllowCleartextPasswords { + return nil, ErrCleartextPassword + } + + // Send password as null-terminated string + return append([]byte(cfg.Passwd), 0), nil +} + +// ProcessAuthResponse handles the server's response to our authentication attempt. +// For cleartext authentication, we simply return the packet as is since no +// additional processing is needed. +func (p *ClearPasswordPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + return packet, nil +} diff --git a/auth_dialog.go b/auth_dialog.go new file mode 100644 index 00000000..915166b6 --- /dev/null +++ b/auth_dialog.go @@ -0,0 +1,92 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "fmt" + "strings" +) + +const ( + dialogPluginName = "dialog" +) + +// dialogAuthPlugin implements the MariaDB PAM authentication plugin +type dialogAuthPlugin struct { + AuthPlugin +} + +func init() { + RegisterAuthPlugin(&dialogAuthPlugin{}) +} + +// GetPluginName returns the name of the authentication plugin +func (p *dialogAuthPlugin) GetPluginName() string { + return dialogPluginName +} + +// InitAuth initializes the authentication process +func (p *dialogAuthPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + if !cfg.AllowDialogPasswords { + return nil, ErrDialogAuth + } + return append([]byte(cfg.Passwd), 0), nil +} + +// ProcessAuthResponse processes the authentication response from the server +func (p *dialogAuthPlugin) ProcessAuthResponse(packet []byte, authData []byte, conn *mysqlConn) ([]byte, error) { + + if len(packet) == 0 { + return nil, fmt.Errorf("%w: empty auth response packet", ErrMalformPkt) + } + + switch packet[0] { + case iOK, iERR, iEOF: + return packet, nil + default: + // Initialize passwords from Config + var passwords []string + if conn.cfg.OtherPasswd != "" { + // Additional passwords from OtherPasswd (comma separated) + otherPasswords := strings.Split(conn.cfg.OtherPasswd, ",") + passwords = append(passwords, otherPasswords...) + } + currentPasswordIndex := 0 + for { + var authResp []byte + if len(passwords) >= currentPasswordIndex+1 { + authResp = append([]byte(passwords[currentPasswordIndex]), 0) + } else { + authResp = []byte{0} + } + currentPasswordIndex++ + + // Send the authentication response + if err := conn.writeAuthSwitchPacket(authResp); err != nil { + return nil, fmt.Errorf("failed to write dialog packet: %w", err) + } + + // Read the final authentication result + response, err := conn.readPacket() + if err != nil { + return nil, fmt.Errorf("failed to read dialog packet: %w", err) + } + if len(response) == 0 { + return nil, fmt.Errorf("%w: empty auth response packet", ErrMalformPkt) + } + + switch response[0] { + case iOK, iERR, iEOF: + return response, nil + default: + continue + } + } + } +} diff --git a/auth_ed25519.go b/auth_ed25519.go new file mode 100644 index 00000000..0883b3f7 --- /dev/null +++ b/auth_ed25519.go @@ -0,0 +1,71 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/sha512" + "filippo.io/edwards25519" +) + +// ClientEd25519Plugin implements the client_ed25519 authentication +type ClientEd25519Plugin struct { + AuthPlugin +} + +func init() { + RegisterAuthPlugin(&ClientEd25519Plugin{}) +} + +func (p *ClientEd25519Plugin) GetPluginName() string { + return "client_ed25519" +} + +func (p *ClientEd25519Plugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 + h := sha512.Sum512([]byte(cfg.Passwd)) + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + A := (&edwards25519.Point{}).ScalarBaseMult(s) + + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(authData) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + return nil, err + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(authData) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + return nil, err + } + + S := k.MultiplyAdd(k, s, r) + + return append(R.Bytes(), S.Bytes()...), nil +} + +// ProcessAuthResponse handles the server's response to our authentication attempt. +// For Ed25519 authentication, we simply return the packet as is since the server +// will verify the signature we sent in InitAuth. +func (p *ClientEd25519Plugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + return packet, nil +} diff --git a/auth_mysql_native.go b/auth_mysql_native.go new file mode 100644 index 00000000..d10f9c12 --- /dev/null +++ b/auth_mysql_native.go @@ -0,0 +1,68 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import "crypto/sha1" + +// NativePasswordPlugin implements the mysql_native_password authentication +type NativePasswordPlugin struct { + AuthPlugin +} + +func init() { + RegisterAuthPlugin(&NativePasswordPlugin{}) +} + +func (p *NativePasswordPlugin) GetPluginName() string { + return "mysql_native_password" +} + +func (p *NativePasswordPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + if !cfg.AllowNativePasswords { + return nil, ErrNativePassword + } + if cfg.Passwd == "" { + return nil, nil + } + return p.scramblePassword(authData[:20], cfg.Passwd), nil +} + +// Hash password using 4.1+ method (SHA1) +func (p *NativePasswordPlugin) scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +func (p *NativePasswordPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + return packet, nil +} diff --git a/auth_old_password.go b/auth_old_password.go new file mode 100644 index 00000000..5428c7be --- /dev/null +++ b/auth_old_password.go @@ -0,0 +1,59 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +// OldPasswordPlugin implements the mysql_old_password authentication +type OldPasswordPlugin struct{ AuthPlugin } + +func init() { + RegisterAuthPlugin(&OldPasswordPlugin{}) +} + +func (p *OldPasswordPlugin) GetPluginName() string { + return "mysql_old_password" +} + +func (p *OldPasswordPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + if !cfg.AllowOldPasswords { + return nil, ErrOldPassword + } + if cfg.Passwd == "" { + return nil, nil + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + return append(p.scrambleOldPassword(authData[:8], cfg.Passwd), 0), nil +} + +func (p *OldPasswordPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + return packet, nil +} + +// Hash password using insecure pre 4.1 method +func (p *OldPasswordPlugin) scrambleOldPassword(scramble []byte, password string) []byte { + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} diff --git a/auth_parsec.go b/auth_parsec.go new file mode 100644 index 00000000..29721155 --- /dev/null +++ b/auth_parsec.go @@ -0,0 +1,120 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/sha512" + "fmt" + + "golang.org/x/crypto/pbkdf2" +) + +// ParsecPlugin implements the parsec authentication +type ParsecPlugin struct{ AuthPlugin } + +func init() { + RegisterAuthPlugin(&ParsecPlugin{}) +} + +func (p *ParsecPlugin) GetPluginName() string { + return "parsec" +} + +func (p *ParsecPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + return []byte{}, nil +} + +func (p *ParsecPlugin) processParsecExtSalt(extSalt, serverScramble []byte, password string) ([]byte, error) { + return ProcessParsecExtSalt(extSalt, serverScramble, password) +} + +func (p *ParsecPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + // Process the ext-salt and generate the client nonce and signature + authResp, err := p.processParsecExtSalt(packet, authData, mc.cfg.Passwd) + if err != nil { + return nil, fmt.Errorf("parsec auth failed: %w", err) + } + + // Send the authentication response + if err = mc.writeAuthSwitchPacket(authResp); err != nil { + return nil, fmt.Errorf("failed to write auth switch packet: %w", err) + } + + // Read the final authentication result + return mc.readPacket() +} + +// ProcessParsecExtSalt processes the ext-salt received from the server and generates +// the authentication response for PARSEC authentication. +// +// The ext-salt format is: 'P' + iteration factor + salt +// The iteration count is 1024 << iteration factor (0x0 means 1024, 0x1 means 2048, etc.) +// +// The authentication process: +// 1. Validates the ext-salt format and iteration factor +// 2. Generates a random 32-byte client nonce +// 3. Derives a key using PBKDF2-HMAC-SHA512 with the password and salt +// 4. Uses the derived key as an Ed25519 seed to generate a signing key +// 5. Signs the concatenation of server scramble and client nonce +// 6. Returns the concatenation of client nonce and signature +// +// This function is exported for testing purposes +func ProcessParsecExtSalt(extSalt, serverScramble []byte, password string) ([]byte, error) { + // Validate ext-salt format and length + if len(extSalt) < 3 { + return nil, fmt.Errorf("%w: ext-salt too short", ErrParsecAuth) + } + if extSalt[0] != 'P' { + return nil, fmt.Errorf("%w: invalid ext-salt prefix", ErrParsecAuth) + } + + // Parse and validate iteration factor + iterationFactor := int(extSalt[1]) + if iterationFactor < 0 || iterationFactor > 3 { + return nil, fmt.Errorf("%w: invalid iteration factor", ErrParsecAuth) + } + + // Calculate iterations + iterations := 1024 << iterationFactor + + // Extract the salt (everything after prefix and iteration factor) + salt := extSalt[2:] + if len(salt) == 0 { + return nil, fmt.Errorf("%w: empty salt", ErrParsecAuth) + } + + // Generate a random 32-byte client nonce + clientNonce := make([]byte, 32) + if _, err := rand.Read(clientNonce); err != nil { + return nil, fmt.Errorf("failed to generate client nonce: %w", err) + } + + // Generate the PBKDF2 key using SHA-512 and the configured iterations + derivedKey := pbkdf2.Key([]byte(password), salt, iterations, ed25519.SeedSize, sha512.New) + + // Create message to sign (server scramble + client nonce) + message := make([]byte, len(serverScramble)+len(clientNonce)) + copy(message, serverScramble) + copy(message[len(serverScramble):], clientNonce) + + // Generate Ed25519 private key from derived key + privateKey := ed25519.NewKeyFromSeed(derivedKey[:ed25519.SeedSize]) + + // Sign the message + signature := ed25519.Sign(privateKey, message) + + // Prepare the authentication response: client nonce (32 bytes) + signature (64 bytes) + response := make([]byte, len(clientNonce)+len(signature)) + copy(response, clientNonce) + copy(response[len(clientNonce):], signature) + + return response, nil +} diff --git a/auth_plugin.go b/auth_plugin.go new file mode 100644 index 00000000..4ee0efdb --- /dev/null +++ b/auth_plugin.go @@ -0,0 +1,55 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +// AuthPlugin represents an authentication plugin for MySQL/MariaDB +type AuthPlugin interface { + // GetPluginName returns the name of the authentication plugin + GetPluginName() string + + // InitAuth initializes the authentication process and returns the initial response + // authData is the challenge data from the server + // password is the password for authentication + InitAuth(authData []byte, cfg *Config) ([]byte, error) + + // ProcessAuthResponse processes the authentication response from the server + // packet is the data from the server's auth response + // authData is the initial auth data from the server + // conn is the MySQL connection (for performing additional interactions if needed) + ProcessAuthResponse(packet []byte, authData []byte, conn *mysqlConn) ([]byte, error) +} + +// PluginRegistry is a registry of available authentication plugins +type PluginRegistry struct { + plugins map[string]AuthPlugin +} + +// NewPluginRegistry creates a new plugin registry +func NewPluginRegistry() *PluginRegistry { + registry := &PluginRegistry{ + plugins: make(map[string]AuthPlugin), + } + return registry +} + +// Register adds a plugin to the registry +func (r *PluginRegistry) Register(plugin AuthPlugin) { + r.plugins[plugin.GetPluginName()] = plugin +} + +// GetPlugin returns the plugin for the given name +func (r *PluginRegistry) GetPlugin(name string) (AuthPlugin, bool) { + plugin, ok := r.plugins[name] + return plugin, ok +} + +// RegisterAuthPlugin registers the plugin to the global plugin registry +func RegisterAuthPlugin(plugin AuthPlugin) { + globalPluginRegistry.Register(plugin) +} diff --git a/auth_sha256.go b/auth_sha256.go new file mode 100644 index 00000000..7cd2a0df --- /dev/null +++ b/auth_sha256.go @@ -0,0 +1,133 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2023 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "fmt" +) + +// Sha256PasswordPlugin implements the sha256_password authentication +// This plugin provides secure password-based authentication using SHA256 and RSA encryption. +type Sha256PasswordPlugin struct{ AuthPlugin } + +func init() { + RegisterAuthPlugin(&Sha256PasswordPlugin{}) +} + +func (p *Sha256PasswordPlugin) GetPluginName() string { + return "sha256_password" +} + +// InitAuth initializes the authentication process. +// +// The function follows these rules: +// 1. If no password is configured, sends a single byte indicating empty password +// 2. If TLS is enabled, sends the password in cleartext +// 3. If a public key is available, encrypts the password and sends it +// 4. Otherwise, requests the server's public key +func (p *Sha256PasswordPlugin) InitAuth(authData []byte, cfg *Config) ([]byte, error) { + if len(cfg.Passwd) == 0 { + return []byte{0}, nil + } + + // Unlike caching_sha2_password, sha256_password does not accept + // cleartext password on unix transport. + if cfg.TLS != nil { + // Write cleartext auth packet + return append([]byte(cfg.Passwd), 0), nil + } + + if cfg.pubKey == nil { + // Request public key from server + return []byte{1}, nil + } + + // Encrypt password using the public key + enc, err := encryptPassword(cfg.Passwd, authData, cfg.pubKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt password: %w", err) + } + return enc, nil +} + +// ProcessAuthResponse processes the server's response to our authentication attempt. +// +// The server can respond in three ways: +// 1. OK packet - Authentication successful +// 2. Error packet - Authentication failed +// 3. More data packet - Contains the server's public key for password encryption +func (p *Sha256PasswordPlugin) ProcessAuthResponse(packet []byte, authData []byte, mc *mysqlConn) ([]byte, error) { + if len(packet) == 0 { + return nil, fmt.Errorf("%w: empty auth response packet", ErrMalformPkt) + } + + switch packet[0] { + case iOK, iERR, iEOF: + return packet, nil + + case iAuthMoreData: + // Parse public key from PEM format + block, rest := pem.Decode(packet[1:]) + if block == nil { + return nil, fmt.Errorf("invalid PEM data in auth response: %q", rest) + } + + // Parse the public key + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + // Send encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pub.(*rsa.PublicKey)) + if err != nil { + return nil, fmt.Errorf("failed to encrypt password with server key: %w", err) + } + + // Send the encrypted password + if err = mc.writeAuthSwitchPacket(enc); err != nil { + return nil, fmt.Errorf("failed to send encrypted password: %w", err) + } + + return mc.readPacket() + + default: + return nil, fmt.Errorf("%w: unexpected packet type %d", ErrMalformPkt, packet[0]) + } +} + +// encryptPassword encrypts the password using RSA-OAEP with SHA1 hash. +// +// The process: +// 1. XORs the password with the auth seed to prevent replay attacks +// 2. Encrypts the XORed password using RSA-OAEP with SHA1 +// +// The encryption uses OAEP padding which is more secure than PKCS#1 v1.5 padding. +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + if pub == nil { + return nil, fmt.Errorf("public key is nil") + } + + // Create the plaintext by XORing password with seed + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + + // Encrypt using RSA-OAEP with SHA1 + sha1Hash := sha1.New() + return rsa.EncryptOAEP(sha1Hash, rand.Reader, pub, plain, nil) +} diff --git a/auth_test.go b/auth_test.go index 46e1e3b4..906a1c0b 100644 --- a/auth_test.go +++ b/auth_test.go @@ -16,6 +16,8 @@ import ( "encoding/pem" "fmt" "testing" + + osuser "os/user" ) var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + @@ -50,8 +52,12 @@ func TestScrambleOldPass(t *testing.T) { {"123\t456", "575c47505b5b5559"}, {"C0mpl!ca ted#PASS123", "5d5d554849584a45"}, } + + // Send Client Authentication Packet + authPlugin := OldPasswordPlugin{} + for _, tuple := range vectors { - ours := scrambleOldPassword(scramble, tuple.pass) + ours := authPlugin.scrambleOldPassword(scramble, tuple.pass) if tuple.out != fmt.Sprintf("%x", ours) { t.Errorf("Failed old password %q", tuple.pass) } @@ -75,6 +81,47 @@ func TestScrambleSHA256Pass(t *testing.T) { } } +func TestDefaultUser(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "" + mc.cfg.Passwd = "secret" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + var expectedUsername string + currentUser, err := osuser.Current() + if err != nil { + expectedUsername = "" + } else { + expectedUsername = currentUser.Username + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + authRespEnd := authRespStart + len(expectedUsername) + writtenAuthResp := conn.written[authRespStart:authRespEnd] + expectedAuthResp := []byte(expectedUsername) + if !bytes.Equal(writtenAuthResp, expectedAuthResp) || conn.written[authRespEnd] != 0 { + t.Fatalf("unexpected written auth response: %v", writtenAuthResp) + } +} + func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { conn, mc := newRWMockConn(1) mc.cfg.User = "root" @@ -85,7 +132,12 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -114,8 +166,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -130,7 +181,12 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -156,8 +212,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -172,12 +227,17 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { + + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { t.Fatal(err) } @@ -207,8 +267,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { } conn.maxReads = 3 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } @@ -228,7 +287,12 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -261,7 +325,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } @@ -280,7 +344,12 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -317,10 +386,9 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } - if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { t.Errorf("unexpected written data: %v", conn.written) } @@ -336,7 +404,12 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + _, err := authPlugin.InitAuth(authData, mc.cfg) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,7 +426,12 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -379,8 +457,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -396,7 +473,12 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -422,12 +504,151 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } +func TestAuthFastDialogPasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "dialog" + + // Send Client Authentication Packet + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + _, err := authPlugin.InitAuth(authData, mc.cfg) + if err != ErrDialogAuth { + t.Errorf("expected ErrDialogPassword, got %v", err) + } +} + +func TestAuthFastDialogPassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowDialogPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "dialog" + + // Send Client Authentication Packet + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + if err = mc.handleAuthResult(authData, authPlugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastDialogPasswordMultiple(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowDialogPasswords = true + mc.cfg.OtherPasswd = "secret2,secret3" + + // auth switch request + conn.data = []byte{43, 0, 0, 2, 254, 100, 105, 97, 108, 111, 103, 0, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {10, 0, 0, 4, 1, 112, 98, 115, 115, 119, 111, 114, 100, 0}, + {10, 0, 0, 6, 1, 112, 98, 115, 115, 119, 111, 114, 100, 0}, + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { + t.Errorf("got error: %v", err) + } + expectedReply := []byte{ + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + 8, 0, 0, 5, 115, 101, 99, 114, 101, 116, 50, 0, + 8, 0, 0, 7, 115, 101, 99, 114, 101, 116, 51, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthFastDialogPasswordMultipleNotSet(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowDialogPasswords = true + mc.cfg.OtherPasswd = "" + + // auth switch request + conn.data = []byte{43, 0, 0, 2, 254, 100, 105, 97, 108, 111, 103, 0, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {10, 0, 0, 4, 1, 112, 98, 115, 115, 119, 111, 114, 100, 0}, + {10, 0, 0, 6, 1, 112, 98, 115, 115, 119, 111, 114, 100, 0}, + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { + t.Errorf("got error: %v", err) + } + expectedReply := []byte{ + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + 1, 0, 0, 5, 0, + 1, 0, 0, 7, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + func TestAuthFastNativePasswordNotAllowed(t *testing.T) { _, mc := newRWMockConn(1) mc.cfg.User = "root" @@ -439,7 +660,12 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + _, err := authPlugin.InitAuth(authData, mc.cfg) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,7 +681,12 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -482,8 +713,7 @@ func TestAuthFastNativePassword(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -498,7 +728,12 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -524,8 +759,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { } conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -540,7 +774,12 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -569,10 +808,9 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } - if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { t.Errorf("unexpected written data: %v", conn.written) } @@ -588,7 +826,12 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -617,7 +860,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } @@ -637,7 +880,12 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -651,7 +899,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } } @@ -670,7 +918,12 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -698,8 +951,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} conn.maxReads = 1 - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, authPlugin); err != nil { t.Errorf("got error: %v", err) } @@ -726,9 +978,7 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -759,9 +1009,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -795,12 +1043,9 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReplyPrefix := []byte{ // 1. Packet: Hash 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, @@ -840,12 +1085,9 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReplyPrefix := []byte{ // 1. Packet: Hash 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, @@ -883,12 +1125,9 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReply := []byte{ // 1. Packet: Hash 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, @@ -911,8 +1150,7 @@ func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, &NativePasswordPlugin{}) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -933,12 +1171,9 @@ func TestAuthSwitchCleartextPassword(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) @@ -960,12 +1195,9 @@ func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReply := []byte{1, 0, 0, 3, 0} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) @@ -983,8 +1215,7 @@ func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, &NativePasswordPlugin{}) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -1007,9 +1238,7 @@ func TestAuthSwitchNativePassword(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1037,12 +1266,9 @@ func TestAuthSwitchNativePasswordEmpty(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } - expectedReply := []byte{0, 0, 0, 3} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) @@ -1058,8 +1284,7 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, &NativePasswordPlugin{}) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1074,8 +1299,7 @@ func TestOldAuthSwitchNotAllowed(t *testing.T) { conn.maxReads = 1 authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, &NativePasswordPlugin{}) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1097,9 +1321,7 @@ func TestAuthSwitchOldPassword(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1124,9 +1346,7 @@ func TestOldAuthSwitch(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1149,11 +1369,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} conn.maxReads = 2 - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1176,11 +1392,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} conn.maxReads = 2 - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1205,11 +1417,7 @@ func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { } conn.maxReads = 3 - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1240,11 +1448,7 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { } conn.maxReads = 3 - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1276,11 +1480,7 @@ func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { } conn.maxReads = 2 - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1312,11 +1512,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { } conn.maxReads = 2 - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult([]byte{}, &NativePasswordPlugin{}); err != nil { t.Errorf("got error: %v", err) } @@ -1339,7 +1535,12 @@ func TestEd25519Auth(t *testing.T) { plugin := "client_ed25519" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + t.Fatalf("plugin not registered") + } + + authResp, err := authPlugin.InitAuth(authData, mc.cfg) if err != nil { t.Fatal(err) } @@ -1348,11 +1549,6 @@ func TestEd25519Auth(t *testing.T) { t.Fatal(err) } - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] expectedAuthResp := []byte{ 232, 61, 201, 63, 67, 63, 51, 53, 86, 73, 238, 35, 170, 117, 146, 214, 26, 17, 35, 9, 8, 132, 245, 141, 48, 99, 66, 58, 36, 228, 48, @@ -1360,11 +1556,11 @@ func TestEd25519Auth(t *testing.T) { 68, 117, 56, 135, 171, 47, 20, 14, 133, 79, 15, 229, 124, 160, 176, 100, 138, 14, } - if writtenAuthRespLen != 64 { - t.Fatalf("expected 64 bytes from client, got %d", writtenAuthRespLen) + if len(authResp) != 64 { + t.Fatalf("expected 64 bytes from client, got %d", len(authResp)) } - if !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("auth response did not match expected value:\n%v\n%v", writtenAuthResp, expectedAuthResp) + if !bytes.Equal(authResp, expectedAuthResp) { + t.Fatalf("auth response did not match expected value:\n%v\n%v", authResp, expectedAuthResp) } conn.written = nil @@ -1375,7 +1571,7 @@ func TestEd25519Auth(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult([]byte{}, authPlugin); err != nil { t.Errorf("got error: %v", err) } } diff --git a/connector.go b/connector.go index bc1d46af..2223de87 100644 --- a/connector.go +++ b/connector.go @@ -140,14 +140,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if plugin == "" { plugin = defaultAuthPlugin } + authPlugin, exists := globalPluginRegistry.GetPlugin(plugin) + if !exists { + return nil, fmt.Errorf("this authentication plugin '%s' is not supported", plugin) + } // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { + authResp, err := authPlugin.InitAuth(authData, mc.cfg) + if err != nil && plugin != defaultAuthPlugin { // try the default auth plugin, if using the requested plugin failed c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, err = mc.auth(authData, plugin) + authPlugin, exists = globalPluginRegistry.GetPlugin(plugin) + if !exists { + return nil, fmt.Errorf("this authentication plugin '%s' is not supported", plugin) + } + authResp, err = authPlugin.InitAuth(authData, mc.cfg) if err != nil { mc.cleanup() return nil, err @@ -159,7 +167,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, authPlugin); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. diff --git a/const.go b/const.go index 4aadcd64..8a8471bc 100644 --- a/const.go +++ b/const.go @@ -184,9 +184,3 @@ const ( statusInTransReadonly statusSessionStateChanged ) - -const ( - cachingSha2PasswordRequestPublicKey = 2 - cachingSha2PasswordFastAuthSuccess = 3 - cachingSha2PasswordPerformFullAuthentication = 4 -) diff --git a/driver.go b/driver.go index 105316b8..03114c3b 100644 --- a/driver.go +++ b/driver.go @@ -41,6 +41,9 @@ type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) var ( dialsLock sync.RWMutex dials map[string]DialContextFunc + + // The global plugin registry for authentication methods + globalPluginRegistry = NewPluginRegistry() ) // RegisterDialContext registers a custom dial function. It can then be used by the diff --git a/driver_test.go b/driver_test.go index 00e82865..1e768c2a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) { } runTests(t, tdsn, func(dbt *DBTest) { + // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // when character_set_collations is set for the charset, it overrides the default collation + // so we need to check if the default collation is overridden + forceExpected := expected + var defaultCollations string + err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) + if err == nil { + // Query succeeded, need to check if we should override expected collation + collationMap := make(map[string]string) + pairs := strings.Split(defaultCollations, ",") + for _, pair := range pairs { + parts := strings.Split(pair, "=") + if len(parts) == 2 { + collationMap[parts[0]] = parts[1] + } + } + + // Get charset prefix from expected collation + parts := strings.Split(expected, "_") + if len(parts) > 0 { + charset := parts[0] + if newCollation, ok := collationMap[charset]; ok { + forceExpected = newCollation + } + } + } + var got string if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { dbt.Fatal(err) } if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) + if forceExpected != expected { + if got != forceExpected { + dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) + } + } else { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } } }) } @@ -1685,7 +1718,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1693,8 +1726,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() @@ -3645,3 +3687,395 @@ func TestIssue1567(t *testing.T) { } }) } + +// TestParsecAuth tests the Parsec authentication method with MariaDB 11.6+ +func TestParsecAuth(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // Connect to the database + db, err := sql.Open(driverNameTest, dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db.Close() + + // Check MariaDB version + var version string + err = db.QueryRow("SELECT VERSION()").Scan(&version) + if err != nil { + t.Fatalf("Error checking version: %s", err.Error()) + } + + t.Logf("Database version: %s", version) + + // Check if this is a MariaDB server + isMariaDB := strings.Contains(strings.ToLower(version), "mariadb") + if !isMariaDB { + t.Skip("Parsec authentication requires MariaDB 11.6+") + } + + // Extract version number + parts := strings.Split(version, "-") + versionParts := strings.Split(parts[0], ".") + if len(versionParts) < 2 { + t.Skip("Cannot determine MariaDB version format") + } + + major, err := strconv.Atoi(versionParts[0]) + if err != nil { + t.Skip("Cannot parse MariaDB major version") + } + + minor, err := strconv.Atoi(versionParts[1]) + if err != nil { + t.Skip("Cannot parse MariaDB minor version") + } + + // Skip if version is below 11.6 + if major < 11 || (major == 11 && minor < 6) { + t.Skip("Parsec authentication requires MariaDB 11.6+") + } + + t.Logf("MariaDB version %d.%d detected", major, minor) + + // Check if the parsec plugin is installed + var count int + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'parsec' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking parsec plugin: %s", err.Error()) + } + + if count == 0 { + // Try to install the plugin + _, err = db.Exec("INSTALL SONAME 'auth_parsec'") + if err != nil { + t.Skipf("Parsec authentication plugin is not available: %s", err.Error()) + } + + // Check again if the plugin is active + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'parsec' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking parsec plugin after installation: %s", err.Error()) + } + + if count == 0 { + t.Skip("Parsec authentication plugin could not be activated") + } + } + + t.Log("Parsec authentication plugin is active") + + // Create a test user with parsec authentication + username := "parsec_test_user" + password := "parsec_password" + + // Drop the user if it exists + _, err = db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Fatalf("Error dropping user: %s", err.Error()) + } + + // Create the user with parsec authentication + createUserSQL := fmt.Sprintf("CREATE USER '%s'@'localhost' IDENTIFIED VIA parsec USING PASSWORD('%s')", username, password) + _, err = db.Exec(createUserSQL) + if err != nil { + t.Fatalf("Error creating user: %s\nSQL: %s", err.Error(), createUserSQL) + } + defer func() { + _, err := db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Logf("Error dropping user during cleanup: %s", err.Error()) + } + }() + + // Grant privileges to the test user + grantSQL := fmt.Sprintf("GRANT ALL ON *.* TO '%s'@'localhost' WITH GRANT OPTION", username) + _, err = db.Exec(grantSQL) + if err != nil { + t.Fatalf("Error granting privileges: %s\nSQL: %s", err.Error(), grantSQL) + } + + // Flush privileges + _, err = db.Exec("FLUSH PRIVILEGES") + if err != nil { + t.Fatalf("Error flushing privileges: %s", err.Error()) + } + + // Connect with the new user using parsec authentication + parsecDSN := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", username, password, netAddr, dbname) + + t.Logf("Attempting to connect with DSN: %s", parsecDSN) + + parsecDB, err := sql.Open(driverNameTest, parsecDSN) + if err != nil { + t.Fatalf("Error opening connection with parsec: %s", err.Error()) + } + defer parsecDB.Close() + + // Verify that we can run a query + var result int + err = parsecDB.QueryRow("SELECT 1").Scan(&result) + if err != nil { + t.Fatalf("Error running query with parsec authentication: %s", err.Error()) + } + if result != 1 { + t.Fatalf("Unexpected result: expected 1, got %d", result) + } + + t.Log("Successfully authenticated using parsec authentication") +} + +func TestEd25519AuthIntegration(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // Connect to the database + db, err := sql.Open(driverNameTest, dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db.Close() + + // Check MariaDB version + var version string + err = db.QueryRow("SELECT VERSION()").Scan(&version) + if err != nil { + t.Fatalf("Error checking version: %s", err.Error()) + } + + t.Logf("Database version: %s", version) + + // Check if this is a MariaDB server + isMariaDB := strings.Contains(strings.ToLower(version), "mariadb") + if !isMariaDB { + t.Skip("ed25519 authentication test requires MariaDB") + } + + // Check if the ed25519 plugin is installed + var count int + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'ed25519' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking ed25519 plugin: %s", err.Error()) + } + + if count == 0 { + // Try to install the plugin + _, err = db.Exec("INSTALL SONAME 'auth_ed25519'") + if err != nil { + t.Skipf("ed25519 authentication plugin is not available: %s", err.Error()) + } + + // Check again if the plugin is active + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'ed25519' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking ed25519 plugin after installation: %s", err.Error()) + } + + if count == 0 { + t.Skip("ed25519 authentication plugin could not be activated") + } + } + + t.Log("ed25519 authentication plugin is active") + + // Create a test user with ed25519 authentication + username := "ed25519_test_user" + password := "ed25519_test_password" + + // Drop the user if it exists + _, err = db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Fatalf("Error dropping user: %s", err.Error()) + } + + // Create the user with ed25519 authentication + createUserSQL := fmt.Sprintf("CREATE USER '%s'@'localhost' IDENTIFIED VIA ed25519 USING PASSWORD('%s')", username, password) + _, err = db.Exec(createUserSQL) + if err != nil { + t.Fatalf("Error creating user: %s\nSQL: %s", err.Error(), createUserSQL) + } + defer func() { + _, err := db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Logf("Error dropping user during cleanup: %s", err.Error()) + } + }() + + // Grant privileges to the test user + grantSQL := fmt.Sprintf("GRANT ALL ON *.* TO '%s'@'localhost'", username) + _, err = db.Exec(grantSQL) + if err != nil { + t.Fatalf("Error granting privileges: %s\nSQL: %s", err.Error(), grantSQL) + } + + // Flush privileges + _, err = db.Exec("FLUSH PRIVILEGES") + if err != nil { + t.Fatalf("Error flushing privileges: %s", err.Error()) + } + + // Connect with the new user using ed25519 authentication + ed25519DSN := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", username, password, netAddr, dbname) + t.Logf("Attempting to connect with DSN: %s", ed25519DSN) + + ed25519DB, err := sql.Open(driverNameTest, ed25519DSN) + if err != nil { + t.Fatalf("Error opening connection with ed25519: %s", err.Error()) + } + defer ed25519DB.Close() + + // Verify that we can run a query + var result int + err = ed25519DB.QueryRow("SELECT 1").Scan(&result) + if err != nil { + t.Fatalf("Error running query with ed25519 authentication: %s", err.Error()) + } + if result != 1 { + t.Fatalf("Unexpected result: expected 1, got %d", result) + } + + t.Log("Successfully authenticated using ed25519 authentication") +} + +// TestMultiAuthIntegration tests the multiple authentication methods with MariaDB +// there will be 3 authentication methods: mysql_native_password, ed25519, parsec +// * first native password will be wrong, +// * server will send a authentication switch request with ed25519 that will fail +// * server will send another authentication switch request with parsec that will succeed +func TestMultiAuthIntegration(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // Connect to the database + db, err := sql.Open(driverNameTest, dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db.Close() + + // Check MariaDB version + var version string + err = db.QueryRow("SELECT VERSION()").Scan(&version) + if err != nil { + t.Fatalf("Error checking version: %s", err.Error()) + } + + t.Logf("Database version: %s", version) + + // Check if this is a MariaDB server + isMariaDB := strings.Contains(strings.ToLower(version), "mariadb") + if !isMariaDB { + t.Skip("Parsec authentication requires MariaDB 11.6+") + } + + // Extract version number + parts := strings.Split(version, "-") + versionParts := strings.Split(parts[0], ".") + if len(versionParts) < 2 { + t.Skip("Cannot determine MariaDB version format") + } + + major, err := strconv.Atoi(versionParts[0]) + if err != nil { + t.Skip("Cannot parse MariaDB major version") + } + + minor, err := strconv.Atoi(versionParts[1]) + if err != nil { + t.Skip("Cannot parse MariaDB minor version") + } + + // Skip if version is below 11.6 + if major < 11 || (major == 11 && minor < 6) { + t.Skip("Parsec authentication requires MariaDB 11.6+") + } + + // Check if the ed25519 plugin is installed + var count int + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'ed25519' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking ed25519 plugin: %s", err.Error()) + } + + if count == 0 { + // Try to install the plugin + _, err = db.Exec("INSTALL SONAME 'auth_ed25519'") + if err != nil { + t.Skipf("ed25519 authentication plugin is not available: %s", err.Error()) + } + + // Check again if the plugin is active + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.plugins WHERE PLUGIN_NAME = 'ed25519' AND PLUGIN_STATUS = 'ACTIVE'").Scan(&count) + if err != nil { + t.Fatalf("Error checking ed25519 plugin after installation: %s", err.Error()) + } + + if count == 0 { + t.Skip("ed25519 authentication plugin could not be activated") + } + } + + t.Log("ed25519 authentication plugin is active") + + // Create a test user with multiple authentication methods + username := "multi_auth_test_user" + password := "multi_auth_test_password" + + // Drop the user if it exists + _, err = db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Fatalf("Error dropping user: %s", err.Error()) + } + + // Create the user with multiple authentication methods, password only for parsec + createUserSQL := fmt.Sprintf("CREATE USER '%s'@'localhost' IDENTIFIED VIA mysql_native_password USING PASSWORD('wrongPwd') OR ed25519 USING PASSWORD('anotherWrongPwd') OR parsec USING PASSWORD('%s')", username, password) + _, err = db.Exec(createUserSQL) + if err != nil { + t.Fatalf("Error creating user: %s\nSQL: %s", err.Error(), createUserSQL) + } + defer func() { + _, err := db.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'localhost'", username)) + if err != nil { + t.Logf("Error dropping user during cleanup: %s", err.Error()) + } + }() + + // Grant privileges to the test user + grantSQL := fmt.Sprintf("GRANT ALL ON *.* TO '%s'@'localhost'", username) + _, err = db.Exec(grantSQL) + if err != nil { + t.Fatalf("Error granting privileges: %s\nSQL: %s", err.Error(), grantSQL) + } + + // Flush privileges + _, err = db.Exec("FLUSH PRIVILEGES") + if err != nil { + t.Fatalf("Error flushing privileges: %s", err.Error()) + } + + // Test parsec authentication with password + parsecDSN := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", username, password, netAddr, dbname) + t.Logf("Attempting to connect with parsec DSN: %s", parsecDSN) + + parsecDB, err := sql.Open(driverNameTest, parsecDSN) + if err != nil { + t.Fatalf("Error opening connection with parsec: %s", err.Error()) + } + defer parsecDB.Close() + + // Verify that we can run a query using parsec auth + var result int + err = parsecDB.QueryRow("SELECT 1").Scan(&result) + if err != nil { + t.Fatalf("Error running query with parsec authentication: %s", err.Error()) + } + if result != 1 { + t.Fatalf("Unexpected result with parsec auth: expected 1, got %d", result) + } + + t.Log("Successfully authenticated using parsec authentication with password") +} diff --git a/dsn.go b/dsn.go index ecf62567..fbc85931 100644 --- a/dsn.go +++ b/dsn.go @@ -39,6 +39,7 @@ type Config struct { User string // Username Passwd string // Password (requires User) + OtherPasswd string // Other Passwords, comma-delimited passwords for dialog authentication Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") DBName string // Database name @@ -61,6 +62,7 @@ type Config struct { AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowDialogPasswords bool // Allows the dialog authentication plugin AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS AllowNativePasswords bool // Allows the native password authentication method AllowOldPasswords bool // Allows the old insecure password method @@ -254,7 +256,7 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { func (cfg *Config) FormatDSN() string { var buf bytes.Buffer - // [username[:password]@] + // [[username][:password]@] if len(cfg.User) > 0 { buf.WriteString(cfg.User) if len(cfg.Passwd) > 0 { @@ -262,6 +264,10 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.Passwd) } buf.WriteByte('@') + } else if len(cfg.Passwd) > 0 { + buf.WriteByte(':') + buf.WriteString(cfg.Passwd) + buf.WriteByte('@') } // [protocol[(address)]] @@ -290,6 +296,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") } + if cfg.AllowDialogPasswords { + writeDSNParam(&buf, &hasParam, "allowDialogPasswords", "true") + } + if cfg.AllowFallbackToPlaintext { writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true") } @@ -342,6 +352,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "multiStatements", "true") } + if cfg.OtherPasswd != "" { + writeDSNParam(&buf, &hasParam, "OtherPasswd", url.QueryEscape(cfg.OtherPasswd)) + } + if cfg.ParseTime { writeDSNParam(&buf, &hasParam, "parseTime", "true") } @@ -398,7 +412,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values cfg = NewConfig() - // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // [[username][:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { @@ -408,11 +422,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // left part is empty if i <= 0 if i > 0 { - // [username[:password]@][protocol[(address)]] + // [[username][:password]@][protocol[(address)]] // Find the last '@' in dsn[:i] for j = i; j >= 0; j-- { if dsn[j] == '@' { - // username[:password] + // [username][:password] // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { @@ -501,6 +515,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Use dialog authentication + case "allowDialogPasswords": + var isBool bool + cfg.AllowDialogPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Allow fallback to unencrypted connection if server does not support TLS case "allowFallbackToPlaintext": var isBool bool @@ -678,6 +700,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { } cfg.ConnectionAttributes = connectionAttributes + case "OtherPasswd": + otherPasswd, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid OtherPasswd value: %v", err) + } + cfg.OtherPasswd = otherPasswd + default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index 436f7799..f350c9de 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -47,6 +47,9 @@ var testDSNs = []struct { }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + ":p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", + &Config{Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname", &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, diff --git a/errors.go b/errors.go index 584617b1..8c47c5d3 100644 --- a/errors.go +++ b/errors.go @@ -29,6 +29,8 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`") ErrBusyBuffer = errors.New("busy buffer") + ErrParsecAuth = errors.New("malformed parsec authentication data") + ErrDialogAuth = errors.New("this user requires PAM authentication. If you still want to use it, please add 'AllowDialogPasswords=1' to your DSN") // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn diff --git a/go.mod b/go.mod index 187aff17..53300d69 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,5 @@ module github.com/go-sql-driver/mysql go 1.21.0 +require golang.org/x/crypto v0.16.0 require filippo.io/edwards25519 v1.1.0 diff --git a/go.sum b/go.sum index 359ca94b..0b79c8f5 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= diff --git a/packets.go b/packets.go index 4b836216..54a18dbf 100644 --- a/packets.go +++ b/packets.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + osuser "os/user" "strconv" "time" ) @@ -303,8 +304,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // length encoded integer clientFlags |= clientPluginAuthLenEncClientData } + var userName string + if len(mc.cfg.User) > 0 { + userName = mc.cfg.User + } else { + // Get current user if username is empty + if currentUser, err := osuser.Current(); err == nil { + userName = currentUser.Username + } + } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(userName) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -372,8 +382,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // User [null terminated string] - if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + if len(userName) > 0 { + pos += copy(data[pos:], userName) } data[pos] = 0x00 pos++ @@ -482,44 +492,6 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ -func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { - data, err := mc.readPacket() - if err != nil { - return nil, "", err - } - - // packet indicator - switch data[0] { - - case iOK: - // resultUnchanged, since auth happens before any queries or - // commands have been executed. - return nil, "", mc.resultUnchanged().handleOkPacket(data) - - case iAuthMoreData: - return data[1:], "", err - - case iEOF: - if len(data) == 1 { - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, "mysql_old_password", nil - } - pluginEndIndex := bytes.IndexByte(data, 0x00) - if pluginEndIndex < 0 { - return nil, "", ErrMalformPkt - } - plugin := string(data[1:pluginEndIndex]) - authData := data[pluginEndIndex+1:] - if len(authData) > 0 && authData[len(authData)-1] == 0 { - authData = authData[:len(authData)-1] - } - return authData, plugin, nil - - default: // Error otherwise - return nil, "", mc.handleErrorPacket(data) - } -} - // Returns error if Packet is not a 'Result OK'-Packet func (mc *okHandler) readResultOK() error { data, err := mc.conn().readPacket()