Skip to content

Commit

Permalink
feat: add target_parallelism support for resize command. (#11557)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanicky authored Aug 9, 2023
1 parent 77c0b1c commit ce2a0c8
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 59 deletions.
1 change: 1 addition & 0 deletions proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ message GetReschedulePlanRequest {
message WorkerChanges {
repeated uint32 include_worker_ids = 1;
repeated uint32 exclude_worker_ids = 2;
optional uint32 target_parallelism = 3;
}

message StableResizePolicy {
Expand Down
111 changes: 54 additions & 57 deletions src/ctl/src/cmd_impl/scale/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ use crate::cmd_impl::meta::ReschedulePayload;
use crate::common::CtlContext;
use crate::ScaleResizeCommands;

macro_rules! fail {
($($arg:tt)*) => {{
println!($($arg)*);
exit(1);
}};
}

pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow::Result<()> {
let meta_client = context.meta_client().await?;

Expand All @@ -41,8 +48,7 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
} = match meta_client.get_cluster_info().await {
Ok(resp) => resp,
Err(e) => {
println!("Failed to fetch cluster info: {}", e);
exit(1);
fail!("Failed to fetch cluster info: {}", e);
}
};

Expand Down Expand Up @@ -79,9 +85,13 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
})
.collect::<HashMap<_, _>>();

let worker_input_to_worker_id = |inputs: Vec<String>| -> Vec<u32> {
let worker_input_to_worker_ids = |inputs: Vec<String>, support_all: bool| -> Vec<u32> {
let mut result: HashSet<_> = HashSet::new();

if inputs.len() == 1 && inputs[0].to_lowercase() == "all" && support_all {
return streaming_workers_index_by_id.keys().cloned().collect();
}

for input in inputs {
let worker_id = input.parse::<u32>().ok().or_else(|| {
streaming_workers_index_by_host
Expand All @@ -94,8 +104,7 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
println!("warn: {} and {} are the same worker", input, worker_id);
}
} else {
println!("Invalid worker input: {}", input);
exit(1);
fail!("Invalid worker input: {}", input);
}
}

Expand All @@ -110,53 +119,50 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
let ScaleResizeCommands {
exclude_workers,
include_workers,
target_parallelism,
generate,
output,
yes,
fragments,
} = resize;

let worker_changes = match (exclude_workers, include_workers) {
(None, None) => unreachable!(),
(exclude, include) => {
let excludes = worker_input_to_worker_id(exclude.unwrap_or_default());
let includes = worker_input_to_worker_id(include.unwrap_or_default());
let worker_changes = {
let exclude_worker_ids =
worker_input_to_worker_ids(exclude_workers.unwrap_or_default(), false);
let include_worker_ids =
worker_input_to_worker_ids(include_workers.unwrap_or_default(), true);

for worker_input in excludes.iter().chain(includes.iter()) {
if !streaming_workers_index_by_id.contains_key(worker_input) {
println!("Invalid worker id: {}", worker_input);
exit(1);
}
}
if let Some(target) = target_parallelism && target == 0 {
fail!("Target parallelism must be greater than 0");
}

for include_worker_id in &includes {
let worker_is_unschedulable = streaming_workers_index_by_id
.get(include_worker_id)
.and_then(|worker| worker.property.as_ref())
.map(|property| property.is_unschedulable)
.unwrap_or(false);

if worker_is_unschedulable {
println!(
"Worker {} is unschedulable, should not be included",
include_worker_id
);
exit(1);
}
for worker_id in exclude_worker_ids.iter().chain(include_worker_ids.iter()) {
if !streaming_workers_index_by_id.contains_key(worker_id) {
fail!("Invalid worker id: {}", worker_id);
}
}

WorkerChanges {
include_worker_ids: includes,
exclude_worker_ids: excludes,
for include_worker_id in &include_worker_ids {
let worker_is_unschedulable = streaming_workers_index_by_id
.get(include_worker_id)
.and_then(|worker| worker.property.as_ref())
.map(|property| property.is_unschedulable)
.unwrap_or(false);

if worker_is_unschedulable {
fail!(
"Worker {} is unschedulable, should not be included",
include_worker_id
);
}
}
};

if worker_changes.exclude_worker_ids.is_empty() && worker_changes.include_worker_ids.is_empty()
{
println!("No worker nodes provided");
exit(1)
}
WorkerChanges {
include_worker_ids,
exclude_worker_ids,
target_parallelism,
}
};

let all_fragment_ids: HashSet<_> = table_fragments
.iter()
Expand All @@ -171,13 +177,12 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
.iter()
.any(|fragment_id| !all_fragment_ids.contains(fragment_id))
{
println!(
fail!(
"Invalid fragment ids: {:?}",
provide_fragment_ids
.difference(&all_fragment_ids)
.collect_vec()
);
exit(1);
}

provide_fragment_ids.into_iter().collect()
Expand All @@ -200,14 +205,12 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
} = match response {
Ok(response) => response,
Err(e) => {
println!("Failed to generate plan: {:?}", e);
exit(1);
fail!("Failed to generate plan: {:?}", e);
}
};

if !success {
println!("Failed to generate plan, current revision is {}", revision);
exit(1);
fail!("Failed to generate plan, current revision is {}", revision);
}

if reschedules.is_empty() {
Expand Down Expand Up @@ -254,12 +257,10 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
{
Ok(true) => println!("Processing..."),
Ok(false) => {
println!("Abort.");
exit(1);
fail!("Abort.");
}
Err(_) => {
println!("Error with questionnaire, try again later");
exit(-1);
fail!("Error with questionnaire, try again later");
}
}
}
Expand All @@ -268,14 +269,12 @@ pub async fn resize(context: &CtlContext, resize: ScaleResizeCommands) -> anyhow
match meta_client.reschedule(reschedules, revision, false).await {
Ok(response) => response,
Err(e) => {
println!("Failed to execute plan: {:?}", e);
exit(1);
fail!("Failed to execute plan: {:?}", e);
}
};

if !success {
println!("Failed to execute plan, current revision is {}", revision);
exit(1);
fail!("Failed to execute plan, current revision is {}", revision);
}

println!(
Expand All @@ -297,8 +296,7 @@ pub async fn update_schedulability(
let GetClusterInfoResponse { worker_nodes, .. } = match meta_client.get_cluster_info().await {
Ok(resp) => resp,
Err(e) => {
println!("Failed to get cluster info: {:?}", e);
exit(1);
fail!("Failed to get cluster info: {:?}", e);
}
};

Expand All @@ -325,8 +323,7 @@ pub async fn update_schedulability(
println!("Warn: {} and {} are the same worker", worker, worker_id);
}
} else {
println!("Invalid worker id: {}", worker);
exit(1);
fail!("Invalid worker id: {}", worker);
}
}

Expand Down
10 changes: 8 additions & 2 deletions src/ctl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ enum TableCommands {
}

#[derive(clap::Args, Debug)]
#[clap(group(clap::ArgGroup::new("workers_group").required(true).multiple(true).args(&["include_workers", "exclude_workers"])))]
#[clap(group(clap::ArgGroup::new("workers_group").required(true).multiple(true).args(&["include_workers", "exclude_workers", "target_parallelism"])))]
pub struct ScaleResizeCommands {
/// The worker that needs to be excluded during scheduling, worker_id and worker_host are both
/// supported
Expand All @@ -278,10 +278,16 @@ pub struct ScaleResizeCommands {
#[clap(
long,
value_delimiter = ',',
value_name = "worker_id or worker_host, ..."
value_name = "all or worker_id or worker_host, ..."
)]
include_workers: Option<Vec<String>>,

/// The target parallelism, currently, it is used to limit the target parallelism and only
/// takes effect when the actual parallelism exceeds this value. Can be used in conjunction
/// with exclude/include_workers.
#[clap(long)]
target_parallelism: Option<u32>,

/// Will generate a plan supported by the `reschedule` command and save it to the provided path
/// by the `--output`.
#[clap(long, default_value_t = false)]
Expand Down
14 changes: 14 additions & 0 deletions src/meta/src/stream/scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,7 @@ where
struct WorkerChanges {
include_worker_ids: BTreeSet<WorkerId>,
exclude_worker_ids: BTreeSet<WorkerId>,
target_parallelism: Option<usize>,
}

let mut fragment_worker_changes: HashMap<_, _> = fragment_worker_changes
Expand All @@ -1690,6 +1691,7 @@ where
WorkerChanges {
include_worker_ids: changes.include_worker_ids.into_iter().collect(),
exclude_worker_ids: changes.exclude_worker_ids.into_iter().collect(),
target_parallelism: changes.target_parallelism.map(|p| p as usize),
},
)
})
Expand All @@ -1707,6 +1709,7 @@ where
WorkerChanges {
include_worker_ids,
exclude_worker_ids,
target_parallelism,
},
) in fragment_worker_changes
{
Expand Down Expand Up @@ -1810,6 +1813,17 @@ where
);
}

if let Some(target_parallelism) = target_parallelism {
if target_parallel_unit_ids.len() < target_parallelism {
bail!("Target parallelism {} is greater than schedulable ParallelUnits {}", target_parallelism, target_parallel_unit_ids.len());
}

target_parallel_unit_ids = target_parallel_unit_ids
.into_iter()
.take(target_parallelism)
.collect();
}

let to_expand_parallel_units = target_parallel_unit_ids
.difference(&fragment_parallel_unit_ids)
.cloned()
Expand Down
5 changes: 5 additions & 0 deletions src/tests/simulation/tests/integration_tests/scale/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async fn test_resize_normal() -> Result<()> {
WorkerChanges {
include_worker_ids: vec![],
exclude_worker_ids: removed_workers,
target_parallelism: None,
},
)]),
}))
Expand Down Expand Up @@ -147,6 +148,7 @@ async fn test_resize_single() -> Result<()> {
WorkerChanges {
include_worker_ids: vec![],
exclude_worker_ids: vec![prev_worker.id],
target_parallelism: None,
},
)]),
}))
Expand Down Expand Up @@ -221,13 +223,15 @@ async fn test_resize_single_failed() -> Result<()> {
WorkerChanges {
include_worker_ids: vec![],
exclude_worker_ids: vec![worker_a.id],
target_parallelism: None,
},
),
(
downstream_fragment_id,
WorkerChanges {
include_worker_ids: vec![],
exclude_worker_ids: vec![worker_b.id],
target_parallelism: None,
},
),
]),
Expand Down Expand Up @@ -298,6 +302,7 @@ join mv5 on mv1.v = mv5.v;",
WorkerChanges {
include_worker_ids: vec![],
exclude_worker_ids: removed_worker_ids,
target_parallelism: None,
},
)]),
}))
Expand Down

0 comments on commit ce2a0c8

Please sign in to comment.