Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(iceberg): fix jni context class loader #17478

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ public class JniSinkCoordinatorHandler {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkCoordinatorHandler.class);

public static void runJniSinkCoordinatorThread(long requestRxPtr, long responseTxPtr) {
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
JniSinkCoordinatorResponseObserver responseObserver =
new JniSinkCoordinatorResponseObserver(responseTxPtr);
SinkCoordinatorStreamObserver sinkCoordinatorStreamObserver =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ public static byte[] validate(byte[] validateSinkRequestBytes)
try {
var request =
ConnectorServiceProto.ValidateSinkRequest.parseFrom(validateSinkRequestBytes);

// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());

ConnectorServiceProto.SinkParam sinkParam = request.getSinkParam();
TableSchema tableSchema = TableSchema.fromProto(sinkParam.getTableSchema());
String connectorName = getConnectorName(request.getSinkParam());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ public class JniSinkWriterHandler {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkWriterHandler.class);

public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr) {
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
JniSinkWriterResponseObserver responseObserver =
new JniSinkWriterResponseObserver(responseTxPtr);
SinkWriterStreamObserver sinkWriterStreamObserver =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ public static byte[] validate(byte[] validateSourceRequestBytes)
var request =
ConnectorServiceProto.ValidateSourceRequest.parseFrom(
validateSourceRequestBytes);

// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
validateSource(request);
// validate pass
return ConnectorServiceProto.ValidateSourceResponse.newBuilder().build().toByteArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** handler for starting a debezium source connectors for jni */

/** handler for starting a debezium source connectors for jni */
public class JniDbzSourceHandler {
static final Logger LOG = LoggerFactory.getLogger(JniDbzSourceHandler.class);
Expand All @@ -56,10 +54,6 @@ public static void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long
throws Exception {
var request =
ConnectorServiceProto.GetEventStreamRequest.parseFrom(getEventStreamRequestBytes);

// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
// userProps extracted from request, underlying implementation is UnmodifiableMap
Map<String, String> mutableUserProps = new HashMap<>(request.getPropertiesMap());
mutableUserProps.put("source.id", Long.toString(request.getSourceId()));
Expand Down
98 changes: 48 additions & 50 deletions src/connector/src/sink/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use risingwave_common::bail;
use risingwave_common::catalog::{ColumnDesc, ColumnId};
use risingwave_common::session_config::sink_decouple::SinkDecouple;
use risingwave_common::types::DataType;
use risingwave_jni_core::jvm_runtime::JVM;
use risingwave_jni_core::jvm_runtime::{execute_with_jni_env, JVM};
use risingwave_jni_core::{
call_static_method, gen_class_name, JniReceiverType, JniSenderType, JniSinkWriterStreamRequest,
};
Expand Down Expand Up @@ -221,28 +221,29 @@ async fn validate_remote_sink(param: &SinkParam, sink_name: &str) -> ConnectorRe
let sink_param = param.to_proto();

spawn_blocking(move || -> anyhow::Result<()> {
let mut env = jvm.attach_current_thread()?;
let validate_sink_request = ValidateSinkRequest {
sink_param: Some(sink_param),
};
let validate_sink_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&validate_sink_request))?;

let validate_sink_response_bytes = call_static_method!(
env,
{com.risingwave.connector.JniSinkValidationHandler},
{byte[] validate(byte[] validateSourceRequestBytes)},
&validate_sink_request_bytes
)?;

let validate_sink_response: ValidateSinkResponse = Message::decode(
risingwave_jni_core::to_guarded_slice(&validate_sink_response_bytes, &mut env)?.deref(),
)?;

validate_sink_response.error.map_or_else(
|| Ok(()), // If there is no error message, return Ok here.
|err| bail!("sink cannot pass validation: {}", err.error_message),
)
execute_with_jni_env(jvm, |env| {
let validate_sink_request = ValidateSinkRequest {
sink_param: Some(sink_param),
};
let validate_sink_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&validate_sink_request))?;

let validate_sink_response_bytes = call_static_method!(
env,
{com.risingwave.connector.JniSinkValidationHandler},
{byte[] validate(byte[] validateSourceRequestBytes)},
&validate_sink_request_bytes
)?;

let validate_sink_response: ValidateSinkResponse = Message::decode(
risingwave_jni_core::to_guarded_slice(&validate_sink_response_bytes, env)?.deref(),
)?;

validate_sink_response.error.map_or_else(
|| Ok(()), // If there is no error message, return Ok here.
|err| bail!("sink cannot pass validation: {}", err.error_message),
)
})
})
.await
.context("JoinHandle returns error")??;
Expand Down Expand Up @@ -770,34 +771,31 @@ impl EmbeddedConnectorClient {

let jvm = self.jvm;
std::thread::spawn(move || {
let mut env = match jvm
.attach_current_thread()
.context("failed to attach current thread")
{
Ok(env) => env,
Err(e) => {
let _ = response_tx.blocking_send(Err(e));
return;
}
};
let result = execute_with_jni_env(jvm, |env| {
let result = call_static_method!(
env,
class_name,
method_name,
{{void}, {long requestRx, long responseTx}},
&mut request_rx as *mut JniReceiverType<REQ>,
&mut response_tx as *mut JniSenderType<RSP>
);

let result = call_static_method!(
env,
class_name,
method_name,
{{void}, {long requestRx, long responseTx}},
&mut request_rx as *mut JniReceiverType<REQ>,
&mut response_tx as *mut JniSenderType<RSP>
);

match result {
Ok(_) => {
tracing::info!("end of jni call {}::{}", class_name, method_name);
}
Err(e) => {
tracing::error!(error = %e.as_report(), "jni call error");
}
};
match result {
Ok(_) => {
tracing::info!("end of jni call {}::{}", class_name, method_name);
}
Err(e) => {
tracing::error!(error = %e.as_report(), "jni call error");
}
};

Ok(())
});

if let Err(e) = result {
let _ = response_tx.blocking_send(Err(e));
}
Comment on lines +796 to +798
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenym1 Do I correctly handle the error in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

});
response_rx
}
Expand Down
69 changes: 37 additions & 32 deletions src/connector/src/source/cdc/enumerator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use itertools::Itertools;
use prost::Message;
use risingwave_common::util::addr::HostAddr;
use risingwave_jni_core::call_static_method;
use risingwave_jni_core::jvm_runtime::JVM;
use risingwave_jni_core::jvm_runtime::{execute_with_jni_env, JVM};
use risingwave_pb::connector_service::{SourceType, ValidateSourceRequest, ValidateSourceResponse};

use crate::error::ConnectorResult;
Expand Down Expand Up @@ -70,39 +70,44 @@ where
SourceType::from(T::source_type())
);

let jvm = JVM.get_or_init()?;
let source_id = context.info.source_id;
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let mut env = JVM.get_or_init()?.attach_current_thread()?;

let validate_source_request = ValidateSourceRequest {
source_id: source_id as u64,
source_type: props.get_source_type_pb() as _,
properties: props.properties,
table_schema: Some(table_schema_exclude_additional_columns(&props.table_schema)),
is_source_job: props.is_cdc_source_job,
is_backfill_table: props.is_backfill_table,
};

let validate_source_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;

let validate_source_response_bytes = call_static_method!(
env,
{com.risingwave.connector.source.JniSourceValidateHandler},
{byte[] validate(byte[] validateSourceRequestBytes)},
&validate_source_request_bytes
)?;

let validate_source_response: ValidateSourceResponse = Message::decode(
risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, &mut env)?
.deref(),
)?;

if let Some(error) = validate_source_response.error {
return Err(anyhow!(error.error_message).context("source cannot pass validation"));
}

Ok(())
execute_with_jni_env(jvm, |env| {
let validate_source_request = ValidateSourceRequest {
source_id: source_id as u64,
source_type: props.get_source_type_pb() as _,
properties: props.properties,
table_schema: Some(table_schema_exclude_additional_columns(
&props.table_schema,
)),
is_source_job: props.is_cdc_source_job,
is_backfill_table: props.is_backfill_table,
};

let validate_source_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&validate_source_request))?;

let validate_source_response_bytes = call_static_method!(
env,
{com.risingwave.connector.source.JniSourceValidateHandler},
{byte[] validate(byte[] validateSourceRequestBytes)},
&validate_source_request_bytes
)?;

let validate_source_response: ValidateSourceResponse = Message::decode(
risingwave_jni_core::to_guarded_slice(&validate_source_response_bytes, env)?
.deref(),
)?;

if let Some(error) = validate_source_response.error {
return Err(
anyhow!(error.error_message).context("source cannot pass validation")
);
}

Ok(())
})
})
.await
.context("failed to validate source")??;
Expand Down
67 changes: 36 additions & 31 deletions src/connector/src/source/cdc/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use prost::Message;
use risingwave_common::bail;
use risingwave_common::metrics::GLOBAL_ERROR_METRICS;
use risingwave_common::util::addr::HostAddr;
use risingwave_jni_core::jvm_runtime::JVM;
use risingwave_jni_core::jvm_runtime::{execute_with_jni_env, JVM};
use risingwave_jni_core::{call_static_method, JniReceiverType, JniSenderType};
use risingwave_pb::connector_service::{GetEventStreamRequest, GetEventStreamResponse};
use thiserror_ext::AsReport;
Expand Down Expand Up @@ -111,38 +111,43 @@ impl<T: CdcSourceTypeTrait> SplitReader for CdcSplitReader<T> {
};

std::thread::spawn(move || {
let result: anyhow::Result<_> = try {
let env = jvm.attach_current_thread()?;
let get_event_stream_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))?;
(env, get_event_stream_request_bytes)
};

let (mut env, get_event_stream_request_bytes) = match result {
Ok(inner) => inner,
Err(e) => {
let _ = tx
.blocking_send(Err(e.context("err before calling runJniDbzSourceThread")));
return;
execute_with_jni_env(jvm, |env| {
let result: anyhow::Result<_> = try {
let get_event_stream_request_bytes = env.byte_array_from_slice(
&Message::encode_to_vec(&get_event_stream_request),
)?;
(env, get_event_stream_request_bytes)
};

let (env, get_event_stream_request_bytes) = match result {
Ok(inner) => inner,
Err(e) => {
let _ = tx.blocking_send(Err(
e.context("err before calling runJniDbzSourceThread")
));
return Ok(());
}
};

let result = call_static_method!(
env,
{com.risingwave.connector.source.core.JniDbzSourceHandler},
{void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr)},
&get_event_stream_request_bytes,
&mut tx as *mut JniSenderType<GetEventStreamResponse>
);

match result {
Ok(_) => {
tracing::info!(?source_id, "end of jni call runJniDbzSourceThread");
}
Err(e) => {
tracing::error!(?source_id, error = %e.as_report(), "jni call error");
}
}
};

let result = call_static_method!(
env,
{com.risingwave.connector.source.core.JniDbzSourceHandler},
{void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr)},
&get_event_stream_request_bytes,
&mut tx as *mut JniSenderType<GetEventStreamResponse>
);

match result {
Ok(_) => {
tracing::info!(?source_id, "end of jni call runJniDbzSourceThread");
}
Err(e) => {
tracing::error!(?source_id, error = %e.as_report(), "jni call error");
}
}
Ok(())
})
});

// wait for the handshake message
Expand Down
Loading
Loading