util.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package mysql
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "crypto/sha1"
  6. "crypto/sha256"
  7. "encoding/binary"
  8. "fmt"
  9. "io"
  10. mrand "math/rand"
  11. "runtime"
  12. "strings"
  13. "time"
  14. "github.com/pingcap/errors"
  15. "github.com/siddontang/go/hack"
  16. )
  17. func Pstack() string {
  18. buf := make([]byte, 1024)
  19. n := runtime.Stack(buf, false)
  20. return string(buf[0:n])
  21. }
  22. func CalcPassword(scramble, password []byte) []byte {
  23. if len(password) == 0 {
  24. return nil
  25. }
  26. // stage1Hash = SHA1(password)
  27. crypt := sha1.New()
  28. crypt.Write(password)
  29. stage1 := crypt.Sum(nil)
  30. // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
  31. // inner Hash
  32. crypt.Reset()
  33. crypt.Write(stage1)
  34. hash := crypt.Sum(nil)
  35. // outer Hash
  36. crypt.Reset()
  37. crypt.Write(scramble)
  38. crypt.Write(hash)
  39. scramble = crypt.Sum(nil)
  40. // token = scrambleHash XOR stage1Hash
  41. for i := range scramble {
  42. scramble[i] ^= stage1[i]
  43. }
  44. return scramble
  45. }
  46. // CalcCachingSha2Password: Hash password using MySQL 8+ method (SHA256)
  47. func CalcCachingSha2Password(scramble []byte, password string) []byte {
  48. if len(password) == 0 {
  49. return nil
  50. }
  51. // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
  52. crypt := sha256.New()
  53. crypt.Write([]byte(password))
  54. message1 := crypt.Sum(nil)
  55. crypt.Reset()
  56. crypt.Write(message1)
  57. message1Hash := crypt.Sum(nil)
  58. crypt.Reset()
  59. crypt.Write(message1Hash)
  60. crypt.Write(scramble)
  61. message2 := crypt.Sum(nil)
  62. for i := range message1 {
  63. message1[i] ^= message2[i]
  64. }
  65. return message1
  66. }
  67. func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
  68. plain := make([]byte, len(password)+1)
  69. copy(plain, password)
  70. for i := range plain {
  71. j := i % len(seed)
  72. plain[i] ^= seed[j]
  73. }
  74. sha1v := sha1.New()
  75. return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil)
  76. }
  77. // AppendLengthEncodedInteger: encodes a uint64 value and appends it to the given bytes slice
  78. func AppendLengthEncodedInteger(b []byte, n uint64) []byte {
  79. switch {
  80. case n <= 250:
  81. return append(b, byte(n))
  82. case n <= 0xffff:
  83. return append(b, 0xfc, byte(n), byte(n>>8))
  84. case n <= 0xffffff:
  85. return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
  86. }
  87. return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
  88. byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
  89. }
  90. func RandomBuf(size int) ([]byte, error) {
  91. buf := make([]byte, size)
  92. mrand.Seed(time.Now().UTC().UnixNano())
  93. min, max := 30, 127
  94. for i := 0; i < size; i++ {
  95. buf[i] = byte(min + mrand.Intn(max-min))
  96. }
  97. return buf, nil
  98. }
  99. // FixedLengthInt: little endian
  100. func FixedLengthInt(buf []byte) uint64 {
  101. var num uint64 = 0
  102. for i, b := range buf {
  103. num |= uint64(b) << (uint(i) * 8)
  104. }
  105. return num
  106. }
  107. // BFixedLengthInt: big endian
  108. func BFixedLengthInt(buf []byte) uint64 {
  109. var num uint64 = 0
  110. for i, b := range buf {
  111. num |= uint64(b) << (uint(len(buf)-i-1) * 8)
  112. }
  113. return num
  114. }
  115. func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
  116. if len(b) == 0 {
  117. return 0, true, 0
  118. }
  119. switch b[0] {
  120. // 251: NULL
  121. case 0xfb:
  122. return 0, true, 1
  123. // 252: value of following 2
  124. case 0xfc:
  125. return uint64(b[1]) | uint64(b[2])<<8, false, 3
  126. // 253: value of following 3
  127. case 0xfd:
  128. return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
  129. // 254: value of following 8
  130. case 0xfe:
  131. return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
  132. uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
  133. uint64(b[7])<<48 | uint64(b[8])<<56,
  134. false, 9
  135. }
  136. // 0-250: value of first byte
  137. return uint64(b[0]), false, 1
  138. }
  139. func PutLengthEncodedInt(n uint64) []byte {
  140. switch {
  141. case n <= 250:
  142. return []byte{byte(n)}
  143. case n <= 0xffff:
  144. return []byte{0xfc, byte(n), byte(n >> 8)}
  145. case n <= 0xffffff:
  146. return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
  147. case n <= 0xffffffffffffffff:
  148. return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24),
  149. byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}
  150. }
  151. return nil
  152. }
  153. // LengthEncodedString returns the string read as a bytes slice, whether the value is NULL,
  154. // the number of bytes read and an error, in case the string is longer than
  155. // the input slice
  156. func LengthEncodedString(b []byte) ([]byte, bool, int, error) {
  157. // Get length
  158. num, isNull, n := LengthEncodedInt(b)
  159. if num < 1 {
  160. return b[n:n], isNull, n, nil
  161. }
  162. n += int(num)
  163. // Check data length
  164. if len(b) >= n {
  165. return b[n-int(num) : n : n], false, n, nil
  166. }
  167. return nil, false, n, io.EOF
  168. }
  169. func SkipLengthEncodedString(b []byte) (int, error) {
  170. // Get length
  171. num, _, n := LengthEncodedInt(b)
  172. if num < 1 {
  173. return n, nil
  174. }
  175. n += int(num)
  176. // Check data length
  177. if len(b) >= n {
  178. return n, nil
  179. }
  180. return n, io.EOF
  181. }
  182. func PutLengthEncodedString(b []byte) []byte {
  183. data := make([]byte, 0, len(b)+9)
  184. data = append(data, PutLengthEncodedInt(uint64(len(b)))...)
  185. data = append(data, b...)
  186. return data
  187. }
  188. func Uint16ToBytes(n uint16) []byte {
  189. return []byte{
  190. byte(n),
  191. byte(n >> 8),
  192. }
  193. }
  194. func Uint32ToBytes(n uint32) []byte {
  195. return []byte{
  196. byte(n),
  197. byte(n >> 8),
  198. byte(n >> 16),
  199. byte(n >> 24),
  200. }
  201. }
  202. func Uint64ToBytes(n uint64) []byte {
  203. return []byte{
  204. byte(n),
  205. byte(n >> 8),
  206. byte(n >> 16),
  207. byte(n >> 24),
  208. byte(n >> 32),
  209. byte(n >> 40),
  210. byte(n >> 48),
  211. byte(n >> 56),
  212. }
  213. }
  214. func FormatBinaryDate(n int, data []byte) ([]byte, error) {
  215. switch n {
  216. case 0:
  217. return []byte("0000-00-00"), nil
  218. case 4:
  219. return []byte(fmt.Sprintf("%04d-%02d-%02d",
  220. binary.LittleEndian.Uint16(data[:2]),
  221. data[2],
  222. data[3])), nil
  223. default:
  224. return nil, errors.Errorf("invalid date packet length %d", n)
  225. }
  226. }
  227. func FormatBinaryDateTime(n int, data []byte) ([]byte, error) {
  228. switch n {
  229. case 0:
  230. return []byte("0000-00-00 00:00:00"), nil
  231. case 4:
  232. return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
  233. binary.LittleEndian.Uint16(data[:2]),
  234. data[2],
  235. data[3])), nil
  236. case 7:
  237. return []byte(fmt.Sprintf(
  238. "%04d-%02d-%02d %02d:%02d:%02d",
  239. binary.LittleEndian.Uint16(data[:2]),
  240. data[2],
  241. data[3],
  242. data[4],
  243. data[5],
  244. data[6])), nil
  245. case 11:
  246. return []byte(fmt.Sprintf(
  247. "%04d-%02d-%02d %02d:%02d:%02d.%06d",
  248. binary.LittleEndian.Uint16(data[:2]),
  249. data[2],
  250. data[3],
  251. data[4],
  252. data[5],
  253. data[6],
  254. binary.LittleEndian.Uint32(data[7:11]))), nil
  255. default:
  256. return nil, errors.Errorf("invalid datetime packet length %d", n)
  257. }
  258. }
  259. func FormatBinaryTime(n int, data []byte) ([]byte, error) {
  260. if n == 0 {
  261. return []byte("0000-00-00"), nil
  262. }
  263. var sign byte
  264. if data[0] == 1 {
  265. sign = byte('-')
  266. }
  267. switch n {
  268. case 8:
  269. return []byte(fmt.Sprintf(
  270. "%c%02d:%02d:%02d",
  271. sign,
  272. uint16(data[1])*24+uint16(data[5]),
  273. data[6],
  274. data[7],
  275. )), nil
  276. case 12:
  277. return []byte(fmt.Sprintf(
  278. "%c%02d:%02d:%02d.%06d",
  279. sign,
  280. uint16(data[1])*24+uint16(data[5]),
  281. data[6],
  282. data[7],
  283. binary.LittleEndian.Uint32(data[8:12]),
  284. )), nil
  285. default:
  286. return nil, errors.Errorf("invalid time packet length %d", n)
  287. }
  288. }
  289. var (
  290. DONTESCAPE = byte(255)
  291. EncodeMap [256]byte
  292. )
  293. // Escape: only support utf-8
  294. func Escape(sql string) string {
  295. dest := make([]byte, 0, 2*len(sql))
  296. for _, w := range hack.Slice(sql) {
  297. if c := EncodeMap[w]; c == DONTESCAPE {
  298. dest = append(dest, w)
  299. } else {
  300. dest = append(dest, '\\', c)
  301. }
  302. }
  303. return string(dest)
  304. }
  305. func GetNetProto(addr string) string {
  306. if strings.Contains(addr, "/") {
  307. return "unix"
  308. } else {
  309. return "tcp"
  310. }
  311. }
  312. // ErrorEqual returns a boolean indicating whether err1 is equal to err2.
  313. func ErrorEqual(err1, err2 error) bool {
  314. e1 := errors.Cause(err1)
  315. e2 := errors.Cause(err2)
  316. if e1 == e2 {
  317. return true
  318. }
  319. if e1 == nil || e2 == nil {
  320. return e1 == e2
  321. }
  322. return e1.Error() == e2.Error()
  323. }
  324. var encodeRef = map[byte]byte{
  325. '\x00': '0',
  326. '\'': '\'',
  327. '"': '"',
  328. '\b': 'b',
  329. '\n': 'n',
  330. '\r': 'r',
  331. '\t': 't',
  332. 26: 'Z', // ctl-Z
  333. '\\': '\\',
  334. }
  335. func init() {
  336. for i := range EncodeMap {
  337. EncodeMap[i] = DONTESCAPE
  338. }
  339. for i := range EncodeMap {
  340. if to, ok := encodeRef[byte(i)]; ok {
  341. EncodeMap[byte(i)] = to
  342. }
  343. }
  344. }