encoder.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. rdebug "runtime/debug"
  10. "sync"
  11. "github.com/klauspost/compress/zstd/internal/xxhash"
  12. )
  13. // Encoder provides encoding to Zstandard.
  14. // An Encoder can be used for either compressing a stream via the
  15. // io.WriteCloser interface supported by the Encoder or as multiple independent
  16. // tasks via the EncodeAll function.
  17. // Smaller encodes are encouraged to use the EncodeAll function.
  18. // Use NewWriter to create a new instance.
  19. type Encoder struct {
  20. o encoderOptions
  21. encoders chan encoder
  22. state encoderState
  23. init sync.Once
  24. }
  25. type encoder interface {
  26. Encode(blk *blockEnc, src []byte)
  27. EncodeNoHist(blk *blockEnc, src []byte)
  28. Block() *blockEnc
  29. CRC() *xxhash.Digest
  30. AppendCRC([]byte) []byte
  31. WindowSize(size int64) int32
  32. UseBlock(*blockEnc)
  33. Reset(d *dict, singleBlock bool)
  34. }
  35. type encoderState struct {
  36. w io.Writer
  37. filling []byte
  38. current []byte
  39. previous []byte
  40. encoder encoder
  41. writing *blockEnc
  42. err error
  43. writeErr error
  44. nWritten int64
  45. nInput int64
  46. frameContentSize int64
  47. headerWritten bool
  48. eofWritten bool
  49. fullFrameWritten bool
  50. // This waitgroup indicates an encode is running.
  51. wg sync.WaitGroup
  52. // This waitgroup indicates we have a block encoding/writing.
  53. wWg sync.WaitGroup
  54. }
  55. // NewWriter will create a new Zstandard encoder.
  56. // If the encoder will be used for encoding blocks a nil writer can be used.
  57. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  58. initPredefined()
  59. var e Encoder
  60. e.o.setDefault()
  61. for _, o := range opts {
  62. err := o(&e.o)
  63. if err != nil {
  64. return nil, err
  65. }
  66. }
  67. if w != nil {
  68. e.Reset(w)
  69. }
  70. return &e, nil
  71. }
  72. func (e *Encoder) initialize() {
  73. if e.o.concurrent == 0 {
  74. e.o.setDefault()
  75. }
  76. e.encoders = make(chan encoder, e.o.concurrent)
  77. for i := 0; i < e.o.concurrent; i++ {
  78. enc := e.o.encoder()
  79. e.encoders <- enc
  80. }
  81. }
  82. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  83. // as a new, independent stream.
  84. func (e *Encoder) Reset(w io.Writer) {
  85. s := &e.state
  86. s.wg.Wait()
  87. s.wWg.Wait()
  88. if cap(s.filling) == 0 {
  89. s.filling = make([]byte, 0, e.o.blockSize)
  90. }
  91. if cap(s.current) == 0 {
  92. s.current = make([]byte, 0, e.o.blockSize)
  93. }
  94. if cap(s.previous) == 0 {
  95. s.previous = make([]byte, 0, e.o.blockSize)
  96. }
  97. if s.encoder == nil {
  98. s.encoder = e.o.encoder()
  99. }
  100. if s.writing == nil {
  101. s.writing = &blockEnc{lowMem: e.o.lowMem}
  102. s.writing.init()
  103. }
  104. s.writing.initNewEncode()
  105. s.filling = s.filling[:0]
  106. s.current = s.current[:0]
  107. s.previous = s.previous[:0]
  108. s.encoder.Reset(e.o.dict, false)
  109. s.headerWritten = false
  110. s.eofWritten = false
  111. s.fullFrameWritten = false
  112. s.w = w
  113. s.err = nil
  114. s.nWritten = 0
  115. s.nInput = 0
  116. s.writeErr = nil
  117. s.frameContentSize = 0
  118. }
  119. // ResetContentSize will reset and set a content size for the next stream.
  120. // If the bytes written does not match the size given an error will be returned
  121. // when calling Close().
  122. // This is removed when Reset is called.
  123. // Sizes <= 0 results in no content size set.
  124. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  125. e.Reset(w)
  126. if size >= 0 {
  127. e.state.frameContentSize = size
  128. }
  129. }
  130. // Write data to the encoder.
  131. // Input data will be buffered and as the buffer fills up
  132. // content will be compressed and written to the output.
  133. // When done writing, use Close to flush the remaining output
  134. // and write CRC if requested.
  135. func (e *Encoder) Write(p []byte) (n int, err error) {
  136. s := &e.state
  137. for len(p) > 0 {
  138. if len(p)+len(s.filling) < e.o.blockSize {
  139. if e.o.crc {
  140. _, _ = s.encoder.CRC().Write(p)
  141. }
  142. s.filling = append(s.filling, p...)
  143. return n + len(p), nil
  144. }
  145. add := p
  146. if len(p)+len(s.filling) > e.o.blockSize {
  147. add = add[:e.o.blockSize-len(s.filling)]
  148. }
  149. if e.o.crc {
  150. _, _ = s.encoder.CRC().Write(add)
  151. }
  152. s.filling = append(s.filling, add...)
  153. p = p[len(add):]
  154. n += len(add)
  155. if len(s.filling) < e.o.blockSize {
  156. return n, nil
  157. }
  158. err := e.nextBlock(false)
  159. if err != nil {
  160. return n, err
  161. }
  162. if debugAsserts && len(s.filling) > 0 {
  163. panic(len(s.filling))
  164. }
  165. }
  166. return n, nil
  167. }
  168. // nextBlock will synchronize and start compressing input in e.state.filling.
  169. // If an error has occurred during encoding it will be returned.
  170. func (e *Encoder) nextBlock(final bool) error {
  171. s := &e.state
  172. // Wait for current block.
  173. s.wg.Wait()
  174. if s.err != nil {
  175. return s.err
  176. }
  177. if len(s.filling) > e.o.blockSize {
  178. return fmt.Errorf("block > maxStoreBlockSize")
  179. }
  180. if !s.headerWritten {
  181. // If we have a single block encode, do a sync compression.
  182. if final && len(s.filling) == 0 && !e.o.fullZero {
  183. s.headerWritten = true
  184. s.fullFrameWritten = true
  185. s.eofWritten = true
  186. return nil
  187. }
  188. if final && len(s.filling) > 0 {
  189. s.current = e.EncodeAll(s.filling, s.current[:0])
  190. var n2 int
  191. n2, s.err = s.w.Write(s.current)
  192. if s.err != nil {
  193. return s.err
  194. }
  195. s.nWritten += int64(n2)
  196. s.nInput += int64(len(s.filling))
  197. s.current = s.current[:0]
  198. s.filling = s.filling[:0]
  199. s.headerWritten = true
  200. s.fullFrameWritten = true
  201. s.eofWritten = true
  202. return nil
  203. }
  204. var tmp [maxHeaderSize]byte
  205. fh := frameHeader{
  206. ContentSize: uint64(s.frameContentSize),
  207. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  208. SingleSegment: false,
  209. Checksum: e.o.crc,
  210. DictID: e.o.dict.ID(),
  211. }
  212. dst, err := fh.appendTo(tmp[:0])
  213. if err != nil {
  214. return err
  215. }
  216. s.headerWritten = true
  217. s.wWg.Wait()
  218. var n2 int
  219. n2, s.err = s.w.Write(dst)
  220. if s.err != nil {
  221. return s.err
  222. }
  223. s.nWritten += int64(n2)
  224. }
  225. if s.eofWritten {
  226. // Ensure we only write it once.
  227. final = false
  228. }
  229. if len(s.filling) == 0 {
  230. // Final block, but no data.
  231. if final {
  232. enc := s.encoder
  233. blk := enc.Block()
  234. blk.reset(nil)
  235. blk.last = true
  236. blk.encodeRaw(nil)
  237. s.wWg.Wait()
  238. _, s.err = s.w.Write(blk.output)
  239. s.nWritten += int64(len(blk.output))
  240. s.eofWritten = true
  241. }
  242. return s.err
  243. }
  244. // Move blocks forward.
  245. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  246. s.nInput += int64(len(s.current))
  247. s.wg.Add(1)
  248. go func(src []byte) {
  249. if debugEncoder {
  250. println("Adding block,", len(src), "bytes, final:", final)
  251. }
  252. defer func() {
  253. if r := recover(); r != nil {
  254. s.err = fmt.Errorf("panic while encoding: %v", r)
  255. rdebug.PrintStack()
  256. }
  257. s.wg.Done()
  258. }()
  259. enc := s.encoder
  260. blk := enc.Block()
  261. enc.Encode(blk, src)
  262. blk.last = final
  263. if final {
  264. s.eofWritten = true
  265. }
  266. // Wait for pending writes.
  267. s.wWg.Wait()
  268. if s.writeErr != nil {
  269. s.err = s.writeErr
  270. return
  271. }
  272. // Transfer encoders from previous write block.
  273. blk.swapEncoders(s.writing)
  274. // Transfer recent offsets to next.
  275. enc.UseBlock(s.writing)
  276. s.writing = blk
  277. s.wWg.Add(1)
  278. go func() {
  279. defer func() {
  280. if r := recover(); r != nil {
  281. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  282. rdebug.PrintStack()
  283. }
  284. s.wWg.Done()
  285. }()
  286. err := errIncompressible
  287. // If we got the exact same number of literals as input,
  288. // assume the literals cannot be compressed.
  289. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  290. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  291. }
  292. switch err {
  293. case errIncompressible:
  294. if debugEncoder {
  295. println("Storing incompressible block as raw")
  296. }
  297. blk.encodeRaw(src)
  298. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  299. case nil:
  300. default:
  301. s.writeErr = err
  302. return
  303. }
  304. _, s.writeErr = s.w.Write(blk.output)
  305. s.nWritten += int64(len(blk.output))
  306. }()
  307. }(s.current)
  308. return nil
  309. }
  310. // ReadFrom reads data from r until EOF or error.
  311. // The return value n is the number of bytes read.
  312. // Any error except io.EOF encountered during the read is also returned.
  313. //
  314. // The Copy function uses ReaderFrom if available.
  315. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  316. if debugEncoder {
  317. println("Using ReadFrom")
  318. }
  319. // Flush any current writes.
  320. if len(e.state.filling) > 0 {
  321. if err := e.nextBlock(false); err != nil {
  322. return 0, err
  323. }
  324. }
  325. e.state.filling = e.state.filling[:e.o.blockSize]
  326. src := e.state.filling
  327. for {
  328. n2, err := r.Read(src)
  329. if e.o.crc {
  330. _, _ = e.state.encoder.CRC().Write(src[:n2])
  331. }
  332. // src is now the unfilled part...
  333. src = src[n2:]
  334. n += int64(n2)
  335. switch err {
  336. case io.EOF:
  337. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  338. if debugEncoder {
  339. println("ReadFrom: got EOF final block:", len(e.state.filling))
  340. }
  341. return n, nil
  342. case nil:
  343. default:
  344. if debugEncoder {
  345. println("ReadFrom: got error:", err)
  346. }
  347. e.state.err = err
  348. return n, err
  349. }
  350. if len(src) > 0 {
  351. if debugEncoder {
  352. println("ReadFrom: got space left in source:", len(src))
  353. }
  354. continue
  355. }
  356. err = e.nextBlock(false)
  357. if err != nil {
  358. return n, err
  359. }
  360. e.state.filling = e.state.filling[:e.o.blockSize]
  361. src = e.state.filling
  362. }
  363. }
  364. // Flush will send the currently written data to output
  365. // and block until everything has been written.
  366. // This should only be used on rare occasions where pushing the currently queued data is critical.
  367. func (e *Encoder) Flush() error {
  368. s := &e.state
  369. if len(s.filling) > 0 {
  370. err := e.nextBlock(false)
  371. if err != nil {
  372. return err
  373. }
  374. }
  375. s.wg.Wait()
  376. s.wWg.Wait()
  377. if s.err != nil {
  378. return s.err
  379. }
  380. return s.writeErr
  381. }
  382. // Close will flush the final output and close the stream.
  383. // The function will block until everything has been written.
  384. // The Encoder can still be re-used after calling this.
  385. func (e *Encoder) Close() error {
  386. s := &e.state
  387. if s.encoder == nil {
  388. return nil
  389. }
  390. err := e.nextBlock(true)
  391. if err != nil {
  392. return err
  393. }
  394. if s.frameContentSize > 0 {
  395. if s.nInput != s.frameContentSize {
  396. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  397. }
  398. }
  399. if e.state.fullFrameWritten {
  400. return s.err
  401. }
  402. s.wg.Wait()
  403. s.wWg.Wait()
  404. if s.err != nil {
  405. return s.err
  406. }
  407. if s.writeErr != nil {
  408. return s.writeErr
  409. }
  410. // Write CRC
  411. if e.o.crc && s.err == nil {
  412. // heap alloc.
  413. var tmp [4]byte
  414. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  415. s.nWritten += 4
  416. }
  417. // Add padding with content from crypto/rand.Reader
  418. if s.err == nil && e.o.pad > 0 {
  419. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  420. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  421. if err != nil {
  422. return err
  423. }
  424. _, s.err = s.w.Write(frame)
  425. }
  426. return s.err
  427. }
  428. // EncodeAll will encode all input in src and append it to dst.
  429. // This function can be called concurrently, but each call will only run on a single goroutine.
  430. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  431. // Encoded blocks can be concatenated and the result will be the combined input stream.
  432. // Data compressed with EncodeAll can be decoded with the Decoder,
  433. // using either a stream or DecodeAll.
  434. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  435. if len(src) == 0 {
  436. if e.o.fullZero {
  437. // Add frame header.
  438. fh := frameHeader{
  439. ContentSize: 0,
  440. WindowSize: MinWindowSize,
  441. SingleSegment: true,
  442. // Adding a checksum would be a waste of space.
  443. Checksum: false,
  444. DictID: 0,
  445. }
  446. dst, _ = fh.appendTo(dst)
  447. // Write raw block as last one only.
  448. var blk blockHeader
  449. blk.setSize(0)
  450. blk.setType(blockTypeRaw)
  451. blk.setLast(true)
  452. dst = blk.appendTo(dst)
  453. }
  454. return dst
  455. }
  456. e.init.Do(e.initialize)
  457. enc := <-e.encoders
  458. defer func() {
  459. // Release encoder reference to last block.
  460. // If a non-single block is needed the encoder will reset again.
  461. e.encoders <- enc
  462. }()
  463. // Use single segments when above minimum window and below 1MB.
  464. single := len(src) < 1<<20 && len(src) > MinWindowSize
  465. if e.o.single != nil {
  466. single = *e.o.single
  467. }
  468. fh := frameHeader{
  469. ContentSize: uint64(len(src)),
  470. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  471. SingleSegment: single,
  472. Checksum: e.o.crc,
  473. DictID: e.o.dict.ID(),
  474. }
  475. // If less than 1MB, allocate a buffer up front.
  476. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  477. dst = make([]byte, 0, len(src))
  478. }
  479. dst, err := fh.appendTo(dst)
  480. if err != nil {
  481. panic(err)
  482. }
  483. // If we can do everything in one block, prefer that.
  484. if len(src) <= maxCompressedBlockSize {
  485. enc.Reset(e.o.dict, true)
  486. // Slightly faster with no history and everything in one block.
  487. if e.o.crc {
  488. _, _ = enc.CRC().Write(src)
  489. }
  490. blk := enc.Block()
  491. blk.last = true
  492. if e.o.dict == nil {
  493. enc.EncodeNoHist(blk, src)
  494. } else {
  495. enc.Encode(blk, src)
  496. }
  497. // If we got the exact same number of literals as input,
  498. // assume the literals cannot be compressed.
  499. err := errIncompressible
  500. oldout := blk.output
  501. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  502. // Output directly to dst
  503. blk.output = dst
  504. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  505. }
  506. switch err {
  507. case errIncompressible:
  508. if debugEncoder {
  509. println("Storing incompressible block as raw")
  510. }
  511. dst = blk.encodeRawTo(dst, src)
  512. case nil:
  513. dst = blk.output
  514. default:
  515. panic(err)
  516. }
  517. blk.output = oldout
  518. } else {
  519. enc.Reset(e.o.dict, false)
  520. blk := enc.Block()
  521. for len(src) > 0 {
  522. todo := src
  523. if len(todo) > e.o.blockSize {
  524. todo = todo[:e.o.blockSize]
  525. }
  526. src = src[len(todo):]
  527. if e.o.crc {
  528. _, _ = enc.CRC().Write(todo)
  529. }
  530. blk.pushOffsets()
  531. enc.Encode(blk, todo)
  532. if len(src) == 0 {
  533. blk.last = true
  534. }
  535. err := errIncompressible
  536. // If we got the exact same number of literals as input,
  537. // assume the literals cannot be compressed.
  538. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  539. err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  540. }
  541. switch err {
  542. case errIncompressible:
  543. if debugEncoder {
  544. println("Storing incompressible block as raw")
  545. }
  546. dst = blk.encodeRawTo(dst, todo)
  547. blk.popOffsets()
  548. case nil:
  549. dst = append(dst, blk.output...)
  550. default:
  551. panic(err)
  552. }
  553. blk.reset(nil)
  554. }
  555. }
  556. if e.o.crc {
  557. dst = enc.AppendCRC(dst)
  558. }
  559. // Add padding with content from crypto/rand.Reader
  560. if e.o.pad > 0 {
  561. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  562. dst, err = skippableFrame(dst, add, rand.Reader)
  563. if err != nil {
  564. panic(err)
  565. }
  566. }
  567. return dst
  568. }