diff --git a/core/mock/headerHandlerStub.go b/core/mock/headerHandlerStub.go deleted file mode 100644 index fe319899..00000000 --- a/core/mock/headerHandlerStub.go +++ /dev/null @@ -1,255 +0,0 @@ -package mock - -import ( - "math/big" - - "github.com/multiversx/mx-chain-core-go/data" -) - -// HeaderHandlerStub - -type HeaderHandlerStub struct { - EpochField uint32 - TimestampField uint64 - GetMiniBlockHeadersWithDstCalled func(destId uint32) map[string]uint32 - GetOrderedCrossMiniblocksWithDstCalled func(destId uint32) []*data.MiniBlockInfo - GetPubKeysBitmapCalled func() []byte - GetSignatureCalled func() []byte - GetRootHashCalled func() []byte - GetRandSeedCalled func() []byte - GetPrevRandSeedCalled func() []byte - GetPrevHashCalled func() []byte - CloneCalled func() data.HeaderHandler - GetChainIDCalled func() []byte - CheckChainIDCalled func(reference []byte) error - GetReservedCalled func() []byte - IsStartOfEpochBlockCalled func() bool -} - -// GetAccumulatedFees - -func (hhs *HeaderHandlerStub) GetAccumulatedFees() *big.Int { - return big.NewInt(0) -} - -// GetDeveloperFees - -func (hhs *HeaderHandlerStub) GetDeveloperFees() *big.Int { - return big.NewInt(0) -} - -// SetAccumulatedFees - -func (hhs *HeaderHandlerStub) SetAccumulatedFees(_ *big.Int) { -} - -// SetDeveloperFees - -func (hhs *HeaderHandlerStub) SetDeveloperFees(_ *big.Int) { -} - -// GetReceiptsHash - -func (hhs *HeaderHandlerStub) GetReceiptsHash() []byte { - return []byte("receipt") -} - -// SetShardID - -func (hhs *HeaderHandlerStub) SetShardID(_ uint32) { -} - -// IsStartOfEpochBlock - -func (hhs *HeaderHandlerStub) IsStartOfEpochBlock() bool { - if hhs.IsStartOfEpochBlockCalled != nil { - return hhs.IsStartOfEpochBlockCalled() - } - - return false -} - -// Clone - -func (hhs *HeaderHandlerStub) Clone() data.HeaderHandler { - return hhs.CloneCalled() -} - -// GetShardID - -func (hhs *HeaderHandlerStub) GetShardID() uint32 { - return 1 -} - -// GetNonce - -func (hhs *HeaderHandlerStub) GetNonce() uint64 { - return 1 -} - -// GetEpoch - -func (hhs *HeaderHandlerStub) GetEpoch() uint32 { - return hhs.EpochField -} - -// GetRound - -func (hhs *HeaderHandlerStub) GetRound() uint64 { - return 1 -} - -// GetTimeStamp - -func (hhs *HeaderHandlerStub) GetTimeStamp() uint64 { - return hhs.TimestampField -} - -// GetRootHash - -func (hhs *HeaderHandlerStub) GetRootHash() []byte { - return hhs.GetRootHashCalled() -} - -// GetPrevHash - -func (hhs *HeaderHandlerStub) GetPrevHash() []byte { - return hhs.GetPrevHashCalled() -} - -// GetPrevRandSeed - -func (hhs *HeaderHandlerStub) GetPrevRandSeed() []byte { - return hhs.GetPrevRandSeedCalled() -} - -// GetRandSeed - -func (hhs *HeaderHandlerStub) GetRandSeed() []byte { - return hhs.GetRandSeedCalled() -} - -// GetPubKeysBitmap - -func (hhs *HeaderHandlerStub) GetPubKeysBitmap() []byte { - return hhs.GetPubKeysBitmapCalled() -} - -// GetSignature - -func (hhs *HeaderHandlerStub) GetSignature() []byte { - return hhs.GetSignatureCalled() -} - -// GetLeaderSignature - -func (hhs *HeaderHandlerStub) GetLeaderSignature() []byte { - return hhs.GetSignatureCalled() -} - -// GetChainID - -func (hhs *HeaderHandlerStub) GetChainID() []byte { - return hhs.GetChainIDCalled() -} - -// GetTxCount - -func (hhs *HeaderHandlerStub) GetTxCount() uint32 { - return 0 -} - -// GetReserved - -func (hhs *HeaderHandlerStub) GetReserved() []byte { - if hhs.GetReservedCalled != nil { - return hhs.GetReservedCalled() - } - - return nil -} - -// SetNonce - -func (hhs *HeaderHandlerStub) SetNonce(_ uint64) { - panic("implement me") -} - -// SetEpoch - -func (hhs *HeaderHandlerStub) SetEpoch(_ uint32) { - panic("implement me") -} - -// SetRound - -func (hhs *HeaderHandlerStub) SetRound(_ uint64) { - panic("implement me") -} - -// SetTimeStamp - -func (hhs *HeaderHandlerStub) SetTimeStamp(_ uint64) { - panic("implement me") -} - -// SetRootHash - -func (hhs *HeaderHandlerStub) SetRootHash(_ []byte) { - panic("implement me") -} - -// SetPrevHash - -func (hhs *HeaderHandlerStub) SetPrevHash(_ []byte) { - panic("implement me") -} - -// SetPrevRandSeed - -func (hhs *HeaderHandlerStub) SetPrevRandSeed(_ []byte) { - panic("implement me") -} - -// SetRandSeed - -func (hhs *HeaderHandlerStub) SetRandSeed(_ []byte) { - panic("implement me") -} - -// SetPubKeysBitmap - -func (hhs *HeaderHandlerStub) SetPubKeysBitmap(_ []byte) { - panic("implement me") -} - -// SetSignature - -func (hhs *HeaderHandlerStub) SetSignature(_ []byte) { - panic("implement me") -} - -// SetLeaderSignature - -func (hhs *HeaderHandlerStub) SetLeaderSignature(_ []byte) { - panic("implement me") -} - -// SetChainID - -func (hhs *HeaderHandlerStub) SetChainID(_ []byte) { - panic("implement me") -} - -// SetTxCount - -func (hhs *HeaderHandlerStub) SetTxCount(_ uint32) { - panic("implement me") -} - -// GetMiniBlockHeadersWithDst - -func (hhs *HeaderHandlerStub) GetMiniBlockHeadersWithDst(destId uint32) map[string]uint32 { - return hhs.GetMiniBlockHeadersWithDstCalled(destId) -} - -// GetOrderedCrossMiniblocksWithDst - -func (hhs *HeaderHandlerStub) GetOrderedCrossMiniblocksWithDst(destId uint32) []*data.MiniBlockInfo { - return hhs.GetOrderedCrossMiniblocksWithDstCalled(destId) -} - -// GetMiniBlockHeadersHashes - -func (hhs *HeaderHandlerStub) GetMiniBlockHeadersHashes() [][]byte { - panic("implement me") -} - -// GetValidatorStatsRootHash - -func (hhs *HeaderHandlerStub) GetValidatorStatsRootHash() []byte { - return []byte("vs root hash") -} - -// SetValidatorStatsRootHash - -func (hhs *HeaderHandlerStub) SetValidatorStatsRootHash(_ []byte) { - panic("implement me") -} - -// IsInterfaceNil returns true if there is no value under the interface -func (hhs *HeaderHandlerStub) IsInterfaceNil() bool { - return hhs == nil -} - -// GetEpochStartMetaHash - -func (hhs *HeaderHandlerStub) GetEpochStartMetaHash() []byte { - panic("implement me") -} - -// GetSoftwareVersion - -func (hhs *HeaderHandlerStub) GetSoftwareVersion() []byte { - return []byte("softwareVersion") -} - -// SetSoftwareVersion - -func (hhs *HeaderHandlerStub) SetSoftwareVersion(_ []byte) { -} diff --git a/data/block/block.go b/data/block/block.go index 5147be5d..cc8a1fd8 100644 --- a/data/block/block.go +++ b/data/block/block.go @@ -2,6 +2,7 @@ package block import ( + "fmt" "math/big" "github.com/multiversx/mx-chain-core-go/data" @@ -581,3 +582,36 @@ func (h *Header) GetAdditionalData() headerVersionData.HeaderAdditionalData { // no extra data for the initial version of shard block header return nil } + +// CheckFieldsForNil checks a predefined set of fields for nil values +func (h *Header) CheckFieldsForNil() error { + if h == nil { + return data.ErrNilPointerReceiver + } + if h.PrevHash == nil { + return fmt.Errorf("%w in Header.PrevHash", data.ErrNilValue) + } + if h.PrevRandSeed == nil { + return fmt.Errorf("%w in Header.PrevRandSeed", data.ErrNilValue) + } + if h.RandSeed == nil { + return fmt.Errorf("%w in Header.RandSeed", data.ErrNilValue) + } + if h.RootHash == nil { + return fmt.Errorf("%w in Header.RootHash", data.ErrNilValue) + } + if h.ChainID == nil { + return fmt.Errorf("%w in Header.ChainID", data.ErrNilValue) + } + if h.SoftwareVersion == nil { + return fmt.Errorf("%w in Header.SoftwareVersion", data.ErrNilValue) + } + if h.AccumulatedFees == nil { + return fmt.Errorf("%w in Header.AccumulatedFees", data.ErrNilValue) + } + if h.DeveloperFees == nil { + return fmt.Errorf("%w in Header.DeveloperFees", data.ErrNilValue) + } + + return nil +} diff --git a/data/block/blockChecks_test.go b/data/block/blockChecks_test.go new file mode 100644 index 00000000..e4d99a80 --- /dev/null +++ b/data/block/blockChecks_test.go @@ -0,0 +1,150 @@ +package block + +import ( + "fmt" + "math/big" + "reflect" + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var headerV1ExceptionFields = []string{ + "Signature", + "LeaderSignature", + "PubKeysBitmap", + "MetaBlockHashes", + "EpochStartMetaHash", + "ReceiptsHash", + "Reserved", +} + +type field struct { + name string + typeValue string + objFieldIndex int +} + +type fieldsChecker interface { + CheckFieldsForNil() error +} + +func prepareFieldsList(object interface{}, fieldNameExceptions ...string) []field { + list := make([]field, 0) + val := reflect.ValueOf(object).Elem() + for i := 0; i < val.NumField(); i++ { + fieldName := val.Type().Field(i).Name + fieldType := fmt.Sprintf("%v", val.Field(i).Type()) + switch fieldType { + case "uint64", "uint32", "int", "string": + continue + case "block.Type", "[]block.MiniBlockHeader", "[]block.PeerChange": + continue + } + if search(fieldName, fieldNameExceptions...) { + continue + } + + list = append(list, field{ + name: fieldName, + typeValue: fieldType, + objFieldIndex: i, + }) + } + + return list +} + +func search(needle string, haystack ...string) bool { + for _, item := range haystack { + if item == needle { + return true + } + } + return false +} + +func populateFieldsWithRandomValue(tb testing.TB, object interface{}, fields []field) { + val := reflect.ValueOf(object) + for counter, f := range fields { + fieldValue := val.Elem().FieldByName(f.name) + + switch f.typeValue { + case "[]uint8": + fieldValue.SetBytes([]byte(fmt.Sprintf("test field %d", counter))) + case "[][]uint8": + fieldValue.Set(reflect.ValueOf([][]byte{ + []byte(fmt.Sprintf("test field1 %d", counter)), + []byte(fmt.Sprintf("test field2 %d", counter)), + })) + case "*big.Int": + fieldValue.Set(reflect.ValueOf(big.NewInt(int64(counter)))) + default: + assert.Fail(tb, "unimplemented field type "+f.typeValue+" for field "+f.name) + } + } +} + +func unsetField(tb testing.TB, object interface{}, f field) { + v := reflect.ValueOf(object) + + fieldValue := v.Elem().FieldByName(f.name) + switch f.typeValue { + case "[]uint8", "[][]uint8", "*big.Int": + fieldValue.Set(reflect.Zero(fieldValue.Type())) + default: + assert.Fail(tb, "unimplemented field type "+f.typeValue+" for field "+f.name) + } +} + +func testField(tb testing.TB, object interface{}, fields []field, fieldIndex int) { + f := fields[fieldIndex] + fmt.Printf(" testing field %s of type %s\n", f.name, f.typeValue) + populateFieldsWithRandomValue(tb, object, fields) + unsetField(tb, object, fields[fieldIndex]) + + checker := object.(fieldsChecker) + err := checker.CheckFieldsForNil() + require.NotNil(tb, err, "should have return a non nil error for nil field %s", f.name) + assert.ErrorIs(tb, err, data.ErrNilValue) +} + +func TestBlockHeader_Checks(t *testing.T) { + t.Parallel() + + t.Run("nil pointer receiver", func(t *testing.T) { + t.Parallel() + + var objectToTest *Header + err := objectToTest.CheckFieldsForNil() + require.NotNil(t, err) + assert.ErrorIs(t, err, data.ErrNilPointerReceiver) + }) + t.Run("test all fields when set", func(t *testing.T) { + t.Parallel() + + objectToTest := &Header{} + + fields := prepareFieldsList(objectToTest, headerV1ExceptionFields...) + assert.NotEmpty(t, fields) + }) + t.Run("test all fields when one is unset", func(t *testing.T) { + t.Parallel() + + objectToTest := &Header{} + + fields := prepareFieldsList(objectToTest, headerV1ExceptionFields...) + assert.NotEmpty(t, fields) + + populateFieldsWithRandomValue(t, objectToTest, fields) + err := objectToTest.CheckFieldsForNil() + require.Nil(t, err) + + fmt.Printf("fields tests on %T\n", objectToTest) + for i := 0; i < len(fields); i++ { + testField(t, objectToTest, fields, i) + } + }) +} diff --git a/data/block/blockV2.go b/data/block/blockV2.go index 1d020c06..0a690e2e 100644 --- a/data/block/blockV2.go +++ b/data/block/blockV2.go @@ -2,6 +2,7 @@ package block import ( + "fmt" "math/big" "github.com/multiversx/mx-chain-core-go/core/check" @@ -629,3 +630,23 @@ func (hv2 *HeaderV2) GetAdditionalData() headerVersionData.HeaderAdditionalData } return additionalVersionData } + +// CheckFieldsForNil checks a predefined set of fields for nil values +func (hv2 *HeaderV2) CheckFieldsForNil() error { + if hv2 == nil { + return data.ErrNilPointerReceiver + } + err := hv2.Header.CheckFieldsForNil() + if err != nil { + return err + } + + if hv2.ScheduledAccumulatedFees == nil { + return fmt.Errorf("%w in HeaderV2.ScheduledAccumulatedFees", data.ErrNilValue) + } + if hv2.ScheduledDeveloperFees == nil { + return fmt.Errorf("%w in HeaderV2.ScheduledDeveloperFees", data.ErrNilValue) + } + + return nil +} diff --git a/data/block/blockV2Checks_test.go b/data/block/blockV2Checks_test.go new file mode 100644 index 00000000..b44e2d43 --- /dev/null +++ b/data/block/blockV2Checks_test.go @@ -0,0 +1,90 @@ +package block + +import ( + "fmt" + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var headerV2ExceptionFields = []string{ + "Header", + "ScheduledRootHash", +} + +func TestBlockHeaderV2_Checks(t *testing.T) { + t.Parallel() + + t.Run("nil pointer receiver", func(t *testing.T) { + t.Parallel() + + var objectToTest *HeaderV2 + err := objectToTest.CheckFieldsForNil() + require.NotNil(t, err) + assert.ErrorIs(t, err, data.ErrNilPointerReceiver) + }) + t.Run("inner header is a nil pointer receiver", func(t *testing.T) { + t.Parallel() + + objectToTest := &HeaderV2{} + err := objectToTest.CheckFieldsForNil() + require.NotNil(t, err) + assert.ErrorIs(t, err, data.ErrNilPointerReceiver) + }) + t.Run("test all fields when set", func(t *testing.T) { + t.Parallel() + + objectToTest := &HeaderV2{} + + fields := prepareFieldsList(objectToTest, headerV1ExceptionFields...) + assert.NotEmpty(t, fields) + }) + t.Run("test all fields when one is unset on inner Header", func(t *testing.T) { + t.Parallel() + + objectToTest := &HeaderV2{ + Header: &Header{}, + } + + fieldsForHeaderV2 := prepareFieldsList(objectToTest, headerV2ExceptionFields...) + assert.NotEmpty(t, fieldsForHeaderV2) + populateFieldsWithRandomValue(t, objectToTest, fieldsForHeaderV2) + + fieldsForHeaderV1 := prepareFieldsList(objectToTest.Header, headerV1ExceptionFields...) + assert.NotEmpty(t, fieldsForHeaderV1) + + populateFieldsWithRandomValue(t, objectToTest.Header, fieldsForHeaderV1) + err := objectToTest.CheckFieldsForNil() + require.Nil(t, err) + + fmt.Printf("fields tests on %T\n", objectToTest.Header) + for i := 0; i < len(fieldsForHeaderV1); i++ { + testField(t, objectToTest.Header, fieldsForHeaderV1, i) + } + }) + t.Run("test all fields when one is unset on HeaderV2", func(t *testing.T) { + t.Parallel() + + objectToTest := &HeaderV2{ + Header: &Header{}, + } + + fieldsForHeaderV1 := prepareFieldsList(objectToTest.Header, headerV1ExceptionFields...) + assert.NotEmpty(t, fieldsForHeaderV1) + populateFieldsWithRandomValue(t, objectToTest.Header, fieldsForHeaderV1) + + fields := prepareFieldsList(objectToTest, headerV2ExceptionFields...) + assert.NotEmpty(t, fields) + + populateFieldsWithRandomValue(t, objectToTest, fields) + err := objectToTest.CheckFieldsForNil() + require.Nil(t, err) + + fmt.Printf("fields tests on %T\n", objectToTest) + for i := 0; i < len(fields); i++ { + testField(t, objectToTest, fields, i) + } + }) +} diff --git a/data/block/metaBlock.go b/data/block/metaBlock.go index 3d78d575..ebaf8da8 100644 --- a/data/block/metaBlock.go +++ b/data/block/metaBlock.go @@ -2,6 +2,7 @@ package block import ( + "fmt" "math/big" "sort" @@ -523,3 +524,45 @@ func (m *MetaBlock) GetAdditionalData() headerVersionData.HeaderAdditionalData { // no extra data for the initial version of meta block header return nil } + +// CheckFieldsForNil checks a predefined set of fields for nil values +func (m *MetaBlock) CheckFieldsForNil() error { + if m == nil { + return data.ErrNilPointerReceiver + } + if m.PrevHash == nil { + return fmt.Errorf("%w in MetaBlock.PrevHash", data.ErrNilValue) + } + if m.PrevRandSeed == nil { + return fmt.Errorf("%w in MetaBlock.PrevRandSeed", data.ErrNilValue) + } + if m.RandSeed == nil { + return fmt.Errorf("%w in MetaBlock.RandSeed", data.ErrNilValue) + } + if m.RootHash == nil { + return fmt.Errorf("%w in MetaBlock.RootHash", data.ErrNilValue) + } + if m.ValidatorStatsRootHash == nil { + return fmt.Errorf("%w in MetaBlock.ValidatorStatsRootHash", data.ErrNilValue) + } + if m.ChainID == nil { + return fmt.Errorf("%w in MetaBlock.ChainID", data.ErrNilValue) + } + if m.SoftwareVersion == nil { + return fmt.Errorf("%w in MetaBlock.SoftwareVersion", data.ErrNilValue) + } + if m.AccumulatedFees == nil { + return fmt.Errorf("%w in MetaBlock.AccumulatedFees", data.ErrNilValue) + } + if m.AccumulatedFeesInEpoch == nil { + return fmt.Errorf("%w in MetaBlock.AccumulatedFeesInEpoch", data.ErrNilValue) + } + if m.DeveloperFees == nil { + return fmt.Errorf("%w in MetaBlock.DeveloperFees", data.ErrNilValue) + } + if m.DevFeesInEpoch == nil { + return fmt.Errorf("%w in MetaBlock.DevFeesInEpoch", data.ErrNilValue) + } + + return nil +} diff --git a/data/block/metaBlockChecks_test.go b/data/block/metaBlockChecks_test.go new file mode 100644 index 00000000..befb3377 --- /dev/null +++ b/data/block/metaBlockChecks_test.go @@ -0,0 +1,59 @@ +package block + +import ( + "fmt" + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var metablockExceptionFields = []string{ + "ShardInfo", + "PeerInfo", + "EpochStart", + "Signature", + "LeaderSignature", + "PubKeysBitmap", + "ReceiptsHash", + "Reserved", +} + +func TestMetaBlockHeader_Checks(t *testing.T) { + t.Parallel() + + t.Run("nil pointer receiver", func(t *testing.T) { + t.Parallel() + + var objectToTest *MetaBlock + err := objectToTest.CheckFieldsForNil() + require.NotNil(t, err) + assert.ErrorIs(t, err, data.ErrNilPointerReceiver) + }) + t.Run("test all fields when set", func(t *testing.T) { + t.Parallel() + + objectToTest := &MetaBlock{} + + fields := prepareFieldsList(objectToTest, headerV1ExceptionFields...) + assert.NotEmpty(t, fields) + }) + t.Run("test all fields when one is unset", func(t *testing.T) { + t.Parallel() + + objectToTest := &MetaBlock{} + + fields := prepareFieldsList(objectToTest, metablockExceptionFields...) + assert.NotEmpty(t, fields) + + populateFieldsWithRandomValue(t, objectToTest, fields) + err := objectToTest.CheckFieldsForNil() + require.Nil(t, err) + + fmt.Printf("fields tests on %T\n", objectToTest) + for i := 0; i < len(fields); i++ { + testField(t, objectToTest, fields, i) + } + }) +} diff --git a/data/interface.go b/data/interface.go index 270bd6bd..8d4fed91 100644 --- a/data/interface.go +++ b/data/interface.go @@ -82,6 +82,7 @@ type HeaderHandler interface { SetAdditionalData(headerVersionData headerVersionData.HeaderAdditionalData) error IsStartOfEpochBlock() bool ShallowClone() HeaderHandler + CheckFieldsForNil() error IsInterfaceNil() bool }