conn.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package packet
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/rand"
  6. "crypto/rsa"
  7. "crypto/sha1"
  8. "crypto/x509"
  9. "encoding/pem"
  10. "io"
  11. "net"
  12. "sync"
  13. . "github.com/go-mysql-org/go-mysql/mysql"
  14. "github.com/go-mysql-org/go-mysql/utils"
  15. "github.com/pingcap/errors"
  16. )
  17. type BufPool struct {
  18. pool *sync.Pool
  19. }
  20. func NewBufPool() *BufPool {
  21. return &BufPool{
  22. pool: &sync.Pool{
  23. New: func() interface{} {
  24. return new(bytes.Buffer)
  25. },
  26. },
  27. }
  28. }
  29. func (b *BufPool) Get() *bytes.Buffer {
  30. return b.pool.Get().(*bytes.Buffer)
  31. }
  32. func (b *BufPool) Return(buf *bytes.Buffer) {
  33. buf.Reset()
  34. b.pool.Put(buf)
  35. }
  36. /*
  37. Conn is the base class to handle MySQL protocol.
  38. */
  39. type Conn struct {
  40. net.Conn
  41. // we removed the buffer reader because it will cause the SSLRequest to block (tls connection handshake won't be
  42. // able to read the "Client Hello" data since it has been buffered into the buffer reader)
  43. bufPool *BufPool
  44. br *bufio.Reader
  45. reader io.Reader
  46. copyNBuf []byte
  47. header [4]byte
  48. Sequence uint8
  49. }
  50. func NewConn(conn net.Conn) *Conn {
  51. c := new(Conn)
  52. c.Conn = conn
  53. c.bufPool = NewBufPool()
  54. c.br = bufio.NewReaderSize(c, 65536) // 64kb
  55. c.reader = c.br
  56. c.copyNBuf = make([]byte, 16*1024)
  57. return c
  58. }
  59. func NewTLSConn(conn net.Conn) *Conn {
  60. c := new(Conn)
  61. c.Conn = conn
  62. c.bufPool = NewBufPool()
  63. c.reader = c
  64. c.copyNBuf = make([]byte, 16*1024)
  65. return c
  66. }
  67. func (c *Conn) ReadPacket() ([]byte, error) {
  68. return c.ReadPacketReuseMem(nil)
  69. }
  70. func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
  71. // Here we use `sync.Pool` to avoid allocate/destroy buffers frequently.
  72. buf := utils.BytesBufferGet()
  73. defer utils.BytesBufferPut(buf)
  74. if err := c.ReadPacketTo(buf); err != nil {
  75. return nil, errors.Trace(err)
  76. }
  77. readBytes := buf.Bytes()
  78. readSize := len(readBytes)
  79. var result []byte
  80. if len(dst) > 0 {
  81. result = append(dst, readBytes...)
  82. // if read block is big, do not cache buf any more
  83. if readSize > utils.TooBigBlockSize {
  84. buf = nil
  85. }
  86. } else {
  87. if readSize > utils.TooBigBlockSize {
  88. // if read block is big, use read block as result and do not cache buf any more
  89. result = readBytes
  90. buf = nil
  91. } else {
  92. result = append(dst, readBytes...)
  93. }
  94. }
  95. return result, nil
  96. }
  97. func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
  98. for n > 0 {
  99. bcap := cap(c.copyNBuf)
  100. if int64(bcap) > n {
  101. bcap = int(n)
  102. }
  103. buf := c.copyNBuf[:bcap]
  104. rd, err := io.ReadAtLeast(src, buf, bcap)
  105. n -= int64(rd)
  106. if err != nil {
  107. return written, errors.Trace(err)
  108. }
  109. wr, err := dst.Write(buf)
  110. written += int64(wr)
  111. if err != nil {
  112. return written, errors.Trace(err)
  113. }
  114. }
  115. return written, nil
  116. }
  117. func (c *Conn) ReadPacketTo(w io.Writer) error {
  118. if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil {
  119. return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
  120. }
  121. length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
  122. sequence := c.header[3]
  123. if sequence != c.Sequence {
  124. return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
  125. }
  126. c.Sequence++
  127. if buf, ok := w.(*bytes.Buffer); ok {
  128. // Allocate the buffer with expected length directly instead of call `grow` and migrate data many times.
  129. buf.Grow(length)
  130. }
  131. if n, err := c.copyN(w, c.reader, int64(length)); err != nil {
  132. return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
  133. } else if n != int64(length) {
  134. return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
  135. } else {
  136. if length < MaxPayloadLen {
  137. return nil
  138. }
  139. if err := c.ReadPacketTo(w); err != nil {
  140. return errors.Wrap(err, "ReadPacketTo failed")
  141. }
  142. }
  143. return nil
  144. }
  145. // WritePacket: data already has 4 bytes header
  146. // will modify data inplace
  147. func (c *Conn) WritePacket(data []byte) error {
  148. length := len(data) - 4
  149. for length >= MaxPayloadLen {
  150. data[0] = 0xff
  151. data[1] = 0xff
  152. data[2] = 0xff
  153. data[3] = c.Sequence
  154. if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil {
  155. return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. err %v", err)
  156. } else if n != (4 + MaxPayloadLen) {
  157. return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+MaxPayloadLen)
  158. } else {
  159. c.Sequence++
  160. length -= MaxPayloadLen
  161. data = data[MaxPayloadLen:]
  162. }
  163. }
  164. data[0] = byte(length)
  165. data[1] = byte(length >> 8)
  166. data[2] = byte(length >> 16)
  167. data[3] = c.Sequence
  168. if n, err := c.Write(data); err != nil {
  169. return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
  170. } else if n != len(data) {
  171. return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
  172. } else {
  173. c.Sequence++
  174. return nil
  175. }
  176. }
  177. // WriteClearAuthPacket: Client clear text authentication packet
  178. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  179. func (c *Conn) WriteClearAuthPacket(password string) error {
  180. // Calculate the packet length and add a tailing 0
  181. pktLen := len(password) + 1
  182. data := make([]byte, 4+pktLen)
  183. // Add the clear password [null terminated string]
  184. copy(data[4:], password)
  185. data[4+pktLen-1] = 0x00
  186. return errors.Wrap(c.WritePacket(data), "WritePacket failed")
  187. }
  188. // WritePublicKeyAuthPacket: Caching sha2 authentication. Public key request and send encrypted password
  189. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  190. func (c *Conn) WritePublicKeyAuthPacket(password string, cipher []byte) error {
  191. // request public key
  192. data := make([]byte, 4+1)
  193. data[4] = 2 // cachingSha2PasswordRequestPublicKey
  194. if err := c.WritePacket(data); err != nil {
  195. return errors.Wrap(err, "WritePacket(single byte) failed")
  196. }
  197. data, err := c.ReadPacket()
  198. if err != nil {
  199. return errors.Wrap(err, "ReadPacket failed")
  200. }
  201. block, _ := pem.Decode(data[1:])
  202. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  203. if err != nil {
  204. return errors.Wrap(err, "x509.ParsePKIXPublicKey failed")
  205. }
  206. plain := make([]byte, len(password)+1)
  207. copy(plain, password)
  208. for i := range plain {
  209. j := i % len(cipher)
  210. plain[i] ^= cipher[j]
  211. }
  212. sha1v := sha1.New()
  213. enc, _ := rsa.EncryptOAEP(sha1v, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
  214. data = make([]byte, 4+len(enc))
  215. copy(data[4:], enc)
  216. return errors.Wrap(c.WritePacket(data), "WritePacket failed")
  217. }
  218. func (c *Conn) WriteEncryptedPassword(password string, seed []byte, pub *rsa.PublicKey) error {
  219. enc, err := EncryptPassword(password, seed, pub)
  220. if err != nil {
  221. return errors.Wrap(err, "EncryptPassword failed")
  222. }
  223. return errors.Wrap(c.WriteAuthSwitchPacket(enc, false), "WriteAuthSwitchPacket failed")
  224. }
  225. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
  226. func (c *Conn) WriteAuthSwitchPacket(authData []byte, addNUL bool) error {
  227. pktLen := 4 + len(authData)
  228. if addNUL {
  229. pktLen++
  230. }
  231. data := make([]byte, pktLen)
  232. // Add the auth data [EOF]
  233. copy(data[4:], authData)
  234. if addNUL {
  235. data[pktLen-1] = 0x00
  236. }
  237. return errors.Wrap(c.WritePacket(data), "WritePacket failed")
  238. }
  239. func (c *Conn) ResetSequence() {
  240. c.Sequence = 0
  241. }
  242. func (c *Conn) Close() error {
  243. c.Sequence = 0
  244. if c.Conn != nil {
  245. return errors.Wrap(c.Conn.Close(), "Conn.Close failed")
  246. }
  247. return nil
  248. }