Skip to content

Commit

Permalink
fix: compress (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
fracasula authored Sep 3, 2024
1 parent 3c43022 commit ae791c9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 56 deletions.
78 changes: 34 additions & 44 deletions compress/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ func (c CompressionAlgorithm) String() string {
}
}

func NewCompressionAlgorithm(s string) (CompressionAlgorithm, error) {
switch s {
case "zstd":
return CompressionAlgoZstd, nil
case "zstd-cgo":
return CompressionAlgoZstdCgo, nil
default:
return 0, fmt.Errorf("unknown compression algorithm: %s", s)
}
}

// CompressionLevel is the interface that wraps the compression level method.
type CompressionLevel int

Expand All @@ -50,18 +39,34 @@ func (c CompressionLevel) String() string {
}
}

func NewCompressionLevel(s string) (CompressionLevel, error) {
switch s {
case "fastest":
return CompressionLevelZstdFastest, nil
case "default":
return CompressionLevelZstdDefault, nil
case "better":
return CompressionLevelZstdBetter, nil
case "best":
return CompressionLevelZstdBest, nil
func NewSettings(algo, level string) (CompressionAlgorithm, CompressionLevel, error) {
switch algo {
case "zstd":
switch level {
case "fastest":
return CompressionAlgoZstd, CompressionLevelZstdFastest, nil
case "default":
return CompressionAlgoZstd, CompressionLevelZstdDefault, nil
case "better":
return CompressionAlgoZstd, CompressionLevelZstdBetter, nil
case "best":
return CompressionAlgoZstd, CompressionLevelZstdBest, nil
default:
return 0, 0, fmt.Errorf("unknown compression level for %s: %s", algo, level)
}
case "zstd-cgo":
switch level {
case "fastest":
return CompressionAlgoZstdCgo, CompressionLevelZstdCgoFastest, nil
case "default":
return CompressionAlgoZstdCgo, CompressionLevelZstdCgoDefault, nil
case "best":
return CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest, nil
default:
return 0, 0, fmt.Errorf("unknown compression level for %s: %s", algo, level)
}
default:
return 0, fmt.Errorf("unknown compression level: %s", s)
return 0, 0, fmt.Errorf("unknown compression algorithm: %s", algo)
}
}

Expand All @@ -79,17 +84,14 @@ var (
)

func New(algo CompressionAlgorithm, level CompressionLevel) (*Compressor, error) {
var err error
algo, level, err = NewSettings(algo.String(), level.String())
if err != nil {
return nil, err
}

switch algo {
case CompressionAlgoZstd:
switch level {
case CompressionLevelZstdFastest,
CompressionLevelZstdDefault,
CompressionLevelZstdBetter,
CompressionLevelZstdBest:
default:
return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level)
}

encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
if err != nil {
return nil, fmt.Errorf("cannot create zstd encoder: %w", err)
Expand All @@ -105,20 +107,8 @@ func New(algo CompressionAlgorithm, level CompressionLevel) (*Compressor, error)
decoder: decoder,
}}, nil
case CompressionAlgoZstdCgo:
var cgoLevel int
switch level {
case CompressionLevelZstdCgoFastest:
cgoLevel = zstdcgo.BestSpeed
case CompressionLevelZstdCgoDefault:
cgoLevel = zstdcgo.DefaultCompression
case CompressionLevelZstdCgoBest:
cgoLevel = zstdcgo.BestCompression
default:
return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level)
}

return &Compressor{
compressorZstdCgo: &compressorZstdCgo{level: cgoLevel},
compressorZstdCgo: &compressorZstdCgo{level: int(level)},
}, nil
default:
return nil, fmt.Errorf("unknown compression algorithm: %d", algo)
Expand Down
42 changes: 30 additions & 12 deletions compress/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,39 @@ func TestCompress(t *testing.T) {
}

func TestSerialization(t *testing.T) {
algo, err := NewCompressionAlgorithm("zstd")
require.NoError(t, err)
require.Equal(t, CompressionAlgoZstd, algo)
type testCase struct {
algo, level string
expectedSerialized string
expectedAlgo CompressionAlgorithm
expectedLevel CompressionLevel
}
testCases := []testCase{
{"zstd", "fastest", "1:1", CompressionAlgoZstd, CompressionLevelZstdFastest},
{"zstd", "default", "1:2", CompressionAlgoZstd, CompressionLevelZstdDefault},
{"zstd", "better", "1:3", CompressionAlgoZstd, CompressionLevelZstdBetter},
{"zstd", "best", "1:4", CompressionAlgoZstd, CompressionLevelZstdBest},

{"zstd-cgo", "fastest", "2:1", CompressionAlgoZstdCgo, CompressionLevelZstdCgoFastest},
{"zstd-cgo", "default", "2:5", CompressionAlgoZstdCgo, CompressionLevelZstdCgoDefault},
{"zstd-cgo", "best", "2:20", CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest},
}

level, err := NewCompressionLevel("best")
require.NoError(t, err)
require.Equal(t, CompressionLevelZstdBest, level)
for _, tc := range testCases {
t.Run(tc.algo+"-"+tc.level, func(t *testing.T) {
algo, level, err := NewSettings(tc.algo, tc.level)
require.NoError(t, err)
require.Equal(t, tc.expectedAlgo, algo)
require.Equal(t, tc.expectedLevel, level)

serialized := SerializeSettings(algo, level)
require.Equal(t, "1:4", serialized)
serialized := SerializeSettings(algo, level)
require.Equal(t, tc.expectedSerialized, serialized)

algo, level, err = DeserializeSettings(serialized)
require.NoError(t, err)
require.Equal(t, CompressionAlgoZstd, algo)
require.Equal(t, CompressionLevelZstdBest, level)
algo, level, err = DeserializeSettings(serialized)
require.NoError(t, err)
require.Equal(t, tc.expectedAlgo, algo)
require.Equal(t, tc.expectedLevel, level)
})
}
}

func TestDeserializationError(t *testing.T) {
Expand Down

0 comments on commit ae791c9

Please sign in to comment.