auth.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package client
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "encoding/binary"
  6. "fmt"
  7. . "github.com/go-mysql-org/go-mysql/mysql"
  8. "github.com/go-mysql-org/go-mysql/packet"
  9. "github.com/pingcap/errors"
  10. )
  11. const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
  12. // defines the supported auth plugins
  13. var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
  14. // helper function to determine what auth methods are allowed by this client
  15. func authPluginAllowed(pluginName string) bool {
  16. for _, p := range supportedAuthPlugins {
  17. if pluginName == p {
  18. return true
  19. }
  20. }
  21. return false
  22. }
  23. // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
  24. func (c *Conn) readInitialHandshake() error {
  25. data, err := c.ReadPacket()
  26. if err != nil {
  27. return errors.Trace(err)
  28. }
  29. if data[0] == ERR_HEADER {
  30. return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error")
  31. }
  32. if data[0] < MinProtocolVersion {
  33. return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
  34. }
  35. // skip mysql version
  36. // mysql version end with 0x00
  37. pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
  38. // connection id length is 4
  39. c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
  40. pos += 4
  41. c.salt = []byte{}
  42. c.salt = append(c.salt, data[pos:pos+8]...)
  43. // skip filter
  44. pos += 8 + 1
  45. // capability lower 2 bytes
  46. c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
  47. // check protocol
  48. if c.capability&CLIENT_PROTOCOL_41 == 0 {
  49. return errors.New("the MySQL server can not support protocol 41 and above required by the client")
  50. }
  51. if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
  52. return errors.New("the MySQL Server does not support TLS required by the client")
  53. }
  54. pos += 2
  55. if len(data) > pos {
  56. // skip server charset
  57. //c.charset = data[pos]
  58. pos += 1
  59. c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
  60. pos += 2
  61. // capability flags (upper 2 bytes)
  62. c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
  63. pos += 2
  64. // auth_data is end with 0x00, min data length is 13 + 8 = 21
  65. // ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
  66. maxAuthDataLen := 21
  67. if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
  68. maxAuthDataLen = int(data[pos])
  69. }
  70. // skip reserved (all [00])
  71. pos += 10 + 1
  72. // auth_data is end with 0x00, so we need to trim 0x00
  73. resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
  74. c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
  75. // skip reset of end pos
  76. pos = resetOfAuthDataEndPos + 1
  77. if c.capability&CLIENT_PLUGIN_AUTH != 0 {
  78. c.authPluginName = string(data[pos : len(data)-1])
  79. }
  80. }
  81. // if server gives no default auth plugin name, use a client default
  82. if c.authPluginName == "" {
  83. c.authPluginName = defaultAuthPluginName
  84. }
  85. return nil
  86. }
  87. // generate auth response data according to auth plugin
  88. //
  89. // NOTE: the returned boolean value indicates whether to add a \NUL to the end of data.
  90. // it is quite tricky because MySQL server expects different formats of responses in different auth situations.
  91. // here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password'
  92. // authentication.
  93. func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
  94. // password hashing
  95. switch c.authPluginName {
  96. case AUTH_NATIVE_PASSWORD:
  97. return CalcPassword(authData[:20], []byte(c.password)), false, nil
  98. case AUTH_CACHING_SHA2_PASSWORD:
  99. return CalcCachingSha2Password(authData, c.password), false, nil
  100. case AUTH_SHA256_PASSWORD:
  101. if len(c.password) == 0 {
  102. return nil, true, nil
  103. }
  104. if c.tlsConfig != nil || c.proto == "unix" {
  105. // write cleartext auth packet
  106. // see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
  107. return []byte(c.password), true, nil
  108. } else {
  109. // request public key from server
  110. // see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html
  111. return []byte{1}, false, nil
  112. }
  113. default:
  114. // not reachable
  115. return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName)
  116. }
  117. }
  118. // generate connection attributes data
  119. func (c *Conn) genAttributes() []byte {
  120. if len(c.attributes) == 0 {
  121. return nil
  122. }
  123. attrData := make([]byte, 0)
  124. for k, v := range c.attributes {
  125. attrData = append(attrData, PutLengthEncodedString([]byte(k))...)
  126. attrData = append(attrData, PutLengthEncodedString([]byte(v))...)
  127. }
  128. return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...)
  129. }
  130. // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
  131. func (c *Conn) writeAuthHandshake() error {
  132. if !authPluginAllowed(c.authPluginName) {
  133. return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName)
  134. }
  135. // Set default client capabilities that reflect the abilities of this library
  136. capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
  137. CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
  138. // Adjust client capability flags based on server support
  139. capability |= c.capability & CLIENT_LONG_FLAG
  140. // Adjust client capability flags on specific client requests
  141. // Only flags that would make any sense setting and aren't handled elsewhere
  142. // in the library are supported here
  143. capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
  144. c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
  145. c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
  146. // To enable TLS / SSL
  147. if c.tlsConfig != nil {
  148. capability |= CLIENT_SSL
  149. }
  150. auth, addNull, err := c.genAuthResponse(c.salt)
  151. if err != nil {
  152. return err
  153. }
  154. // encode length of the auth plugin data
  155. // here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
  156. // see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
  157. var authRespLEIBuf [9]byte
  158. authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
  159. if len(authRespLEI) > 1 {
  160. // if the length can not be written in 1 byte, it must be written as a
  161. // length encoded integer
  162. capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
  163. }
  164. //packet length
  165. //capability 4
  166. //max-packet size 4
  167. //charset 1
  168. //reserved all[0] 23
  169. //username
  170. //auth
  171. //mysql_native_password + null-terminated
  172. length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1
  173. if addNull {
  174. length++
  175. }
  176. // db name
  177. if len(c.db) > 0 {
  178. capability |= CLIENT_CONNECT_WITH_DB
  179. length += len(c.db) + 1
  180. }
  181. // connection attributes
  182. attrData := c.genAttributes()
  183. if len(attrData) > 0 {
  184. capability |= CLIENT_CONNECT_ATTRS
  185. length += len(attrData)
  186. }
  187. data := make([]byte, length+4)
  188. // capability [32 bit]
  189. data[4] = byte(capability)
  190. data[5] = byte(capability >> 8)
  191. data[6] = byte(capability >> 16)
  192. data[7] = byte(capability >> 24)
  193. // MaxPacketSize [32 bit] (none)
  194. data[8] = 0x00
  195. data[9] = 0x00
  196. data[10] = 0x00
  197. data[11] = 0x00
  198. // Charset [1 byte]
  199. // use default collation id 33 here, is utf-8
  200. data[12] = DEFAULT_COLLATION_ID
  201. // SSL Connection Request Packet
  202. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
  203. if c.tlsConfig != nil {
  204. // Send TLS / SSL request packet
  205. if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
  206. return err
  207. }
  208. // Switch to TLS
  209. tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig)
  210. if err := tlsConn.Handshake(); err != nil {
  211. return err
  212. }
  213. currentSequence := c.Sequence
  214. c.Conn = packet.NewConn(tlsConn)
  215. c.Sequence = currentSequence
  216. }
  217. // Filler [23 bytes] (all 0x00)
  218. pos := 13
  219. for ; pos < 13+23; pos++ {
  220. data[pos] = 0
  221. }
  222. // User [null terminated string]
  223. if len(c.user) > 0 {
  224. pos += copy(data[pos:], c.user)
  225. }
  226. data[pos] = 0x00
  227. pos++
  228. // auth [length encoded integer]
  229. pos += copy(data[pos:], authRespLEI)
  230. pos += copy(data[pos:], auth)
  231. if addNull {
  232. data[pos] = 0x00
  233. pos++
  234. }
  235. // db [null terminated string]
  236. if len(c.db) > 0 {
  237. pos += copy(data[pos:], c.db)
  238. data[pos] = 0x00
  239. pos++
  240. }
  241. // Assume native client during response
  242. pos += copy(data[pos:], c.authPluginName)
  243. data[pos] = 0x00
  244. pos++
  245. // connection attributes
  246. if len(attrData) > 0 {
  247. copy(data[pos:], attrData)
  248. }
  249. return c.WritePacket(data)
  250. }