stmt.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. package client
  2. import (
  3. "encoding/binary"
  4. "encoding/json"
  5. "fmt"
  6. "math"
  7. . "github.com/go-mysql-org/go-mysql/mysql"
  8. "github.com/pingcap/errors"
  9. )
  10. type Stmt struct {
  11. conn *Conn
  12. id uint32
  13. params int
  14. columns int
  15. warnings int
  16. }
  17. func (s *Stmt) ParamNum() int {
  18. return s.params
  19. }
  20. func (s *Stmt) ColumnNum() int {
  21. return s.columns
  22. }
  23. func (s *Stmt) WarningsNum() int {
  24. return s.warnings
  25. }
  26. func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
  27. if err := s.write(args...); err != nil {
  28. return nil, errors.Trace(err)
  29. }
  30. return s.conn.readResult(true)
  31. }
  32. func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error {
  33. if err := s.write(args...); err != nil {
  34. return errors.Trace(err)
  35. }
  36. return s.conn.readResultStreaming(true, result, perRowCb, perResCb)
  37. }
  38. func (s *Stmt) Close() error {
  39. if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
  40. return errors.Trace(err)
  41. }
  42. return nil
  43. }
  44. func (s *Stmt) write(args ...interface{}) error {
  45. paramsNum := s.params
  46. if len(args) != paramsNum {
  47. return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
  48. }
  49. paramTypes := make([]byte, paramsNum<<1)
  50. paramValues := make([][]byte, paramsNum)
  51. //NULL-bitmap, length: (num-params+7)
  52. nullBitmap := make([]byte, (paramsNum+7)>>3)
  53. length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)
  54. var newParamBoundFlag byte = 0
  55. for i := range args {
  56. if args[i] == nil {
  57. nullBitmap[i/8] |= (1 << (uint(i) % 8))
  58. paramTypes[i<<1] = MYSQL_TYPE_NULL
  59. continue
  60. }
  61. newParamBoundFlag = 1
  62. switch v := args[i].(type) {
  63. case int8:
  64. paramTypes[i<<1] = MYSQL_TYPE_TINY
  65. paramValues[i] = []byte{byte(v)}
  66. case int16:
  67. paramTypes[i<<1] = MYSQL_TYPE_SHORT
  68. paramValues[i] = Uint16ToBytes(uint16(v))
  69. case int32:
  70. paramTypes[i<<1] = MYSQL_TYPE_LONG
  71. paramValues[i] = Uint32ToBytes(uint32(v))
  72. case int:
  73. paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
  74. paramValues[i] = Uint64ToBytes(uint64(v))
  75. case int64:
  76. paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
  77. paramValues[i] = Uint64ToBytes(uint64(v))
  78. case uint8:
  79. paramTypes[i<<1] = MYSQL_TYPE_TINY
  80. paramTypes[(i<<1)+1] = 0x80
  81. paramValues[i] = []byte{v}
  82. case uint16:
  83. paramTypes[i<<1] = MYSQL_TYPE_SHORT
  84. paramTypes[(i<<1)+1] = 0x80
  85. paramValues[i] = Uint16ToBytes(v)
  86. case uint32:
  87. paramTypes[i<<1] = MYSQL_TYPE_LONG
  88. paramTypes[(i<<1)+1] = 0x80
  89. paramValues[i] = Uint32ToBytes(v)
  90. case uint:
  91. paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
  92. paramTypes[(i<<1)+1] = 0x80
  93. paramValues[i] = Uint64ToBytes(uint64(v))
  94. case uint64:
  95. paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
  96. paramTypes[(i<<1)+1] = 0x80
  97. paramValues[i] = Uint64ToBytes(v)
  98. case bool:
  99. paramTypes[i<<1] = MYSQL_TYPE_TINY
  100. if v {
  101. paramValues[i] = []byte{1}
  102. } else {
  103. paramValues[i] = []byte{0}
  104. }
  105. case float32:
  106. paramTypes[i<<1] = MYSQL_TYPE_FLOAT
  107. paramValues[i] = Uint32ToBytes(math.Float32bits(v))
  108. case float64:
  109. paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
  110. paramValues[i] = Uint64ToBytes(math.Float64bits(v))
  111. case string:
  112. paramTypes[i<<1] = MYSQL_TYPE_STRING
  113. paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
  114. case []byte:
  115. paramTypes[i<<1] = MYSQL_TYPE_STRING
  116. paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
  117. case json.RawMessage:
  118. paramTypes[i<<1] = MYSQL_TYPE_STRING
  119. paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
  120. default:
  121. return fmt.Errorf("invalid argument type %T", args[i])
  122. }
  123. length += len(paramValues[i])
  124. }
  125. data := make([]byte, 4, 4+length)
  126. data = append(data, COM_STMT_EXECUTE)
  127. data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
  128. //flag: CURSOR_TYPE_NO_CURSOR
  129. data = append(data, 0x00)
  130. //iteration-count, always 1
  131. data = append(data, 1, 0, 0, 0)
  132. if s.params > 0 {
  133. data = append(data, nullBitmap...)
  134. //new-params-bound-flag
  135. data = append(data, newParamBoundFlag)
  136. if newParamBoundFlag == 1 {
  137. //type of each parameter, length: num-params * 2
  138. data = append(data, paramTypes...)
  139. //value of each parameter
  140. for _, v := range paramValues {
  141. data = append(data, v...)
  142. }
  143. }
  144. }
  145. s.conn.ResetSequence()
  146. return s.conn.WritePacket(data)
  147. }
  148. func (c *Conn) Prepare(query string) (*Stmt, error) {
  149. if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
  150. return nil, errors.Trace(err)
  151. }
  152. data, err := c.ReadPacket()
  153. if err != nil {
  154. return nil, errors.Trace(err)
  155. }
  156. if data[0] == ERR_HEADER {
  157. return nil, c.handleErrorPacket(data)
  158. } else if data[0] != OK_HEADER {
  159. return nil, ErrMalformPacket
  160. }
  161. s := new(Stmt)
  162. s.conn = c
  163. pos := 1
  164. //for statement id
  165. s.id = binary.LittleEndian.Uint32(data[pos:])
  166. pos += 4
  167. //number columns
  168. s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
  169. pos += 2
  170. //number params
  171. s.params = int(binary.LittleEndian.Uint16(data[pos:]))
  172. pos += 2
  173. //warnings
  174. s.warnings = int(binary.LittleEndian.Uint16(data[pos:]))
  175. // pos += 2
  176. if s.params > 0 {
  177. if err := s.conn.readUntilEOF(); err != nil {
  178. return nil, errors.Trace(err)
  179. }
  180. }
  181. if s.columns > 0 {
  182. if err := s.conn.readUntilEOF(); err != nil {
  183. return nil, errors.Trace(err)
  184. }
  185. }
  186. return s, nil
  187. }