From 7468d73fa46f90526e3b3f59720bf6a226c9b017 Mon Sep 17 00:00:00 2001 From: Trenton Date: Tue, 5 Dec 2023 23:59:13 -0500 Subject: [PATCH] rsa fixes and refactors --- bwt/rsa_bitvector.go | 88 ++++++++++++++++++--------------------- bwt/rsa_bitvector_test.go | 51 ++++++++++++++++------- bwt/wavelet_test.go | 2 - 3 files changed, 76 insertions(+), 65 deletions(-) diff --git a/bwt/rsa_bitvector.go b/bwt/rsa_bitvector.go index 5212fbbe..c33374b3 100644 --- a/bwt/rsa_bitvector.go +++ b/bwt/rsa_bitvector.go @@ -5,68 +5,60 @@ import "math/bits" // TODO: doc what rsa is, why these DSAs, and why we take in a bit vector // TODO: clarks select type rsaBitVector struct { - bv bitvector - totalOnesRank int - jrc []chunk - jrBitsPerChunk int - jrBitsPerSubChunk int - oneSelectMap map[int]int - zeroSelectMap map[int]int + bv bitvector + totalOnesRank int + jrc []chunk + jrSubChunksPerChunk int + jrBitsPerChunk int + jrBitsPerSubChunk int + oneSelectMap map[int]int + zeroSelectMap map[int]int } // TODO: talk about why bv should never be modidifed after building the RSA bit vector func newRSABitVectorFromBitVector(bv bitvector) rsaBitVector { - jacobsonRankChunks, jrBitsPerChunk, jrBitsPerSubChunk, totalOnesRank := buildJacobsonRank(bv) + jacobsonRankChunks, jrSubChunksPerChunk, jrBitsPerSubChunk, totalOnesRank := buildJacobsonRank(bv) ones, zeros := buildSelectMaps(bv) return rsaBitVector{ - bv: bv, - totalOnesRank: totalOnesRank, - jrc: jacobsonRankChunks, - jrBitsPerChunk: jrBitsPerChunk, - jrBitsPerSubChunk: jrBitsPerSubChunk, - oneSelectMap: ones, - zeroSelectMap: zeros, + bv: bv, + totalOnesRank: totalOnesRank, + jrc: jacobsonRankChunks, + jrSubChunksPerChunk: jrSubChunksPerChunk, + jrBitsPerChunk: jrSubChunksPerChunk * jrBitsPerSubChunk, + jrBitsPerSubChunk: jrBitsPerSubChunk, + oneSelectMap: ones, + zeroSelectMap: zeros, } } func (rsa rsaBitVector) Rank(val bool, i int) int { - c := 0 - for j := 0; j < i; j++ { - if rsa.bv.getBit(j) { - c++ + if i > rsa.bv.len()-1 { + if val { + return rsa.totalOnesRank } + return rsa.bv.len() - rsa.totalOnesRank } + + chunkPos := (i / rsa.jrBitsPerChunk) + chunk := rsa.jrc[chunkPos] + + subChunkPos := (i % rsa.jrBitsPerChunk) / rsa.jrBitsPerSubChunk + subChunk := chunk.subChunks[subChunkPos] + + bitOffset := i % rsa.jrBitsPerSubChunk + + bitSet := rsa.bv.getBitSet(chunkPos*rsa.jrSubChunksPerChunk + subChunkPos) + + shiftRightAmount := uint64(rsa.jrBitsPerSubChunk - bitOffset) if val { - return c + remaining := bitSet >> shiftRightAmount + return chunk.onesCumulativeRank + subChunk.onesCumulativeRank + bits.OnesCount64(remaining) } - return i - c - // if i > rsa.bv.len()-1 { - // if val { - // return rsa.totalOnesRank - // } - // return rsa.bv.len() - rsa.totalOnesRank - // } - // - // chunkPos := (i / rsa.jrBitsPerChunk) - // chunk := rsa.jrc[chunkPos] - // - // subChunkPos := (i % rsa.jrBitsPerChunk) / rsa.jrBitsPerSubChunk - // subChunk := chunk.subChunks[subChunkPos] - // - // bitOffset := i % rsa.jrBitsPerSubChunk - // - // bitSet := rsa.bv.getBitSet(chunkPos*len(rsa.jrc) + subChunkPos) - // - // shiftRightAmount := uint64(rsa.jrBitsPerSubChunk - bitOffset) - // if val { - // remaining := bitSet >> shiftRightAmount - // return chunk.onesCumulativeRank + subChunk.onesCumulativeRank + bits.OnesCount64(remaining) - // } - // remaining := ^bitSet >> shiftRightAmount - // - // // cumulative ranks for 0 should just be the sum of the compliment of cumulative ranks for 1 - // return (chunkPos*rsa.jrBitsPerChunk - chunk.onesCumulativeRank) + (subChunkPos*rsa.jrBitsPerSubChunk - subChunk.onesCumulativeRank) + bits.OnesCount64(remaining) + remaining := ^bitSet >> shiftRightAmount + + // cumulative ranks for 0 should just be the sum of the compliment of cumulative ranks for 1 + return (chunkPos*rsa.jrBitsPerChunk - chunk.onesCumulativeRank) + (subChunkPos*rsa.jrBitsPerSubChunk - subChunk.onesCumulativeRank) + bits.OnesCount64(remaining) } func (rsa rsaBitVector) Select(val bool, rank int) (i int, ok bool) { @@ -130,7 +122,7 @@ func buildJacobsonRank(inBv bitvector) (jacobsonRankChunks []chunk, numOfSubChun }) } - return jacobsonRankChunks, numOfSubChunksPerChunk * wordSize, wordSize, totalRank + return jacobsonRankChunks, numOfSubChunksPerChunk, wordSize, totalRank } // TODO: talk about how this could be improved memory wise. Talk about how clarks select exists, but keeping it "simple for now" but maybe worth diff --git a/bwt/rsa_bitvector_test.go b/bwt/rsa_bitvector_test.go index 0a8f9880..a2c55b7b 100644 --- a/bwt/rsa_bitvector_test.go +++ b/bwt/rsa_bitvector_test.go @@ -164,27 +164,50 @@ func TestRSARank_singleCompleteChunk(t *testing.T) { } func TestRSARank_multipleChunks(t *testing.T) { - numBitsToTruncate := 17 - initialNumberOfBits := wordSize*15 - numBitsToTruncate - rsa := newTestRSAFromWords(initialNumberOfBits, + rsa := newTestRSAFromWords((8*4+3)*64, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, 0x0000000000000000, + 0xffffffffffffffff, 0x0000000000000000, 0xffffffffffffffff, + 0x0000000000000000, 0xffffffffffffffff, 0x0000000000000000, 0xffffffffffffffff, 0x0000000000000000, + 0xffffffffffffffff, 0x0000000000000000, 0xffffffffffffffff, 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, + 0xffffffffffffffff, + 0x0000000000000000, 0xffffffffffffffff, 0x0000000000000000, - 0xffffffffffffffff, // this should end up getting truncated + 0xffffffffffffffff, ) testCases := []rsaRankTestCase{ @@ -207,7 +230,10 @@ func TestRSARank_multipleChunks(t *testing.T) { {true, 832, 448}, {false, 832, 384}, {true, 896, 448}, {false, 896, 448}, - {true, 896 + wordSize - numBitsToTruncate - 1, 448 + wordSize - numBitsToTruncate - 1}, {false, 896 + wordSize - numBitsToTruncate - 1, 448}, + + {true, 1024, 512}, {false, 1024, 512}, + + {true, 2048, 1024}, {false, 2048, 1024}, } for _, tc := range testCases { @@ -310,16 +336,11 @@ func TestRSASelect_notOk(t *testing.T) { func newTestRSAFromWords(sizeInBits int, wordsToCopy ...uint64) rsaBitVector { bv := newBitVector(sizeInBits) - for i := 0; i < len(wordsToCopy); i++ { - w := wordsToCopy[i] - for j := 0; j < 64; j++ { - if i*64+j == sizeInBits { - break - } - mask := uint64(1) << uint64(63-j%64) - bit := w&mask != 0 - bv.setBit(i*64+j, bit) - } + for i := 0; i < sizeInBits; i++ { + w := wordsToCopy[i/64] + mask := uint64(1) << uint64(63-i%64) + bit := w&mask != 0 + bv.setBit(i, bit) } return newRSABitVectorFromBitVector(bv) } diff --git a/bwt/wavelet_test.go b/bwt/wavelet_test.go index 3bae6a0b..f027809f 100644 --- a/bwt/wavelet_test.go +++ b/bwt/wavelet_test.go @@ -1,7 +1,6 @@ package bwt import ( - "fmt" "strings" "testing" ) @@ -182,7 +181,6 @@ func TestWaveletTree_Access_Reconstruction(t *testing.T) { for _, str := range testCases { wt := NewWaveletTreeFromString(str) - fmt.Println(len(str)) actual := "" for i := 0; i < len(str); i++ { actual += string(wt.Access(i))