Skip to content

Commit

Permalink
Handle cli interrupt gracefully
Browse files Browse the repository at this point in the history
Print an informative message on how to proceed when any `send`,
`receive`, or `resume` command is interrupted.

Use `tokio::signal::ctrl_c` to do this without introducing any new
dependency.
  • Loading branch information
DanGould committed Jun 27, 2024
1 parent 1656f90 commit 56c587d
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 78 deletions.
88 changes: 57 additions & 31 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

133 changes: 86 additions & 47 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use payjoin::bitcoin::Amount;
use payjoin::receive::v2::ActiveSession;
use payjoin::send::RequestContext;
use payjoin::{base64, bitcoin, Error, Uri};
use tokio::signal;
use tokio::sync::watch;

use super::config::AppConfig;
use super::App as AppTrait;
Expand All @@ -19,13 +21,16 @@ use crate::db::Database;
pub(crate) struct App {
config: AppConfig,
db: Arc<Database>,
interrupt: watch::Receiver<()>,
}

#[async_trait::async_trait]
impl AppTrait for App {
fn new(config: AppConfig) -> Result<Self> {
let db = Arc::new(Database::create(&config.db_path)?);
let app = Self { config, db };
let (interrupt_tx, interrupt_rx) = watch::channel(());
tokio::spawn(handle_interrupt(interrupt_tx));
let app = Self { config, db, interrupt: interrupt_rx };
app.bitcoind()?
.get_blockchain_info()
.context("Failed to connect to bitcoind. Check config RPC connection.")?;
Expand Down Expand Up @@ -102,9 +107,16 @@ impl AppTrait for App {

impl App {
async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> {
let res = self.long_poll_post(&mut req_ctx).await?;
self.process_pj_response(res)?;
self.db.clear_send_session(req_ctx.endpoint())?;
let mut interrupt = self.interrupt.clone();
tokio::select! {
res = self.long_poll_post(&mut req_ctx) => {
self.process_pj_response(res?)?;
self.db.clear_send_session(req_ctx.endpoint())?;
}
_ = interrupt.changed() => {
println!("Interrupted. Call `send` with the same arguments to resume this session or `resume` to resume all sessions.");
}
}
Ok(())
}

Expand All @@ -123,60 +135,80 @@ impl App {
println!("Request Payjoin by sharing this Payjoin Uri:");
println!("{}", pj_uri);

let res = self.long_poll_fallback(&mut session).await?;
println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:");
println!("{}", serialize_hex(&res.extract_tx_to_schedule_broadcast()));
let mut payjoin_proposal = self
.process_v2_proposal(res)
.map_err(|e| anyhow!("Failed to process proposal {}", e))?;
let (req, ohttp_ctx) = payjoin_proposal
.extract_v2_req()
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let http = http_agent()?;
let res = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
payjoin_proposal
.process_res(res.bytes().await?.to_vec(), ohttp_ctx)
.map_err(|e| anyhow!("Failed to deserialize response {}", e))?;
let payjoin_psbt = payjoin_proposal.psbt().clone();
println!(
"Response successful. Watch mempool for successful Payjoin. TXID: {}",
payjoin_psbt.extract_tx().clone().txid()
);
self.db.clear_recv_session()?;
let mut interrupt = self.interrupt.clone();
tokio::select! {
res = self.long_poll_fallback(&mut session) => {
let res = res?;
println!("Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails:");
println!("{}", serialize_hex(&res.extract_tx_to_schedule_broadcast()));
let mut payjoin_proposal = self
.process_v2_proposal(res)
.map_err(|e| anyhow!("Failed to process proposal {}", e))?;
let (req, ohttp_ctx) = payjoin_proposal
.extract_v2_req()
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let http = http_agent()?;
let res = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
payjoin_proposal
.process_res(res.bytes().await?.to_vec(), ohttp_ctx)
.map_err(|e| anyhow!("Failed to deserialize response {}", e))?;
let payjoin_psbt = payjoin_proposal.psbt().clone();
println!(
"Response successful. Watch mempool for successful Payjoin. TXID: {}",
payjoin_psbt.extract_tx().clone().txid()
);
self.db.clear_recv_session()?;
}
_ = interrupt.changed() => {
println!("Interrupted. Call the `resume` command to resume all sessions.");

}
}
Ok(())
}

pub async fn resume_payjoins(&self) -> Result<()> {
let recv_sessions = self.db.get_recv_sessions()?;
let send_sessions = self.db.get_send_sessions()?;

if recv_sessions.is_empty() && send_sessions.is_empty() {
println!("No sessions to resume.");
return Ok(());
}

let mut tasks = Vec::new();

let recv_sessions = self.db.get_recv_sessions()?;
for recv_session in recv_sessions {
for session in recv_sessions {
let self_clone = self.clone();
tasks.push(tokio::task::spawn(async move {
self_clone.spawn_payjoin_receiver(recv_session, None).await
tasks.push(tokio::spawn(async move {
self_clone.spawn_payjoin_receiver(session, None).await
}));
}
let send_sessions = self.db.get_send_sessions()?;
for send_session in send_sessions {

for session in send_sessions {
let self_clone = self.clone();
tasks.push(tokio::task::spawn(async move {
self_clone.spawn_payjoin_sender(send_session).await
}));
tasks.push(tokio::spawn(async move { self_clone.spawn_payjoin_sender(session).await }));
}
if tasks.is_empty() {
println!("No sessions to resume.");
} else {
for task in tasks {
let _ = task.await?;

let mut interrupt = self.interrupt.clone();
tokio::select! {
_ = async {
for task in tasks {
let _ = task.await;
}
} => {
println!("All resumed sessions completed.");
}
_ = interrupt.changed() => {
println!("Resumed sessions were interrupted.");
}
println!("All resumed sessions completed.");
}
Ok(())
}
Expand All @@ -199,7 +231,7 @@ impl App {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
std::thread::sleep(std::time::Duration::from_secs(5))
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
Err(re) => {
println!("{}", re);
Expand Down Expand Up @@ -351,6 +383,13 @@ async fn unwrap_ohttp_keys_or_else_fetch(config: &AppConfig) -> Result<payjoin::
}
}

async fn handle_interrupt(tx: watch::Sender<()>) {
if let Err(e) = signal::ctrl_c().await {
eprintln!("Error setting up Ctrl-C handler: {}", e);
}
let _ = tx.send(());
}

fn map_reqwest_err(e: reqwest::Error) -> anyhow::Error {
match e.status() {
Some(status_code) => anyhow!("HTTP request failed: {} {}", status_code, e),
Expand Down

0 comments on commit 56c587d

Please sign in to comment.