Skip to content

Commit

Permalink
feat: try to reduce memory usage in scaling (#15193)
Browse files Browse the repository at this point in the history
Signed-off-by: Shanicky Chen <[email protected]>
  • Loading branch information
shanicky authored Feb 23, 2024
1 parent 0c329e9 commit c6ed6d1
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 40 deletions.
220 changes: 184 additions & 36 deletions src/meta/src/stream/scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ use risingwave_common::catalog::TableId;
use risingwave_common::hash::{ActorMapping, ParallelUnitId, VirtualNode};
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_meta_model_v2::StreamingParallelism;
use risingwave_pb::common::{ActorInfo, ParallelUnit, WorkerNode};
use risingwave_pb::common::{ActorInfo, Buffer, ParallelUnit, ParallelUnitMapping, WorkerNode};
use risingwave_pb::meta::get_reschedule_plan_request::{Policy, StableResizePolicy};
use risingwave_pb::meta::subscribe_response::{Info, Operation};
use risingwave_pb::meta::table_fragments::actor_status::ActorState;
use risingwave_pb::meta::table_fragments::fragment::FragmentDistributionType;
use risingwave_pb::meta::table_fragments::{self, ActorStatus, Fragment, State};
use risingwave_pb::meta::table_fragments::fragment::{
FragmentDistributionType, PbFragmentDistributionType,
};
use risingwave_pb::meta::table_fragments::{self, ActorStatus, PbFragment, State};
use risingwave_pb::meta::FragmentParallelUnitMappings;
use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::{DispatcherType, FragmentTypeFlag, StreamActor, StreamNode};
use risingwave_pb::stream_plan::{
Dispatcher, DispatcherType, FragmentTypeFlag, PbStreamActor, StreamNode,
};
use thiserror_ext::AsReport;
use tokio::sync::oneshot::Receiver;
use tokio::sync::{oneshot, RwLock, RwLockReadGuard, RwLockWriteGuard};
Expand Down Expand Up @@ -105,15 +109,85 @@ pub struct ParallelUnitReschedule {
pub removed_parallel_units: BTreeSet<ParallelUnitId>,
}

pub struct CustomFragmentInfo {
pub fragment_id: u32,
pub fragment_type_mask: u32,
pub distribution_type: PbFragmentDistributionType,
pub vnode_mapping: Option<ParallelUnitMapping>,
pub state_table_ids: Vec<u32>,
pub upstream_fragment_ids: Vec<u32>,
pub actor_template: PbStreamActor,
pub actors: Vec<CustomActorInfo>,
}

#[derive(Default)]
pub struct CustomActorInfo {
pub actor_id: u32,
pub fragment_id: u32,
pub dispatcher: Vec<Dispatcher>,
pub upstream_actor_id: Vec<u32>,
pub vnode_bitmap: Option<Buffer>,
}

impl From<&PbStreamActor> for CustomActorInfo {
fn from(
PbStreamActor {
actor_id,
fragment_id,
dispatcher,
upstream_actor_id,
vnode_bitmap,
..
}: &PbStreamActor,
) -> Self {
CustomActorInfo {
actor_id: *actor_id,
fragment_id: *fragment_id,
dispatcher: dispatcher.clone(),
upstream_actor_id: upstream_actor_id.clone(),
vnode_bitmap: vnode_bitmap.clone(),
}
}
}

impl From<&PbFragment> for CustomFragmentInfo {
fn from(fragment: &PbFragment) -> Self {
CustomFragmentInfo {
fragment_id: fragment.fragment_id,
fragment_type_mask: fragment.fragment_type_mask,
distribution_type: fragment.distribution_type(),
vnode_mapping: fragment.vnode_mapping.clone(),
state_table_ids: fragment.state_table_ids.clone(),
upstream_fragment_ids: fragment.upstream_fragment_ids.clone(),
actor_template: fragment
.actors
.first()
.cloned()
.expect("no actor in fragment"),
actors: fragment.actors.iter().map(CustomActorInfo::from).collect(),
}
}
}

impl CustomFragmentInfo {
pub fn get_fragment_type_mask(&self) -> u32 {
self.fragment_type_mask
}

pub fn distribution_type(&self) -> FragmentDistributionType {
self.distribution_type
}
}

pub struct RescheduleContext {
/// Index used to map `ParallelUnitId` to `WorkerId`
parallel_unit_id_to_worker_id: BTreeMap<ParallelUnitId, WorkerId>,
/// Meta information for all Actors
actor_map: HashMap<ActorId, StreamActor>,
actor_map: HashMap<ActorId, CustomActorInfo>,
/// Status of all Actors, used to find the location of the `Actor`
actor_status: BTreeMap<ActorId, ActorStatus>,
/// Meta information of all `Fragment`, used to find the `Fragment`'s `Actor`
fragment_map: HashMap<FragmentId, Fragment>,
fragment_map: HashMap<FragmentId, CustomFragmentInfo>,
/// Indexes for all `Worker`s
worker_nodes: HashMap<WorkerId, WorkerNode>,
/// Index of all `Actor` upstreams, specific to `Dispatcher`
Expand Down Expand Up @@ -180,7 +254,7 @@ impl RescheduleContext {
///
/// The return value is the bitmap distribution after scaling, which covers all virtual node indexes
pub fn rebalance_actor_vnode(
actors: &[StreamActor],
actors: &[CustomActorInfo],
actors_to_remove: &BTreeSet<ActorId>,
actors_to_create: &BTreeSet<ActorId>,
) -> HashMap<ActorId, Bitmap> {
Expand Down Expand Up @@ -464,16 +538,29 @@ impl ScaleController {
let mut fragment_state = HashMap::new();
let mut fragment_to_table = HashMap::new();

let all_table_fragments = self.list_all_table_fragments().await?;

for table_fragments in all_table_fragments {
// We are reusing code for the metadata manager of both V1 and V2, which will be deprecated in the future.
fn fulfill_index_by_table_fragments_ref(
actor_map: &mut HashMap<u32, CustomActorInfo>,
fragment_map: &mut HashMap<FragmentId, CustomFragmentInfo>,
actor_status: &mut BTreeMap<ActorId, ActorStatus>,
fragment_state: &mut HashMap<FragmentId, State>,
fragment_to_table: &mut HashMap<FragmentId, TableId>,
table_fragments: &TableFragments,
) {
fragment_state.extend(
table_fragments
.fragment_ids()
.map(|f| (f, table_fragments.state())),
);
fragment_map.extend(table_fragments.fragments.clone());
actor_map.extend(table_fragments.actor_map());

for (fragment_id, fragment) in &table_fragments.fragments {
for actor in &fragment.actors {
actor_map.insert(actor.actor_id, CustomActorInfo::from(actor));
}

fragment_map.insert(*fragment_id, CustomFragmentInfo::from(fragment));
}

actor_status.extend(table_fragments.actor_status.clone());

fragment_to_table.extend(
Expand All @@ -483,6 +570,37 @@ impl ScaleController {
);
}

match &self.metadata_manager {
MetadataManager::V1(mgr) => {
let guard = mgr.fragment_manager.get_fragment_read_guard().await;

for table_fragments in guard.table_fragments().values() {
fulfill_index_by_table_fragments_ref(
&mut actor_map,
&mut fragment_map,
&mut actor_status,
&mut fragment_state,
&mut fragment_to_table,
table_fragments,
);
}
}
MetadataManager::V2(_) => {
let all_table_fragments = self.list_all_table_fragments().await?;

for table_fragments in &all_table_fragments {
fulfill_index_by_table_fragments_ref(
&mut actor_map,
&mut fragment_map,
&mut actor_status,
&mut fragment_state,
&mut fragment_to_table,
table_fragments,
);
}
}
};

// NoShuffle relation index
let mut no_shuffle_source_fragment_ids = HashSet::new();
let mut no_shuffle_target_fragment_ids = HashSet::new();
Expand Down Expand Up @@ -608,7 +726,7 @@ impl ScaleController {
}

if (fragment.get_fragment_type_mask() & FragmentTypeFlag::Source as u32) != 0 {
let stream_node = fragment.actors.first().unwrap().get_nodes().unwrap();
let stream_node = fragment.actor_template.nodes.as_ref().unwrap();
if TableFragments::find_stream_source(stream_node).is_some() {
stream_source_fragment_ids.insert(*fragment_id);
}
Expand Down Expand Up @@ -698,7 +816,7 @@ impl ScaleController {
&self,
worker_nodes: &HashMap<WorkerId, WorkerNode>,
actor_infos_to_broadcast: BTreeMap<ActorId, ActorInfo>,
node_actors_to_create: HashMap<WorkerId, Vec<StreamActor>>,
node_actors_to_create: HashMap<WorkerId, Vec<PbStreamActor>>,
broadcast_worker_ids: HashSet<WorkerId>,
) -> MetaResult<()> {
self.stream_rpc_manager
Expand Down Expand Up @@ -963,7 +1081,7 @@ impl ScaleController {

for (actor_to_create, sample_actor) in actors_to_create
.iter()
.zip_eq_debug(repeat(fragment.actors.first().unwrap()).take(actors_to_create.len()))
.zip_eq_debug(repeat(&fragment.actor_template).take(actors_to_create.len()))
{
let new_actor_id = actor_to_create.0;
let mut new_actor = sample_actor.clone();
Expand Down Expand Up @@ -1407,7 +1525,7 @@ impl ScaleController {
fragment_actor_bitmap: &HashMap<FragmentId, HashMap<ActorId, Bitmap>>,
no_shuffle_upstream_actor_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
no_shuffle_downstream_actors_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
new_actor: &mut StreamActor,
new_actor: &mut PbStreamActor,
) -> MetaResult<()> {
let fragment = &ctx.fragment_map.get(&new_actor.fragment_id).unwrap();
let mut applied_upstream_fragment_actor_ids = HashMap::new();
Expand Down Expand Up @@ -1953,33 +2071,63 @@ impl ScaleController {
})
.collect::<HashMap<_, _>>();

let all_table_fragments = self.list_all_table_fragments().await?;

// FIXME: only need actor id and dispatcher info, avoid clone it.
let mut actor_map = HashMap::new();
let mut actor_status = HashMap::new();
// FIXME: only need fragment distribution info, should avoid clone it.
let mut fragment_map = HashMap::new();
let mut fragment_parallelism = HashMap::new();

for table_fragments in all_table_fragments {
for (fragment_id, fragment) in table_fragments.fragments {
fragment
.actors
.iter()
.map(|actor| (actor.actor_id, actor))
.for_each(|(id, actor)| {
actor_map.insert(id as ActorId, actor.clone());
});
// We are reusing code for the metadata manager of both V1 and V2, which will be deprecated in the future.
fn fulfill_index_by_table_fragments_ref(
actor_map: &mut HashMap<u32, CustomActorInfo>,
actor_status: &mut HashMap<ActorId, ActorStatus>,
fragment_map: &mut HashMap<FragmentId, CustomFragmentInfo>,
fragment_parallelism: &mut HashMap<FragmentId, TableParallelism>,
table_fragments: &TableFragments,
) {
for (fragment_id, fragment) in &table_fragments.fragments {
for actor in &fragment.actors {
actor_map.insert(actor.actor_id, CustomActorInfo::from(actor));
}

fragment_map.insert(fragment_id, fragment);
fragment_map.insert(*fragment_id, CustomFragmentInfo::from(fragment));

fragment_parallelism.insert(fragment_id, table_fragments.assigned_parallelism);
fragment_parallelism.insert(*fragment_id, table_fragments.assigned_parallelism);
}

actor_status.extend(table_fragments.actor_status);
actor_status.extend(table_fragments.actor_status.clone());
}

match &self.metadata_manager {
MetadataManager::V1(mgr) => {
let guard = mgr.fragment_manager.get_fragment_read_guard().await;

for table_fragments in guard.table_fragments().values() {
fulfill_index_by_table_fragments_ref(
&mut actor_map,
&mut actor_status,
&mut fragment_map,
&mut fragment_parallelism,
table_fragments,
);
}
}
MetadataManager::V2(_) => {
let all_table_fragments = self.list_all_table_fragments().await?;

for table_fragments in &all_table_fragments {
fulfill_index_by_table_fragments_ref(
&mut actor_map,
&mut actor_status,
&mut fragment_map,
&mut fragment_parallelism,
table_fragments,
);
}
}
};

let mut no_shuffle_source_fragment_ids = HashSet::new();
let mut no_shuffle_target_fragment_ids = HashSet::new();

Expand Down Expand Up @@ -2034,7 +2182,7 @@ impl ScaleController {
},
) in fragment_worker_changes
{
let fragment = match fragment_map.get(&fragment_id).cloned() {
let fragment = match fragment_map.get(&fragment_id) {
None => bail!("Fragment id {} not found", fragment_id),
Some(fragment) => fragment,
};
Expand Down Expand Up @@ -2122,7 +2270,7 @@ impl ScaleController {
// then we re-add the limited parallel units from the limited workers
target_parallel_unit_ids.extend(limited_worker_parallel_unit_ids.into_iter());
}
match fragment.get_distribution_type().unwrap() {
match fragment.distribution_type() {
FragmentDistributionType::Unspecified => unreachable!(),
FragmentDistributionType::Single => {
let single_parallel_unit_id =
Expand Down Expand Up @@ -2274,7 +2422,7 @@ impl ScaleController {
}

pub fn build_no_shuffle_relation_index(
actor_map: &HashMap<ActorId, StreamActor>,
actor_map: &HashMap<ActorId, CustomActorInfo>,
no_shuffle_source_fragment_ids: &mut HashSet<FragmentId>,
no_shuffle_target_fragment_ids: &mut HashSet<FragmentId>,
) {
Expand Down Expand Up @@ -2302,7 +2450,7 @@ impl ScaleController {
}

pub fn build_fragment_dispatcher_index(
actor_map: &HashMap<ActorId, StreamActor>,
actor_map: &HashMap<ActorId, CustomActorInfo>,
fragment_dispatcher_map: &mut HashMap<FragmentId, HashMap<FragmentId, DispatcherType>>,
) {
for actor in actor_map.values() {
Expand All @@ -2324,7 +2472,7 @@ impl ScaleController {

pub fn resolve_no_shuffle_upstream_tables(
fragment_ids: HashSet<FragmentId>,
fragment_map: &HashMap<FragmentId, Fragment>,
fragment_map: &HashMap<FragmentId, CustomFragmentInfo>,
no_shuffle_source_fragment_ids: &HashSet<FragmentId>,
no_shuffle_target_fragment_ids: &HashSet<FragmentId>,
fragment_to_table: &HashMap<FragmentId, TableId>,
Expand Down Expand Up @@ -2394,7 +2542,7 @@ impl ScaleController {

pub fn resolve_no_shuffle_upstream_fragments<T>(
reschedule: &mut HashMap<FragmentId, T>,
fragment_map: &HashMap<FragmentId, Fragment>,
fragment_map: &HashMap<FragmentId, CustomFragmentInfo>,
no_shuffle_source_fragment_ids: &HashSet<FragmentId>,
no_shuffle_target_fragment_ids: &HashSet<FragmentId>,
) -> MetaResult<()>
Expand Down
8 changes: 4 additions & 4 deletions src/meta/src/stream/test_scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ mod tests {
use risingwave_common::buffer::Bitmap;
use risingwave_common::hash::{ActorMapping, ParallelUnitId, ParallelUnitMapping, VirtualNode};
use risingwave_pb::common::ParallelUnit;
use risingwave_pb::stream_plan::StreamActor;

use crate::model::ActorId;
use crate::stream::scale::rebalance_actor_vnode;
use crate::stream::CustomActorInfo;

fn simulated_parallel_unit_nums(min: Option<usize>, max: Option<usize>) -> Vec<usize> {
let mut raw = vec![1, 3, 12, 42, VirtualNode::COUNT];
Expand All @@ -39,13 +39,13 @@ mod tests {
raw
}

fn build_fake_actors(info: &[(ActorId, ParallelUnitId)]) -> Vec<StreamActor> {
fn build_fake_actors(info: &[(ActorId, ParallelUnitId)]) -> Vec<CustomActorInfo> {
let parallel_units = generate_parallel_units(info);

let vnode_bitmaps = ParallelUnitMapping::build(&parallel_units).to_bitmaps();

info.iter()
.map(|(actor_id, parallel_unit_id)| StreamActor {
.map(|(actor_id, parallel_unit_id)| CustomActorInfo {
actor_id: *actor_id,
vnode_bitmap: vnode_bitmaps
.get(parallel_unit_id)
Expand All @@ -64,7 +64,7 @@ mod tests {
.collect_vec()
}

fn check_affinity_for_scale_in(bitmap: &Bitmap, actor: &StreamActor) {
fn check_affinity_for_scale_in(bitmap: &Bitmap, actor: &CustomActorInfo) {
let prev_bitmap = Bitmap::from(actor.vnode_bitmap.as_ref().unwrap());

for idx in 0..VirtualNode::COUNT {
Expand Down

0 comments on commit c6ed6d1

Please sign in to comment.