diff --git a/internal/db/postgres/transformers/cmd.go b/internal/db/postgres/transformers/cmd.go index e978fb48..c778432f 100644 --- a/internal/db/postgres/transformers/cmd.go +++ b/internal/db/postgres/transformers/cmd.go @@ -455,7 +455,7 @@ func cmdValidateSkipBehaviour(p *toolkit.Parameter, v toolkit.ParamsValue) (tool if value != skipOnAnyName && value != skipOnAllName { return toolkit.ValidationWarnings{ toolkit.NewValidationWarning(). - AddMeta("ParameterName", p.Name). + SetSeverity(toolkit.ErrorValidationSeverity). AddMeta("ParameterValue", value). SetMsg(`unsupported skip_on type: must be one of "all" or "any"`), }, nil diff --git a/internal/db/postgres/transformers/hash.go b/internal/db/postgres/transformers/hash.go index ffccb953..f10209d5 100644 --- a/internal/db/postgres/transformers/hash.go +++ b/internal/db/postgres/transformers/hash.go @@ -16,24 +16,19 @@ package transformers import ( "context" - crand "crypto/rand" - "encoding/base64" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" "fmt" - "slices" - - "golang.org/x/crypto/scrypt" + "hash" + "strconv" "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/pkg/toolkit" ) -// TODO: Make length truncation - -const ( - saltLength = 32 - bufLength = 1024 -) - var HashTransformerDefinition = utils.NewDefinition( utils.NewTransformerProperties( "Hash", @@ -51,17 +46,27 @@ var HashTransformerDefinition = utils.NewDefinition( ).SetRequired(true), toolkit.MustNewParameter( - "salt", - "salt for hash function base64 encoded", - ), + "function", + "hash function name. Possible values sha1, sha256, sha512, md5", + ).SetDefaultValue([]byte("sha1")). + SetRawValueValidator(validateHashFunctionsParameter), + + toolkit.MustNewParameter( + "max_length", + "limit length of hash function result", + ).SetDefaultValue([]byte("0")). + SetRawValueValidator(validateMaxLengthParameter), ) type HashTransformer struct { - salt toolkit.ParamsValue - columnName string - affectedColumns map[int]string - columnIdx int - res []byte + columnName string + affectedColumns map[int]string + columnIdx int + h hash.Hash + maxLength int + encodedOutputLength int + hashBuf []byte + resultBuf []byte } func NewHashTransformer( @@ -80,28 +85,46 @@ func NewHashTransformer( affectedColumns := make(map[int]string) affectedColumns[idx] = columnName - var salt toolkit.ParamsValue - p = parameters["salt"] - var err error - if len(p.RawValue()) > 0 { - salt, err = base64.StdEncoding.DecodeString(string(p.RawValue())) - if err != nil { - return nil, nil, fmt.Errorf("error decoding \"salt\" value from base64: %w", err) - } - } else { - b := make(toolkit.ParamsValue, saltLength) - if _, err := crand.Read(b); err != nil { - return nil, nil, err - } - salt = b + p = parameters["function"] + var hashFunctionName string + if _, err := p.Scan(&hashFunctionName); err != nil { + return nil, nil, fmt.Errorf("unable to scan \"function\" parameter: %w", err) + } + + p = parameters["max_length"] + var maxLength int + if _, err := p.Scan(&maxLength); err != nil { + return nil, nil, fmt.Errorf("unable to scan \"max_length\" parameter: %w", err) + } + + var h hash.Hash + var hashFunctionLength int + switch hashFunctionName { + case "md5": + h = md5.New() + hashFunctionLength = 16 + case "sha1": + h = sha1.New() + hashFunctionLength = 20 + case "sha256": + h = sha256.New() + hashFunctionLength = 32 + case "sha512": + h = sha512.New() + hashFunctionLength = 64 + default: + return nil, nil, fmt.Errorf("unknown hash function \"%s\"", hashFunctionName) } return &HashTransformer{ - salt: salt, - columnName: columnName, - affectedColumns: affectedColumns, - columnIdx: idx, - res: make([]byte, 0, bufLength), + columnName: columnName, + affectedColumns: affectedColumns, + columnIdx: idx, + maxLength: maxLength, + hashBuf: make([]byte, 0, hashFunctionLength), + resultBuf: make([]byte, hex.EncodedLen(hashFunctionLength)), + encodedOutputLength: hex.EncodedLen(hashFunctionLength), + h: h, }, nil, nil } @@ -126,26 +149,56 @@ func (ht *HashTransformer) Transform(ctx context.Context, r *toolkit.Record) (*t return r, nil } - dk, err := scrypt.Key(val.Data, ht.salt, 32768, 8, 1, 32) + defer ht.h.Reset() + _, err = ht.h.Write(val.Data) if err != nil { - return nil, fmt.Errorf("cannot perform hash calculation: %w", err) + return nil, fmt.Errorf("unable to write raw data into writer: %w", err) } + ht.hashBuf = ht.hashBuf[:0] + ht.hashBuf = ht.h.Sum(ht.hashBuf) + + hex.Encode(ht.resultBuf, ht.hashBuf) - length := base64.StdEncoding.EncodedLen(len(dk)) - if len(ht.res) < length { - slices.Grow(ht.res, length) + maxLength := ht.encodedOutputLength + if ht.maxLength > 0 && ht.encodedOutputLength > ht.maxLength { + maxLength = ht.maxLength } - ht.res = ht.res[0:length] - //base64.StdEncoding.EncodeToString(ht.res) - base64.StdEncoding.Encode(ht.res, dk) - if err := r.SetRawColumnValueByIdx(ht.columnIdx, toolkit.NewRawValue(ht.res, false)); err != nil { + if err := r.SetRawColumnValueByIdx(ht.columnIdx, toolkit.NewRawValue(ht.resultBuf[:maxLength], false)); err != nil { return nil, fmt.Errorf("unable to set new value: %w", err) } return r, nil } +func validateHashFunctionsParameter(p *toolkit.Parameter, v toolkit.ParamsValue) (toolkit.ValidationWarnings, error) { + functionName := string(v) + switch functionName { + case "md5", "sha1", "sha256", "sha512": + return nil, nil + } + return toolkit.ValidationWarnings{ + toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + AddMeta("ParameterValue", functionName). + SetMsg(`unknown hash function name`)}, nil +} + +func validateMaxLengthParameter(p *toolkit.Parameter, v toolkit.ParamsValue) (toolkit.ValidationWarnings, error) { + max_length, err := strconv.ParseInt(string(v), 10, 32) + if err != nil { + return nil, fmt.Errorf("error parsing \"max_length\" as integer: %w", err) + } + if max_length >= 0 { + return nil, nil + } + return toolkit.ValidationWarnings{ + toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + AddMeta("ParameterValue", string(v)). + SetMsg(`max_length parameter cannot be less than zero`)}, nil +} + func init() { utils.DefaultTransformerRegistry.MustRegister(HashTransformerDefinition) } diff --git a/internal/db/postgres/transformers/hash_test.go b/internal/db/postgres/transformers/hash_test.go index 4b27d7ec..1bc040a8 100644 --- a/internal/db/postgres/transformers/hash_test.go +++ b/internal/db/postgres/transformers/hash_test.go @@ -23,46 +23,212 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -func TestHashTransformer_Transform(t *testing.T) { - var attrName = "data" - var originalValue = "old_value" - var expectedValue = toolkit.NewValue("jzTVGK2UHz3ERhrYiZDoDzcKeMxSsgxHHgWlL9OrkZ4=", false) - driver, record := getDriverAndRecord(attrName, originalValue) +func TestHashTransformer_Transform_all_functions(t *testing.T) { + columnValue := toolkit.ParamsValue("data") + tests := []struct { + name string + params map[string]toolkit.ParamsValue + original string + result string + }{ + { + name: "md5", + params: map[string]toolkit.ParamsValue{ + "column": columnValue, + "function": []byte("md5"), + }, + original: "123", + result: "202cb962ac59075b964b07152d234b70", + }, + { + name: "sha1", + params: map[string]toolkit.ParamsValue{ + "column": columnValue, + "function": []byte("sha1"), + }, + original: "123", + result: "40bd001563085fc35165329ea1ff5c5ecbdbbeef", + }, + { + name: "sha256", + params: map[string]toolkit.ParamsValue{ + "column": columnValue, + "function": []byte("sha256"), + }, + original: "123", + result: "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3", + }, + { + name: "sha512", + params: map[string]toolkit.ParamsValue{ + "column": columnValue, + "function": []byte("sha512"), + }, + original: "123", + result: "3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, record := getDriverAndRecord(string(tt.params["column"]), tt.original) + transformer, warnings, err := HashTransformerDefinition.Instance( + context.Background(), + driver, tt.params, + nil, + ) + require.NoError(t, err) + require.Empty(t, warnings) + r, err := transformer.Transform( + context.Background(), + record, + ) + require.NoError(t, err) + + res, err := r.GetRawColumnValueByName(string(tt.params["column"])) + require.NoError(t, err) + + require.False(t, res.IsNull) + require.Equal(t, tt.result, string(res.Data)) + + }) + } +} + +func Test_validateHashFunctionsParameter(t *testing.T) { + + tests := []struct { + name string + value []byte + }{ + { + name: "md5", + value: []byte("md5"), + }, + { + name: "sha1", + value: []byte("md5"), + }, + { + name: "sha256", + value: []byte("md5"), + }, + { + name: "sha512", + value: []byte("md5"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + warns, err := validateHashFunctionsParameter(nil, tt.value) + require.NoError(t, err) + require.Empty(t, warns) + }) + } + + t.Run("wrong value", func(t *testing.T) { + warns, err := validateHashFunctionsParameter(nil, []byte("md8")) + require.NoError(t, err) + require.Len(t, warns, 1) + warn := warns[0] + require.Equal(t, toolkit.ErrorValidationSeverity, warn.Severity) + require.Equal(t, "unknown hash function name", warn.Msg) + }) + +} + +func TestHashTransformer_Transform_length_truncation(t *testing.T) { + + params := map[string]toolkit.ParamsValue{ + "column": toolkit.ParamsValue("data"), + "max_length": toolkit.ParamsValue("4"), + "function": toolkit.ParamsValue("sha1"), + } + original := "123" + expected := "40bd" + // Check that internal buffers wipes correctly without data lost + driver, record := getDriverAndRecord(string(params["column"]), original) transformer, warnings, err := HashTransformerDefinition.Instance( context.Background(), - driver, map[string]toolkit.ParamsValue{ - "column": toolkit.ParamsValue(attrName), - "salt": toolkit.ParamsValue("MTIzNDU2Nw=="), - }, + driver, params, nil, ) require.NoError(t, err) require.Empty(t, warnings) - r, err := transformer.Transform( context.Background(), record, ) require.NoError(t, err) - res, err := r.GetColumnValueByName(attrName) + + res, err := r.GetRawColumnValueByName(string(params["column"])) require.NoError(t, err) - require.Equal(t, expectedValue.IsNull, res.IsNull) - require.Equal(t, expectedValue.Value, res.Value) + require.False(t, res.IsNull) + require.Equal(t, expected, string(res.Data)) +} + +func TestHashTransformer_Transform_multiple_iterations(t *testing.T) { + columnValue := toolkit.ParamsValue("data") - originalValue = "123asdasdasdaasdlmaklsdmklamsdlkmalksdmlkamsdlkmalkdmlkasds" - expectedValue = toolkit.NewValue("kZsJbWbVoBGMqniHTCzU6fJrxQdlfeqhYIUxOo3JniA=", false) - _, record = getDriverAndRecord(attrName, originalValue) - r, err = transformer.Transform( + params := map[string]toolkit.ParamsValue{ + "column": toolkit.ParamsValue("data"), + "function": toolkit.ParamsValue("sha1"), + } + original := "123" + // Check that internal buffers wipes correctly without data lost + driver, record := getDriverAndRecord(string(params["column"]), original) + transformer, warnings, err := HashTransformerDefinition.Instance( context.Background(), - record, + driver, params, + nil, ) require.NoError(t, err) - res, err = r.GetColumnValueByName(attrName) - require.NoError(t, err) + require.Empty(t, warnings) + + tests := []struct { + name string + original string + expected string + }{ + { + name: "run1", + original: "123", + expected: "40bd001563085fc35165329ea1ff5c5ecbdbbeef", + }, + { + name: "run2", + original: "456", + expected: "51eac6b471a284d3341d8c0c63d0f1a286262a18", + }, + { + name: "run3", + original: "789", + expected: "fc1200c7a7aa52109d762a9f005b149abef01479", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer record.Row.Encode() + + err = record.Row.Decode([]byte(tt.original)) + require.NoError(t, err) + + _, err = transformer.Transform( + context.Background(), + record, + ) + require.NoError(t, err) + + res, err := record.GetRawColumnValueByName(string(columnValue)) + require.NoError(t, err) - require.Equal(t, expectedValue.IsNull, res.IsNull) - require.Equal(t, expectedValue.Value, res.Value) + require.False(t, res.IsNull) + require.Equal(t, tt.expected, string(res.Data)) + }) + } } diff --git a/pkg/toolkit/parameter.go b/pkg/toolkit/parameter.go index 81b98bd9..9e93754d 100644 --- a/pkg/toolkit/parameter.go +++ b/pkg/toolkit/parameter.go @@ -351,15 +351,15 @@ func (p *Parameter) Init(driver *Driver, types []*Type, params []*Parameter, raw } if p.RawValueValidator != nil { - w, err := p.RawValueValidator(p, p.rawValue) + rawValueValidatorWarns, err := p.RawValueValidator(p, p.rawValue) if err != nil { return nil, fmt.Errorf("error performing parameter raw value validation: %w", err) } - for _, w := range warnings { + for _, w := range rawValueValidatorWarns { w.AddMeta("ParameterName", p.Name) } - warnings = append(warnings, w...) - if w.IsFatal() { + warnings = append(warnings, rawValueValidatorWarns...) + if rawValueValidatorWarns.IsFatal() { return warnings, nil } }