Skip to content

Commit

Permalink
adds unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Shri Saran Raj N committed Jan 13, 2025
1 parent adef5b6 commit e195862
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ case class WarmpoolJob(
queryWaitTimeoutMillis, // Used only for interactive queries
queryLoopExecutionFrequency)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
val statementExecutionManager = instantiateStatementExecutionManager(commandContext)

try {
FlintREPL.exponentialBackoffRetry(maxRetries = 5, initialDelay = 2.seconds) {
queryLoop(commandContext)
queryLoop(commandContext, statementExecutionManager)
}
} finally {
sparkSession.stop()
Expand All @@ -93,10 +95,10 @@ case class WarmpoolJob(
}
}

def queryLoop(commandContext: CommandContext): Unit = {
def queryLoop(
commandContext: CommandContext,
statementExecutionManager: StatementExecutionManager): Unit = {
import commandContext._

val statementExecutionManager = instantiateStatementExecutionManager(commandContext)
var canProceed = true

try {
Expand Down Expand Up @@ -158,7 +160,7 @@ case class WarmpoolJob(
Thread.sleep(commandContext.queryLoopExecutionFrequency)
}

private def processStreamingJob(
def processStreamingJob(
applicationId: String,
jobId: String,
query: String,
Expand Down Expand Up @@ -196,7 +198,7 @@ case class WarmpoolJob(
jobOperator.start()
}

private def processInteractiveJob(
def processInteractiveJob(
sparkSession: SparkSession,
commandContext: CommandContext,
flintStatement: FlintStatement,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import java.time.Instant

import scala.concurrent.duration.{Duration, MINUTES}

import org.mockito.ArgumentMatchersSugar
import org.mockito.Mockito.{times, verify, when}
import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.common.scheduler.model.LangType
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.FlintREPLConfConstants.DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY
import org.apache.spark.sql.flint.config.FlintSparkConf

class WarmpoolJobTest
extends SparkFunSuite
with MockitoSugar
with ArgumentMatchersSugar
with JobMatchers {

private val jobId = "testJobId"
private val applicationId = "testApplicationId"
private val INTERACTIVE_JOB_TYPE = "interactive"
private val STREAMING_JOB_TYPE = "streaming"

test("queryLoop calls processInteractiveQuery when interactive query is received") {
val resultIndex = "resultIndex"
val dataSource = "testDataSource"
val mockSparkConf = mock[SparkConf]
val mockSparkSession = mock[SparkSession]
val mockConf = mock[RuntimeConfig]
when(mockSparkSession.conf).thenReturn(mockConf)
when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
.thenReturn("someSessionIndex")
when(mockSparkSession.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, ""))
.thenReturn("datasourceName")
when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH))
.thenReturn(INTERACTIVE_JOB_TYPE)
val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex"))

try {
val commandContext = CommandContext(
applicationId,
jobId,
mockSparkSession,
dataSource,
"",
"",
sessionManager,
Duration(10, MINUTES),
60,
60,
DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)

val flintStatement =
new FlintStatement(
"running",
"select 1",
"30",
"10",
LangType.SQL,
Instant.now().toEpochMilli(),
None)

val warmpoolJob = WarmpoolJob(mockSparkConf, mockSparkSession, Some(resultIndex))
val mockStatementExecutionManager = mock[StatementExecutionManager]
when(mockStatementExecutionManager.getNextStatement()).thenReturn(Some(flintStatement))

warmpoolJob.queryLoop(commandContext, mockStatementExecutionManager)
verify(warmpoolJob, times(1)).processInteractiveJob(*, *, *, *, *)
} catch {
case _: Exception => ()
}
}

test("queryLoop calls processStreamingQuery when streaming query is received") {
val resultIndex = "resultIndex"
val dataSource = "testDataSource"
val mockSparkConf = mock[SparkConf]
val mockSparkSession = mock[SparkSession]
val mockConf = mock[RuntimeConfig]
when(mockSparkSession.conf).thenReturn(mockConf)
when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, ""))
.thenReturn("someSessionIndex")
when(mockSparkSession.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, ""))
.thenReturn("datasourceName")
when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH))
.thenReturn(STREAMING_JOB_TYPE)
val sessionManager = new SessionManagerImpl(mockSparkSession, Some("resultIndex"))
try {
val commandContext = CommandContext(
applicationId,
jobId,
mockSparkSession,
dataSource,
"",
"",
sessionManager,
Duration(10, MINUTES),
60,
60,
DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY)

val flintStatement =
new FlintStatement(
"running",
"select 1",
"30",
"10",
LangType.SQL,
Instant.now().toEpochMilli(),
None)

val warmpoolJob = WarmpoolJob(mockSparkConf, mockSparkSession, Some(resultIndex))
val mockStatementExecutionManager = mock[StatementExecutionManager]
when(mockStatementExecutionManager.getNextStatement()).thenReturn(Some(flintStatement))

warmpoolJob.queryLoop(commandContext, mockStatementExecutionManager)
verify(warmpoolJob, times(1)).processStreamingJob(*, *, *, *, *, *, *, *, *)
} catch {
case _: Exception => ()
}
}
}

0 comments on commit e195862

Please sign in to comment.