Skip to content

Commit

Permalink
Code changes: import EtcdElectionClient, update import path, define…
Browse files Browse the repository at this point in the history
… `ElectionHandle` as tuple.
  • Loading branch information
shanicky committed Sep 18, 2023
1 parent 382c686 commit 9bfe704
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 40 deletions.
3 changes: 2 additions & 1 deletion src/meta/src/rpc/election/etcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ mod tests {
use tokio::sync::watch::Sender;
use tokio::time;

use crate::rpc::election_client::{ElectionClient, EtcdElectionClient, META_ELECTION_KEY};
use crate::rpc::election::etcd::EtcdElectionClient;
use crate::rpc::election::{ElectionClient, META_ELECTION_KEY};

type ElectionHandle = (Sender<()>, Arc<dyn ElectionClient>);

Expand Down
161 changes: 122 additions & 39 deletions src/meta/src/rpc/election/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ pub(crate) trait SqlDriverCommon {
const ELECTION_LEADER_TABLE_NAME: &'static str = "election_leader";
const ELECTION_MEMBER_TABLE_NAME: &'static str = "election_members";

fn election_table_name(&self) -> &'static str {
fn election_table_name() -> &'static str {
Self::ELECTION_LEADER_TABLE_NAME
}
fn member_table_name(&self) -> &'static str {
fn member_table_name() -> &'static str {
Self::ELECTION_MEMBER_TABLE_NAME
}
}
Expand Down Expand Up @@ -88,7 +88,7 @@ ON CONFLICT (id, service)
DO
UPDATE SET last_heartbeat = EXCLUDED.last_heartbeat;
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(id)
.bind(service_name)
Expand Down Expand Up @@ -120,7 +120,7 @@ ON CONFLICT (service)
END
RETURNING service, id, last_heartbeat;
"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -134,7 +134,7 @@ RETURNING service, id, last_heartbeat;
async fn leader(&self, service_name: &str) -> MetaResult<Option<ElectionRow>> {
let row = sqlx::query_as::<_, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = $1;"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.fetch_optional(&self.pool)
Expand All @@ -146,7 +146,7 @@ RETURNING service, id, last_heartbeat;
async fn candidates(&self, service_name: &str) -> MetaResult<Vec<ElectionRow>> {
let row = sqlx::query_as::<_, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = $1;"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.fetch_all(&self.pool)
Expand All @@ -161,7 +161,7 @@ RETURNING service, id, last_heartbeat;
r#"
DELETE FROM {table} WHERE service = $1 AND id = $2;
"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -172,7 +172,7 @@ RETURNING service, id, last_heartbeat;
r#"
DELETE FROM {table} WHERE service = $1 AND id = $2;
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -194,7 +194,7 @@ VALUES(?, ?, NOW())
ON duplicate KEY
UPDATE last_heartbeat = VALUES(last_heartbeat);
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(id)
.bind(service_name)
Expand All @@ -210,7 +210,7 @@ ON duplicate KEY
id: &str,
ttl: i64,
) -> MetaResult<ElectionRow> {
let row = sqlx::query::<MySql>(&format!(
let _ = sqlx::query::<MySql>(&format!(
r#"INSERT
IGNORE
INTO {table} (service, id, last_heartbeat)
Expand All @@ -221,19 +221,17 @@ ON duplicate KEY
last_heartbeat = if(id =
VALUES(id),
VALUES(last_heartbeat), last_heartbeat);"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
.bind(ttl)
.execute(&self.pool)
.await?;

println!("row {:?}", row);

let row = sqlx::query_as::<MySql, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = ?;"#,
table = self.election_table_name(),
table = Self::election_table_name(),
))
.bind(service_name)
.fetch_one(&self.pool)
Expand All @@ -245,7 +243,7 @@ ON duplicate KEY
async fn leader(&self, service_name: &str) -> MetaResult<Option<ElectionRow>> {
let row = sqlx::query_as::<MySql, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = ?;"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.fetch_optional(&self.pool)
Expand All @@ -257,7 +255,7 @@ ON duplicate KEY
async fn candidates(&self, service_name: &str) -> MetaResult<Vec<ElectionRow>> {
let row = sqlx::query_as::<MySql, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = ?;"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.fetch_all(&self.pool)
Expand All @@ -272,7 +270,7 @@ ON duplicate KEY
r#"
DELETE FROM {table} WHERE service = ? AND id = ?;
"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -283,7 +281,7 @@ ON duplicate KEY
r#"
DELETE FROM {table} WHERE service = ? AND id = ?;
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -306,7 +304,7 @@ ON CONFLICT (id, service)
DO
UPDATE SET last_heartbeat = EXCLUDED.last_heartbeat;
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(id)
.bind(service_name)
Expand Down Expand Up @@ -338,23 +336,21 @@ ON CONFLICT (service)
END
RETURNING service, id, last_heartbeat;
"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
.bind(Duration::from_secs(ttl as u64))
.fetch_one(&self.pool)
.await?;

println!("row {:?}", row);

Ok(row)
}

async fn leader(&self, service_name: &str) -> MetaResult<Option<ElectionRow>> {
let row = sqlx::query_as::<Postgres, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = $1;"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.fetch_optional(&self.pool)
Expand All @@ -366,7 +362,7 @@ RETURNING service, id, last_heartbeat;
async fn candidates(&self, service_name: &str) -> MetaResult<Vec<ElectionRow>> {
let row = sqlx::query_as::<Postgres, ElectionRow>(&format!(
r#"SELECT service, id, last_heartbeat FROM {table} WHERE service = $1;"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.fetch_all(&self.pool)
Expand All @@ -381,7 +377,7 @@ RETURNING service, id, last_heartbeat;
r#"
DELETE FROM {table} WHERE service = $1 AND id = $2;
"#,
table = self.election_table_name()
table = Self::election_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -392,7 +388,7 @@ RETURNING service, id, last_heartbeat;
r#"
DELETE FROM {table} WHERE service = $1 AND id = $2;
"#,
table = self.member_table_name()
table = Self::member_table_name()
))
.bind(service_name)
.bind(id)
Expand All @@ -417,8 +413,6 @@ where
async fn run_once(&self, ttl: i64, stop: Receiver<()>) -> MetaResult<()> {
let stop = stop.clone();

let mut election_ticker = tokio::time::interval(Duration::from_secs(1));

let member_refresh_driver = self.driver.clone();

let id = self.id.clone();
Expand All @@ -431,6 +425,7 @@ where
loop {
tokio::select! {
_ = ticker.tick() => {

if let Err(e) = member_refresh_driver
.update_heartbeat(META_ELECTION_KEY, id.as_str())
.await {
Expand All @@ -456,13 +451,15 @@ where

let mut is_leader = false;

let mut election_ticker = time::interval(Duration::from_secs(1));

loop {
tokio::select! {
_ = election_ticker.tick() => {
let election_row = self
.driver
.try_campaign(META_ELECTION_KEY, self.id.as_str(), ttl)
.await?;
_ = election_ticker.tick() => {
let election_row = self
.driver
.try_campaign(META_ELECTION_KEY, self.id.as_str(), ttl)
.await?;

assert_eq!(election_row.service, META_ELECTION_KEY);

Expand All @@ -472,13 +469,12 @@ where
is_leader = true;
}
} else if is_leader {
tracing::warn!("leader has been changed to {}", election_row.id);
break;

tracing::warn!("leader has been changed to {}", election_row.id);
break;
}

timeout_ticker.reset();
}
timeout_ticker.reset();
}
_ = timeout_ticker.tick() => {
tracing::error!("member {} election timeout", self.id);
break;
Expand All @@ -497,7 +493,6 @@ where
}
}
}

self.is_leader_sender.send_replace(false);

return Ok(());
Expand Down Expand Up @@ -539,3 +534,91 @@ where
*self.is_leader_sender.borrow()
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use sqlx::sqlite::SqlitePoolOptions;
use sqlx::SqlitePool;
use tokio::sync::watch;

use crate::rpc::election::sql::{SqlBackendElectionClient, SqlDriverCommon, SqliteDriver};
use crate::{ElectionClient, MetaResult};

async fn prepare_sqlite_env() -> MetaResult<SqlitePool> {
let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?;
let _ = sqlx::query(
&format!("CREATE TABLE {table} (service VARCHAR(256) PRIMARY KEY, id VARCHAR(256), last_heartbeat DATETIME)",
table = SqliteDriver::election_table_name()))
.execute(&pool).await?;

let _ = sqlx::query(
&format!("CREATE TABLE {table} (service VARCHAR(256), id VARCHAR(256), last_heartbeat DATETIME, PRIMARY KEY (service, id))",
table = SqliteDriver::member_table_name()))
.execute(&pool).await?;

Ok(pool)
}

#[tokio::test]
async fn test_sql_election() {
let id = "test_id".to_string();
let pool = prepare_sqlite_env().await.unwrap();

let provider = SqliteDriver { pool };
let (sender, _) = watch::channel(false);
let sql_election_client: Arc<dyn ElectionClient> = Arc::new(SqlBackendElectionClient {
id,
driver: Arc::new(provider),
is_leader_sender: sender,
});
let (stop_sender, _) = watch::channel(());

let stop_receiver = stop_sender.subscribe();

let mut receiver = sql_election_client.subscribe();
let client_ = sql_election_client.clone();
tokio::spawn(async move { client_.run_once(10, stop_receiver).await.unwrap() });

loop {
receiver.changed().await.unwrap();
if *receiver.borrow() {
assert!(sql_election_client.is_leader().await);
break;
}
}
}

#[tokio::test]
async fn test_sql_election_multi() {
let (stop_sender, _) = watch::channel(());

let mut clients = vec![];

let pool = prepare_sqlite_env().await.unwrap();
for i in 1..3 {
let id = format!("test_id_{}", i);
let provider = SqliteDriver { pool: pool.clone() };
let (sender, _) = watch::channel(false);
let sql_election_client: Arc<dyn ElectionClient> = Arc::new(SqlBackendElectionClient {
id,
driver: Arc::new(provider),
is_leader_sender: sender,
});

let stop_receiver = stop_sender.subscribe();
let client_ = sql_election_client.clone();
tokio::spawn(async move { client_.run_once(10, stop_receiver).await.unwrap() });
clients.push(sql_election_client);
}

let mut is_leaders = vec![];

for client in clients {
is_leaders.push(client.is_leader().await);
}

assert!(is_leaders.iter().filter(|&x| *x).count() <= 1);
}
}

0 comments on commit 9bfe704

Please sign in to comment.