123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- package client
- import (
- "bytes"
- "crypto/rsa"
- "crypto/x509"
- "encoding/binary"
- "encoding/pem"
- "github.com/pingcap/errors"
- "github.com/siddontang/go/hack"
- . "github.com/go-mysql-org/go-mysql/mysql"
- "github.com/go-mysql-org/go-mysql/utils"
- )
- func (c *Conn) readUntilEOF() (err error) {
- var data []byte
- for {
- data, err = c.ReadPacket()
- if err != nil {
- return
- }
- // EOF Packet
- if c.isEOFPacket(data) {
- return
- }
- }
- }
- func (c *Conn) isEOFPacket(data []byte) bool {
- return data[0] == EOF_HEADER && len(data) <= 5
- }
- func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
- var n int
- var pos = 1
- r := new(Result)
- r.AffectedRows, _, n = LengthEncodedInt(data[pos:])
- pos += n
- r.InsertId, _, n = LengthEncodedInt(data[pos:])
- pos += n
- if c.capability&CLIENT_PROTOCOL_41 > 0 {
- r.Status = binary.LittleEndian.Uint16(data[pos:])
- c.status = r.Status
- pos += 2
- //todo:strict_mode, check warnings as error
- r.Warnings = binary.LittleEndian.Uint16(data[pos:])
- // pos += 2
- } else if c.capability&CLIENT_TRANSACTIONS > 0 {
- r.Status = binary.LittleEndian.Uint16(data[pos:])
- c.status = r.Status
- // pos += 2
- }
- //new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
- //skip info
- return r, nil
- }
- func (c *Conn) handleErrorPacket(data []byte) error {
- e := new(MyError)
- var pos = 1
- e.Code = binary.LittleEndian.Uint16(data[pos:])
- pos += 2
- if c.capability&CLIENT_PROTOCOL_41 > 0 {
- //skip '#'
- pos++
- e.State = hack.String(data[pos : pos+5])
- pos += 5
- }
- e.Message = hack.String(data[pos:])
- return e
- }
- func (c *Conn) handleAuthResult() error {
- data, switchToPlugin, err := c.readAuthResult()
- if err != nil {
- return err
- }
- // handle auth switch, only support 'sha256_password', and 'caching_sha2_password'
- if switchToPlugin != "" {
- //fmt.Printf("now switching auth plugin to '%s'\n", switchToPlugin)
- if data == nil {
- data = c.salt
- } else {
- copy(c.salt, data)
- }
- c.authPluginName = switchToPlugin
- auth, addNull, err := c.genAuthResponse(data)
- if err != nil {
- return err
- }
- if err = c.WriteAuthSwitchPacket(auth, addNull); err != nil {
- return err
- }
- // Read Result Packet
- data, switchToPlugin, err = c.readAuthResult()
- if err != nil {
- return err
- }
- // Do not allow to change the auth plugin more than once
- if switchToPlugin != "" {
- return errors.Errorf("can not switch auth plugin more than once")
- }
- }
- // handle caching_sha2_password
- if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD {
- if data == nil {
- return nil // auth already succeeded
- }
- if data[0] == CACHE_SHA2_FAST_AUTH {
- _, err = c.readOK()
- return err
- } else if data[0] == CACHE_SHA2_FULL_AUTH {
- // need full authentication
- if c.tlsConfig != nil || c.proto == "unix" {
- if err = c.WriteClearAuthPacket(c.password); err != nil {
- return err
- }
- } else {
- if err = c.WritePublicKeyAuthPacket(c.password, c.salt); err != nil {
- return err
- }
- }
- _, err = c.readOK()
- return err
- } else {
- return errors.Errorf("invalid packet %x", data[0])
- }
- } else if c.authPluginName == AUTH_SHA256_PASSWORD {
- if len(data) == 0 {
- return nil // auth already succeeded
- }
- block, _ := pem.Decode(data)
- pub, err := x509.ParsePKIXPublicKey(block.Bytes)
- if err != nil {
- return err
- }
- // send encrypted password
- err = c.WriteEncryptedPassword(c.password, c.salt, pub.(*rsa.PublicKey))
- if err != nil {
- return err
- }
- _, err = c.readOK()
- return err
- }
- return nil
- }
- func (c *Conn) readAuthResult() ([]byte, string, error) {
- data, err := c.ReadPacket()
- if err != nil {
- return nil, "", err
- }
- // see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
- // packet indicator
- switch data[0] {
- case OK_HEADER:
- _, err := c.handleOKPacket(data)
- return nil, "", err
- case MORE_DATE_HEADER:
- return data[1:], "", err
- case EOF_HEADER:
- // server wants to switch auth
- if len(data) < 1 {
- // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
- return nil, AUTH_MYSQL_OLD_PASSWORD, nil
- }
- pluginEndIndex := bytes.IndexByte(data, 0x00)
- if pluginEndIndex < 0 {
- return nil, "", errors.New("invalid packet")
- }
- plugin := string(data[1:pluginEndIndex])
- authData := data[pluginEndIndex+1:]
- return authData, plugin, nil
- default: // Error otherwise
- return nil, "", c.handleErrorPacket(data)
- }
- }
- func (c *Conn) readOK() (*Result, error) {
- data, err := c.ReadPacket()
- if err != nil {
- return nil, errors.Trace(err)
- }
- if data[0] == OK_HEADER {
- return c.handleOKPacket(data)
- } else if data[0] == ERR_HEADER {
- return nil, c.handleErrorPacket(data)
- } else {
- return nil, errors.New("invalid ok packet")
- }
- }
- func (c *Conn) readResult(binary bool) (*Result, error) {
- bs := utils.ByteSliceGet(16)
- defer utils.ByteSlicePut(bs)
- var err error
- bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
- if err != nil {
- return nil, errors.Trace(err)
- }
- switch bs.B[0] {
- case OK_HEADER:
- return c.handleOKPacket(bs.B)
- case ERR_HEADER:
- return nil, c.handleErrorPacket(bytes.Repeat(bs.B, 1))
- case LocalInFile_HEADER:
- return nil, ErrMalformPacket
- default:
- return c.readResultset(bs.B, binary)
- }
- }
- func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
- bs := utils.ByteSliceGet(16)
- defer utils.ByteSlicePut(bs)
- var err error
- bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
- if err != nil {
- return errors.Trace(err)
- }
- switch bs.B[0] {
- case OK_HEADER:
- // https://dev.mysql.com/doc/internals/en/com-query-response.html
- // 14.6.4.1 COM_QUERY Response
- // If the number of columns in the resultset is 0, this is a OK_Packet.
- okResult, err := c.handleOKPacket(bs.B)
- if err != nil {
- return errors.Trace(err)
- }
- result.Status = okResult.Status
- result.AffectedRows = okResult.AffectedRows
- result.InsertId = okResult.InsertId
- result.Warnings = okResult.Warnings
- if result.Resultset == nil {
- result.Resultset = NewResultset(0)
- } else {
- result.Reset(0)
- }
- return nil
- case ERR_HEADER:
- return c.handleErrorPacket(bytes.Repeat(bs.B, 1))
- case LocalInFile_HEADER:
- return ErrMalformPacket
- default:
- return c.readResultsetStreaming(bs.B, binary, result, perRowCb, perResCb)
- }
- }
- func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
- // column count
- count, _, n := LengthEncodedInt(data)
- if n-len(data) != 0 {
- return nil, ErrMalformPacket
- }
- result := &Result{
- Resultset: NewResultset(int(count)),
- }
- if err := c.readResultColumns(result); err != nil {
- return nil, errors.Trace(err)
- }
- if err := c.readResultRows(result, binary); err != nil {
- return nil, errors.Trace(err)
- }
- return result, nil
- }
- func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
- columnCount, _, n := LengthEncodedInt(data)
- if n-len(data) != 0 {
- return ErrMalformPacket
- }
- if result.Resultset == nil {
- result.Resultset = NewResultset(int(columnCount))
- } else {
- // Reuse memory if can
- result.Reset(int(columnCount))
- }
- // this is a streaming resultset
- result.Resultset.Streaming = StreamingSelect
- if err := c.readResultColumns(result); err != nil {
- return errors.Trace(err)
- }
- if perResCb != nil {
- if err := perResCb(result); err != nil {
- return err
- }
- }
- if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil {
- return errors.Trace(err)
- }
- // this resultset is done streaming
- result.Resultset.StreamingDone = true
- return nil
- }
- func (c *Conn) readResultColumns(result *Result) (err error) {
- var i int = 0
- var data []byte
- for {
- rawPkgLen := len(result.RawPkg)
- result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
- if err != nil {
- return
- }
- data = result.RawPkg[rawPkgLen:]
- // EOF Packet
- if c.isEOFPacket(data) {
- if c.capability&CLIENT_PROTOCOL_41 > 0 {
- result.Warnings = binary.LittleEndian.Uint16(data[1:])
- //todo add strict_mode, warning will be treat as error
- result.Status = binary.LittleEndian.Uint16(data[3:])
- c.status = result.Status
- }
- if i != len(result.Fields) {
- err = ErrMalformPacket
- }
- return
- }
- if result.Fields[i] == nil {
- result.Fields[i] = &Field{}
- }
- err = result.Fields[i].Parse(data)
- if err != nil {
- return
- }
- result.FieldNames[hack.String(result.Fields[i].Name)] = i
- i++
- }
- }
- func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {
- var data []byte
- for {
- rawPkgLen := len(result.RawPkg)
- result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
- if err != nil {
- return
- }
- data = result.RawPkg[rawPkgLen:]
- // EOF Packet
- if c.isEOFPacket(data) {
- if c.capability&CLIENT_PROTOCOL_41 > 0 {
- result.Warnings = binary.LittleEndian.Uint16(data[1:])
- //todo add strict_mode, warning will be treat as error
- result.Status = binary.LittleEndian.Uint16(data[3:])
- c.status = result.Status
- }
- break
- }
- if data[0] == ERR_HEADER {
- return c.handleErrorPacket(data)
- }
- result.RowDatas = append(result.RowDatas, data)
- }
- if cap(result.Values) < len(result.RowDatas) {
- result.Values = make([][]FieldValue, len(result.RowDatas))
- } else {
- result.Values = result.Values[:len(result.RowDatas)]
- }
- for i := range result.Values {
- result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary, result.Values[i])
- if err != nil {
- return errors.Trace(err)
- }
- }
- return nil
- }
- func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) {
- var (
- data []byte
- row []FieldValue
- )
- for {
- data, err = c.ReadPacketReuseMem(data[:0])
- if err != nil {
- return
- }
- // EOF Packet
- if c.isEOFPacket(data) {
- if c.capability&CLIENT_PROTOCOL_41 > 0 {
- result.Warnings = binary.LittleEndian.Uint16(data[1:])
- // todo add strict_mode, warning will be treat as error
- result.Status = binary.LittleEndian.Uint16(data[3:])
- c.status = result.Status
- }
- break
- }
- if data[0] == ERR_HEADER {
- return c.handleErrorPacket(data)
- }
- // Parse this row
- row, err = RowData(data).Parse(result.Fields, isBinary, row)
- if err != nil {
- return errors.Trace(err)
- }
- // Send the row to "userland" code
- err = perRowCb(row)
- if err != nil {
- return errors.Trace(err)
- }
- }
- return nil
- }
|