framedec.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "bytes"
  7. "encoding/hex"
  8. "errors"
  9. "hash"
  10. "io"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. type frameDec struct {
  15. o decoderOptions
  16. crc hash.Hash64
  17. offset int64
  18. WindowSize uint64
  19. // In order queue of blocks being decoded.
  20. decoding chan *blockDec
  21. // Frame history passed between blocks
  22. history history
  23. rawInput byteBuffer
  24. // Byte buffer that can be reused for small input blocks.
  25. bBuf byteBuf
  26. FrameContentSize uint64
  27. frameDone sync.WaitGroup
  28. DictionaryID *uint32
  29. HasCheckSum bool
  30. SingleSegment bool
  31. // asyncRunning indicates whether the async routine processes input on 'decoding'.
  32. asyncRunningMu sync.Mutex
  33. asyncRunning bool
  34. }
  35. const (
  36. // MinWindowSize is the minimum Window Size, which is 1 KB.
  37. MinWindowSize = 1 << 10
  38. // MaxWindowSize is the maximum encoder window size
  39. // and the default decoder maximum window size.
  40. MaxWindowSize = 1 << 29
  41. )
  42. var (
  43. frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
  44. skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
  45. )
  46. func newFrameDec(o decoderOptions) *frameDec {
  47. if o.maxWindowSize > o.maxDecodedSize {
  48. o.maxWindowSize = o.maxDecodedSize
  49. }
  50. d := frameDec{
  51. o: o,
  52. }
  53. return &d
  54. }
  55. // reset will read the frame header and prepare for block decoding.
  56. // If nothing can be read from the input, io.EOF will be returned.
  57. // Any other error indicated that the stream contained data, but
  58. // there was a problem.
  59. func (d *frameDec) reset(br byteBuffer) error {
  60. d.HasCheckSum = false
  61. d.WindowSize = 0
  62. var signature [4]byte
  63. for {
  64. var err error
  65. // Check if we can read more...
  66. b, err := br.readSmall(1)
  67. switch err {
  68. case io.EOF, io.ErrUnexpectedEOF:
  69. return io.EOF
  70. default:
  71. return err
  72. case nil:
  73. signature[0] = b[0]
  74. }
  75. // Read the rest, don't allow io.ErrUnexpectedEOF
  76. b, err = br.readSmall(3)
  77. switch err {
  78. case io.EOF:
  79. return io.EOF
  80. default:
  81. return err
  82. case nil:
  83. copy(signature[1:], b)
  84. }
  85. if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 {
  86. if debugDecoder {
  87. println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic))
  88. }
  89. // Break if not skippable frame.
  90. break
  91. }
  92. // Read size to skip
  93. b, err = br.readSmall(4)
  94. if err != nil {
  95. if debugDecoder {
  96. println("Reading Frame Size", err)
  97. }
  98. return err
  99. }
  100. n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  101. println("Skipping frame with", n, "bytes.")
  102. err = br.skipN(int(n))
  103. if err != nil {
  104. if debugDecoder {
  105. println("Reading discarded frame", err)
  106. }
  107. return err
  108. }
  109. }
  110. if !bytes.Equal(signature[:], frameMagic) {
  111. if debugDecoder {
  112. println("Got magic numbers: ", signature, "want:", frameMagic)
  113. }
  114. return ErrMagicMismatch
  115. }
  116. // Read Frame_Header_Descriptor
  117. fhd, err := br.readByte()
  118. if err != nil {
  119. if debugDecoder {
  120. println("Reading Frame_Header_Descriptor", err)
  121. }
  122. return err
  123. }
  124. d.SingleSegment = fhd&(1<<5) != 0
  125. if fhd&(1<<3) != 0 {
  126. return errors.New("reserved bit set on frame header")
  127. }
  128. // Read Window_Descriptor
  129. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
  130. d.WindowSize = 0
  131. if !d.SingleSegment {
  132. wd, err := br.readByte()
  133. if err != nil {
  134. if debugDecoder {
  135. println("Reading Window_Descriptor", err)
  136. }
  137. return err
  138. }
  139. printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
  140. windowLog := 10 + (wd >> 3)
  141. windowBase := uint64(1) << windowLog
  142. windowAdd := (windowBase / 8) * uint64(wd&0x7)
  143. d.WindowSize = windowBase + windowAdd
  144. }
  145. // Read Dictionary_ID
  146. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
  147. d.DictionaryID = nil
  148. if size := fhd & 3; size != 0 {
  149. if size == 3 {
  150. size = 4
  151. }
  152. b, err := br.readSmall(int(size))
  153. if err != nil {
  154. println("Reading Dictionary_ID", err)
  155. return err
  156. }
  157. var id uint32
  158. switch size {
  159. case 1:
  160. id = uint32(b[0])
  161. case 2:
  162. id = uint32(b[0]) | (uint32(b[1]) << 8)
  163. case 4:
  164. id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  165. }
  166. if debugDecoder {
  167. println("Dict size", size, "ID:", id)
  168. }
  169. if id > 0 {
  170. // ID 0 means "sorry, no dictionary anyway".
  171. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
  172. d.DictionaryID = &id
  173. }
  174. }
  175. // Read Frame_Content_Size
  176. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
  177. var fcsSize int
  178. v := fhd >> 6
  179. switch v {
  180. case 0:
  181. if d.SingleSegment {
  182. fcsSize = 1
  183. }
  184. default:
  185. fcsSize = 1 << v
  186. }
  187. d.FrameContentSize = 0
  188. if fcsSize > 0 {
  189. b, err := br.readSmall(fcsSize)
  190. if err != nil {
  191. println("Reading Frame content", err)
  192. return err
  193. }
  194. switch fcsSize {
  195. case 1:
  196. d.FrameContentSize = uint64(b[0])
  197. case 2:
  198. // When FCS_Field_Size is 2, the offset of 256 is added.
  199. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
  200. case 4:
  201. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
  202. case 8:
  203. d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  204. d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
  205. d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
  206. }
  207. if debugDecoder {
  208. println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
  209. }
  210. }
  211. // Move this to shared.
  212. d.HasCheckSum = fhd&(1<<2) != 0
  213. if d.HasCheckSum {
  214. if d.crc == nil {
  215. d.crc = xxhash.New()
  216. }
  217. d.crc.Reset()
  218. }
  219. if d.WindowSize == 0 && d.SingleSegment {
  220. // We may not need window in this case.
  221. d.WindowSize = d.FrameContentSize
  222. if d.WindowSize < MinWindowSize {
  223. d.WindowSize = MinWindowSize
  224. }
  225. }
  226. if d.WindowSize > uint64(d.o.maxWindowSize) {
  227. if debugDecoder {
  228. printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
  229. }
  230. return ErrWindowSizeExceeded
  231. }
  232. // The minimum Window_Size is 1 KB.
  233. if d.WindowSize < MinWindowSize {
  234. if debugDecoder {
  235. println("got window size: ", d.WindowSize)
  236. }
  237. return ErrWindowSizeTooSmall
  238. }
  239. d.history.windowSize = int(d.WindowSize)
  240. if d.o.lowMem && d.history.windowSize < maxBlockSize {
  241. d.history.maxSize = d.history.windowSize * 2
  242. } else {
  243. d.history.maxSize = d.history.windowSize + maxBlockSize
  244. }
  245. // history contains input - maybe we do something
  246. d.rawInput = br
  247. return nil
  248. }
  249. // next will start decoding the next block from stream.
  250. func (d *frameDec) next(block *blockDec) error {
  251. if debugDecoder {
  252. printf("decoding new block %p:%p", block, block.data)
  253. }
  254. err := block.reset(d.rawInput, d.WindowSize)
  255. if err != nil {
  256. println("block error:", err)
  257. // Signal the frame decoder we have a problem.
  258. d.sendErr(block, err)
  259. return err
  260. }
  261. block.input <- struct{}{}
  262. if debugDecoder {
  263. println("next block:", block)
  264. }
  265. d.asyncRunningMu.Lock()
  266. defer d.asyncRunningMu.Unlock()
  267. if !d.asyncRunning {
  268. return nil
  269. }
  270. if block.Last {
  271. // We indicate the frame is done by sending io.EOF
  272. d.decoding <- block
  273. return io.EOF
  274. }
  275. d.decoding <- block
  276. return nil
  277. }
  278. // sendEOF will queue an error block on the frame.
  279. // This will cause the frame decoder to return when it encounters the block.
  280. // Returns true if the decoder was added.
  281. func (d *frameDec) sendErr(block *blockDec, err error) bool {
  282. d.asyncRunningMu.Lock()
  283. defer d.asyncRunningMu.Unlock()
  284. if !d.asyncRunning {
  285. return false
  286. }
  287. println("sending error", err.Error())
  288. block.sendErr(err)
  289. d.decoding <- block
  290. return true
  291. }
  292. // checkCRC will check the checksum if the frame has one.
  293. // Will return ErrCRCMismatch if crc check failed, otherwise nil.
  294. func (d *frameDec) checkCRC() error {
  295. if !d.HasCheckSum {
  296. return nil
  297. }
  298. var tmp [4]byte
  299. got := d.crc.Sum64()
  300. // Flip to match file order.
  301. tmp[0] = byte(got >> 0)
  302. tmp[1] = byte(got >> 8)
  303. tmp[2] = byte(got >> 16)
  304. tmp[3] = byte(got >> 24)
  305. // We can overwrite upper tmp now
  306. want, err := d.rawInput.readSmall(4)
  307. if err != nil {
  308. println("CRC missing?", err)
  309. return err
  310. }
  311. if !bytes.Equal(tmp[:], want) {
  312. if debugDecoder {
  313. println("CRC Check Failed:", tmp[:], "!=", want)
  314. }
  315. return ErrCRCMismatch
  316. }
  317. if debugDecoder {
  318. println("CRC ok", tmp[:])
  319. }
  320. return nil
  321. }
  322. func (d *frameDec) initAsync() {
  323. if !d.o.lowMem && !d.SingleSegment {
  324. // set max extra size history to 2MB.
  325. d.history.maxSize = d.history.windowSize + maxBlockSize
  326. }
  327. // re-alloc if more than one extra block size.
  328. if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
  329. d.history.b = make([]byte, 0, d.history.maxSize)
  330. }
  331. if cap(d.history.b) < d.history.maxSize {
  332. d.history.b = make([]byte, 0, d.history.maxSize)
  333. }
  334. if cap(d.decoding) < d.o.concurrent {
  335. d.decoding = make(chan *blockDec, d.o.concurrent)
  336. }
  337. if debugDecoder {
  338. h := d.history
  339. printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
  340. }
  341. d.asyncRunningMu.Lock()
  342. d.asyncRunning = true
  343. d.asyncRunningMu.Unlock()
  344. }
  345. // startDecoder will start decoding blocks and write them to the writer.
  346. // The decoder will stop as soon as an error occurs or at end of frame.
  347. // When the frame has finished decoding the *bufio.Reader
  348. // containing the remaining input will be sent on frameDec.frameDone.
  349. func (d *frameDec) startDecoder(output chan decodeOutput) {
  350. written := int64(0)
  351. defer func() {
  352. d.asyncRunningMu.Lock()
  353. d.asyncRunning = false
  354. d.asyncRunningMu.Unlock()
  355. // Drain the currently decoding.
  356. d.history.error = true
  357. flushdone:
  358. for {
  359. select {
  360. case b := <-d.decoding:
  361. b.history <- &d.history
  362. output <- <-b.result
  363. default:
  364. break flushdone
  365. }
  366. }
  367. println("frame decoder done, signalling done")
  368. d.frameDone.Done()
  369. }()
  370. // Get decoder for first block.
  371. block := <-d.decoding
  372. block.history <- &d.history
  373. for {
  374. var next *blockDec
  375. // Get result
  376. r := <-block.result
  377. if r.err != nil {
  378. println("Result contained error", r.err)
  379. output <- r
  380. return
  381. }
  382. if debugDecoder {
  383. println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
  384. d.offset += int64(len(r.b))
  385. }
  386. if !block.Last {
  387. // Send history to next block
  388. select {
  389. case next = <-d.decoding:
  390. if debugDecoder {
  391. println("Sending ", len(d.history.b), "bytes as history")
  392. }
  393. next.history <- &d.history
  394. default:
  395. // Wait until we have sent the block, so
  396. // other decoders can potentially get the decoder.
  397. next = nil
  398. }
  399. }
  400. // Add checksum, async to decoding.
  401. if d.HasCheckSum {
  402. n, err := d.crc.Write(r.b)
  403. if err != nil {
  404. r.err = err
  405. if n != len(r.b) {
  406. r.err = io.ErrShortWrite
  407. }
  408. output <- r
  409. return
  410. }
  411. }
  412. written += int64(len(r.b))
  413. if d.SingleSegment && uint64(written) > d.FrameContentSize {
  414. println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
  415. r.err = ErrFrameSizeExceeded
  416. output <- r
  417. return
  418. }
  419. if block.Last {
  420. r.err = d.checkCRC()
  421. output <- r
  422. return
  423. }
  424. output <- r
  425. if next == nil {
  426. // There was no decoder available, we wait for one now that we have sent to the writer.
  427. if debugDecoder {
  428. println("Sending ", len(d.history.b), " bytes as history")
  429. }
  430. next = <-d.decoding
  431. next.history <- &d.history
  432. }
  433. block = next
  434. }
  435. }
  436. // runDecoder will create a sync decoder that will decode a block of data.
  437. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
  438. saved := d.history.b
  439. // We use the history for output to avoid copying it.
  440. d.history.b = dst
  441. // Store input length, so we only check new data.
  442. crcStart := len(dst)
  443. var err error
  444. for {
  445. err = dec.reset(d.rawInput, d.WindowSize)
  446. if err != nil {
  447. break
  448. }
  449. if debugDecoder {
  450. println("next block:", dec)
  451. }
  452. err = dec.decodeBuf(&d.history)
  453. if err != nil || dec.Last {
  454. break
  455. }
  456. if uint64(len(d.history.b)) > d.o.maxDecodedSize {
  457. err = ErrDecoderSizeExceeded
  458. break
  459. }
  460. if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
  461. println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
  462. err = ErrFrameSizeExceeded
  463. break
  464. }
  465. }
  466. dst = d.history.b
  467. if err == nil {
  468. if d.HasCheckSum {
  469. var n int
  470. n, err = d.crc.Write(dst[crcStart:])
  471. if err == nil {
  472. if n != len(dst)-crcStart {
  473. err = io.ErrShortWrite
  474. } else {
  475. err = d.checkCRC()
  476. }
  477. }
  478. }
  479. }
  480. d.history.b = saved
  481. return dst, err
  482. }