diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index a43dcdf8f..be12de0d2 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -9,13 +9,14 @@ use payjoin::receive::v2::{ replay_event_log as replay_receiver_event_log, HasReplyableError, Initialized, MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, PendingFallback as ReceiverPendingFallback, ProvisionalProposal, ReceiveSession, Receiver, - ReceiverBuilder, SessionOutcome as ReceiverSessionOutcome, UncheckedOriginalPayload, - WantsFeeRange, WantsInputs, WantsOutputs, + ReceiverBuilder, SessionOutcome as ReceiverSessionOutcome, + SessionStatus as ReceiverSessionStatus, UncheckedOriginalPayload, WantsFeeRange, WantsInputs, + WantsOutputs, }; use payjoin::send::v2::{ replay_event_log as replay_sender_event_log, PendingFallback as SenderPendingFallback, PollingForProposal, SendSession, Sender, SenderBuilder, SessionOutcome as SenderSessionOutcome, - WithReplyKey, + SessionStatus as SenderSessionStatus, WithReplyKey, }; use payjoin::{ImplementationError, PjParam, Uri}; use tokio::sync::watch; @@ -33,7 +34,6 @@ mod ohttp; const W_ID: usize = 12; const W_ROLE: usize = 25; -const W_DONE: usize = 15; const W_STATUS: usize = 15; #[derive(Clone)] @@ -46,27 +46,32 @@ pub(crate) struct App { } trait StatusText { - fn status_text(&self) -> &'static str; + fn status_text(&self, has_fallback_tx: bool) -> String; } impl StatusText for SendSession { - fn status_text(&self) -> &'static str { + fn status_text(&self, has_fallback_tx: bool) -> String { match self { SendSession::WithReplyKey(_) | SendSession::PollingForProposal(_) => - "Waiting for proposal", + "Waiting for proposal".to_string(), SendSession::Closed(session_outcome) => match session_outcome { - SenderSessionOutcome::Aborted => "Session aborted", - SenderSessionOutcome::Success(_) => "Session success", + SenderSessionOutcome::Aborted => + if has_fallback_tx { + "Session aborted, Fallback transaction available".to_string() + } else { + "Session aborted".to_string() + }, + SenderSessionOutcome::Success(_) => "Session success".to_string(), }, - SendSession::PendingFallback(_) => "Session awaiting fallback", + SendSession::PendingFallback(_) => "Session awaiting fallback".to_string(), } } } impl StatusText for ReceiveSession { - fn status_text(&self) -> &'static str { + fn status_text(&self, has_fallback_tx: bool) -> String { match self { - ReceiveSession::Initialized(_) => "Waiting for original proposal", + ReceiveSession::Initialized(_) => "Waiting for original proposal".to_string(), ReceiveSession::UncheckedOriginalPayload(_) | ReceiveSession::MaybeInputsOwned(_) | ReceiveSession::MaybeInputsSeen(_) @@ -74,28 +79,31 @@ impl StatusText for ReceiveSession { | ReceiveSession::WantsOutputs(_) | ReceiveSession::WantsInputs(_) | ReceiveSession::WantsFeeRange(_) - | ReceiveSession::ProvisionalProposal(_) => "Processing original proposal", - ReceiveSession::PayjoinProposal(_) => "Payjoin proposal sent", + | ReceiveSession::ProvisionalProposal(_) => "Processing original proposal".to_string(), + ReceiveSession::PayjoinProposal(_) => "Payjoin proposal sent".to_string(), ReceiveSession::HasReplyableError(_) => - "Session failure, waiting to post error response", - ReceiveSession::Monitor(_) => "Monitoring payjoin proposal", - ReceiveSession::PendingFallback(_) => "Pending fallback handling", + "Session failure, waiting to post error response".to_string(), + ReceiveSession::Monitor(_) => "Monitoring payjoin proposal".to_string(), + ReceiveSession::PendingFallback(_) => "Pending fallback handling".to_string(), ReceiveSession::Closed(session_outcome) => match session_outcome { - ReceiverSessionOutcome::Aborted => "Session aborted", - ReceiverSessionOutcome::Success(_) => "Session success, Payjoin proposal was broadcasted", - ReceiverSessionOutcome::FallbackBroadcasted => "Fallback broadcasted", + ReceiverSessionOutcome::Aborted => if has_fallback_tx { + "Session aborted, Fallback Tx available".to_string() + } else { + "Session aborted".to_string() + }, + ReceiverSessionOutcome::Success(_) => + "Session success, Payjoin proposal was broadcasted".to_string(), + ReceiverSessionOutcome::FallbackBroadcasted => "Fallback broadcasted".to_string(), ReceiverSessionOutcome::PayjoinProposalSent => - "Payjoin proposal sent, skipping monitoring as the sender is spending non-SegWit inputs", + "Payjoin proposal sent, skipping monitoring as the sender is spending non-SegWit inputs" + .to_string(), }, } } } fn print_header() { - println!( - "{: { session_id: SessionId, role: Role, status: Status, - completed_at: Option, + has_fallback_tx: bool, error_message: Option, } impl fmt::Display for SessionHistoryRow { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let status_text = self.status.status_text(self.has_fallback_tx); write!( f, - "{: "Not Completed".to_string(), - Some(secs) => { - // TODO: human readable time - secs.to_string() - } - }, - self.error_message.as_deref().unwrap_or(self.status.status_text()) + self.error_message.as_deref().unwrap_or(status_text.as_str()) ) } } @@ -243,6 +245,7 @@ impl AppTrait for App { let psbt = self.create_original_psbt(&address, amount, fee_rate)?; let persister = SenderPersister::new(self.db.clone(), bip21, receiver_pubkey)?; + println!("Send session {} established", persister.session_id()); let sender = SenderBuilder::from_parts(psbt, pj_param, &address, Some(amount)) .build_recommended(fee_rate)? @@ -288,7 +291,7 @@ impl AppTrait for App { } let session = receiver_builder.build().save(&persister)?; - println!("Receive session established"); + println!("Receive session {} established", persister.session_id()); let pj_uri = session.pj_uri(); println!("Request Payjoin by sharing this Payjoin Uri:"); println!("{pj_uri}"); @@ -302,11 +305,6 @@ impl AppTrait for App { let recv_session_ids = self.db.get_recv_session_ids()?; let send_session_ids = self.db.get_send_session_ids()?; - if recv_session_ids.is_empty() && send_session_ids.is_empty() { - println!("No sessions to resume."); - return Ok(()); - } - let mut tasks: Vec<(String, tokio::task::JoinHandle>)> = Vec::new(); // Process receiver sessions @@ -314,15 +312,17 @@ impl AppTrait for App { let self_clone = self.clone(); let recv_persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); match replay_receiver_event_log(&recv_persister) { - Ok((receiver_state, _)) => { - tasks.push(( - session_id.to_string(), - tokio::spawn(async move { - self_clone - .process_receiver_session(receiver_state, &recv_persister) - .await - }), - )); + Ok((receiver_state, history)) => { + if history.status() == ReceiverSessionStatus::Active { + tasks.push(( + session_id.to_string(), + tokio::spawn(async move { + self_clone + .process_receiver_session(receiver_state, &recv_persister) + .await + }), + )); + } } Err(e) => { if e.is_expired() { @@ -334,7 +334,7 @@ impl AppTrait for App { ); println!("Session {session_id} receiver failed to replay - {e}"); } - Self::close_failed_session(&recv_persister, &session_id, "receiver"); + self.close_failed_receiver_session(e, &recv_persister, &session_id); } } } @@ -343,15 +343,18 @@ impl AppTrait for App { for session_id in send_session_ids { let sender_persister = SenderPersister::from_id(self.db.clone(), session_id.clone()); match replay_sender_event_log(&sender_persister) { - Ok((sender_state, _)) => { - let self_clone = self.clone(); - tasks.push(( - session_id.clone().to_string(), - tokio::spawn(async move { - self_clone.process_sender_session(sender_state, &sender_persister).await - }), - )); - } + Ok((sender_state, history)) => + if history.status() == SenderSessionStatus::Active { + let self_clone = self.clone(); + tasks.push(( + session_id.clone().to_string(), + tokio::spawn(async move { + self_clone + .process_sender_session(sender_state, &sender_persister) + .await + }), + )); + }, Err(e) => { if e.is_expired() { println!("Session {session_id} sender expired."); @@ -359,11 +362,16 @@ impl AppTrait for App { tracing::error!("An error {:?} occurred while replaying Sender session", e); println!("Session {session_id} sender failed to replay - {e}"); } - Self::close_failed_session(&sender_persister, &session_id, "sender"); + self.close_failed_sender_session(e, &sender_persister, &session_id); } } } + if tasks.is_empty() { + println!("No sessions to resume."); + return Ok(()); + } + let mut interrupt = self.interrupt.clone(); tokio::select! { _ = async { @@ -402,7 +410,7 @@ impl AppTrait for App { session_id, role: Role::Sender, status: sender_state.clone(), - completed_at: None, + has_fallback_tx: true, error_message: None, }; send_rows.push(row); @@ -412,7 +420,7 @@ impl AppTrait for App { session_id, role: Role::Sender, status: SendSession::Closed(SenderSessionOutcome::Aborted), - completed_at: None, + has_fallback_tx: false, error_message: Some(e.to_string()), }; send_rows.push(row); @@ -423,12 +431,12 @@ impl AppTrait for App { self.db.get_recv_session_ids()?.into_iter().for_each(|session_id| { let persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); match replay_receiver_event_log(&persister) { - Ok((receiver_state, _)) => { + Ok((receiver_state, history)) => { let row = SessionHistoryRow { session_id, role: Role::Receiver, status: receiver_state.clone(), - completed_at: None, + has_fallback_tx: history.fallback_tx().is_some(), error_message: None, }; recv_rows.push(row); @@ -438,7 +446,7 @@ impl AppTrait for App { session_id, role: Role::Receiver, status: ReceiveSession::Closed(ReceiverSessionOutcome::Aborted), - completed_at: None, + has_fallback_tx: false, error_message: Some(e.to_string()), }; recv_rows.push(row); @@ -446,62 +454,6 @@ impl AppTrait for App { } }); - self.db.get_inactive_send_session_ids()?.into_iter().for_each( - |(session_id, completed_at)| { - let persister = SenderPersister::from_id(self.db.clone(), session_id.clone()); - match replay_sender_event_log(&persister) { - Ok((sender_state, _)) => { - let row = SessionHistoryRow { - session_id, - role: Role::Sender, - status: sender_state.clone(), - completed_at: Some(completed_at), - error_message: None, - }; - send_rows.push(row); - } - Err(e) => { - let row = SessionHistoryRow { - session_id, - role: Role::Sender, - status: SendSession::Closed(SenderSessionOutcome::Aborted), - completed_at: Some(completed_at), - error_message: Some(e.to_string()), - }; - send_rows.push(row); - } - } - }, - ); - - self.db.get_inactive_recv_session_ids()?.into_iter().for_each( - |(session_id, completed_at)| { - let persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); - match replay_receiver_event_log(&persister) { - Ok((receiver_state, _)) => { - let row = SessionHistoryRow { - session_id, - role: Role::Receiver, - status: receiver_state.clone(), - completed_at: Some(completed_at), - error_message: None, - }; - recv_rows.push(row); - } - Err(e) => { - let row = SessionHistoryRow { - session_id, - role: Role::Receiver, - status: ReceiveSession::Closed(ReceiverSessionOutcome::Aborted), - completed_at: Some(completed_at), - error_message: Some(e.to_string()), - }; - recv_rows.push(row); - } - } - }, - ); - // Print receiver and sender rows separately for row in send_rows { println!("{row}"); @@ -656,14 +608,65 @@ impl App { Ok(()) } - fn close_failed_session

(persister: &P, session_id: &SessionId, role: &str) - where - P: SessionPersister, - { - if let Err(close_err) = SessionPersister::close(persister) { - tracing::error!("Failed to close {} session {}: {:?}", role, session_id, close_err); - } else { - tracing::error!("Closed failed {} session: {}", role, session_id); + fn close_failed_receiver_session( + &self, + error: payjoin::error::ReplayError, + persister: &ReceiverPersister, + session_id: &SessionId, + ) { + if let Some(session) = error.into_session() { + match session.cancel().save(persister) { + Ok(Some(pending)) => { + println!( + "Session {session_id} receiver failed. \ + Broadcast the fallback transaction manually:\n{}", + serialize_hex(pending.fallback_tx()) + ); + if let Err(e) = pending.close().save(persister) { + tracing::error!("Failed to close receiver session {session_id}: {e:?}"); + } + return; + } + Ok(None) => { + tracing::debug!("Closed failed receiver session: {session_id}"); + return; + } + Err(e) => tracing::error!("Failed to cancel receiver session {session_id}: {e:?}"), + } + } + if let Err(e) = SessionPersister::close(persister) { + tracing::error!("Failed to close receiver session {session_id}: {e:?}"); + } + } + + fn close_failed_sender_session( + &self, + error: payjoin::error::ReplayError, + persister: &SenderPersister, + session_id: &SessionId, + ) { + if let Some(session) = error.into_session() { + match session.cancel().save(persister) { + Ok(Some(pending)) => { + println!( + "Session {session_id} sender failed. \ + Broadcast the fallback transaction manually:\n{}", + serialize_hex(pending.fallback_tx()) + ); + if let Err(e) = pending.close().save(persister) { + tracing::error!("Failed to close sender session {session_id}: {e:?}"); + } + return; + } + Ok(None) => { + tracing::debug!("Closed failed sender session: {session_id}"); + return; + } + Err(e) => tracing::error!("Failed to cancel sender session {session_id}: {e:?}"), + } + } + if let Err(e) = SessionPersister::close(persister) { + tracing::error!("Failed to close sender session {session_id}: {e:?}"); } } @@ -716,24 +719,46 @@ impl App { let mut session = sender.clone(); // Long poll until we get a response loop { - let (response, ctx) = - self.post_via_relay(|relay| session.create_poll_request(relay)).await?; - let res = session.process_response(&response.bytes().await?, ctx).save(persister); - match res { - Ok(OptionalTransitionOutcome::Progress(psbt)) => { - println!("Proposal received. Processing..."); - self.process_pj_response(psbt)?; + let relay = self.relay_manager.choose_relay()?; + let (req, ctx) = match session.create_poll_request(relay.as_str()) { + Ok(r) => r, + Err(e) if e.is_expired() => { + let pending = session.cancel().save(persister)?; + println!( + "Session expired. Broadcast the fallback transaction manually:\n{}", + serialize_hex(pending.fallback_tx()) + ); + pending.close().save(persister)?; return Ok(()); } - Ok(OptionalTransitionOutcome::Stasis(current_state)) => { - println!("No response yet."); - session = current_state; - continue; + Err(e) => return Err(e.into()), + }; + match self.post_request(req).await { + Ok(response) => { + let bytes = response.bytes().await?; + let res = session.process_response(&bytes, ctx).save(persister); + match res { + Ok(OptionalTransitionOutcome::Progress(psbt)) => { + println!("Proposal received. Processing..."); + self.process_pj_response(psbt)?; + return Ok(()); + } + Ok(OptionalTransitionOutcome::Stasis(current_state)) => { + println!("No response yet."); + session = current_state; + continue; + } + Err(re) => { + println!("{re}"); + tracing::debug!("{re:?}"); + return Err(anyhow!("Response error").context(re)); + } + } } - Err(re) => { - println!("{re}"); - tracing::debug!("{re:?}"); - return Err(anyhow!("Response error").context(re)); + Err(e) => { + tracing::debug!("Request to relay {relay} failed: {e:?}"); + self.relay_manager.add_failed_relay(relay); + continue; } } } diff --git a/payjoin-cli/src/db/mod.rs b/payjoin-cli/src/db/mod.rs index e44f51894..31e2e597e 100644 --- a/payjoin-cli/src/db/mod.rs +++ b/payjoin-cli/src/db/mod.rs @@ -42,16 +42,14 @@ impl Database { "CREATE TABLE IF NOT EXISTS send_sessions ( session_id INTEGER PRIMARY KEY AUTOINCREMENT, pj_uri TEXT NOT NULL, - receiver_pubkey BLOB NOT NULL, - completed_at INTEGER + receiver_pubkey BLOB NOT NULL )", [], )?; conn.execute( "CREATE TABLE IF NOT EXISTS receive_sessions ( - session_id INTEGER PRIMARY KEY AUTOINCREMENT, - completed_at INTEGER + session_id INTEGER PRIMARY KEY AUTOINCREMENT )", [], )?; diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 4d61cb7f2..f3fa2ecaf 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -1,10 +1,14 @@ use std::sync::Arc; use payjoin::persist::SessionPersister; -use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent; -use payjoin::send::v2::SessionEvent as SenderSessionEvent; +use payjoin::receive::v2::{ + SessionEvent as ReceiverSessionEvent, SessionOutcome as ReceiverSessionOutcome, +}; +use payjoin::send::v2::{ + SessionEvent as SenderSessionEvent, SessionOutcome as SenderSessionOutcome, +}; use payjoin::HpkePublicKey; -use rusqlite::params; +use rusqlite::{params, OptionalExtension}; use super::*; @@ -109,13 +113,24 @@ impl SessionPersister for SenderPersister { } fn close(&self) -> std::result::Result<(), Self::InternalStorageError> { - let conn = self.db.get_connection()?; - - conn.execute( - "UPDATE send_sessions SET completed_at = ?1 WHERE session_id = ?2", - params![now(), *self.session_id], - )?; - + let already_closed = { + let conn = self.db.get_connection()?; + let last_event: Option = conn + .query_row( + "SELECT event_data FROM send_session_events + WHERE session_id = ?1 ORDER BY id DESC LIMIT 1", + params![*self.session_id], + |row| row.get(0), + ) + .optional()?; + last_event + .and_then(|data| serde_json::from_str::(&data).ok()) + .map(|e| matches!(e, SenderSessionEvent::Closed(_))) + .unwrap_or(false) + }; + if !already_closed { + self.save_event(SenderSessionEvent::Closed(SenderSessionOutcome::Aborted))?; + } Ok(()) } } @@ -192,13 +207,24 @@ impl SessionPersister for ReceiverPersister { } fn close(&self) -> std::result::Result<(), Self::InternalStorageError> { - let conn = self.db.get_connection()?; - - conn.execute( - "UPDATE receive_sessions SET completed_at = ?1 WHERE session_id = ?2", - params![now(), *self.session_id], - )?; - + let already_closed = { + let conn = self.db.get_connection()?; + let last_event: Option = conn + .query_row( + "SELECT event_data FROM receive_session_events + WHERE session_id = ?1 ORDER BY id DESC LIMIT 1", + params![*self.session_id], + |row| row.get(0), + ) + .optional()?; + last_event + .and_then(|data| serde_json::from_str::(&data).ok()) + .map(|e| matches!(e, ReceiverSessionEvent::Closed(_))) + .unwrap_or(false) + }; + if !already_closed { + self.save_event(ReceiverSessionEvent::Closed(ReceiverSessionOutcome::Aborted))?; + } Ok(()) } } @@ -206,8 +232,7 @@ impl SessionPersister for ReceiverPersister { impl Database { pub(crate) fn get_recv_session_ids(&self) -> Result> { let conn = self.get_connection()?; - let mut stmt = - conn.prepare("SELECT session_id FROM receive_sessions WHERE completed_at IS NULL")?; + let mut stmt = conn.prepare("SELECT session_id FROM receive_sessions")?; let session_rows = stmt.query_map([], |row| { let session_id: i64 = row.get(0)?; @@ -225,8 +250,7 @@ impl Database { pub(crate) fn get_send_session_ids(&self) -> Result> { let conn = self.get_connection()?; - let mut stmt = - conn.prepare("SELECT session_id FROM send_sessions WHERE completed_at IS NULL")?; + let mut stmt = conn.prepare("SELECT session_id FROM send_sessions")?; let session_rows = stmt.query_map([], |row| { let session_id: i64 = row.get(0)?; @@ -253,44 +277,6 @@ impl Database { Ok(HpkePublicKey::from_compressed_bytes(&receiver_pubkey).expect("Valid receiver pubkey")) } - pub(crate) fn get_inactive_send_session_ids(&self) -> Result> { - let conn = self.get_connection()?; - let mut stmt = conn.prepare( - "SELECT session_id, completed_at FROM send_sessions WHERE completed_at IS NOT NULL", - )?; - let session_rows = stmt.query_map([], |row| { - let session_id: i64 = row.get(0)?; - let completed_at: u64 = row.get(1)?; - Ok((SessionId(session_id), completed_at)) - })?; - - let mut session_ids = Vec::new(); - for session_row in session_rows { - let (session_id, completed_at) = session_row?; - session_ids.push((session_id, completed_at)); - } - Ok(session_ids) - } - - pub(crate) fn get_inactive_recv_session_ids(&self) -> Result> { - let conn = self.get_connection()?; - let mut stmt = conn.prepare( - "SELECT session_id, completed_at FROM receive_sessions WHERE completed_at IS NOT NULL", - )?; - let session_rows = stmt.query_map([], |row| { - let session_id: i64 = row.get(0)?; - let completed_at: u64 = row.get(1)?; - Ok((SessionId(session_id), completed_at)) - })?; - - let mut session_ids = Vec::new(); - for session_row in session_rows { - let (session_id, completed_at) = session_row?; - session_ids.push((session_id, completed_at)); - } - Ok(session_ids) - } - /// Look up a sender session by ID regardless of active/inactive state. pub(crate) fn send_session_exists(&self, session_id: &SessionId) -> Result { let conn = self.get_connection()?; diff --git a/payjoin/src/core/error.rs b/payjoin/src/core/error.rs index e8f546a05..766c4687f 100644 --- a/payjoin/src/core/error.rs +++ b/payjoin/src/core/error.rs @@ -51,7 +51,7 @@ impl std::fmt::Display Some(session) => write!(f, "Invalid event ({event:?}) for session ({session:?})",), None => write!(f, "Invalid first event ({event:?}) for session",), }, - Expired(time) => write!(f, "Session expired at {time:?}"), + Expired(time, _) => write!(f, "Session expired at {time:?}"), PersistenceFailure(e) => write!(f, "Persistence failure: {e}"), } } @@ -73,7 +73,21 @@ impl From ReplayError { /// Returns `true` if the event log could not be replayed because the /// session has expired. - pub fn is_expired(&self) -> bool { matches!(self.0, InternalReplayError::Expired(_)) } + pub fn is_expired(&self) -> bool { matches!(self.0, InternalReplayError::Expired(..)) } + + /// Returns the partial session state when the event log was replayable + /// but the session could not be returned normally. This covers + /// `InvalidEvent` (a bad event mid-replay) and `Expired` (session + /// replayed successfully but is expired). Returns `None` when no + /// session could be reconstructed (load failure, no events, bad first + /// event). + pub fn into_session(self) -> Option { + match self.0 { + InternalReplayError::InvalidEvent(_, Some(session)) => Some(*session), + InternalReplayError::Expired(_, session) => Some(*session), + _ => None, + } + } } #[cfg(feature = "v2")] @@ -83,8 +97,10 @@ pub(crate) enum InternalReplayError { NoEvents, /// Invalid initial event InvalidEvent(Box, Option>), - /// Session is expired - Expired(crate::time::Time), + /// Session is expired. The partial session state is included so the + /// caller can gracefully fail (e.g. surface a fallback transaction) + /// before closing. + Expired(crate::time::Time, Box), /// Application storage error PersistenceFailure(ImplementationError), } @@ -96,7 +112,8 @@ mod tests { #[test] fn replay_error_is_expired() { - let expired: ReplayError<(), ()> = ReplayError(InternalReplayError::Expired(Time::now())); + let expired: ReplayError<(), ()> = + ReplayError(InternalReplayError::Expired(Time::now(), Box::new(()))); assert!(expired.is_expired()); let other: ReplayError<(), ()> = ReplayError(InternalReplayError::NoEvents); diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 81b872b13..10f54489c 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -264,6 +264,20 @@ impl ReceiveSession { .into()), } } + + /// Cancel the session, transitioning to [`PendingFallback`] if a fallback + /// transaction exists, or closing the session as `Aborted` otherwise. + /// This provides type-erased cancellation over the [`ReceiveSession`] enum, + /// dispatching to the appropriate per-state `cancel()` implementation. + pub fn cancel(self) -> MaybeTerminalTransition> { + match try_pending_fallback(self) { + Ok(ReceiveSession::PendingFallback(pending)) => + MaybeTerminalTransition::advance(SessionEvent::Cancelled, pending), + Ok(_) => unreachable!("try_pending_fallback only returns PendingFallback"), + Err(_) => + MaybeTerminalTransition::terminate(SessionEvent::Closed(SessionOutcome::Aborted)), + } + } } fn pending_fallback_from(r: Receiver) -> ReceiveSession { diff --git a/payjoin/src/core/receive/v2/session.rs b/payjoin/src/core/receive/v2/session.rs index b79431787..4d4f9e82e 100644 --- a/payjoin/src/core/receive/v2/session.rs +++ b/payjoin/src/core/receive/v2/session.rs @@ -26,17 +26,17 @@ fn replay_events( fn construct_history( session_events: Vec, - receiver: &ReceiveSession, -) -> Result> { + receiver: ReceiveSession, +) -> Result<(ReceiveSession, SessionHistory), ReplayError> { let history = SessionHistory::new(session_events); // Closed sessions terminated before expiration; do not surface an expired error for them. - if !matches!(receiver, ReceiveSession::Closed(_)) { + if !matches!(&receiver, ReceiveSession::Closed(_)) { let ctx = history.session_context(); if ctx.expiration.elapsed() { - return Err(InternalReplayError::Expired(ctx.expiration).into()); + return Err(InternalReplayError::Expired(ctx.expiration, Box::new(receiver)).into()); } } - Ok(history) + Ok((receiver, history)) } /// Replay a receiver event log to get the receiver in its current state [ReceiveSession] @@ -53,17 +53,9 @@ where .load() .map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?; - let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) { - Ok(r) => r, - Err(e) => { - persister.close().map_err(|ce| { - InternalReplayError::PersistenceFailure(ImplementationError::new(ce)) - })?; - return Err(e); - } - }; + let (receiver, session_events) = replay_events(logs.map(|e| e.into()))?; - let history = construct_history(session_events, &receiver)?; + let (receiver, history) = construct_history(session_events, receiver)?; Ok((receiver, history)) } @@ -81,17 +73,9 @@ where .await .map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?; - let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) { - Ok(r) => r, - Err(e) => { - persister.close().await.map_err(|ce| { - InternalReplayError::PersistenceFailure(ImplementationError::new(ce)) - })?; - return Err(e); - } - }; + let (receiver, session_events) = replay_events(logs.map(|e| e.into()))?; - let history = construct_history(session_events, &receiver)?; + let (receiver, history) = construct_history(session_events, receiver)?; Ok((receiver, history)) } @@ -394,18 +378,25 @@ mod tests { .save_event(SessionEvent::Created(session_context.clone())) .expect("in memory persister save should not fail"); let err = replay_event_log(&persister).expect_err("session should be expired"); - let expected_err: ReplayError = - InternalReplayError::Expired(expiration).into(); + let expected_err: ReplayError = InternalReplayError::Expired( + expiration, + Box::new(ReceiveSession::new(session_context)), + ) + .into(); assert_eq!(err.to_string(), expected_err.to_string()); + let session_context = SessionContext { expiration, ..SHARED_CONTEXT.clone() }; let persister = InMemoryAsyncPersister::::default(); persister - .save_event(SessionEvent::Created(session_context)) + .save_event(SessionEvent::Created(session_context.clone())) .await .expect("in memory async persister save should not fail"); let err = replay_event_log_async(&persister).await.expect_err("session should be expired"); - let expected_err: ReplayError = - InternalReplayError::Expired(expiration).into(); + let expected_err: ReplayError = InternalReplayError::Expired( + expiration, + Box::new(ReceiveSession::new(session_context)), + ) + .into(); assert_eq!(err.to_string(), expected_err.to_string()); } @@ -483,7 +474,6 @@ mod tests { ) .into(); assert_eq!(err.to_string(), expected_err.to_string()); - assert!(persister.inner.lock().expect("lock should not be poisoned").is_closed); let persister = InMemoryAsyncPersister::::default(); persister @@ -500,7 +490,6 @@ mod tests { ) .into(); assert_eq!(err.to_string(), expected_err.to_string()); - assert!(persister.inner.lock().await.is_closed); } #[tokio::test] diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index e2dda4a11..938fdc3c8 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -46,8 +46,8 @@ use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; use crate::persist::{ - MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition, - TerminalTransition, + MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, MaybeTerminalTransition, + NextStateTransition, TerminalTransition, }; use crate::uri::v2::PjParam; use crate::uri::ShortId; @@ -327,6 +327,27 @@ impl SendSession { .into()), } } + + /// Cancel the session, transitioning to [`PendingFallback`] if the sender + /// has an original PSBT, or closing the session as `Aborted` if already + /// closed. This provides type-erased cancellation over the [`SendSession`] + /// enum. + pub fn cancel(self) -> MaybeTerminalTransition> { + match self { + SendSession::WithReplyKey(sender) => { + let (_, pending) = sender.cancel().deconstruct(); + MaybeTerminalTransition::advance(SessionEvent::Cancelled(), pending) + } + SendSession::PollingForProposal(sender) => { + let (_, pending) = sender.cancel().deconstruct(); + MaybeTerminalTransition::advance(SessionEvent::Cancelled(), pending) + } + SendSession::PendingFallback(sender) => + MaybeTerminalTransition::advance(SessionEvent::Cancelled(), sender), + SendSession::Closed(_) => + MaybeTerminalTransition::terminate(SessionEvent::Closed(SessionOutcome::Aborted)), + } + } } /// A payjoin V2 sender, allowing the construction of a payjoin V2 request diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index fc6955951..8b8e8f28a 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -23,17 +23,19 @@ fn replay_events( fn construct_history( session_events: Vec, - sender: &SendSession, -) -> Result> { + sender: SendSession, +) -> Result<(SendSession, SessionHistory), ReplayError> { let history = SessionHistory::new(session_events); // Closed sessions terminated before expiration; do not surface an expired error for them. - if !matches!(sender, SendSession::Closed(_)) { + if !matches!(&sender, SendSession::Closed(_)) { let pj_param = history.pj_param(); if pj_param.expiration().elapsed() { - return Err(InternalReplayError::Expired(pj_param.expiration()).into()); + return Err( + InternalReplayError::Expired(pj_param.expiration(), Box::new(sender)).into() + ); } } - Ok(history) + Ok((sender, history)) } /// Replay a sender event log to get the sender in its current state [SendSession] @@ -50,17 +52,9 @@ where .load() .map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?; - let (sender, session_events) = match replay_events(logs.map(|e| e.into())) { - Ok(r) => r, - Err(e) => { - persister.close().map_err(|ce| { - InternalReplayError::PersistenceFailure(ImplementationError::new(ce)) - })?; - return Err(e); - } - }; + let (sender, session_events) = replay_events(logs.map(|e| e.into()))?; - let history = construct_history(session_events, &sender)?; + let (sender, history) = construct_history(session_events, sender)?; Ok((sender, history)) } @@ -78,17 +72,9 @@ where .await .map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?; - let (sender, session_events) = match replay_events(logs.map(|e| e.into())) { - Ok(r) => r, - Err(e) => { - persister.close().await.map_err(|ce| { - InternalReplayError::PersistenceFailure(ImplementationError::new(ce)) - })?; - return Err(e); - } - }; + let (sender, session_events) = replay_events(logs.map(|e| e.into()))?; - let history = construct_history(session_events, &sender)?; + let (sender, history) = construct_history(session_events, sender)?; Ok((sender, history)) } @@ -513,26 +499,22 @@ mod tests { persister .save_event(SessionEvent::PostedOriginalPsbt()) .expect("in memory persister save should not fail"); - assert!(!persister.inner.lock().expect("session read should succeed").is_closed); let err = replay_event_log(&persister).expect_err("session replay should be fail"); let expected_err: ReplayError = InternalReplayError::InvalidEvent(Box::new(SessionEvent::PostedOriginalPsbt()), None) .into(); assert_eq!(err.to_string(), expected_err.to_string()); - assert!(persister.inner.lock().expect("lock should not be poisoned").is_closed); let persister = InMemoryAsyncPersister::::default(); persister .save_event(SessionEvent::PostedOriginalPsbt()) .await .expect("in memory async persister save should not fail"); - assert!(!persister.inner.lock().await.is_closed); let err = replay_event_log_async(&persister).await.expect_err("session replay should be fail"); let expected_err: ReplayError = InternalReplayError::InvalidEvent(Box::new(SessionEvent::PostedOriginalPsbt()), None) .into(); assert_eq!(err.to_string(), expected_err.to_string()); - assert!(persister.inner.lock().await.is_closed); } }