Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 170 additions & 145 deletions payjoin-cli/src/app/v2/mod.rs

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions payjoin-cli/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ impl Database {
"CREATE TABLE IF NOT EXISTS send_sessions (
session_id INTEGER PRIMARY KEY AUTOINCREMENT,
pj_uri TEXT NOT NULL,
receiver_pubkey BLOB NOT NULL,
completed_at INTEGER
receiver_pubkey BLOB NOT NULL
)",
[],
)?;

conn.execute(
"CREATE TABLE IF NOT EXISTS receive_sessions (
session_id INTEGER PRIMARY KEY AUTOINCREMENT,
completed_at INTEGER
session_id INTEGER PRIMARY KEY AUTOINCREMENT
)",
[],
)?;
Expand Down
104 changes: 45 additions & 59 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::sync::Arc;

use payjoin::persist::SessionPersister;
use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent;
use payjoin::send::v2::SessionEvent as SenderSessionEvent;
use payjoin::receive::v2::{
SessionEvent as ReceiverSessionEvent, SessionOutcome as ReceiverSessionOutcome,
};
use payjoin::send::v2::{
SessionEvent as SenderSessionEvent, SessionOutcome as SenderSessionOutcome,
};
use payjoin::HpkePublicKey;
use rusqlite::params;
use rusqlite::{params, OptionalExtension};

use super::*;

Expand Down Expand Up @@ -109,13 +113,24 @@ impl SessionPersister for SenderPersister {
}

fn close(&self) -> std::result::Result<(), Self::InternalStorageError> {
let conn = self.db.get_connection()?;

conn.execute(
"UPDATE send_sessions SET completed_at = ?1 WHERE session_id = ?2",
params![now(), *self.session_id],
)?;

let already_closed = {
let conn = self.db.get_connection()?;
let last_event: Option<String> = conn
.query_row(
"SELECT event_data FROM send_session_events
WHERE session_id = ?1 ORDER BY id DESC LIMIT 1",
params![*self.session_id],
|row| row.get(0),
)
.optional()?;
last_event
.and_then(|data| serde_json::from_str::<SenderSessionEvent>(&data).ok())
.map(|e| matches!(e, SenderSessionEvent::Closed(_)))
.unwrap_or(false)
};
if !already_closed {
self.save_event(SenderSessionEvent::Closed(SenderSessionOutcome::Aborted))?;
}
Ok(())
}
}
Expand Down Expand Up @@ -192,22 +207,32 @@ impl SessionPersister for ReceiverPersister {
}

fn close(&self) -> std::result::Result<(), Self::InternalStorageError> {
let conn = self.db.get_connection()?;

conn.execute(
"UPDATE receive_sessions SET completed_at = ?1 WHERE session_id = ?2",
params![now(), *self.session_id],
)?;

let already_closed = {
let conn = self.db.get_connection()?;
let last_event: Option<String> = conn
.query_row(
"SELECT event_data FROM receive_session_events
WHERE session_id = ?1 ORDER BY id DESC LIMIT 1",
params![*self.session_id],
|row| row.get(0),
)
.optional()?;
last_event
.and_then(|data| serde_json::from_str::<ReceiverSessionEvent>(&data).ok())
.map(|e| matches!(e, ReceiverSessionEvent::Closed(_)))
.unwrap_or(false)
};
if !already_closed {
self.save_event(ReceiverSessionEvent::Closed(ReceiverSessionOutcome::Aborted))?;
}
Ok(())
}
}

impl Database {
pub(crate) fn get_recv_session_ids(&self) -> Result<Vec<SessionId>> {
let conn = self.get_connection()?;
let mut stmt =
conn.prepare("SELECT session_id FROM receive_sessions WHERE completed_at IS NULL")?;
let mut stmt = conn.prepare("SELECT session_id FROM receive_sessions")?;

let session_rows = stmt.query_map([], |row| {
let session_id: i64 = row.get(0)?;
Expand All @@ -225,8 +250,7 @@ impl Database {

pub(crate) fn get_send_session_ids(&self) -> Result<Vec<SessionId>> {
let conn = self.get_connection()?;
let mut stmt =
conn.prepare("SELECT session_id FROM send_sessions WHERE completed_at IS NULL")?;
let mut stmt = conn.prepare("SELECT session_id FROM send_sessions")?;

let session_rows = stmt.query_map([], |row| {
let session_id: i64 = row.get(0)?;
Expand All @@ -253,44 +277,6 @@ impl Database {
Ok(HpkePublicKey::from_compressed_bytes(&receiver_pubkey).expect("Valid receiver pubkey"))
}

pub(crate) fn get_inactive_send_session_ids(&self) -> Result<Vec<(SessionId, u64)>> {
let conn = self.get_connection()?;
let mut stmt = conn.prepare(
"SELECT session_id, completed_at FROM send_sessions WHERE completed_at IS NOT NULL",
)?;
let session_rows = stmt.query_map([], |row| {
let session_id: i64 = row.get(0)?;
let completed_at: u64 = row.get(1)?;
Ok((SessionId(session_id), completed_at))
})?;

let mut session_ids = Vec::new();
for session_row in session_rows {
let (session_id, completed_at) = session_row?;
session_ids.push((session_id, completed_at));
}
Ok(session_ids)
}

pub(crate) fn get_inactive_recv_session_ids(&self) -> Result<Vec<(SessionId, u64)>> {
let conn = self.get_connection()?;
let mut stmt = conn.prepare(
"SELECT session_id, completed_at FROM receive_sessions WHERE completed_at IS NOT NULL",
)?;
let session_rows = stmt.query_map([], |row| {
let session_id: i64 = row.get(0)?;
let completed_at: u64 = row.get(1)?;
Ok((SessionId(session_id), completed_at))
})?;

let mut session_ids = Vec::new();
for session_row in session_rows {
let (session_id, completed_at) = session_row?;
session_ids.push((session_id, completed_at));
}
Ok(session_ids)
}

/// Look up a sender session by ID regardless of active/inactive state.
pub(crate) fn send_session_exists(&self, session_id: &SessionId) -> Result<bool> {
let conn = self.get_connection()?;
Expand Down
27 changes: 22 additions & 5 deletions payjoin/src/core/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<SessionState: Debug, SessionEvent: Debug> std::fmt::Display
Some(session) => write!(f, "Invalid event ({event:?}) for session ({session:?})",),
None => write!(f, "Invalid first event ({event:?}) for session",),
},
Expired(time) => write!(f, "Session expired at {time:?}"),
Expired(time, _) => write!(f, "Session expired at {time:?}"),
PersistenceFailure(e) => write!(f, "Persistence failure: {e}"),
}
}
Expand All @@ -73,7 +73,21 @@ impl<SessionState: Debug, SessionEvent: Debug> From<InternalReplayError<SessionS
impl<SessionState, SessionEvent> ReplayError<SessionState, SessionEvent> {
/// Returns `true` if the event log could not be replayed because the
/// session has expired.
pub fn is_expired(&self) -> bool { matches!(self.0, InternalReplayError::Expired(_)) }
pub fn is_expired(&self) -> bool { matches!(self.0, InternalReplayError::Expired(..)) }

/// Returns the partial session state when the event log was replayable
/// but the session could not be returned normally. This covers
/// `InvalidEvent` (a bad event mid-replay) and `Expired` (session
/// replayed successfully but is expired). Returns `None` when no
/// session could be reconstructed (load failure, no events, bad first
/// event).
pub fn into_session(self) -> Option<SessionState> {
match self.0 {
InternalReplayError::InvalidEvent(_, Some(session)) => Some(*session),
InternalReplayError::Expired(_, session) => Some(*session),
_ => None,
}
}
}

#[cfg(feature = "v2")]
Expand All @@ -83,8 +97,10 @@ pub(crate) enum InternalReplayError<SessionState, SessionEvent> {
NoEvents,
/// Invalid initial event
InvalidEvent(Box<SessionEvent>, Option<Box<SessionState>>),
/// Session is expired
Expired(crate::time::Time),
/// Session is expired. The partial session state is included so the
/// caller can gracefully fail (e.g. surface a fallback transaction)
/// before closing.
Expired(crate::time::Time, Box<SessionState>),
/// Application storage error
PersistenceFailure(ImplementationError),
}
Expand All @@ -96,7 +112,8 @@ mod tests {

#[test]
fn replay_error_is_expired() {
let expired: ReplayError<(), ()> = ReplayError(InternalReplayError::Expired(Time::now()));
let expired: ReplayError<(), ()> =
ReplayError(InternalReplayError::Expired(Time::now(), Box::new(())));
assert!(expired.is_expired());

let other: ReplayError<(), ()> = ReplayError(InternalReplayError::NoEvents);
Expand Down
14 changes: 14 additions & 0 deletions payjoin/src/core/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,20 @@ impl ReceiveSession {
.into()),
}
}

/// Cancel the session, transitioning to [`PendingFallback`] if a fallback
/// transaction exists, or closing the session as `Aborted` otherwise.
/// This provides type-erased cancellation over the [`ReceiveSession`] enum,
/// dispatching to the appropriate per-state `cancel()` implementation.
pub fn cancel(self) -> MaybeTerminalTransition<SessionEvent, Receiver<PendingFallback>> {
match try_pending_fallback(self) {
Ok(ReceiveSession::PendingFallback(pending)) =>
MaybeTerminalTransition::advance(SessionEvent::Cancelled, pending),
Ok(_) => unreachable!("try_pending_fallback only returns PendingFallback"),
Err(_) =>
MaybeTerminalTransition::terminate(SessionEvent::Closed(SessionOutcome::Aborted)),
}
}
}

fn pending_fallback_from<S: HasFallbackTx>(r: Receiver<S>) -> ReceiveSession {
Expand Down
53 changes: 21 additions & 32 deletions payjoin/src/core/receive/v2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ fn replay_events(

fn construct_history(
session_events: Vec<SessionEvent>,
receiver: &ReceiveSession,
) -> Result<SessionHistory, ReplayError<ReceiveSession, SessionEvent>> {
receiver: ReceiveSession,
) -> Result<(ReceiveSession, SessionHistory), ReplayError<ReceiveSession, SessionEvent>> {
let history = SessionHistory::new(session_events);
// Closed sessions terminated before expiration; do not surface an expired error for them.
if !matches!(receiver, ReceiveSession::Closed(_)) {
if !matches!(&receiver, ReceiveSession::Closed(_)) {
let ctx = history.session_context();
if ctx.expiration.elapsed() {
return Err(InternalReplayError::Expired(ctx.expiration).into());
return Err(InternalReplayError::Expired(ctx.expiration, Box::new(receiver)).into());
}
}
Ok(history)
Ok((receiver, history))
}

/// Replay a receiver event log to get the receiver in its current state [ReceiveSession]
Expand All @@ -53,17 +53,9 @@ where
.load()
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};
let (receiver, session_events) = replay_events(logs.map(|e| e.into()))?;

let history = construct_history(session_events, &receiver)?;
let (receiver, history) = construct_history(session_events, receiver)?;
Ok((receiver, history))
}

Expand All @@ -81,17 +73,9 @@ where
.await
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

let (receiver, session_events) = match replay_events(logs.map(|e| e.into())) {
Ok(r) => r,
Err(e) => {
persister.close().await.map_err(|ce| {
InternalReplayError::PersistenceFailure(ImplementationError::new(ce))
})?;
return Err(e);
}
};
let (receiver, session_events) = replay_events(logs.map(|e| e.into()))?;

let history = construct_history(session_events, &receiver)?;
let (receiver, history) = construct_history(session_events, receiver)?;
Ok((receiver, history))
}

Expand Down Expand Up @@ -394,18 +378,25 @@ mod tests {
.save_event(SessionEvent::Created(session_context.clone()))
.expect("in memory persister save should not fail");
let err = replay_event_log(&persister).expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
let expected_err: ReplayError<ReceiveSession, SessionEvent> = InternalReplayError::Expired(
expiration,
Box::new(ReceiveSession::new(session_context)),
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());

let session_context = SessionContext { expiration, ..SHARED_CONTEXT.clone() };
let persister = InMemoryAsyncPersister::<SessionEvent>::default();
persister
.save_event(SessionEvent::Created(session_context))
.save_event(SessionEvent::Created(session_context.clone()))
.await
.expect("in memory async persister save should not fail");
let err = replay_event_log_async(&persister).await.expect_err("session should be expired");
let expected_err: ReplayError<ReceiveSession, SessionEvent> =
InternalReplayError::Expired(expiration).into();
let expected_err: ReplayError<ReceiveSession, SessionEvent> = InternalReplayError::Expired(
expiration,
Box::new(ReceiveSession::new(session_context)),
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());
}

Expand Down Expand Up @@ -483,7 +474,6 @@ mod tests {
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());
assert!(persister.inner.lock().expect("lock should not be poisoned").is_closed);

let persister = InMemoryAsyncPersister::<SessionEvent>::default();
persister
Expand All @@ -500,7 +490,6 @@ mod tests {
)
.into();
assert_eq!(err.to_string(), expected_err.to_string());
assert!(persister.inner.lock().await.is_closed);
}

#[tokio::test]
Expand Down
25 changes: 23 additions & 2 deletions payjoin/src/core/send/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ use crate::error::{InternalReplayError, ReplayError};
use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey};
use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res};
use crate::persist::{
MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition,
TerminalTransition,
MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, MaybeTerminalTransition,
NextStateTransition, TerminalTransition,
};
use crate::uri::v2::PjParam;
use crate::uri::ShortId;
Expand Down Expand Up @@ -327,6 +327,27 @@ impl SendSession {
.into()),
}
}

/// Cancel the session, transitioning to [`PendingFallback`] if the sender
/// has an original PSBT, or closing the session as `Aborted` if already
/// closed. This provides type-erased cancellation over the [`SendSession`]
/// enum.
pub fn cancel(self) -> MaybeTerminalTransition<SessionEvent, Sender<PendingFallback>> {
match self {
SendSession::WithReplyKey(sender) => {
let (_, pending) = sender.cancel().deconstruct();
MaybeTerminalTransition::advance(SessionEvent::Cancelled(), pending)
}
SendSession::PollingForProposal(sender) => {
let (_, pending) = sender.cancel().deconstruct();
MaybeTerminalTransition::advance(SessionEvent::Cancelled(), pending)
}
SendSession::PendingFallback(sender) =>
MaybeTerminalTransition::advance(SessionEvent::Cancelled(), sender),
SendSession::Closed(_) =>
MaybeTerminalTransition::terminate(SessionEvent::Closed(SessionOutcome::Aborted)),
}
}
}

/// A payjoin V2 sender, allowing the construction of a payjoin V2 request
Expand Down
Loading
Loading