resp.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. package client
  2. import (
  3. "bytes"
  4. "crypto/rsa"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "encoding/pem"
  8. "github.com/pingcap/errors"
  9. "github.com/siddontang/go/hack"
  10. . "github.com/go-mysql-org/go-mysql/mysql"
  11. "github.com/go-mysql-org/go-mysql/utils"
  12. )
  13. func (c *Conn) readUntilEOF() (err error) {
  14. var data []byte
  15. for {
  16. data, err = c.ReadPacket()
  17. if err != nil {
  18. return
  19. }
  20. // EOF Packet
  21. if c.isEOFPacket(data) {
  22. return
  23. }
  24. }
  25. }
  26. func (c *Conn) isEOFPacket(data []byte) bool {
  27. return data[0] == EOF_HEADER && len(data) <= 5
  28. }
  29. func (c *Conn) handleOKPacket(data []byte) (*Result, error) {
  30. var n int
  31. var pos = 1
  32. r := new(Result)
  33. r.AffectedRows, _, n = LengthEncodedInt(data[pos:])
  34. pos += n
  35. r.InsertId, _, n = LengthEncodedInt(data[pos:])
  36. pos += n
  37. if c.capability&CLIENT_PROTOCOL_41 > 0 {
  38. r.Status = binary.LittleEndian.Uint16(data[pos:])
  39. c.status = r.Status
  40. pos += 2
  41. //todo:strict_mode, check warnings as error
  42. r.Warnings = binary.LittleEndian.Uint16(data[pos:])
  43. // pos += 2
  44. } else if c.capability&CLIENT_TRANSACTIONS > 0 {
  45. r.Status = binary.LittleEndian.Uint16(data[pos:])
  46. c.status = r.Status
  47. // pos += 2
  48. }
  49. //new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
  50. //skip info
  51. return r, nil
  52. }
  53. func (c *Conn) handleErrorPacket(data []byte) error {
  54. e := new(MyError)
  55. var pos = 1
  56. e.Code = binary.LittleEndian.Uint16(data[pos:])
  57. pos += 2
  58. if c.capability&CLIENT_PROTOCOL_41 > 0 {
  59. //skip '#'
  60. pos++
  61. e.State = hack.String(data[pos : pos+5])
  62. pos += 5
  63. }
  64. e.Message = hack.String(data[pos:])
  65. return e
  66. }
  67. func (c *Conn) handleAuthResult() error {
  68. data, switchToPlugin, err := c.readAuthResult()
  69. if err != nil {
  70. return err
  71. }
  72. // handle auth switch, only support 'sha256_password', and 'caching_sha2_password'
  73. if switchToPlugin != "" {
  74. //fmt.Printf("now switching auth plugin to '%s'\n", switchToPlugin)
  75. if data == nil {
  76. data = c.salt
  77. } else {
  78. copy(c.salt, data)
  79. }
  80. c.authPluginName = switchToPlugin
  81. auth, addNull, err := c.genAuthResponse(data)
  82. if err != nil {
  83. return err
  84. }
  85. if err = c.WriteAuthSwitchPacket(auth, addNull); err != nil {
  86. return err
  87. }
  88. // Read Result Packet
  89. data, switchToPlugin, err = c.readAuthResult()
  90. if err != nil {
  91. return err
  92. }
  93. // Do not allow to change the auth plugin more than once
  94. if switchToPlugin != "" {
  95. return errors.Errorf("can not switch auth plugin more than once")
  96. }
  97. }
  98. // handle caching_sha2_password
  99. if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD {
  100. if data == nil {
  101. return nil // auth already succeeded
  102. }
  103. if data[0] == CACHE_SHA2_FAST_AUTH {
  104. _, err = c.readOK()
  105. return err
  106. } else if data[0] == CACHE_SHA2_FULL_AUTH {
  107. // need full authentication
  108. if c.tlsConfig != nil || c.proto == "unix" {
  109. if err = c.WriteClearAuthPacket(c.password); err != nil {
  110. return err
  111. }
  112. } else {
  113. if err = c.WritePublicKeyAuthPacket(c.password, c.salt); err != nil {
  114. return err
  115. }
  116. }
  117. _, err = c.readOK()
  118. return err
  119. } else {
  120. return errors.Errorf("invalid packet %x", data[0])
  121. }
  122. } else if c.authPluginName == AUTH_SHA256_PASSWORD {
  123. if len(data) == 0 {
  124. return nil // auth already succeeded
  125. }
  126. block, _ := pem.Decode(data)
  127. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  128. if err != nil {
  129. return err
  130. }
  131. // send encrypted password
  132. err = c.WriteEncryptedPassword(c.password, c.salt, pub.(*rsa.PublicKey))
  133. if err != nil {
  134. return err
  135. }
  136. _, err = c.readOK()
  137. return err
  138. }
  139. return nil
  140. }
  141. func (c *Conn) readAuthResult() ([]byte, string, error) {
  142. data, err := c.ReadPacket()
  143. if err != nil {
  144. return nil, "", err
  145. }
  146. // see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
  147. // packet indicator
  148. switch data[0] {
  149. case OK_HEADER:
  150. _, err := c.handleOKPacket(data)
  151. return nil, "", err
  152. case MORE_DATE_HEADER:
  153. return data[1:], "", err
  154. case EOF_HEADER:
  155. // server wants to switch auth
  156. if len(data) < 1 {
  157. // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
  158. return nil, AUTH_MYSQL_OLD_PASSWORD, nil
  159. }
  160. pluginEndIndex := bytes.IndexByte(data, 0x00)
  161. if pluginEndIndex < 0 {
  162. return nil, "", errors.New("invalid packet")
  163. }
  164. plugin := string(data[1:pluginEndIndex])
  165. authData := data[pluginEndIndex+1:]
  166. return authData, plugin, nil
  167. default: // Error otherwise
  168. return nil, "", c.handleErrorPacket(data)
  169. }
  170. }
  171. func (c *Conn) readOK() (*Result, error) {
  172. data, err := c.ReadPacket()
  173. if err != nil {
  174. return nil, errors.Trace(err)
  175. }
  176. if data[0] == OK_HEADER {
  177. return c.handleOKPacket(data)
  178. } else if data[0] == ERR_HEADER {
  179. return nil, c.handleErrorPacket(data)
  180. } else {
  181. return nil, errors.New("invalid ok packet")
  182. }
  183. }
  184. func (c *Conn) readResult(binary bool) (*Result, error) {
  185. bs := utils.ByteSliceGet(16)
  186. defer utils.ByteSlicePut(bs)
  187. var err error
  188. bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
  189. if err != nil {
  190. return nil, errors.Trace(err)
  191. }
  192. switch bs.B[0] {
  193. case OK_HEADER:
  194. return c.handleOKPacket(bs.B)
  195. case ERR_HEADER:
  196. return nil, c.handleErrorPacket(bytes.Repeat(bs.B, 1))
  197. case LocalInFile_HEADER:
  198. return nil, ErrMalformPacket
  199. default:
  200. return c.readResultset(bs.B, binary)
  201. }
  202. }
  203. func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
  204. bs := utils.ByteSliceGet(16)
  205. defer utils.ByteSlicePut(bs)
  206. var err error
  207. bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
  208. if err != nil {
  209. return errors.Trace(err)
  210. }
  211. switch bs.B[0] {
  212. case OK_HEADER:
  213. // https://dev.mysql.com/doc/internals/en/com-query-response.html
  214. // 14.6.4.1 COM_QUERY Response
  215. // If the number of columns in the resultset is 0, this is a OK_Packet.
  216. okResult, err := c.handleOKPacket(bs.B)
  217. if err != nil {
  218. return errors.Trace(err)
  219. }
  220. result.Status = okResult.Status
  221. result.AffectedRows = okResult.AffectedRows
  222. result.InsertId = okResult.InsertId
  223. result.Warnings = okResult.Warnings
  224. if result.Resultset == nil {
  225. result.Resultset = NewResultset(0)
  226. } else {
  227. result.Reset(0)
  228. }
  229. return nil
  230. case ERR_HEADER:
  231. return c.handleErrorPacket(bytes.Repeat(bs.B, 1))
  232. case LocalInFile_HEADER:
  233. return ErrMalformPacket
  234. default:
  235. return c.readResultsetStreaming(bs.B, binary, result, perRowCb, perResCb)
  236. }
  237. }
  238. func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
  239. // column count
  240. count, _, n := LengthEncodedInt(data)
  241. if n-len(data) != 0 {
  242. return nil, ErrMalformPacket
  243. }
  244. result := &Result{
  245. Resultset: NewResultset(int(count)),
  246. }
  247. if err := c.readResultColumns(result); err != nil {
  248. return nil, errors.Trace(err)
  249. }
  250. if err := c.readResultRows(result, binary); err != nil {
  251. return nil, errors.Trace(err)
  252. }
  253. return result, nil
  254. }
  255. func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
  256. columnCount, _, n := LengthEncodedInt(data)
  257. if n-len(data) != 0 {
  258. return ErrMalformPacket
  259. }
  260. if result.Resultset == nil {
  261. result.Resultset = NewResultset(int(columnCount))
  262. } else {
  263. // Reuse memory if can
  264. result.Reset(int(columnCount))
  265. }
  266. // this is a streaming resultset
  267. result.Resultset.Streaming = StreamingSelect
  268. if err := c.readResultColumns(result); err != nil {
  269. return errors.Trace(err)
  270. }
  271. if perResCb != nil {
  272. if err := perResCb(result); err != nil {
  273. return err
  274. }
  275. }
  276. if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil {
  277. return errors.Trace(err)
  278. }
  279. // this resultset is done streaming
  280. result.Resultset.StreamingDone = true
  281. return nil
  282. }
  283. func (c *Conn) readResultColumns(result *Result) (err error) {
  284. var i int = 0
  285. var data []byte
  286. for {
  287. rawPkgLen := len(result.RawPkg)
  288. result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
  289. if err != nil {
  290. return
  291. }
  292. data = result.RawPkg[rawPkgLen:]
  293. // EOF Packet
  294. if c.isEOFPacket(data) {
  295. if c.capability&CLIENT_PROTOCOL_41 > 0 {
  296. result.Warnings = binary.LittleEndian.Uint16(data[1:])
  297. //todo add strict_mode, warning will be treat as error
  298. result.Status = binary.LittleEndian.Uint16(data[3:])
  299. c.status = result.Status
  300. }
  301. if i != len(result.Fields) {
  302. err = ErrMalformPacket
  303. }
  304. return
  305. }
  306. if result.Fields[i] == nil {
  307. result.Fields[i] = &Field{}
  308. }
  309. err = result.Fields[i].Parse(data)
  310. if err != nil {
  311. return
  312. }
  313. result.FieldNames[hack.String(result.Fields[i].Name)] = i
  314. i++
  315. }
  316. }
  317. func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {
  318. var data []byte
  319. for {
  320. rawPkgLen := len(result.RawPkg)
  321. result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
  322. if err != nil {
  323. return
  324. }
  325. data = result.RawPkg[rawPkgLen:]
  326. // EOF Packet
  327. if c.isEOFPacket(data) {
  328. if c.capability&CLIENT_PROTOCOL_41 > 0 {
  329. result.Warnings = binary.LittleEndian.Uint16(data[1:])
  330. //todo add strict_mode, warning will be treat as error
  331. result.Status = binary.LittleEndian.Uint16(data[3:])
  332. c.status = result.Status
  333. }
  334. break
  335. }
  336. if data[0] == ERR_HEADER {
  337. return c.handleErrorPacket(data)
  338. }
  339. result.RowDatas = append(result.RowDatas, data)
  340. }
  341. if cap(result.Values) < len(result.RowDatas) {
  342. result.Values = make([][]FieldValue, len(result.RowDatas))
  343. } else {
  344. result.Values = result.Values[:len(result.RowDatas)]
  345. }
  346. for i := range result.Values {
  347. result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary, result.Values[i])
  348. if err != nil {
  349. return errors.Trace(err)
  350. }
  351. }
  352. return nil
  353. }
  354. func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) {
  355. var (
  356. data []byte
  357. row []FieldValue
  358. )
  359. for {
  360. data, err = c.ReadPacketReuseMem(data[:0])
  361. if err != nil {
  362. return
  363. }
  364. // EOF Packet
  365. if c.isEOFPacket(data) {
  366. if c.capability&CLIENT_PROTOCOL_41 > 0 {
  367. result.Warnings = binary.LittleEndian.Uint16(data[1:])
  368. // todo add strict_mode, warning will be treat as error
  369. result.Status = binary.LittleEndian.Uint16(data[3:])
  370. c.status = result.Status
  371. }
  372. break
  373. }
  374. if data[0] == ERR_HEADER {
  375. return c.handleErrorPacket(data)
  376. }
  377. // Parse this row
  378. row, err = RowData(data).Parse(result.Fields, isBinary, row)
  379. if err != nil {
  380. return errors.Trace(err)
  381. }
  382. // Send the row to "userland" code
  383. err = perRowCb(row)
  384. if err != nil {
  385. return errors.Trace(err)
  386. }
  387. }
  388. return nil
  389. }