123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- package client
- import (
- "bytes"
- "crypto/tls"
- "encoding/binary"
- "fmt"
- . "github.com/go-mysql-org/go-mysql/mysql"
- "github.com/go-mysql-org/go-mysql/packet"
- "github.com/pingcap/errors"
- )
- const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
- // defines the supported auth plugins
- var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
- // helper function to determine what auth methods are allowed by this client
- func authPluginAllowed(pluginName string) bool {
- for _, p := range supportedAuthPlugins {
- if pluginName == p {
- return true
- }
- }
- return false
- }
- // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
- func (c *Conn) readInitialHandshake() error {
- data, err := c.ReadPacket()
- if err != nil {
- return errors.Trace(err)
- }
- if data[0] == ERR_HEADER {
- return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error")
- }
- if data[0] < MinProtocolVersion {
- return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
- }
- // skip mysql version
- // mysql version end with 0x00
- pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
- // connection id length is 4
- c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
- pos += 4
- c.salt = []byte{}
- c.salt = append(c.salt, data[pos:pos+8]...)
- // skip filter
- pos += 8 + 1
- // capability lower 2 bytes
- c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
- // check protocol
- if c.capability&CLIENT_PROTOCOL_41 == 0 {
- return errors.New("the MySQL server can not support protocol 41 and above required by the client")
- }
- if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
- return errors.New("the MySQL Server does not support TLS required by the client")
- }
- pos += 2
- if len(data) > pos {
- // skip server charset
- //c.charset = data[pos]
- pos += 1
- c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
- pos += 2
- // capability flags (upper 2 bytes)
- c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
- pos += 2
- // auth_data is end with 0x00, min data length is 13 + 8 = 21
- // ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
- maxAuthDataLen := 21
- if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
- maxAuthDataLen = int(data[pos])
- }
- // skip reserved (all [00])
- pos += 10 + 1
- // auth_data is end with 0x00, so we need to trim 0x00
- resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
- c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
- // skip reset of end pos
- pos = resetOfAuthDataEndPos + 1
- if c.capability&CLIENT_PLUGIN_AUTH != 0 {
- c.authPluginName = string(data[pos : len(data)-1])
- }
- }
- // if server gives no default auth plugin name, use a client default
- if c.authPluginName == "" {
- c.authPluginName = defaultAuthPluginName
- }
- return nil
- }
- // generate auth response data according to auth plugin
- //
- // NOTE: the returned boolean value indicates whether to add a \NUL to the end of data.
- // it is quite tricky because MySQL server expects different formats of responses in different auth situations.
- // here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password'
- // authentication.
- func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
- // password hashing
- switch c.authPluginName {
- case AUTH_NATIVE_PASSWORD:
- return CalcPassword(authData[:20], []byte(c.password)), false, nil
- case AUTH_CACHING_SHA2_PASSWORD:
- return CalcCachingSha2Password(authData, c.password), false, nil
- case AUTH_SHA256_PASSWORD:
- if len(c.password) == 0 {
- return nil, true, nil
- }
- if c.tlsConfig != nil || c.proto == "unix" {
- // write cleartext auth packet
- // see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
- return []byte(c.password), true, nil
- } else {
- // request public key from server
- // see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html
- return []byte{1}, false, nil
- }
- default:
- // not reachable
- return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName)
- }
- }
- // generate connection attributes data
- func (c *Conn) genAttributes() []byte {
- if len(c.attributes) == 0 {
- return nil
- }
- attrData := make([]byte, 0)
- for k, v := range c.attributes {
- attrData = append(attrData, PutLengthEncodedString([]byte(k))...)
- attrData = append(attrData, PutLengthEncodedString([]byte(v))...)
- }
- return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...)
- }
- // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
- func (c *Conn) writeAuthHandshake() error {
- if !authPluginAllowed(c.authPluginName) {
- return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName)
- }
- // Set default client capabilities that reflect the abilities of this library
- capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
- CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
- // Adjust client capability flags based on server support
- capability |= c.capability & CLIENT_LONG_FLAG
- // Adjust client capability flags on specific client requests
- // Only flags that would make any sense setting and aren't handled elsewhere
- // in the library are supported here
- capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
- c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
- c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
- // To enable TLS / SSL
- if c.tlsConfig != nil {
- capability |= CLIENT_SSL
- }
- auth, addNull, err := c.genAuthResponse(c.salt)
- if err != nil {
- return err
- }
- // encode length of the auth plugin data
- // here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
- // see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
- var authRespLEIBuf [9]byte
- authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
- if len(authRespLEI) > 1 {
- // if the length can not be written in 1 byte, it must be written as a
- // length encoded integer
- capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
- }
- //packet length
- //capability 4
- //max-packet size 4
- //charset 1
- //reserved all[0] 23
- //username
- //auth
- //mysql_native_password + null-terminated
- length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1
- if addNull {
- length++
- }
- // db name
- if len(c.db) > 0 {
- capability |= CLIENT_CONNECT_WITH_DB
- length += len(c.db) + 1
- }
- // connection attributes
- attrData := c.genAttributes()
- if len(attrData) > 0 {
- capability |= CLIENT_CONNECT_ATTRS
- length += len(attrData)
- }
- data := make([]byte, length+4)
- // capability [32 bit]
- data[4] = byte(capability)
- data[5] = byte(capability >> 8)
- data[6] = byte(capability >> 16)
- data[7] = byte(capability >> 24)
- // MaxPacketSize [32 bit] (none)
- data[8] = 0x00
- data[9] = 0x00
- data[10] = 0x00
- data[11] = 0x00
- // Charset [1 byte]
- // use default collation id 33 here, is utf-8
- data[12] = DEFAULT_COLLATION_ID
- // SSL Connection Request Packet
- // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
- if c.tlsConfig != nil {
- // Send TLS / SSL request packet
- if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
- return err
- }
- // Switch to TLS
- tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig)
- if err := tlsConn.Handshake(); err != nil {
- return err
- }
- currentSequence := c.Sequence
- c.Conn = packet.NewConn(tlsConn)
- c.Sequence = currentSequence
- }
- // Filler [23 bytes] (all 0x00)
- pos := 13
- for ; pos < 13+23; pos++ {
- data[pos] = 0
- }
- // User [null terminated string]
- if len(c.user) > 0 {
- pos += copy(data[pos:], c.user)
- }
- data[pos] = 0x00
- pos++
- // auth [length encoded integer]
- pos += copy(data[pos:], authRespLEI)
- pos += copy(data[pos:], auth)
- if addNull {
- data[pos] = 0x00
- pos++
- }
- // db [null terminated string]
- if len(c.db) > 0 {
- pos += copy(data[pos:], c.db)
- data[pos] = 0x00
- pos++
- }
- // Assume native client during response
- pos += copy(data[pos:], c.authPluginName)
- data[pos] = 0x00
- pos++
- // connection attributes
- if len(attrData) > 0 {
- copy(data[pos:], attrData)
- }
- return c.WritePacket(data)
- }
|