conn.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. package client
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "fmt"
  7. "net"
  8. "strings"
  9. "time"
  10. . "github.com/go-mysql-org/go-mysql/mysql"
  11. "github.com/go-mysql-org/go-mysql/packet"
  12. "github.com/go-mysql-org/go-mysql/utils"
  13. "github.com/pingcap/errors"
  14. )
  15. type Conn struct {
  16. *packet.Conn
  17. user string
  18. password string
  19. db string
  20. tlsConfig *tls.Config
  21. proto string
  22. // server capabilities
  23. capability uint32
  24. // client-set capabilities only
  25. ccaps uint32
  26. attributes map[string]string
  27. status uint16
  28. charset string
  29. salt []byte
  30. authPluginName string
  31. connectionID uint32
  32. }
  33. // This function will be called for every row in resultset from ExecuteSelectStreaming.
  34. type SelectPerRowCallback func(row []FieldValue) error
  35. // This function will be called once per result from ExecuteSelectStreaming
  36. type SelectPerResultCallback func(result *Result) error
  37. // This function will be called once per result from ExecuteMultiple
  38. type ExecPerResultCallback func(result *Result, err error)
  39. func getNetProto(addr string) string {
  40. proto := "tcp"
  41. if strings.Contains(addr, "/") {
  42. proto = "unix"
  43. }
  44. return proto
  45. }
  46. // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
  47. // Accepts a series of configuration functions as a variadic argument.
  48. func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
  49. proto := getNetProto(addr)
  50. ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
  51. defer cancel()
  52. dialer := &net.Dialer{}
  53. return ConnectWithDialer(ctx, proto, addr, user, password, dbName, dialer.DialContext, options...)
  54. }
  55. // Dialer connects to the address on the named network using the provided context.
  56. type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
  57. // Connect to a MySQL server using the given Dialer.
  58. func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
  59. c := new(Conn)
  60. var err error
  61. conn, err := dialer(ctx, network, addr)
  62. if err != nil {
  63. return nil, errors.Trace(err)
  64. }
  65. if c.tlsConfig != nil {
  66. c.Conn = packet.NewTLSConn(conn)
  67. } else {
  68. c.Conn = packet.NewConn(conn)
  69. }
  70. c.user = user
  71. c.password = password
  72. c.db = dbName
  73. c.proto = network
  74. // use default charset here, utf-8
  75. c.charset = DEFAULT_CHARSET
  76. // Apply configuration functions.
  77. for i := range options {
  78. options[i](c)
  79. }
  80. if err = c.handshake(); err != nil {
  81. return nil, errors.Trace(err)
  82. }
  83. return c, nil
  84. }
  85. func (c *Conn) handshake() error {
  86. var err error
  87. if err = c.readInitialHandshake(); err != nil {
  88. c.Close()
  89. return errors.Trace(err)
  90. }
  91. if err := c.writeAuthHandshake(); err != nil {
  92. c.Close()
  93. return errors.Trace(err)
  94. }
  95. if err := c.handleAuthResult(); err != nil {
  96. c.Close()
  97. return errors.Trace(err)
  98. }
  99. return nil
  100. }
  101. func (c *Conn) Close() error {
  102. return c.Conn.Close()
  103. }
  104. func (c *Conn) Ping() error {
  105. if err := c.writeCommand(COM_PING); err != nil {
  106. return errors.Trace(err)
  107. }
  108. if _, err := c.readOK(); err != nil {
  109. return errors.Trace(err)
  110. }
  111. return nil
  112. }
  113. // SetCapability enables the use of a specific capability
  114. func (c *Conn) SetCapability(cap uint32) {
  115. c.ccaps |= cap
  116. }
  117. // UnsetCapability disables the use of a specific capability
  118. func (c *Conn) UnsetCapability(cap uint32) {
  119. c.ccaps &= ^cap
  120. }
  121. // UseSSL: use default SSL
  122. // pass to options when connect
  123. func (c *Conn) UseSSL(insecureSkipVerify bool) {
  124. c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify}
  125. }
  126. // SetTLSConfig: use user-specified TLS config
  127. // pass to options when connect
  128. func (c *Conn) SetTLSConfig(config *tls.Config) {
  129. c.tlsConfig = config
  130. }
  131. func (c *Conn) UseDB(dbName string) error {
  132. if c.db == dbName {
  133. return nil
  134. }
  135. if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
  136. return errors.Trace(err)
  137. }
  138. if _, err := c.readOK(); err != nil {
  139. return errors.Trace(err)
  140. }
  141. c.db = dbName
  142. return nil
  143. }
  144. func (c *Conn) GetDB() string {
  145. return c.db
  146. }
  147. func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
  148. if len(args) == 0 {
  149. return c.exec(command)
  150. } else {
  151. if s, err := c.Prepare(command); err != nil {
  152. return nil, errors.Trace(err)
  153. } else {
  154. var r *Result
  155. r, err = s.Execute(args...)
  156. s.Close()
  157. return r, err
  158. }
  159. }
  160. }
  161. // ExecuteMultiple will call perResultCallback for every result of the multiple queries
  162. // that are executed.
  163. //
  164. // When ExecuteMultiple is used, the connection should have the SERVER_MORE_RESULTS_EXISTS
  165. // flag set to signal the server multiple queries are executed. Handling the responses
  166. // is up to the implementation of perResultCallback.
  167. //
  168. // Example:
  169. //
  170. // queries := "SELECT 1; SELECT NOW();"
  171. // conn.ExecuteMultiple(queries, func(result *mysql.Result, err error) {
  172. // // Use the result as you want
  173. // })
  174. //
  175. func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) {
  176. if err := c.writeCommandStr(COM_QUERY, query); err != nil {
  177. return nil, errors.Trace(err)
  178. }
  179. var err error
  180. var result *Result
  181. bs := utils.ByteSliceGet(16)
  182. defer utils.ByteSlicePut(bs)
  183. for {
  184. bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
  185. if err != nil {
  186. return nil, errors.Trace(err)
  187. }
  188. switch bs.B[0] {
  189. case OK_HEADER:
  190. result, err = c.handleOKPacket(bs.B)
  191. case ERR_HEADER:
  192. err = c.handleErrorPacket(bytes.Repeat(bs.B, 1))
  193. result = nil
  194. case LocalInFile_HEADER:
  195. err = ErrMalformPacket
  196. result = nil
  197. default:
  198. result, err = c.readResultset(bs.B, false)
  199. }
  200. // call user-defined callback
  201. perResultCallback(result, err)
  202. // if there was an error of this was the last result, stop looping
  203. if err != nil || result.Status&SERVER_MORE_RESULTS_EXISTS == 0 {
  204. break
  205. }
  206. }
  207. // return an empty result(set) signaling we're done streaming a multiple
  208. // streaming session
  209. // if this would end up in WriteValue, it would just be ignored as all
  210. // responses should have been handled in perResultCallback
  211. return &Result{Resultset: &Resultset{
  212. Streaming: StreamingMultiple,
  213. StreamingDone: true,
  214. }}, nil
  215. }
  216. // ExecuteSelectStreaming will call perRowCallback for every row in resultset
  217. // WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
  218. // When given, perResultCallback will be called once per result
  219. //
  220. // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
  221. //
  222. // Example:
  223. //
  224. // var result mysql.Result
  225. // conn.ExecuteSelectStreaming(`SELECT ... LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
  226. // // Use the row as you want.
  227. // // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
  228. // return nil
  229. // }, nil)
  230. //
  231. func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
  232. if err := c.writeCommandStr(COM_QUERY, command); err != nil {
  233. return errors.Trace(err)
  234. }
  235. return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
  236. }
  237. func (c *Conn) Begin() error {
  238. _, err := c.exec("BEGIN")
  239. return errors.Trace(err)
  240. }
  241. func (c *Conn) Commit() error {
  242. _, err := c.exec("COMMIT")
  243. return errors.Trace(err)
  244. }
  245. func (c *Conn) Rollback() error {
  246. _, err := c.exec("ROLLBACK")
  247. return errors.Trace(err)
  248. }
  249. func (c *Conn) SetAttributes(attributes map[string]string) {
  250. c.attributes = attributes
  251. }
  252. func (c *Conn) SetCharset(charset string) error {
  253. if c.charset == charset {
  254. return nil
  255. }
  256. if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
  257. return errors.Trace(err)
  258. } else {
  259. c.charset = charset
  260. return nil
  261. }
  262. }
  263. func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
  264. if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
  265. return nil, errors.Trace(err)
  266. }
  267. data, err := c.ReadPacket()
  268. if err != nil {
  269. return nil, errors.Trace(err)
  270. }
  271. fs := make([]*Field, 0, 4)
  272. var f *Field
  273. if data[0] == ERR_HEADER {
  274. return nil, c.handleErrorPacket(data)
  275. }
  276. for {
  277. if data, err = c.ReadPacket(); err != nil {
  278. return nil, errors.Trace(err)
  279. }
  280. // EOF Packet
  281. if c.isEOFPacket(data) {
  282. return fs, nil
  283. }
  284. if f, err = FieldData(data).Parse(); err != nil {
  285. return nil, errors.Trace(err)
  286. }
  287. fs = append(fs, f)
  288. }
  289. }
  290. func (c *Conn) SetAutoCommit() error {
  291. if !c.IsAutoCommit() {
  292. if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
  293. return errors.Trace(err)
  294. }
  295. }
  296. return nil
  297. }
  298. func (c *Conn) IsAutoCommit() bool {
  299. return c.status&SERVER_STATUS_AUTOCOMMIT > 0
  300. }
  301. func (c *Conn) IsInTransaction() bool {
  302. return c.status&SERVER_STATUS_IN_TRANS > 0
  303. }
  304. func (c *Conn) GetCharset() string {
  305. return c.charset
  306. }
  307. func (c *Conn) GetConnectionID() uint32 {
  308. return c.connectionID
  309. }
  310. func (c *Conn) HandleOKPacket(data []byte) *Result {
  311. r, _ := c.handleOKPacket(data)
  312. return r
  313. }
  314. func (c *Conn) HandleErrorPacket(data []byte) error {
  315. return c.handleErrorPacket(data)
  316. }
  317. func (c *Conn) ReadOKPacket() (*Result, error) {
  318. return c.readOK()
  319. }
  320. func (c *Conn) exec(query string) (*Result, error) {
  321. if err := c.writeCommandStr(COM_QUERY, query); err != nil {
  322. return nil, errors.Trace(err)
  323. }
  324. return c.readResult(false)
  325. }