Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cherry-pick #15163
Browse files Browse the repository at this point in the history
kwannoel committed Feb 27, 2024
1 parent 1f95a0d commit 31be2c6
Showing 17 changed files with 260 additions and 24 deletions.
1 change: 1 addition & 0 deletions ci/scripts/run-e2e-test.sh
Original file line number Diff line number Diff line change
@@ -90,6 +90,7 @@ pkill python3

sqllogictest -p 4566 -d dev './e2e_test/udf/alter_function.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/graceful_shutdown_python.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/always_retry_python.slt'
# FIXME: flaky test
# sqllogictest -p 4566 -d dev './e2e_test/udf/retry_python.slt'

75 changes: 75 additions & 0 deletions e2e_test/udf/always_retry_python.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
system ok
python3 e2e_test/udf/test.py &

# wait for server to start
sleep 10s

statement ok
CREATE FUNCTION sleep_always_retry(INT) RETURNS INT AS 'sleep' USING LINK 'http://localhost:8815' WITH ( always_retry_on_network_error = true );

statement ok
CREATE FUNCTION sleep_no_retry(INT) RETURNS INT AS 'sleep' USING LINK 'http://localhost:8815';

# Create a table with 30 records
statement ok
CREATE TABLE t (v1 int);

statement ok
INSERT INTO t select 0 from generate_series(1, 30);

statement ok
flush;

statement ok
SET STREAMING_RATE_LIMIT=1;

statement ok
SET BACKGROUND_DDL=true;

statement ok
CREATE MATERIALIZED VIEW mv_no_retry AS SELECT sleep_no_retry(v1) as s1 from t;

# Create a Materialized View
statement ok
CREATE MATERIALIZED VIEW mv_always_retry AS SELECT sleep_always_retry(v1) as s1 from t;

# Immediately kill the server, sleep for 1minute.
system ok
pkill -9 -i python && sleep 60

# Restart the server
system ok
python3 e2e_test/udf/test.py &

# Wait for materialized view to be complete
statement ok
wait;

query I
SELECT count(*) FROM mv_always_retry where s1 is NULL;
----
0

query B
SELECT count(*) > 0 FROM mv_no_retry where s1 is NULL;
----
t

statement ok
SET STREAMING_RATE_LIMIT=0;

statement ok
SET BACKGROUND_DDL=false;

# close the server
system ok
pkill -i python

statement ok
DROP FUNCTION sleep_always_retry;

statement ok
DROP FUNCTION sleep_no_retry;

statement ok
DROP TABLE t CASCADE;
1 change: 1 addition & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
@@ -220,6 +220,7 @@ message Function {
optional string link = 8;
optional string identifier = 10;
optional string body = 14;
bool always_retry_on_network_error = 16;

oneof kind {
ScalarFunction scalar = 11;
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
@@ -492,6 +492,7 @@ message UserDefinedFunction {
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
bool always_retry_on_network_error = 9;
}

// Additional information for user defined table functions.
53 changes: 32 additions & 21 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
@@ -54,6 +54,8 @@ pub struct UserDefinedFunction {
/// On each successful call, the count will be decreased by 1.
/// See <https://github.com/risingwavelabs/risingwave/issues/13791>.
disable_retry_count: AtomicU8,
/// Always retry. Overrides `disable_retry_count`.
always_retry_on_network_error: bool,
}

const INITIAL_RETRY_COUNT: u8 = 16;
@@ -128,32 +130,40 @@ impl UserDefinedFunction {
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if disable_retry_count != 0 {
let result = if self.always_retry_on_network_error {
client
.call(&self.identifier, input)
.call_with_always_retry_on_network_error(&self.identifier, input)
.instrument_await(self.span.clone())
.await
} else {
client
.call_with_retry(&self.identifier, input)
.instrument_await(self.span.clone())
.await
let result = if disable_retry_count != 0 {
client
.call(&self.identifier, input)
.instrument_await(self.span.clone())
.await
} else {
client
.call_with_retry(&self.identifier, input)
.instrument_await(self.span.clone())
.await
};
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let connection_error = matches!(&result, Err(e) if e.is_connection_error());
if connection_error && disable_retry_count != INITIAL_RETRY_COUNT {
// reset count on connection error
self.disable_retry_count
.store(INITIAL_RETRY_COUNT, Ordering::Relaxed);
} else if !connection_error && disable_retry_count != 0 {
// decrease count on success, ignore if exchange failed
_ = self.disable_retry_count.compare_exchange(
disable_retry_count,
disable_retry_count - 1,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
result
};
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let connection_error = matches!(&result, Err(e) if e.is_connection_error());
if connection_error && disable_retry_count != INITIAL_RETRY_COUNT {
// reset count on connection error
self.disable_retry_count
.store(INITIAL_RETRY_COUNT, Ordering::Relaxed);
} else if !connection_error && disable_retry_count != 0 {
// decrease count on success, ignore if exchange failed
_ = self.disable_retry_count.compare_exchange(
disable_retry_count,
disable_retry_count - 1,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
result?
}
};
@@ -248,6 +258,7 @@ impl Build for UserDefinedFunction {
identifier: identifier.clone(),
span: format!("udf_call({})", identifier).into(),
disable_retry_count: AtomicU8::new(0),
always_retry_on_network_error: udf.always_retry_on_network_error,
})
}
}
19 changes: 19 additions & 0 deletions src/expr/udf/src/external.rs
Original file line number Diff line number Diff line change
@@ -208,6 +208,25 @@ impl ArrowFlightUdfClient {
unreachable!()
}

/// Always retry on connection error
pub async fn call_with_always_retry_on_network_error(
&self,
id: &str,
input: RecordBatch,
) -> Result<RecordBatch> {
let mut backoff = Duration::from_millis(100);
loop {
match self.call(id, input.clone()).await {
Err(err) if err.is_connection_error() => {
tracing::error!(error = %err.as_report(), "UDF connection error. retry...");
}
ret => return ret,
}
tokio::time::sleep(backoff).await;
backoff *= 2;
}
}

/// Call a function with streaming input and output.
#[panic_return = "Result<stream::Empty<_>>"]
pub async fn call_stream(
2 changes: 2 additions & 0 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ pub struct FunctionCatalog {
pub identifier: Option<String>,
pub body: Option<String>,
pub link: Option<String>,
pub always_retry_on_network_error: bool,
}

#[derive(Clone, Display, PartialEq, Eq, Hash, Debug)]
@@ -68,6 +69,7 @@ impl From<&PbFunction> for FunctionCatalog {
identifier: prost.identifier.clone(),
body: prost.body.clone(),
link: prost.link.clone(),
always_retry_on_network_error: prost.always_retry_on_network_error,
}
}
}
2 changes: 2 additions & 0 deletions src/frontend/src/expr/user_defined_function.rs
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ impl UserDefinedFunction {
identifier: udf.identifier.clone(),
body: udf.body.clone(),
link: udf.link.clone(),
always_retry_on_network_error: udf.always_retry_on_network_error,
};

Ok(Self {
@@ -92,6 +93,7 @@ impl Expr for UserDefinedFunction {
identifier: self.catalog.identifier.clone(),
link: self.catalog.link.clone(),
body: self.catalog.body.clone(),
always_retry_on_network_error: self.catalog.always_retry_on_network_error,
})),
}
}
4 changes: 4 additions & 0 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ pub async fn handle_create_function(
args: Option<Vec<OperateFunctionArg>>,
returns: Option<CreateFunctionReturns>,
params: CreateFunctionBody,
with_options: CreateFunctionWithOptions,
) -> Result<RwPgResponse> {
if or_replace {
bail_not_implemented!("CREATE OR REPLACE FUNCTION");
@@ -247,6 +248,9 @@ pub async fn handle_create_function(
link,
body,
owner: session.user_id(),
always_retry_on_network_error: with_options
.always_retry_on_network_error
.unwrap_or_default(),
};

let catalog_writer = session.catalog_writer()?;
1 change: 1 addition & 0 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
@@ -236,6 +236,7 @@ pub async fn handle_create_sql_function(
body: Some(body),
link: None,
owner: session.user_id(),
always_retry_on_network_error: false,
};

let catalog_writer = session.catalog_writer()?;
2 changes: 2 additions & 0 deletions src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
@@ -206,6 +206,7 @@ pub async fn handle(
args,
returns,
params,
with_options,
} => {
// For general udf, `language` clause could be ignored
// refer: https://github.com/risingwavelabs/risingwave/pull/10608
@@ -226,6 +227,7 @@ pub async fn handle(
args,
returns,
params,
with_options,
)
.await
} else {
6 changes: 6 additions & 0 deletions src/meta/model_v2/migration/src/m20230908_072257_init.rs
Original file line number Diff line number Diff line change
@@ -715,6 +715,11 @@ impl MigrationTrait for Migration {
.col(ColumnDef::new(Function::Identifier).string())
.col(ColumnDef::new(Function::Body).string())
.col(ColumnDef::new(Function::Kind).string().not_null())
.col(
ColumnDef::new(Function::AlwaysRetryOnNetworkError)
.boolean()
.not_null(),
)
.foreign_key(
&mut ForeignKey::create()
.name("FK_function_object_id")
@@ -1113,6 +1118,7 @@ enum Function {
Identifier,
Body,
Kind,
AlwaysRetryOnNetworkError,
}

#[derive(DeriveIden)]
2 changes: 2 additions & 0 deletions src/meta/model_v2/src/function.rs
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ pub struct Model {
pub identifier: Option<String>,
pub body: Option<String>,
pub kind: FunctionKind,
pub always_retry_on_network_error: bool,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -100,6 +101,7 @@ impl From<PbFunction> for ActiveModel {
identifier: Set(function.identifier),
body: Set(function.body),
kind: Set(function.kind.unwrap().into()),
always_retry_on_network_error: Set(function.always_retry_on_network_error),
}
}
}
1 change: 1 addition & 0 deletions src/meta/src/controller/mod.rs
Original file line number Diff line number Diff line change
@@ -287,6 +287,7 @@ impl From<ObjectModel<function::Model>> for PbFunction {
identifier: value.0.identifier,
body: value.0.body,
kind: Some(value.0.kind.into()),
always_retry_on_network_error: value.0.always_retry_on_network_error,
}
}
}
101 changes: 101 additions & 0 deletions src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ use alloc::{
vec::Vec,
};
use core::fmt;
use core::fmt::Display;

use itertools::Itertools;
#[cfg(feature = "serde")]
@@ -1168,6 +1169,7 @@ pub enum Statement {
returns: Option<CreateFunctionReturns>,
/// Optional parameters.
params: CreateFunctionBody,
with_options: CreateFunctionWithOptions,
},
/// CREATE AGGREGATE
///
@@ -1536,6 +1538,7 @@ impl fmt::Display for Statement {
args,
returns,
params,
with_options,
} => {
write!(
f,
@@ -1550,6 +1553,7 @@ impl fmt::Display for Statement {
write!(f, " {}", return_type)?;
}
write!(f, "{params}")?;
write!(f, "{with_options}")?;
Ok(())
}
Statement::CreateAggregate {
@@ -2718,6 +2722,57 @@ impl fmt::Display for CreateFunctionBody {
Ok(())
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CreateFunctionWithOptions {
/// Always retry on network errors.
pub always_retry_on_network_error: Option<bool>,
}

/// TODO(kwannoel): Generate from the struct definition instead.
impl CreateFunctionWithOptions {
fn is_empty(&self) -> bool {
self.always_retry_on_network_error.is_none()
}
}

/// TODO(kwannoel): Generate from the struct definition instead.
impl TryFrom<Vec<SqlOption>> for CreateFunctionWithOptions {
type Error = ParserError;

fn try_from(with_options: Vec<SqlOption>) -> Result<Self, Self::Error> {
let mut always_retry_on_network_error = None;
for option in with_options {
if option.name.to_string().to_lowercase() == "always_retry_on_network_error" {
always_retry_on_network_error = Some(option.value == Value::Boolean(true));
} else {
return Err(ParserError::ParserError(format!(
"Unsupported option: {}",
option.name
)));
}
}
Ok(Self {
always_retry_on_network_error,
})
}
}

impl Display for CreateFunctionWithOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_empty() {
return Ok(());
}
let mut options = vec![];
if let Some(always_retry_on_network_error) = self.always_retry_on_network_error {
options.push(format!(
"ALWAYS_RETRY_NETWORK_ERRORS = {}",
always_retry_on_network_error
));
}
write!(f, " WITH ( {} )", display_comma_separated(&options))
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -2918,4 +2973,50 @@ mod tests {
};
assert_eq!("NOT true IS NOT FALSE", format!("{}", unary_op));
}

#[test]
fn test_create_function_display() {
let create_function = Statement::CreateFunction {
temporary: false,
or_replace: false,
name: ObjectName(vec![Ident::new_unchecked("foo")]),
args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]),
returns: Some(CreateFunctionReturns::Value(DataType::Int)),
params: CreateFunctionBody {
language: Some(Ident::new_unchecked("python")),
behavior: Some(FunctionBehavior::Immutable),
as_: Some(FunctionDefinition::SingleQuotedDef("SELECT 1".to_string())),
return_: None,
using: None,
},
with_options: CreateFunctionWithOptions {
always_retry_on_network_error: None,
},
};
assert_eq!(
"CREATE FUNCTION foo(INT) RETURNS INT LANGUAGE python IMMUTABLE AS 'SELECT 1'",
format!("{}", create_function)
);
let create_function = Statement::CreateFunction {
temporary: false,
or_replace: false,
name: ObjectName(vec![Ident::new_unchecked("foo")]),
args: Some(vec![OperateFunctionArg::unnamed(DataType::Int)]),
returns: Some(CreateFunctionReturns::Value(DataType::Int)),
params: CreateFunctionBody {
language: Some(Ident::new_unchecked("python")),
behavior: Some(FunctionBehavior::Immutable),
as_: Some(FunctionDefinition::SingleQuotedDef("SELECT 1".to_string())),
return_: None,
using: None,
},
with_options: CreateFunctionWithOptions {
always_retry_on_network_error: Some(true),
},
};
assert_eq!(
"CREATE FUNCTION foo(INT) RETURNS INT LANGUAGE python IMMUTABLE AS 'SELECT 1' WITH ( ALWAYS_RETRY_NETWORK_ERRORS = true )",
format!("{}", create_function)
);
}
}
4 changes: 3 additions & 1 deletion src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
@@ -2238,14 +2238,16 @@ impl Parser {
};

let params = self.parse_create_function_body()?;

let with_options = self.parse_options_with_preceding_keyword(Keyword::WITH)?;
let with_options = with_options.try_into()?;
Ok(Statement::CreateFunction {
or_replace,
temporary,
name,
args,
returns: return_type,
params,
with_options,
})
}

9 changes: 7 additions & 2 deletions src/sqlparser/tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
@@ -765,6 +765,7 @@ fn parse_create_function() {
)),
..Default::default()
},
with_options: Default::default(),
}
);

@@ -786,7 +787,8 @@ fn parse_create_function() {
"select $1 - $2;".into()
)),
..Default::default()
}
},
with_options: Default::default(),
},
);

@@ -811,7 +813,8 @@ fn parse_create_function() {
right: Box::new(Expr::Parameter { index: 2 }),
}),
..Default::default()
}
},
with_options: Default::default(),
},
);

@@ -842,6 +845,7 @@ fn parse_create_function() {
}),
..Default::default()
},
with_options: Default::default(),
}
);

@@ -865,6 +869,7 @@ fn parse_create_function() {
return_: Some(Expr::Identifier("a".into())),
..Default::default()
},
with_options: Default::default(),
}
);
}

0 comments on commit 31be2c6

Please sign in to comment.