Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster multi-contains in rlike regex rewrite #11810

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer
import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

import org.apache.spark.unsafe.types.UTF8String

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -1988,7 +1990,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2057,16 +2059,17 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case RegexSequence(parts) if (isLiteralString(parts)) =>
Seq(UTF8String.fromString(RegexCharsToString(parts)))
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case MultipleContains(ls) => GpuContainsAny(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1233,7 +1233,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim {

override def dataType: DataType = BooleanType
Expand All @@ -1243,19 +1243,13 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String])
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
}
}
withResource(boolCvs.tail) { _ =>
// boolCvs.head and intermediate values are closed within the withResource in the lambda
boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still has not closed item.

val a = Array(1, 2, 5)
a.tail.foldLeft(a.head)((acc, c) => {println(s"acc is $acc, current is $c"); acc + c})

outputs:

acc is 1, current is 2
acc is 3, current is 5

Your approach closed 3 times, but we need to close 4 times(including the intermediate result)
Items need to close: 1, 2, 5 and 3(intermediate result)

Do the following chang:

    withResource(boolCvs.tail) { _ =>
      // boolCvs.head and intermediate values are closed within the withResource in the lambda
      boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)})

==>>

      boolCvs.tail.foldLeft(boolCvs.head)((acc, c) => 
        withResource(acc) { _ => 
          withResource(c) {_ =>
            acc.or(c)
        }
      )

Copy link
Collaborator

@revans2 revans2 Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do a for loop. We are trying really hard to do something fancy when we can just make it simple.

var ret: ColumnVector = null
withResource(boolCvs) { _ =>
  boolCvs.indicies.foreach { i =>
    if (ret == null) {
      ret = boolCvs[i]
      boolCvs[i] = null
    } else {
      val tmp = ret.or(boolCvs[i])
      ret.close()
      ret = tmp
      boolCvs[i].close()
      boolCvs[i] = null
    }
  }
}

Once we have that working, then you can play around with ways to make it cleaner and/or faster.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my previous comment
Haoyang's method is right, I ignored that withResource(boolCvs.tail) will close multiple times if boolCvs.size > 3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do a for loop. We are trying really hard to do something fancy when we can just make it simple.

var ret: ColumnVector = null
withResource(boolCvs) { _ =>
  boolCvs.indicies.foreach { i =>
    if (ret == null) {
      ret = boolCvs[i]
      boolCvs[i] = null
    } else {
      val tmp = ret.or(boolCvs[i])
      ret.close()
      ret = tmp
      boolCvs[i].close()
      boolCvs[i] = null
    }
  }
}

Once we have that working, then you can play around with ways to make it cleaner and/or faster.

Talked with @res-life offline, we think the following method is correct:

withResource(boolCvs.tail) { _ =>
	boolCvs.tail.foldLeft(boolCvs.head)((l, r) => withResource(l) { _ => l.or(r)})
}

because all elements in boolCvs.tail can be closed in outer withResource, boolCvs.head and intermediate results are closed in inner withResource.

But your approach looks better because it can close items in boolCvs earlier after use, so the memory usage is less. Updated to this method with some changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok updated again, this should be cleaner and close values as early as the foreach method:

closeOnExcept(boolCvs.tail) { _ =>
  boolCvs.tail.foldLeft(boolCvs.head) {
    (l, r) => withResource(l) { _ =>
      withResource(r) { _ =>
        l.or(r)
      }
    }
  }
}

please take a look.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.unsafe.types.UTF8String

class RegularExpressionRewriteSuite extends AnyFunSuite {

private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]):
Unit = {
private def verifyRewritePattern(patterns: Seq[String],
excepted: Seq[RegexOptimizationType]): Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
Expand Down Expand Up @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("火花", "急流").map(UTF8String.fromString))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down
Loading