123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- package client
- import (
- "encoding/binary"
- "encoding/json"
- "fmt"
- "math"
- . "github.com/go-mysql-org/go-mysql/mysql"
- "github.com/pingcap/errors"
- )
- type Stmt struct {
- conn *Conn
- id uint32
- params int
- columns int
- warnings int
- }
- func (s *Stmt) ParamNum() int {
- return s.params
- }
- func (s *Stmt) ColumnNum() int {
- return s.columns
- }
- func (s *Stmt) WarningsNum() int {
- return s.warnings
- }
- func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
- if err := s.write(args...); err != nil {
- return nil, errors.Trace(err)
- }
- return s.conn.readResult(true)
- }
- func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error {
- if err := s.write(args...); err != nil {
- return errors.Trace(err)
- }
- return s.conn.readResultStreaming(true, result, perRowCb, perResCb)
- }
- func (s *Stmt) Close() error {
- if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
- return errors.Trace(err)
- }
- return nil
- }
- func (s *Stmt) write(args ...interface{}) error {
- paramsNum := s.params
- if len(args) != paramsNum {
- return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
- }
- paramTypes := make([]byte, paramsNum<<1)
- paramValues := make([][]byte, paramsNum)
- //NULL-bitmap, length: (num-params+7)
- nullBitmap := make([]byte, (paramsNum+7)>>3)
- length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)
- var newParamBoundFlag byte = 0
- for i := range args {
- if args[i] == nil {
- nullBitmap[i/8] |= (1 << (uint(i) % 8))
- paramTypes[i<<1] = MYSQL_TYPE_NULL
- continue
- }
- newParamBoundFlag = 1
- switch v := args[i].(type) {
- case int8:
- paramTypes[i<<1] = MYSQL_TYPE_TINY
- paramValues[i] = []byte{byte(v)}
- case int16:
- paramTypes[i<<1] = MYSQL_TYPE_SHORT
- paramValues[i] = Uint16ToBytes(uint16(v))
- case int32:
- paramTypes[i<<1] = MYSQL_TYPE_LONG
- paramValues[i] = Uint32ToBytes(uint32(v))
- case int:
- paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
- paramValues[i] = Uint64ToBytes(uint64(v))
- case int64:
- paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
- paramValues[i] = Uint64ToBytes(uint64(v))
- case uint8:
- paramTypes[i<<1] = MYSQL_TYPE_TINY
- paramTypes[(i<<1)+1] = 0x80
- paramValues[i] = []byte{v}
- case uint16:
- paramTypes[i<<1] = MYSQL_TYPE_SHORT
- paramTypes[(i<<1)+1] = 0x80
- paramValues[i] = Uint16ToBytes(v)
- case uint32:
- paramTypes[i<<1] = MYSQL_TYPE_LONG
- paramTypes[(i<<1)+1] = 0x80
- paramValues[i] = Uint32ToBytes(v)
- case uint:
- paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
- paramTypes[(i<<1)+1] = 0x80
- paramValues[i] = Uint64ToBytes(uint64(v))
- case uint64:
- paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
- paramTypes[(i<<1)+1] = 0x80
- paramValues[i] = Uint64ToBytes(v)
- case bool:
- paramTypes[i<<1] = MYSQL_TYPE_TINY
- if v {
- paramValues[i] = []byte{1}
- } else {
- paramValues[i] = []byte{0}
- }
- case float32:
- paramTypes[i<<1] = MYSQL_TYPE_FLOAT
- paramValues[i] = Uint32ToBytes(math.Float32bits(v))
- case float64:
- paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
- paramValues[i] = Uint64ToBytes(math.Float64bits(v))
- case string:
- paramTypes[i<<1] = MYSQL_TYPE_STRING
- paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
- case []byte:
- paramTypes[i<<1] = MYSQL_TYPE_STRING
- paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
- case json.RawMessage:
- paramTypes[i<<1] = MYSQL_TYPE_STRING
- paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
- default:
- return fmt.Errorf("invalid argument type %T", args[i])
- }
- length += len(paramValues[i])
- }
- data := make([]byte, 4, 4+length)
- data = append(data, COM_STMT_EXECUTE)
- data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
- //flag: CURSOR_TYPE_NO_CURSOR
- data = append(data, 0x00)
- //iteration-count, always 1
- data = append(data, 1, 0, 0, 0)
- if s.params > 0 {
- data = append(data, nullBitmap...)
- //new-params-bound-flag
- data = append(data, newParamBoundFlag)
- if newParamBoundFlag == 1 {
- //type of each parameter, length: num-params * 2
- data = append(data, paramTypes...)
- //value of each parameter
- for _, v := range paramValues {
- data = append(data, v...)
- }
- }
- }
- s.conn.ResetSequence()
- return s.conn.WritePacket(data)
- }
- func (c *Conn) Prepare(query string) (*Stmt, error) {
- if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
- return nil, errors.Trace(err)
- }
- data, err := c.ReadPacket()
- if err != nil {
- return nil, errors.Trace(err)
- }
- if data[0] == ERR_HEADER {
- return nil, c.handleErrorPacket(data)
- } else if data[0] != OK_HEADER {
- return nil, ErrMalformPacket
- }
- s := new(Stmt)
- s.conn = c
- pos := 1
- //for statement id
- s.id = binary.LittleEndian.Uint32(data[pos:])
- pos += 4
- //number columns
- s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
- pos += 2
- //number params
- s.params = int(binary.LittleEndian.Uint16(data[pos:]))
- pos += 2
- //warnings
- s.warnings = int(binary.LittleEndian.Uint16(data[pos:]))
- // pos += 2
- if s.params > 0 {
- if err := s.conn.readUntilEOF(); err != nil {
- return nil, errors.Trace(err)
- }
- }
- if s.columns > 0 {
- if err := s.conn.readUntilEOF(); err != nil {
- return nil, errors.Trace(err)
- }
- }
- return s, nil
- }
|