zstd.go 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. package sarama
  2. import (
  3. "sync"
  4. "github.com/klauspost/compress/zstd"
  5. )
  6. type ZstdEncoderParams struct {
  7. Level int
  8. }
  9. type ZstdDecoderParams struct {
  10. }
  11. var zstdEncMap, zstdDecMap sync.Map
  12. func getEncoder(params ZstdEncoderParams) *zstd.Encoder {
  13. if ret, ok := zstdEncMap.Load(params); ok {
  14. return ret.(*zstd.Encoder)
  15. }
  16. // It's possible to race and create multiple new writers.
  17. // Only one will survive GC after use.
  18. encoderLevel := zstd.SpeedDefault
  19. if params.Level != CompressionLevelDefault {
  20. encoderLevel = zstd.EncoderLevelFromZstd(params.Level)
  21. }
  22. zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true),
  23. zstd.WithEncoderLevel(encoderLevel))
  24. zstdEncMap.Store(params, zstdEnc)
  25. return zstdEnc
  26. }
  27. func getDecoder(params ZstdDecoderParams) *zstd.Decoder {
  28. if ret, ok := zstdDecMap.Load(params); ok {
  29. return ret.(*zstd.Decoder)
  30. }
  31. // It's possible to race and create multiple new readers.
  32. // Only one will survive GC after use.
  33. zstdDec, _ := zstd.NewReader(nil)
  34. zstdDecMap.Store(params, zstdDec)
  35. return zstdDec
  36. }
  37. func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) {
  38. return getDecoder(params).DecodeAll(src, dst)
  39. }
  40. func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) {
  41. return getEncoder(params).EncodeAll(src, dst), nil
  42. }