diff --git a/pkg/util/blockutil/block_time_calculator.go b/pkg/util/blockutil/block_time_calculator.go new file mode 100644 index 0000000000..71726fcb14 --- /dev/null +++ b/pkg/util/blockutil/block_time_calculator.go @@ -0,0 +1,61 @@ +package blockutil + +import ( + "errors" + "math" + "time" +) + +type ( + // BlockTimeCalculator calculates block time of a given height. + BlockTimeCalculator struct { + getBlockInterval getBlockIntervalFn + getTipHeight getTipHeightFn + getHistoryBlockTime getHistoryblockTimeFn + } + + getBlockIntervalFn func(uint64) time.Duration + getTipHeightFn func() uint64 + getHistoryblockTimeFn func(uint64) (time.Time, error) +) + +// NewBlockTimeCalculator creates a new BlockTimeCalculator. +func NewBlockTimeCalculator(getBlockInterval getBlockIntervalFn, getTipHeight getTipHeightFn, getHistoryBlockTime getHistoryblockTimeFn) (*BlockTimeCalculator, error) { + if getBlockInterval == nil { + return nil, errors.New("nil getBlockInterval") + } + if getTipHeight == nil { + return nil, errors.New("nil getTipHeight") + } + if getHistoryBlockTime == nil { + return nil, errors.New("nil getHistoryBlockTime") + } + return &BlockTimeCalculator{ + getBlockInterval: getBlockInterval, + getTipHeight: getTipHeight, + getHistoryBlockTime: getHistoryBlockTime, + }, nil +} + +// CalculateBlockTime returns the block time of the given height. +// If the height is in the future, it will predict the block time according to the tip block time and interval. +// If the height is in the past, it will get the block time from indexer. +func (btc *BlockTimeCalculator) CalculateBlockTime(height uint64) (time.Time, error) { + // get block time from indexer if height is in the past + tipHeight := btc.getTipHeight() + if height <= tipHeight { + return btc.getHistoryBlockTime(height) + } + + // predict block time according to tip block time and interval + blockInterval := btc.getBlockInterval(tipHeight) + blockNumer := time.Duration(height - tipHeight) + if blockNumer > math.MaxInt64/blockInterval { + return time.Time{}, errors.New("height overflow") + } + tipBlockTime, err := btc.getHistoryBlockTime(tipHeight) + if err != nil { + return time.Time{}, err + } + return tipBlockTime.Add(blockNumer * blockInterval), nil +} diff --git a/pkg/util/blockutil/block_time_calculator_test.go b/pkg/util/blockutil/block_time_calculator_test.go new file mode 100644 index 0000000000..6cdbdbff3a --- /dev/null +++ b/pkg/util/blockutil/block_time_calculator_test.go @@ -0,0 +1,54 @@ +package blockutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBlockTimeCalculator_CalculateBlockTime(t *testing.T) { + r := require.New(t) + interval := 5 * time.Second + intervalFn := func(h uint64) time.Duration { + return 5 * time.Second + } + tipHeight := uint64(100) + tipHeightF := func() uint64 { return tipHeight } + baseTime, err := time.Parse("2006-01-02T15:04:05.000Z", "2022-01-01T00:00:00.000Z") + r.NoError(err) + historyBlockTimeF := func(height uint64) (time.Time, error) { return baseTime.Add(time.Hour * time.Duration(height)), nil } + btc, err := NewBlockTimeCalculator(intervalFn, tipHeightF, historyBlockTimeF) + r.NoError(err) + + historyWrapper := func(height uint64) time.Time { + t, err := historyBlockTimeF(height) + r.NoError(err) + return t + } + cases := []struct { + name string + height uint64 + want time.Time + errMsg string + }{ + {"height is in the past", tipHeight - 1, historyWrapper(tipHeight - 1), ""}, + {"height is in the past I", tipHeight, historyWrapper(tipHeight), ""}, + {"height is in the future", tipHeight + 1, historyWrapper(tipHeight).Add(interval), ""}, + {"height is in the future I", tipHeight + 2, historyWrapper(tipHeight).Add(2 * interval), ""}, + {"height is not overflow", tipHeight + (1<<63-1)/uint64(interval), historyWrapper(tipHeight).Add((1<<63 - 1) / interval * interval), ""}, + {"height is overflow", tipHeight + (1<<63-1)/uint64(interval) + 1, time.Time{}, "height overflow"}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := btc.CalculateBlockTime(c.height) + if c.errMsg != "" { + r.ErrorContains(err, c.errMsg) + return + } + r.NoError(err) + r.Equal(c.want, got) + }) + } +}