123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663 |
- package msgpack
- import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "io"
- "reflect"
- "sync"
- "time"
- "github.com/vmihailenco/msgpack/v5/msgpcode"
- )
- const (
- looseInterfaceDecodingFlag uint32 = 1 << iota
- disallowUnknownFieldsFlag
- )
- const (
- bytesAllocLimit = 1e6 // 1mb
- sliceAllocLimit = 1e4
- maxMapSize = 1e6
- )
- type bufReader interface {
- io.Reader
- io.ByteScanner
- }
- //------------------------------------------------------------------------------
- var decPool = sync.Pool{
- New: func() interface{} {
- return NewDecoder(nil)
- },
- }
- func GetDecoder() *Decoder {
- return decPool.Get().(*Decoder)
- }
- func PutDecoder(dec *Decoder) {
- dec.r = nil
- dec.s = nil
- decPool.Put(dec)
- }
- //------------------------------------------------------------------------------
- // Unmarshal decodes the MessagePack-encoded data and stores the result
- // in the value pointed to by v.
- func Unmarshal(data []byte, v interface{}) error {
- dec := GetDecoder()
- dec.Reset(bytes.NewReader(data))
- err := dec.Decode(v)
- PutDecoder(dec)
- return err
- }
- // A Decoder reads and decodes MessagePack values from an input stream.
- type Decoder struct {
- r io.Reader
- s io.ByteScanner
- buf []byte
- rec []byte // accumulates read data if not nil
- dict []string
- flags uint32
- structTag string
- mapDecoder func(*Decoder) (interface{}, error)
- }
- // NewDecoder returns a new decoder that reads from r.
- //
- // The decoder introduces its own buffering and may read data from r
- // beyond the requested msgpack values. Buffering can be disabled
- // by passing a reader that implements io.ByteScanner interface.
- func NewDecoder(r io.Reader) *Decoder {
- d := new(Decoder)
- d.Reset(r)
- return d
- }
- // Reset discards any buffered data, resets all state, and switches the buffered
- // reader to read from r.
- func (d *Decoder) Reset(r io.Reader) {
- d.ResetDict(r, nil)
- }
- // ResetDict is like Reset, but also resets the dict.
- func (d *Decoder) ResetDict(r io.Reader, dict []string) {
- d.resetReader(r)
- d.flags = 0
- d.structTag = ""
- d.mapDecoder = nil
- d.dict = dict
- }
- func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error {
- oldDict := d.dict
- d.dict = dict
- err := fn(d)
- d.dict = oldDict
- return err
- }
- func (d *Decoder) resetReader(r io.Reader) {
- if br, ok := r.(bufReader); ok {
- d.r = br
- d.s = br
- } else {
- br := bufio.NewReader(r)
- d.r = br
- d.s = br
- }
- }
- func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) {
- d.mapDecoder = fn
- }
- // UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose
- // to decode msgpack value into Go interface{}.
- func (d *Decoder) UseLooseInterfaceDecoding(on bool) {
- if on {
- d.flags |= looseInterfaceDecodingFlag
- } else {
- d.flags &= ^looseInterfaceDecodingFlag
- }
- }
- // SetCustomStructTag causes the decoder to use the supplied tag as a fallback option
- // if there is no msgpack tag.
- func (d *Decoder) SetCustomStructTag(tag string) {
- d.structTag = tag
- }
- // DisallowUnknownFields causes the Decoder to return an error when the destination
- // is a struct and the input contains object keys which do not match any
- // non-ignored, exported fields in the destination.
- func (d *Decoder) DisallowUnknownFields(on bool) {
- if on {
- d.flags |= disallowUnknownFieldsFlag
- } else {
- d.flags &= ^disallowUnknownFieldsFlag
- }
- }
- // UseInternedStrings enables support for decoding interned strings.
- func (d *Decoder) UseInternedStrings(on bool) {
- if on {
- d.flags |= useInternedStringsFlag
- } else {
- d.flags &= ^useInternedStringsFlag
- }
- }
- // Buffered returns a reader of the data remaining in the Decoder's buffer.
- // The reader is valid until the next call to Decode.
- func (d *Decoder) Buffered() io.Reader {
- return d.r
- }
- //nolint:gocyclo
- func (d *Decoder) Decode(v interface{}) error {
- var err error
- switch v := v.(type) {
- case *string:
- if v != nil {
- *v, err = d.DecodeString()
- return err
- }
- case *[]byte:
- if v != nil {
- return d.decodeBytesPtr(v)
- }
- case *int:
- if v != nil {
- *v, err = d.DecodeInt()
- return err
- }
- case *int8:
- if v != nil {
- *v, err = d.DecodeInt8()
- return err
- }
- case *int16:
- if v != nil {
- *v, err = d.DecodeInt16()
- return err
- }
- case *int32:
- if v != nil {
- *v, err = d.DecodeInt32()
- return err
- }
- case *int64:
- if v != nil {
- *v, err = d.DecodeInt64()
- return err
- }
- case *uint:
- if v != nil {
- *v, err = d.DecodeUint()
- return err
- }
- case *uint8:
- if v != nil {
- *v, err = d.DecodeUint8()
- return err
- }
- case *uint16:
- if v != nil {
- *v, err = d.DecodeUint16()
- return err
- }
- case *uint32:
- if v != nil {
- *v, err = d.DecodeUint32()
- return err
- }
- case *uint64:
- if v != nil {
- *v, err = d.DecodeUint64()
- return err
- }
- case *bool:
- if v != nil {
- *v, err = d.DecodeBool()
- return err
- }
- case *float32:
- if v != nil {
- *v, err = d.DecodeFloat32()
- return err
- }
- case *float64:
- if v != nil {
- *v, err = d.DecodeFloat64()
- return err
- }
- case *[]string:
- return d.decodeStringSlicePtr(v)
- case *map[string]string:
- return d.decodeMapStringStringPtr(v)
- case *map[string]interface{}:
- return d.decodeMapStringInterfacePtr(v)
- case *time.Duration:
- if v != nil {
- vv, err := d.DecodeInt64()
- *v = time.Duration(vv)
- return err
- }
- case *time.Time:
- if v != nil {
- *v, err = d.DecodeTime()
- return err
- }
- }
- vv := reflect.ValueOf(v)
- if !vv.IsValid() {
- return errors.New("msgpack: Decode(nil)")
- }
- if vv.Kind() != reflect.Ptr {
- return fmt.Errorf("msgpack: Decode(non-pointer %T)", v)
- }
- if vv.IsNil() {
- return fmt.Errorf("msgpack: Decode(non-settable %T)", v)
- }
- vv = vv.Elem()
- if vv.Kind() == reflect.Interface {
- if !vv.IsNil() {
- vv = vv.Elem()
- if vv.Kind() != reflect.Ptr {
- return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String())
- }
- }
- }
- return d.DecodeValue(vv)
- }
- func (d *Decoder) DecodeMulti(v ...interface{}) error {
- for _, vv := range v {
- if err := d.Decode(vv); err != nil {
- return err
- }
- }
- return nil
- }
- func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
- if d.flags&looseInterfaceDecodingFlag != 0 {
- return d.DecodeInterfaceLoose()
- }
- return d.DecodeInterface()
- }
- func (d *Decoder) DecodeValue(v reflect.Value) error {
- decode := getDecoder(v.Type())
- return decode(d, v)
- }
- func (d *Decoder) DecodeNil() error {
- c, err := d.readCode()
- if err != nil {
- return err
- }
- if c != msgpcode.Nil {
- return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
- }
- return nil
- }
- func (d *Decoder) decodeNilValue(v reflect.Value) error {
- err := d.DecodeNil()
- if v.IsNil() {
- return err
- }
- if v.Kind() == reflect.Ptr {
- v = v.Elem()
- }
- v.Set(reflect.Zero(v.Type()))
- return err
- }
- func (d *Decoder) DecodeBool() (bool, error) {
- c, err := d.readCode()
- if err != nil {
- return false, err
- }
- return d.bool(c)
- }
- func (d *Decoder) bool(c byte) (bool, error) {
- if c == msgpcode.Nil {
- return false, nil
- }
- if c == msgpcode.False {
- return false, nil
- }
- if c == msgpcode.True {
- return true, nil
- }
- return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
- }
- func (d *Decoder) DecodeDuration() (time.Duration, error) {
- n, err := d.DecodeInt64()
- if err != nil {
- return 0, err
- }
- return time.Duration(n), nil
- }
- // DecodeInterface decodes value into interface. It returns following types:
- // - nil,
- // - bool,
- // - int8, int16, int32, int64,
- // - uint8, uint16, uint32, uint64,
- // - float32 and float64,
- // - string,
- // - []byte,
- // - slices of any of the above,
- // - maps of any of the above.
- //
- // DecodeInterface should be used only when you don't know the type of value
- // you are decoding. For example, if you are decoding number it is better to use
- // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
- func (d *Decoder) DecodeInterface() (interface{}, error) {
- c, err := d.readCode()
- if err != nil {
- return nil, err
- }
- if msgpcode.IsFixedNum(c) {
- return int8(c), nil
- }
- if msgpcode.IsFixedMap(c) {
- err = d.s.UnreadByte()
- if err != nil {
- return nil, err
- }
- return d.decodeMapDefault()
- }
- if msgpcode.IsFixedArray(c) {
- return d.decodeSlice(c)
- }
- if msgpcode.IsFixedString(c) {
- return d.string(c)
- }
- switch c {
- case msgpcode.Nil:
- return nil, nil
- case msgpcode.False, msgpcode.True:
- return d.bool(c)
- case msgpcode.Float:
- return d.float32(c)
- case msgpcode.Double:
- return d.float64(c)
- case msgpcode.Uint8:
- return d.uint8()
- case msgpcode.Uint16:
- return d.uint16()
- case msgpcode.Uint32:
- return d.uint32()
- case msgpcode.Uint64:
- return d.uint64()
- case msgpcode.Int8:
- return d.int8()
- case msgpcode.Int16:
- return d.int16()
- case msgpcode.Int32:
- return d.int32()
- case msgpcode.Int64:
- return d.int64()
- case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
- return d.bytes(c, nil)
- case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
- return d.string(c)
- case msgpcode.Array16, msgpcode.Array32:
- return d.decodeSlice(c)
- case msgpcode.Map16, msgpcode.Map32:
- err = d.s.UnreadByte()
- if err != nil {
- return nil, err
- }
- return d.decodeMapDefault()
- case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
- msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
- return d.decodeInterfaceExt(c)
- }
- return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
- }
- // DecodeInterfaceLoose is like DecodeInterface except that:
- // - int8, int16, and int32 are converted to int64,
- // - uint8, uint16, and uint32 are converted to uint64,
- // - float32 is converted to float64.
- // - []byte is converted to string.
- func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
- c, err := d.readCode()
- if err != nil {
- return nil, err
- }
- if msgpcode.IsFixedNum(c) {
- return int64(int8(c)), nil
- }
- if msgpcode.IsFixedMap(c) {
- err = d.s.UnreadByte()
- if err != nil {
- return nil, err
- }
- return d.decodeMapDefault()
- }
- if msgpcode.IsFixedArray(c) {
- return d.decodeSlice(c)
- }
- if msgpcode.IsFixedString(c) {
- return d.string(c)
- }
- switch c {
- case msgpcode.Nil:
- return nil, nil
- case msgpcode.False, msgpcode.True:
- return d.bool(c)
- case msgpcode.Float, msgpcode.Double:
- return d.float64(c)
- case msgpcode.Uint8, msgpcode.Uint16, msgpcode.Uint32, msgpcode.Uint64:
- return d.uint(c)
- case msgpcode.Int8, msgpcode.Int16, msgpcode.Int32, msgpcode.Int64:
- return d.int(c)
- case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32,
- msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
- return d.string(c)
- case msgpcode.Array16, msgpcode.Array32:
- return d.decodeSlice(c)
- case msgpcode.Map16, msgpcode.Map32:
- err = d.s.UnreadByte()
- if err != nil {
- return nil, err
- }
- return d.decodeMapDefault()
- case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
- msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
- return d.decodeInterfaceExt(c)
- }
- return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
- }
- // Skip skips next value.
- func (d *Decoder) Skip() error {
- c, err := d.readCode()
- if err != nil {
- return err
- }
- if msgpcode.IsFixedNum(c) {
- return nil
- }
- if msgpcode.IsFixedMap(c) {
- return d.skipMap(c)
- }
- if msgpcode.IsFixedArray(c) {
- return d.skipSlice(c)
- }
- if msgpcode.IsFixedString(c) {
- return d.skipBytes(c)
- }
- switch c {
- case msgpcode.Nil, msgpcode.False, msgpcode.True:
- return nil
- case msgpcode.Uint8, msgpcode.Int8:
- return d.skipN(1)
- case msgpcode.Uint16, msgpcode.Int16:
- return d.skipN(2)
- case msgpcode.Uint32, msgpcode.Int32, msgpcode.Float:
- return d.skipN(4)
- case msgpcode.Uint64, msgpcode.Int64, msgpcode.Double:
- return d.skipN(8)
- case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
- return d.skipBytes(c)
- case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
- return d.skipBytes(c)
- case msgpcode.Array16, msgpcode.Array32:
- return d.skipSlice(c)
- case msgpcode.Map16, msgpcode.Map32:
- return d.skipMap(c)
- case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
- msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
- return d.skipExt(c)
- }
- return fmt.Errorf("msgpack: unknown code %x", c)
- }
- func (d *Decoder) DecodeRaw() (RawMessage, error) {
- d.rec = make([]byte, 0)
- if err := d.Skip(); err != nil {
- return nil, err
- }
- msg := RawMessage(d.rec)
- d.rec = nil
- return msg, nil
- }
- // PeekCode returns the next MessagePack code without advancing the reader.
- // Subpackage msgpack/codes defines the list of available msgpcode.
- func (d *Decoder) PeekCode() (byte, error) {
- c, err := d.s.ReadByte()
- if err != nil {
- return 0, err
- }
- return c, d.s.UnreadByte()
- }
- // ReadFull reads exactly len(buf) bytes into the buf.
- func (d *Decoder) ReadFull(buf []byte) error {
- _, err := readN(d.r, buf, len(buf))
- return err
- }
- func (d *Decoder) hasNilCode() bool {
- code, err := d.PeekCode()
- return err == nil && code == msgpcode.Nil
- }
- func (d *Decoder) readCode() (byte, error) {
- c, err := d.s.ReadByte()
- if err != nil {
- return 0, err
- }
- if d.rec != nil {
- d.rec = append(d.rec, c)
- }
- return c, nil
- }
- func (d *Decoder) readFull(b []byte) error {
- _, err := io.ReadFull(d.r, b)
- if err != nil {
- return err
- }
- if d.rec != nil {
- d.rec = append(d.rec, b...)
- }
- return nil
- }
- func (d *Decoder) readN(n int) ([]byte, error) {
- var err error
- d.buf, err = readN(d.r, d.buf, n)
- if err != nil {
- return nil, err
- }
- if d.rec != nil {
- // TODO: read directly into d.rec?
- d.rec = append(d.rec, d.buf...)
- }
- return d.buf, nil
- }
- func readN(r io.Reader, b []byte, n int) ([]byte, error) {
- if b == nil {
- if n == 0 {
- return make([]byte, 0), nil
- }
- switch {
- case n < 64:
- b = make([]byte, 0, 64)
- case n <= bytesAllocLimit:
- b = make([]byte, 0, n)
- default:
- b = make([]byte, 0, bytesAllocLimit)
- }
- }
- if n <= cap(b) {
- b = b[:n]
- _, err := io.ReadFull(r, b)
- return b, err
- }
- b = b[:cap(b)]
- var pos int
- for {
- alloc := min(n-len(b), bytesAllocLimit)
- b = append(b, make([]byte, alloc)...)
- _, err := io.ReadFull(r, b[pos:])
- if err != nil {
- return b, err
- }
- if len(b) == n {
- break
- }
- pos = len(b)
- }
- return b, nil
- }
- func min(a, b int) int { //nolint:unparam
- if a <= b {
- return a
- }
- return b
- }
|