diff --git a/bwt/bwt.go b/bwt/bwt.go index 519180e5..45875b05 100644 --- a/bwt/bwt.go +++ b/bwt/bwt.go @@ -13,25 +13,48 @@ const nullChar = "$" // allows for sub sequence querying. type BWT struct { skipList []skipEntry - l waveletTree + // TODO: talk about how we would want to remove this in favor of a RLFM and or r-index + l waveletTree + // TODO: Talk about how we can cut way down on memory usage by sampling this in a specific way with the RLFM and or r-index + suffixArray []int } func (bwt BWT) Count(pattern string) int { + searchRange := bwt.lfSearch(pattern) + return searchRange.end - searchRange.start +} + +func (bwt BWT) Locate(pattern string) []int { + searchRange := bwt.lfSearch(pattern) + if searchRange.start >= searchRange.end { + return nil + } + + numOfOffsets := searchRange.end - searchRange.start + offsets := make([]int, numOfOffsets) + for i := 0; i < numOfOffsets; i++ { + offsets[i] = bwt.suffixArray[searchRange.start+i] + } + + return offsets +} + +func (bwt BWT) lfSearch(pattern string) interval { searchRange := interval{start: 0, end: bwt.getLenOfOriginalString()} for i := 0; i < len(pattern); i++ { if searchRange.end-searchRange.start <= 0 { - return 0 + return interval{} } c := pattern[len(pattern)-1-i] skip, ok := bwt.lookupSkip(c) if !ok { - return 0 + return interval{} } searchRange.start = skip.openEndedInterval.start + bwt.l.Rank(c, searchRange.start) searchRange.end = skip.openEndedInterval.start + bwt.l.Rank(c, searchRange.end) } - return searchRange.end - searchRange.start + return searchRange } func (bwt BWT) lookupSkip(c byte) (entry skipEntry, ok bool) { @@ -67,10 +90,13 @@ func New(sequence string) BWT { slices.Sort(prefixArray) + suffixArray := make([]int, len(sequence)) lastColBuilder := strings.Builder{} for i := 0; i < len(prefixArray); i++ { currChar := sequence[getBWTIndex(len(sequence), len(prefixArray[i]))] lastColBuilder.WriteByte(currChar) + + suffixArray[i] = len(sequence) - len(prefixArray[i]) } fb := strings.Builder{} for i := 0; i < len(prefixArray); i++ { @@ -78,8 +104,9 @@ func New(sequence string) BWT { } return BWT{ - skipList: buildSkipList(prefixArray), - l: NewWaveletTreeFromString(lastColBuilder.String()), + skipList: buildSkipList(prefixArray), + l: NewWaveletTreeFromString(lastColBuilder.String()), + suffixArray: suffixArray, } } diff --git a/bwt/bwt_test.go b/bwt/bwt_test.go index d9039bd8..fd3d2cda 100644 --- a/bwt/bwt_test.go +++ b/bwt/bwt_test.go @@ -3,6 +3,8 @@ package bwt import ( "strings" "testing" + + "golang.org/x/exp/slices" ) type BWTCountTestCase struct { @@ -10,12 +12,11 @@ type BWTCountTestCase struct { expected int } -const augmentedQuickBrownFoxTest = "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" - -var threeAugmentedQuickBrownFoxTest = strings.Join([]string{augmentedQuickBrownFoxTest, augmentedQuickBrownFoxTest, augmentedQuickBrownFoxTest}, "") - func TestBWT_Count(t *testing.T) { - bwt := New(threeAugmentedQuickBrownFoxTest) + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt := New(testStr) testTable := []BWTCountTestCase{ {"uick", 3}, @@ -38,6 +39,44 @@ func TestBWT_Count(t *testing.T) { } } +type BWTLocateTestCase struct { + seq string + expected []int +} + +func TestBWT_Locate(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" // len == 112 + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt := New(testStr) + + testTable := []BWTLocateTestCase{ + {"uick", []int{4, 117, 230}}, + {"the", []int{0, 25, 113, 138, 226, 251}}, + {"over", []int{21, 41, 134, 154, 247, 267}}, + {"own", []int{10, 48, 106, 110, 123, 161, 219, 223, 236, 274, 332, 336}}, + {"ana", []int{87, 89, 200, 202, 313, 315}}, + {"an", []int{39, 87, 89, 152, 200, 202, 265, 313, 315}}, + {"na", []int{50, 88, 90, 163, 201, 203, 276, 314, 316}}, + {"rown", []int{9, 47, 122, 160, 235, 273}}, + {"townthe", []int{109, 222}}, + {"zzz", nil}, + } + + for _, v := range testTable { + offsets := bwt.Locate(v.seq) + slices.Sort(offsets) + if len(offsets) != len(v.expected) { + t.Fatalf("seq=%s expectedOffsets=%v actualOffsets=%v", v.seq, v.expected, offsets) + } + for i := range offsets { + if offsets[i] != v.expected[i] { + t.Fatalf("seq=%s expectedOffsets=%v actualOffsets=%v", v.seq, v.expected, offsets) + } + } + } +} + func BenchmarkBWTBuildPower12(b *testing.B) { base := "!BANANA!" BaseBenchmarkBWTBuild(base, 12, b)