decode.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. package msgpack
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "reflect"
  9. "sync"
  10. "time"
  11. "github.com/vmihailenco/msgpack/v5/msgpcode"
  12. )
  13. const (
  14. looseInterfaceDecodingFlag uint32 = 1 << iota
  15. disallowUnknownFieldsFlag
  16. )
  17. const (
  18. bytesAllocLimit = 1e6 // 1mb
  19. sliceAllocLimit = 1e4
  20. maxMapSize = 1e6
  21. )
  22. type bufReader interface {
  23. io.Reader
  24. io.ByteScanner
  25. }
  26. //------------------------------------------------------------------------------
  27. var decPool = sync.Pool{
  28. New: func() interface{} {
  29. return NewDecoder(nil)
  30. },
  31. }
  32. func GetDecoder() *Decoder {
  33. return decPool.Get().(*Decoder)
  34. }
  35. func PutDecoder(dec *Decoder) {
  36. dec.r = nil
  37. dec.s = nil
  38. decPool.Put(dec)
  39. }
  40. //------------------------------------------------------------------------------
  41. // Unmarshal decodes the MessagePack-encoded data and stores the result
  42. // in the value pointed to by v.
  43. func Unmarshal(data []byte, v interface{}) error {
  44. dec := GetDecoder()
  45. dec.Reset(bytes.NewReader(data))
  46. err := dec.Decode(v)
  47. PutDecoder(dec)
  48. return err
  49. }
  50. // A Decoder reads and decodes MessagePack values from an input stream.
  51. type Decoder struct {
  52. r io.Reader
  53. s io.ByteScanner
  54. buf []byte
  55. rec []byte // accumulates read data if not nil
  56. dict []string
  57. flags uint32
  58. structTag string
  59. mapDecoder func(*Decoder) (interface{}, error)
  60. }
  61. // NewDecoder returns a new decoder that reads from r.
  62. //
  63. // The decoder introduces its own buffering and may read data from r
  64. // beyond the requested msgpack values. Buffering can be disabled
  65. // by passing a reader that implements io.ByteScanner interface.
  66. func NewDecoder(r io.Reader) *Decoder {
  67. d := new(Decoder)
  68. d.Reset(r)
  69. return d
  70. }
  71. // Reset discards any buffered data, resets all state, and switches the buffered
  72. // reader to read from r.
  73. func (d *Decoder) Reset(r io.Reader) {
  74. d.ResetDict(r, nil)
  75. }
  76. // ResetDict is like Reset, but also resets the dict.
  77. func (d *Decoder) ResetDict(r io.Reader, dict []string) {
  78. d.resetReader(r)
  79. d.flags = 0
  80. d.structTag = ""
  81. d.mapDecoder = nil
  82. d.dict = dict
  83. }
  84. func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error {
  85. oldDict := d.dict
  86. d.dict = dict
  87. err := fn(d)
  88. d.dict = oldDict
  89. return err
  90. }
  91. func (d *Decoder) resetReader(r io.Reader) {
  92. if br, ok := r.(bufReader); ok {
  93. d.r = br
  94. d.s = br
  95. } else {
  96. br := bufio.NewReader(r)
  97. d.r = br
  98. d.s = br
  99. }
  100. }
  101. func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) {
  102. d.mapDecoder = fn
  103. }
  104. // UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose
  105. // to decode msgpack value into Go interface{}.
  106. func (d *Decoder) UseLooseInterfaceDecoding(on bool) {
  107. if on {
  108. d.flags |= looseInterfaceDecodingFlag
  109. } else {
  110. d.flags &= ^looseInterfaceDecodingFlag
  111. }
  112. }
  113. // SetCustomStructTag causes the decoder to use the supplied tag as a fallback option
  114. // if there is no msgpack tag.
  115. func (d *Decoder) SetCustomStructTag(tag string) {
  116. d.structTag = tag
  117. }
  118. // DisallowUnknownFields causes the Decoder to return an error when the destination
  119. // is a struct and the input contains object keys which do not match any
  120. // non-ignored, exported fields in the destination.
  121. func (d *Decoder) DisallowUnknownFields(on bool) {
  122. if on {
  123. d.flags |= disallowUnknownFieldsFlag
  124. } else {
  125. d.flags &= ^disallowUnknownFieldsFlag
  126. }
  127. }
  128. // UseInternedStrings enables support for decoding interned strings.
  129. func (d *Decoder) UseInternedStrings(on bool) {
  130. if on {
  131. d.flags |= useInternedStringsFlag
  132. } else {
  133. d.flags &= ^useInternedStringsFlag
  134. }
  135. }
  136. // Buffered returns a reader of the data remaining in the Decoder's buffer.
  137. // The reader is valid until the next call to Decode.
  138. func (d *Decoder) Buffered() io.Reader {
  139. return d.r
  140. }
  141. //nolint:gocyclo
  142. func (d *Decoder) Decode(v interface{}) error {
  143. var err error
  144. switch v := v.(type) {
  145. case *string:
  146. if v != nil {
  147. *v, err = d.DecodeString()
  148. return err
  149. }
  150. case *[]byte:
  151. if v != nil {
  152. return d.decodeBytesPtr(v)
  153. }
  154. case *int:
  155. if v != nil {
  156. *v, err = d.DecodeInt()
  157. return err
  158. }
  159. case *int8:
  160. if v != nil {
  161. *v, err = d.DecodeInt8()
  162. return err
  163. }
  164. case *int16:
  165. if v != nil {
  166. *v, err = d.DecodeInt16()
  167. return err
  168. }
  169. case *int32:
  170. if v != nil {
  171. *v, err = d.DecodeInt32()
  172. return err
  173. }
  174. case *int64:
  175. if v != nil {
  176. *v, err = d.DecodeInt64()
  177. return err
  178. }
  179. case *uint:
  180. if v != nil {
  181. *v, err = d.DecodeUint()
  182. return err
  183. }
  184. case *uint8:
  185. if v != nil {
  186. *v, err = d.DecodeUint8()
  187. return err
  188. }
  189. case *uint16:
  190. if v != nil {
  191. *v, err = d.DecodeUint16()
  192. return err
  193. }
  194. case *uint32:
  195. if v != nil {
  196. *v, err = d.DecodeUint32()
  197. return err
  198. }
  199. case *uint64:
  200. if v != nil {
  201. *v, err = d.DecodeUint64()
  202. return err
  203. }
  204. case *bool:
  205. if v != nil {
  206. *v, err = d.DecodeBool()
  207. return err
  208. }
  209. case *float32:
  210. if v != nil {
  211. *v, err = d.DecodeFloat32()
  212. return err
  213. }
  214. case *float64:
  215. if v != nil {
  216. *v, err = d.DecodeFloat64()
  217. return err
  218. }
  219. case *[]string:
  220. return d.decodeStringSlicePtr(v)
  221. case *map[string]string:
  222. return d.decodeMapStringStringPtr(v)
  223. case *map[string]interface{}:
  224. return d.decodeMapStringInterfacePtr(v)
  225. case *time.Duration:
  226. if v != nil {
  227. vv, err := d.DecodeInt64()
  228. *v = time.Duration(vv)
  229. return err
  230. }
  231. case *time.Time:
  232. if v != nil {
  233. *v, err = d.DecodeTime()
  234. return err
  235. }
  236. }
  237. vv := reflect.ValueOf(v)
  238. if !vv.IsValid() {
  239. return errors.New("msgpack: Decode(nil)")
  240. }
  241. if vv.Kind() != reflect.Ptr {
  242. return fmt.Errorf("msgpack: Decode(non-pointer %T)", v)
  243. }
  244. if vv.IsNil() {
  245. return fmt.Errorf("msgpack: Decode(non-settable %T)", v)
  246. }
  247. vv = vv.Elem()
  248. if vv.Kind() == reflect.Interface {
  249. if !vv.IsNil() {
  250. vv = vv.Elem()
  251. if vv.Kind() != reflect.Ptr {
  252. return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String())
  253. }
  254. }
  255. }
  256. return d.DecodeValue(vv)
  257. }
  258. func (d *Decoder) DecodeMulti(v ...interface{}) error {
  259. for _, vv := range v {
  260. if err := d.Decode(vv); err != nil {
  261. return err
  262. }
  263. }
  264. return nil
  265. }
  266. func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
  267. if d.flags&looseInterfaceDecodingFlag != 0 {
  268. return d.DecodeInterfaceLoose()
  269. }
  270. return d.DecodeInterface()
  271. }
  272. func (d *Decoder) DecodeValue(v reflect.Value) error {
  273. decode := getDecoder(v.Type())
  274. return decode(d, v)
  275. }
  276. func (d *Decoder) DecodeNil() error {
  277. c, err := d.readCode()
  278. if err != nil {
  279. return err
  280. }
  281. if c != msgpcode.Nil {
  282. return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
  283. }
  284. return nil
  285. }
  286. func (d *Decoder) decodeNilValue(v reflect.Value) error {
  287. err := d.DecodeNil()
  288. if v.IsNil() {
  289. return err
  290. }
  291. if v.Kind() == reflect.Ptr {
  292. v = v.Elem()
  293. }
  294. v.Set(reflect.Zero(v.Type()))
  295. return err
  296. }
  297. func (d *Decoder) DecodeBool() (bool, error) {
  298. c, err := d.readCode()
  299. if err != nil {
  300. return false, err
  301. }
  302. return d.bool(c)
  303. }
  304. func (d *Decoder) bool(c byte) (bool, error) {
  305. if c == msgpcode.Nil {
  306. return false, nil
  307. }
  308. if c == msgpcode.False {
  309. return false, nil
  310. }
  311. if c == msgpcode.True {
  312. return true, nil
  313. }
  314. return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
  315. }
  316. func (d *Decoder) DecodeDuration() (time.Duration, error) {
  317. n, err := d.DecodeInt64()
  318. if err != nil {
  319. return 0, err
  320. }
  321. return time.Duration(n), nil
  322. }
  323. // DecodeInterface decodes value into interface. It returns following types:
  324. // - nil,
  325. // - bool,
  326. // - int8, int16, int32, int64,
  327. // - uint8, uint16, uint32, uint64,
  328. // - float32 and float64,
  329. // - string,
  330. // - []byte,
  331. // - slices of any of the above,
  332. // - maps of any of the above.
  333. //
  334. // DecodeInterface should be used only when you don't know the type of value
  335. // you are decoding. For example, if you are decoding number it is better to use
  336. // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
  337. func (d *Decoder) DecodeInterface() (interface{}, error) {
  338. c, err := d.readCode()
  339. if err != nil {
  340. return nil, err
  341. }
  342. if msgpcode.IsFixedNum(c) {
  343. return int8(c), nil
  344. }
  345. if msgpcode.IsFixedMap(c) {
  346. err = d.s.UnreadByte()
  347. if err != nil {
  348. return nil, err
  349. }
  350. return d.decodeMapDefault()
  351. }
  352. if msgpcode.IsFixedArray(c) {
  353. return d.decodeSlice(c)
  354. }
  355. if msgpcode.IsFixedString(c) {
  356. return d.string(c)
  357. }
  358. switch c {
  359. case msgpcode.Nil:
  360. return nil, nil
  361. case msgpcode.False, msgpcode.True:
  362. return d.bool(c)
  363. case msgpcode.Float:
  364. return d.float32(c)
  365. case msgpcode.Double:
  366. return d.float64(c)
  367. case msgpcode.Uint8:
  368. return d.uint8()
  369. case msgpcode.Uint16:
  370. return d.uint16()
  371. case msgpcode.Uint32:
  372. return d.uint32()
  373. case msgpcode.Uint64:
  374. return d.uint64()
  375. case msgpcode.Int8:
  376. return d.int8()
  377. case msgpcode.Int16:
  378. return d.int16()
  379. case msgpcode.Int32:
  380. return d.int32()
  381. case msgpcode.Int64:
  382. return d.int64()
  383. case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
  384. return d.bytes(c, nil)
  385. case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
  386. return d.string(c)
  387. case msgpcode.Array16, msgpcode.Array32:
  388. return d.decodeSlice(c)
  389. case msgpcode.Map16, msgpcode.Map32:
  390. err = d.s.UnreadByte()
  391. if err != nil {
  392. return nil, err
  393. }
  394. return d.decodeMapDefault()
  395. case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
  396. msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
  397. return d.decodeInterfaceExt(c)
  398. }
  399. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  400. }
  401. // DecodeInterfaceLoose is like DecodeInterface except that:
  402. // - int8, int16, and int32 are converted to int64,
  403. // - uint8, uint16, and uint32 are converted to uint64,
  404. // - float32 is converted to float64.
  405. // - []byte is converted to string.
  406. func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
  407. c, err := d.readCode()
  408. if err != nil {
  409. return nil, err
  410. }
  411. if msgpcode.IsFixedNum(c) {
  412. return int64(int8(c)), nil
  413. }
  414. if msgpcode.IsFixedMap(c) {
  415. err = d.s.UnreadByte()
  416. if err != nil {
  417. return nil, err
  418. }
  419. return d.decodeMapDefault()
  420. }
  421. if msgpcode.IsFixedArray(c) {
  422. return d.decodeSlice(c)
  423. }
  424. if msgpcode.IsFixedString(c) {
  425. return d.string(c)
  426. }
  427. switch c {
  428. case msgpcode.Nil:
  429. return nil, nil
  430. case msgpcode.False, msgpcode.True:
  431. return d.bool(c)
  432. case msgpcode.Float, msgpcode.Double:
  433. return d.float64(c)
  434. case msgpcode.Uint8, msgpcode.Uint16, msgpcode.Uint32, msgpcode.Uint64:
  435. return d.uint(c)
  436. case msgpcode.Int8, msgpcode.Int16, msgpcode.Int32, msgpcode.Int64:
  437. return d.int(c)
  438. case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32,
  439. msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
  440. return d.string(c)
  441. case msgpcode.Array16, msgpcode.Array32:
  442. return d.decodeSlice(c)
  443. case msgpcode.Map16, msgpcode.Map32:
  444. err = d.s.UnreadByte()
  445. if err != nil {
  446. return nil, err
  447. }
  448. return d.decodeMapDefault()
  449. case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
  450. msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
  451. return d.decodeInterfaceExt(c)
  452. }
  453. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  454. }
  455. // Skip skips next value.
  456. func (d *Decoder) Skip() error {
  457. c, err := d.readCode()
  458. if err != nil {
  459. return err
  460. }
  461. if msgpcode.IsFixedNum(c) {
  462. return nil
  463. }
  464. if msgpcode.IsFixedMap(c) {
  465. return d.skipMap(c)
  466. }
  467. if msgpcode.IsFixedArray(c) {
  468. return d.skipSlice(c)
  469. }
  470. if msgpcode.IsFixedString(c) {
  471. return d.skipBytes(c)
  472. }
  473. switch c {
  474. case msgpcode.Nil, msgpcode.False, msgpcode.True:
  475. return nil
  476. case msgpcode.Uint8, msgpcode.Int8:
  477. return d.skipN(1)
  478. case msgpcode.Uint16, msgpcode.Int16:
  479. return d.skipN(2)
  480. case msgpcode.Uint32, msgpcode.Int32, msgpcode.Float:
  481. return d.skipN(4)
  482. case msgpcode.Uint64, msgpcode.Int64, msgpcode.Double:
  483. return d.skipN(8)
  484. case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
  485. return d.skipBytes(c)
  486. case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
  487. return d.skipBytes(c)
  488. case msgpcode.Array16, msgpcode.Array32:
  489. return d.skipSlice(c)
  490. case msgpcode.Map16, msgpcode.Map32:
  491. return d.skipMap(c)
  492. case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
  493. msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
  494. return d.skipExt(c)
  495. }
  496. return fmt.Errorf("msgpack: unknown code %x", c)
  497. }
  498. func (d *Decoder) DecodeRaw() (RawMessage, error) {
  499. d.rec = make([]byte, 0)
  500. if err := d.Skip(); err != nil {
  501. return nil, err
  502. }
  503. msg := RawMessage(d.rec)
  504. d.rec = nil
  505. return msg, nil
  506. }
  507. // PeekCode returns the next MessagePack code without advancing the reader.
  508. // Subpackage msgpack/codes defines the list of available msgpcode.
  509. func (d *Decoder) PeekCode() (byte, error) {
  510. c, err := d.s.ReadByte()
  511. if err != nil {
  512. return 0, err
  513. }
  514. return c, d.s.UnreadByte()
  515. }
  516. // ReadFull reads exactly len(buf) bytes into the buf.
  517. func (d *Decoder) ReadFull(buf []byte) error {
  518. _, err := readN(d.r, buf, len(buf))
  519. return err
  520. }
  521. func (d *Decoder) hasNilCode() bool {
  522. code, err := d.PeekCode()
  523. return err == nil && code == msgpcode.Nil
  524. }
  525. func (d *Decoder) readCode() (byte, error) {
  526. c, err := d.s.ReadByte()
  527. if err != nil {
  528. return 0, err
  529. }
  530. if d.rec != nil {
  531. d.rec = append(d.rec, c)
  532. }
  533. return c, nil
  534. }
  535. func (d *Decoder) readFull(b []byte) error {
  536. _, err := io.ReadFull(d.r, b)
  537. if err != nil {
  538. return err
  539. }
  540. if d.rec != nil {
  541. d.rec = append(d.rec, b...)
  542. }
  543. return nil
  544. }
  545. func (d *Decoder) readN(n int) ([]byte, error) {
  546. var err error
  547. d.buf, err = readN(d.r, d.buf, n)
  548. if err != nil {
  549. return nil, err
  550. }
  551. if d.rec != nil {
  552. // TODO: read directly into d.rec?
  553. d.rec = append(d.rec, d.buf...)
  554. }
  555. return d.buf, nil
  556. }
  557. func readN(r io.Reader, b []byte, n int) ([]byte, error) {
  558. if b == nil {
  559. if n == 0 {
  560. return make([]byte, 0), nil
  561. }
  562. switch {
  563. case n < 64:
  564. b = make([]byte, 0, 64)
  565. case n <= bytesAllocLimit:
  566. b = make([]byte, 0, n)
  567. default:
  568. b = make([]byte, 0, bytesAllocLimit)
  569. }
  570. }
  571. if n <= cap(b) {
  572. b = b[:n]
  573. _, err := io.ReadFull(r, b)
  574. return b, err
  575. }
  576. b = b[:cap(b)]
  577. var pos int
  578. for {
  579. alloc := min(n-len(b), bytesAllocLimit)
  580. b = append(b, make([]byte, alloc)...)
  581. _, err := io.ReadFull(r, b[pos:])
  582. if err != nil {
  583. return b, err
  584. }
  585. if len(b) == n {
  586. break
  587. }
  588. pos = len(b)
  589. }
  590. return b, nil
  591. }
  592. func min(a, b int) int { //nolint:unparam
  593. if a <= b {
  594. return a
  595. }
  596. return b
  597. }