diff --git a/src/import/ingest.rs b/src/import/ingest.rs index b46ff25fb4b..7ce8d6ca0d1 100644 --- a/src/import/ingest.rs +++ b/src/import/ingest.rs @@ -162,6 +162,11 @@ pub(super) fn async_snapshot( ..Default::default() }); async move { + fail::fail_point!("failed_to_async_snapshot", |_| { + let mut e = errorpb::Error::default(); + e.set_message("faild to get snapshot".to_string()); + Err(e) + }); res.await.map_err(|e| { let err: storage::Error = e.into(); if let Some(e) = extract_region_error_from_error(&err) { diff --git a/src/import/sst_service.rs b/src/import/sst_service.rs index f5c4bf809f7..2973dda3caf 100644 --- a/src/import/sst_service.rs +++ b/src/import/sst_service.rs @@ -11,7 +11,8 @@ use engine_traits::{CompactExt, CF_DEFAULT, CF_WRITE}; use file_system::{set_io_type, IoType}; use futures::{sink::SinkExt, stream::TryStreamExt, FutureExt, TryFutureExt}; use grpcio::{ - ClientStreamingSink, RequestStream, RpcContext, ServerStreamingSink, UnarySink, WriteFlags, + ClientStreamingSink, RequestStream, RpcContext, RpcStatus, RpcStatusCode, ServerStreamingSink, + UnarySink, WriteFlags, }; use kvproto::{ encryptionpb::EncryptionMethod, @@ -1136,15 +1137,18 @@ impl ImportSst for ImportSstService { IMPORT_RPC_DURATION .with_label_values(&[label, "ok"]) .observe(timer.saturating_elapsed_secs()); + let _ = sink.close().await; } Err(e) => { warn!( "connection send message fail"; "err" => %e ); + let status = + RpcStatus::with_message(RpcStatusCode::UNKNOWN, format!("{:?}", e)); + let _ = sink.fail(status).await; } } - let _ = sink.close().await; return; } }; @@ -1160,7 +1164,10 @@ impl ImportSst for ImportSstService { "connection send message fail"; "err" => %e ); - break; + let status = + RpcStatus::with_message(RpcStatusCode::UNKNOWN, format!("{:?}", e)); + let _ = sink.fail(status).await; + return; } } let _ = sink.close().await; diff --git a/tests/failpoints/cases/test_import_service.rs b/tests/failpoints/cases/test_import_service.rs index 57504d2c722..bd139f3859d 100644 --- a/tests/failpoints/cases/test_import_service.rs +++ b/tests/failpoints/cases/test_import_service.rs @@ -6,7 +6,7 @@ use std::{ }; use file_system::calc_crc32; -use futures::executor::block_on; +use futures::{executor::block_on, stream::StreamExt}; use grpcio::{ChannelBuilder, Environment}; use kvproto::{disk_usage::DiskUsage, import_sstpb::*, tikvpb_grpc::TikvClient}; use tempfile::{Builder, TempDir}; @@ -499,3 +499,91 @@ fn test_flushed_applied_index_after_ingset() { fail::remove("on_apply_ingest"); fail::remove("on_flush_completed"); } + +#[test] +fn test_duplicate_detect_with_client_stop() { + let (_cluster, ctx, _, import) = new_cluster_and_tikv_import_client(); + let mut req = SwitchModeRequest::default(); + req.set_mode(SwitchMode::Import); + import.switch_mode(&req).unwrap(); + + let data_count: u64 = 4096; + for commit_ts in 0..4 { + let mut meta = new_sst_meta(0, 0); + meta.set_region_id(ctx.get_region_id()); + meta.set_region_epoch(ctx.get_region_epoch().clone()); + + let mut keys = vec![]; + let mut values = vec![]; + for i in 1000..data_count { + let key = i.to_string(); + keys.push(key.as_bytes().to_vec()); + values.push(key.as_bytes().to_vec()); + } + let resp = send_write_sst(&import, &meta, keys, values, commit_ts).unwrap(); + for m in resp.metas.into_iter() { + must_ingest_sst(&import, ctx.clone(), m.clone()); + } + } + + let mut duplicate = DuplicateDetectRequest::default(); + duplicate.set_context(ctx); + duplicate.set_start_key((0_u64).to_string().as_bytes().to_vec()); + + // failed to get snapshot. and stream is normal, it will get response with err. + fail::cfg("failed_to_async_snapshot", "return()").unwrap(); + let mut stream = import.duplicate_detect(&duplicate).unwrap(); + let resp = block_on(async move { + let resp: DuplicateDetectResponse = stream.next().await.unwrap().unwrap(); + resp + }); + assert_eq!( + resp.get_region_error().get_message(), + "faild to get snapshot" + ); + + // failed to get snapshot, and stream stops. + // A stopeed remote don't cause panic in server. + let stream = import.duplicate_detect(&duplicate).unwrap(); + drop(stream); + + // drop stream after received part of response. + // A stopped remote must not cause panic at server. + fail::remove("failed_to_async_snapshot"); + let mut stream = import.duplicate_detect(&duplicate).unwrap(); + let ret: Vec = block_on(async move { + let mut resp: DuplicateDetectResponse = stream.next().await.unwrap().unwrap(); + let pairs = resp.take_pairs(); + // drop stream, Do not cause panic at server. + drop(stream); + pairs.into() + }); + + assert_eq!(ret.len(), 4096); + + // call duplicate_detect() successfully. + let mut stream = import.duplicate_detect(&duplicate).unwrap(); + let ret = block_on(async move { + let mut ret: Vec = vec![]; + while let Some(resp) = stream.next().await { + match resp { + Ok(mut resp) => { + if resp.has_key_error() || resp.has_region_error() { + break; + } + let pairs = resp.take_pairs(); + ret.append(&mut pairs.into()); + } + Err(e) => { + println!("receive error: {:?}", e); + break; + } + } + } + + ret + }); + assert_eq!(ret.len(), (data_count - 1000) as usize * 4); + req.set_mode(SwitchMode::Normal); + import.switch_mode(&req).unwrap(); +}