Skip to content

Commit

Permalink
Use seq for node cps
Browse files Browse the repository at this point in the history
  • Loading branch information
nitely committed Dec 11, 2024
1 parent 1d3f085 commit 3eb89a9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 48 deletions.
31 changes: 29 additions & 2 deletions src/regex/common.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import std/unicode
import std/strutils
import std/algorithm

type
RegexError* = object of ValueError
Expand All @@ -23,10 +24,10 @@ func toRune*(c: char): Rune =
result = Rune(c.ord)

func `<=`*(x, y: Rune): bool =
x.int <= y.int
x.int32 <= y.int32

func cmp*(x, y: Rune): int =
x.int - y.int
x.int32 - y.int32

func bwRuneAt*(s: string, n: int): Rune =
## Take rune ending at ``n``
Expand Down Expand Up @@ -106,3 +107,29 @@ func verifyUtf8*(s: string): int =
inc i
if state == vusStart:
result = -1

type
SortedSeq*[T] = object
s: seq[T]

func initSortedSeq*[T]: SortedSeq[T] {.inline.} =
SortedSeq[T](s: newSeq[T]())

#func toSeq*[T](s: SortedSeq[T]): seq[T] =
# result = s.s

func len*[T](s: SortedSeq[T]): int {.inline.} =
s.s.len

func add*[T](s: var SortedSeq[T], x: openArray[T]) =
if x.len == 0:
return
s.s.add x
sort s.s, cmp

func contains*[T](s: SortedSeq[T], x: T): bool =
binarySearch(s.s, x, cmp) != -1

iterator items*[T](s: SortedSeq[T]): T {.inline.} =
for i in 0 .. s.s.len-1:
yield s.s[i]
12 changes: 6 additions & 6 deletions src/regex/exptransformation.nim
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ func applyFlag(n: var Node, f: Flag) =
# todo: apply recursevely to
# shorthands of reInSet/reNotSet (i.e: [:ascii:])
if n.kind in {reInSet, reNotSet}:
var cps = initHashSet[Rune](2)
cps.incl(n.cps)
for cp in cps:
let cpsc = cp.swapCase()
if cp != cpsc:
n.cps.incl(cpsc)
var cps = newSeq[Rune]()
for cp in items n.cps:
let cp2 = cp.swapCase()
if cp != cp2:
cps.add cp2
n.cps.add cps
for sl in n.ranges[0 .. ^1]:
let
cpa = sl.a.swapCase()
Expand Down
1 change: 0 additions & 1 deletion src/regex/nodematch.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import std/unicode except `==`
import std/sets

import pkg/unicodedb/properties
import pkg/unicodedb/types as utypes
Expand Down
64 changes: 37 additions & 27 deletions src/regex/parser.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import std/unicode
import std/strutils
import std/sets
import std/parseutils
import std/sequtils

import pkg/unicodedb/properties

Expand Down Expand Up @@ -291,69 +292,79 @@ func parseAsciiSet(sc: Scanner[Rune]): Node =
break
name.add(r.toUTF8)
prettyCheck(
sc.peek == ']'.toRune,
"Invalid ascii set. Expected [:name:]")
sc.peek == ']'.toRune, "Invalid ascii set. Expected [:name:]"
)
discard sc.next
case name
of "alpha":
result.ranges.add([
'a'.toRune .. 'z'.toRune,
'A'.toRune .. 'Z'.toRune])
'A'.toRune .. 'Z'.toRune
])
of "alnum":
result.ranges.add([
'0'.toRune .. '9'.toRune,
'a'.toRune .. 'z'.toRune,
'A'.toRune .. 'Z'.toRune])
'A'.toRune .. 'Z'.toRune
])
of "ascii":
result.ranges.add(
'\x00'.toRune .. '\x7F'.toRune)
'\x00'.toRune .. '\x7F'.toRune
)
of "blank":
result.cps.incl(toHashSet([
'\t'.toRune, ' '.toRune]))
result.cps.add(['\t'.toRune, ' '.toRune])
of "cntrl":
result.ranges.add(
'\x00'.toRune .. '\x1F'.toRune)
result.cps.incl('\x7F'.toRune)
'\x00'.toRune .. '\x1F'.toRune
)
result.cps.add(['\x7F'.toRune])
of "digit":
result.ranges.add(
'0'.toRune .. '9'.toRune)
'0'.toRune .. '9'.toRune
)
of "graph":
result.ranges.add(
'!'.toRune .. '~'.toRune)
'!'.toRune .. '~'.toRune
)
of "lower":
result.ranges.add(
'a'.toRune .. 'z'.toRune)
'a'.toRune .. 'z'.toRune
)
of "print":
result.ranges.add(
' '.toRune .. '~'.toRune)
' '.toRune .. '~'.toRune
)
of "punct":
result.ranges.add([
'!'.toRune .. '/'.toRune,
':'.toRune .. '@'.toRune,
'['.toRune .. '`'.toRune,
'{'.toRune .. '~'.toRune])
'{'.toRune .. '~'.toRune
])
of "space":
result.cps.incl(toHashSet([
result.cps.add([
'\t'.toRune, '\L'.toRune, '\v'.toRune,
'\f'.toRune, '\r'.toRune, ' '.toRune]))
'\f'.toRune, '\r'.toRune, ' '.toRune
])
of "upper":
result.ranges.add(
'A'.toRune .. 'Z'.toRune)
result.ranges.add('A'.toRune .. 'Z'.toRune)
of "word":
result.ranges.add([
'0'.toRune .. '9'.toRune,
'a'.toRune .. 'z'.toRune,
'A'.toRune .. 'Z'.toRune])
result.cps.incl('_'.toRune)
'A'.toRune .. 'Z'.toRune
])
result.cps.add(['_'.toRune])
of "xdigit":
result.ranges.add([
'0'.toRune .. '9'.toRune,
'a'.toRune .. 'f'.toRune,
'A'.toRune .. 'F'.toRune])
'A'.toRune .. 'F'.toRune
])
else:
prettyCheck(
false,
"Invalid ascii set. `$#` is not a valid name" %% name)
false, "Invalid ascii set. `$#` is not a valid name" %% name
)

func parseSet(sc: Scanner[Rune]): Node =
## parse a set atom (i.e ``[a-z]``) into a
Expand Down Expand Up @@ -430,11 +441,10 @@ func parseSet(sc: Scanner[Rune]): Node =
cps.add(cp)
else:
cps.add(cp)
# todo: use ref and set to nil when empty
result.cps.incl(cps.toHashSet)
result.cps.add toSeq(cps.toHashSet)
prettyCheck(
hasEnd,
"Invalid set. Missing `]`")
hasEnd, "Invalid set. Missing `]`"
)

func noRepeatCheck(sc: Scanner[Rune]) =
## Check next symbol is not a repetition
Expand Down
18 changes: 6 additions & 12 deletions src/regex/types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
{.used.}

import std/unicode
import std/sets
from std/algorithm import sorted
from std/sequtils import toSeq

import pkg/unicodedb/properties
Expand Down Expand Up @@ -112,7 +110,7 @@ type
# reRepRange
min*, max*: int16
# reInSet, reNotSet
cps*: HashSet[Rune]
cps*: SortedSeq[Rune]
ranges*: seq[Slice[Rune]] # todo: interval tree
shorthands*: seq[Node]
# reUCC, reNotUCC
Expand Down Expand Up @@ -148,9 +146,10 @@ template initSetNodeImpl(result: var Node, k: NodeKind) =
result = Node(
kind: k,
cp: '#'.toRune,
cps: initHashSet[Rune](2),
cps: initSortedSeq[Rune](),
ranges: @[],
shorthands: @[])
shorthands: @[]
)

func initSetNode*(): Node =
## return a set ``Node``,
Expand Down Expand Up @@ -193,7 +192,8 @@ func isEmpty*(n: Node): bool =
result = (
n.cps.len == 0 and
n.ranges.len == 0 and
n.shorthands.len == 0)
n.shorthands.len == 0
)

const
opKind* = {
Expand Down Expand Up @@ -317,13 +317,7 @@ func `$`*(n: Node): string =
str.add '['
if n.kind == reNotSet:
str.add '^'
var
cps = newSeq[Rune](n.cps.len)
i = 0
for cp in n.cps:
cps[i] = cp
inc i
for cp in cps.sorted(cmp):
str.add $cp
for sl in n.ranges:
str.add($sl.a & '-' & $sl.b)
Expand Down

0 comments on commit 3eb89a9

Please sign in to comment.