123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- package packet
- import (
- "bufio"
- "bytes"
- "crypto/rand"
- "crypto/rsa"
- "crypto/sha1"
- "crypto/x509"
- "encoding/pem"
- "io"
- "net"
- "sync"
- . "github.com/go-mysql-org/go-mysql/mysql"
- "github.com/go-mysql-org/go-mysql/utils"
- "github.com/pingcap/errors"
- )
- type BufPool struct {
- pool *sync.Pool
- }
- func NewBufPool() *BufPool {
- return &BufPool{
- pool: &sync.Pool{
- New: func() interface{} {
- return new(bytes.Buffer)
- },
- },
- }
- }
- func (b *BufPool) Get() *bytes.Buffer {
- return b.pool.Get().(*bytes.Buffer)
- }
- func (b *BufPool) Return(buf *bytes.Buffer) {
- buf.Reset()
- b.pool.Put(buf)
- }
- /*
- Conn is the base class to handle MySQL protocol.
- */
- type Conn struct {
- net.Conn
- // we removed the buffer reader because it will cause the SSLRequest to block (tls connection handshake won't be
- // able to read the "Client Hello" data since it has been buffered into the buffer reader)
- bufPool *BufPool
- br *bufio.Reader
- reader io.Reader
- copyNBuf []byte
- header [4]byte
- Sequence uint8
- }
- func NewConn(conn net.Conn) *Conn {
- c := new(Conn)
- c.Conn = conn
- c.bufPool = NewBufPool()
- c.br = bufio.NewReaderSize(c, 65536) // 64kb
- c.reader = c.br
- c.copyNBuf = make([]byte, 16*1024)
- return c
- }
- func NewTLSConn(conn net.Conn) *Conn {
- c := new(Conn)
- c.Conn = conn
- c.bufPool = NewBufPool()
- c.reader = c
- c.copyNBuf = make([]byte, 16*1024)
- return c
- }
- func (c *Conn) ReadPacket() ([]byte, error) {
- return c.ReadPacketReuseMem(nil)
- }
- func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
- // Here we use `sync.Pool` to avoid allocate/destroy buffers frequently.
- buf := utils.BytesBufferGet()
- defer utils.BytesBufferPut(buf)
- if err := c.ReadPacketTo(buf); err != nil {
- return nil, errors.Trace(err)
- }
- readBytes := buf.Bytes()
- readSize := len(readBytes)
- var result []byte
- if len(dst) > 0 {
- result = append(dst, readBytes...)
- // if read block is big, do not cache buf any more
- if readSize > utils.TooBigBlockSize {
- buf = nil
- }
- } else {
- if readSize > utils.TooBigBlockSize {
- // if read block is big, use read block as result and do not cache buf any more
- result = readBytes
- buf = nil
- } else {
- result = append(dst, readBytes...)
- }
- }
- return result, nil
- }
- func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
- for n > 0 {
- bcap := cap(c.copyNBuf)
- if int64(bcap) > n {
- bcap = int(n)
- }
- buf := c.copyNBuf[:bcap]
- rd, err := io.ReadAtLeast(src, buf, bcap)
- n -= int64(rd)
- if err != nil {
- return written, errors.Trace(err)
- }
- wr, err := dst.Write(buf)
- written += int64(wr)
- if err != nil {
- return written, errors.Trace(err)
- }
- }
- return written, nil
- }
- func (c *Conn) ReadPacketTo(w io.Writer) error {
- if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil {
- return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
- }
- length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
- sequence := c.header[3]
- if sequence != c.Sequence {
- return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
- }
- c.Sequence++
- if buf, ok := w.(*bytes.Buffer); ok {
- // Allocate the buffer with expected length directly instead of call `grow` and migrate data many times.
- buf.Grow(length)
- }
- if n, err := c.copyN(w, c.reader, int64(length)); err != nil {
- return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
- } else if n != int64(length) {
- return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
- } else {
- if length < MaxPayloadLen {
- return nil
- }
- if err := c.ReadPacketTo(w); err != nil {
- return errors.Wrap(err, "ReadPacketTo failed")
- }
- }
- return nil
- }
- // WritePacket: data already has 4 bytes header
- // will modify data inplace
- func (c *Conn) WritePacket(data []byte) error {
- length := len(data) - 4
- for length >= MaxPayloadLen {
- data[0] = 0xff
- data[1] = 0xff
- data[2] = 0xff
- data[3] = c.Sequence
- if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil {
- return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. err %v", err)
- } else if n != (4 + MaxPayloadLen) {
- return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+MaxPayloadLen)
- } else {
- c.Sequence++
- length -= MaxPayloadLen
- data = data[MaxPayloadLen:]
- }
- }
- data[0] = byte(length)
- data[1] = byte(length >> 8)
- data[2] = byte(length >> 16)
- data[3] = c.Sequence
- if n, err := c.Write(data); err != nil {
- return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
- } else if n != len(data) {
- return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
- } else {
- c.Sequence++
- return nil
- }
- }
- // WriteClearAuthPacket: Client clear text authentication packet
- // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
- func (c *Conn) WriteClearAuthPacket(password string) error {
- // Calculate the packet length and add a tailing 0
- pktLen := len(password) + 1
- data := make([]byte, 4+pktLen)
- // Add the clear password [null terminated string]
- copy(data[4:], password)
- data[4+pktLen-1] = 0x00
- return errors.Wrap(c.WritePacket(data), "WritePacket failed")
- }
- // WritePublicKeyAuthPacket: Caching sha2 authentication. Public key request and send encrypted password
- // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
- func (c *Conn) WritePublicKeyAuthPacket(password string, cipher []byte) error {
- // request public key
- data := make([]byte, 4+1)
- data[4] = 2 // cachingSha2PasswordRequestPublicKey
- if err := c.WritePacket(data); err != nil {
- return errors.Wrap(err, "WritePacket(single byte) failed")
- }
- data, err := c.ReadPacket()
- if err != nil {
- return errors.Wrap(err, "ReadPacket failed")
- }
- block, _ := pem.Decode(data[1:])
- pub, err := x509.ParsePKIXPublicKey(block.Bytes)
- if err != nil {
- return errors.Wrap(err, "x509.ParsePKIXPublicKey failed")
- }
- plain := make([]byte, len(password)+1)
- copy(plain, password)
- for i := range plain {
- j := i % len(cipher)
- plain[i] ^= cipher[j]
- }
- sha1v := sha1.New()
- enc, _ := rsa.EncryptOAEP(sha1v, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
- data = make([]byte, 4+len(enc))
- copy(data[4:], enc)
- return errors.Wrap(c.WritePacket(data), "WritePacket failed")
- }
- func (c *Conn) WriteEncryptedPassword(password string, seed []byte, pub *rsa.PublicKey) error {
- enc, err := EncryptPassword(password, seed, pub)
- if err != nil {
- return errors.Wrap(err, "EncryptPassword failed")
- }
- return errors.Wrap(c.WriteAuthSwitchPacket(enc, false), "WriteAuthSwitchPacket failed")
- }
- // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
- func (c *Conn) WriteAuthSwitchPacket(authData []byte, addNUL bool) error {
- pktLen := 4 + len(authData)
- if addNUL {
- pktLen++
- }
- data := make([]byte, pktLen)
- // Add the auth data [EOF]
- copy(data[4:], authData)
- if addNUL {
- data[pktLen-1] = 0x00
- }
- return errors.Wrap(c.WritePacket(data), "WritePacket failed")
- }
- func (c *Conn) ResetSequence() {
- c.Sequence = 0
- }
- func (c *Conn) Close() error {
- c.Sequence = 0
- if c.Conn != nil {
- return errors.Wrap(c.Conn.Close(), "Conn.Close failed")
- }
- return nil
- }
|