diff --git a/Cargo.lock b/Cargo.lock index ac7c1ea5..85d0ad3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1689,6 +1689,15 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "futures-option" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01141bf1c1a803403a2e7ae6a7eb20506c6bd2ebd209fc2424d05572a78af5f5" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -3663,6 +3672,7 @@ dependencies = [ "fallible-iterator 0.3.0", "futures", "futures-core", + "futures-option", "hmac", "hyper", "hyper-rustls 0.24.1 (git+https://github.com/rustls/hyper-rustls.git?rev=163b3f5)", diff --git a/sqld/Cargo.toml b/sqld/Cargo.toml index 8cb7f8ea..a06b491e 100644 --- a/sqld/Cargo.toml +++ b/sqld/Cargo.toml @@ -71,6 +71,7 @@ rustls-pemfile = "1.0.3" rustls = "0.21.7" async-stream = "0.3.5" libsql = { git = "https://github.com/tursodatabase/libsql.git", rev = "bea8863", optional = true } +futures-option = "0.2.0" [dev-dependencies] proptest = "1.0.0" diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index 065c95a2..d85bbeb0 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -4,7 +4,7 @@ package proxy; message Queries { repeated Query queries = 1; // Uuid - string clientId = 2; + string client_id = 2; } message Query { @@ -34,10 +34,10 @@ message QueryResult { message Error { enum ErrorCode { - SQLError = 0; - TxBusy = 1; - TxTimeout = 2; - Internal = 3; + SQL_ERROR = 0; + TX_BUSY = 1; + TX_TIMEOUT = 2; + INTERNAL = 3; } ErrorCode code = 1; @@ -71,7 +71,7 @@ message Description { message Value { /// bincode encoded Value - bytes data = 1; + bytes data = 1; } message Row { @@ -84,18 +84,19 @@ message Column { } message DisconnectMessage { - string clientId = 1; + string client_id = 1; } message Ack { } +enum State { + INIT = 0; + INVALID = 1; + TXN = 2; +} + message ExecuteResults { repeated QueryResult results = 1; - enum State { - Init = 0; - Invalid = 1; - Txn = 2; - } /// State after executing the queries State state = 2; /// Primary frame_no after executing the request. @@ -110,7 +111,6 @@ message Step { optional Cond cond = 1; Query query = 2; } - message Cond { oneof cond { OkCond ok = 1; @@ -150,7 +150,113 @@ message ProgramReq { Program pgm = 2; } +/// Streaming exec request +message ExecReq { + /// id of the request. The response will contain this id. + uint32 request_id = 1; + oneof request { + StreamProgramReq execute = 2; + StreamDescribeReq describe = 3; + } +} + +/// Describe request for the streaming protocol +message StreamProgramReq { + Program pgm = 1; +} + +/// descibre request for the streaming protocol +message StreamDescribeReq { + string stmt = 1; +} + +/// Response message for the streaming proto + +/// Request response types +message Init { } +message BeginStep { } +message FinishStep { + uint64 affected_row_count = 1; + optional int64 last_insert_rowid = 2; +} +message StepError { + Error error = 1; +} +message ColsDescription { + repeated Column columns = 1; +} +message RowValue { + oneof value { + string text = 1; + int64 integer = 2; + double real = 3; + bytes blob = 4; + // null if present + bool null = 5; + } +} +message BeginRows { } +message BeginRow { } +message AddRowValue { + RowValue val = 1; +} +message FinishRow { } +message FinishRows { } +message Finish { + optional uint64 last_frame_no = 1; + State state = 2; +} + +/// Stream execx dexcribe response messages +message DescribeParam { + optional string name = 1; +} + +message DescribeCol { + string name = 1; + optional string decltype = 2; +} + +message DescribeResp { + repeated DescribeParam params = 1; + repeated DescribeCol cols = 2; + bool is_explain = 3; + bool is_readonly = 4; +} + +message RespStep { + oneof step { + Init init = 1; + BeginStep begin_step = 2; + FinishStep finish_step = 3; + StepError step_error = 4; + ColsDescription cols_description = 5; + BeginRows begin_rows = 6; + BeginRow begin_row = 7; + AddRowValue add_row_value = 8; + FinishRow finish_row = 9; + FinishRows finish_rows = 10; + Finish finish = 11; + } +} + +message ProgramResp { + repeated RespStep steps = 1; +} + +message ExecResp { + uint32 request_id = 1; + oneof response { + ProgramResp program_resp = 2; + DescribeResp describe_resp = 3; + Error error = 4; + } +} + service Proxy { + rpc StreamExec(stream ExecReq) returns (stream ExecResp) {} + + // Deprecated: rpc Execute(ProgramReq) returns (ExecuteResults) {} rpc Describe(DescribeRequest) returns (DescribeResult) {} rpc Disconnect(DisconnectMessage) returns (Ack) {} diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index e70e66d3..392e9ee0 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use parking_lot::{Mutex, RwLock}; -use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus}; +use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState}; use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; @@ -13,14 +13,14 @@ use crate::auth::{Authenticated, Authorized, Permission}; use crate::error::Error; use crate::libsql_bindings::wal_hook::WalHook; use crate::query::Query; -use crate::query_analysis::{State, StmtKind}; +use crate::query_analysis::{StmtKind, TxnStatus}; use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::stats::Stats; use crate::Result; use super::config::DatabaseConfigStore; -use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse, DescribeResult}; +use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse}; use super::{MakeConnection, Program, Step, TXN_TIMEOUT}; pub struct MakeLibSqlConn { @@ -144,7 +144,6 @@ where } } -#[derive(Clone)] pub struct LibSqlConnection { inner: Arc>>, } @@ -160,6 +159,14 @@ impl std::fmt::Debug for LibSqlConnection { } } +impl Clone for LibSqlConnection { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + pub fn open_conn( path: &Path, wal_methods: &'static WalMethodsHook, @@ -219,6 +226,38 @@ where inner: Arc::new(Mutex::new(conn)), }) } + + pub fn txn_status(&self) -> crate::Result { + Ok(self + .inner + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into()) + } +} + +#[cfg(test)] +impl LibSqlConnection { + pub fn new_test(path: &Path) -> Self { + let (_snd, rcv) = watch::channel(None); + let conn = Connection::new( + path, + Arc::new([]), + &crate::libsql_bindings::wal_hook::TRANSPARENT_METHODS, + (), + Default::default(), + DatabaseConfigStore::new_test().into(), + QueryBuilderConfig::default(), + rcv, + Default::default(), + ) + .unwrap(); + + Self { + inner: Arc::new(Mutex::new(conn)), + } + } } struct Connection { @@ -351,6 +390,16 @@ unsafe extern "C" fn busy_handler(state: *mut c_void, _retries: c_in }) } +impl From for TxnStatus { + fn from(value: TransactionState) -> Self { + use TransactionState as Tx; + match value { + Tx::None => TxnStatus::Init, + Tx::Read | Tx::Write => TxnStatus::Txn, + _ => unreachable!(), + } + } +} impl Connection { fn new( path: &Path, @@ -405,7 +454,7 @@ impl Connection { this: Arc>, pgm: Program, mut builder: B, - ) -> Result<(B, State)> { + ) -> Result { use rusqlite::TransactionState as Tx; let state = this.lock().state.clone(); @@ -469,20 +518,18 @@ impl Connection { results.push(res); } - builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?; + let status = this + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into(); - let state = if matches!( - this.lock() - .conn - .transaction_state(Some(DatabaseName::Main))?, - Tx::Read | Tx::Write - ) { - State::Txn - } else { - State::Init - }; + builder.finish( + *this.lock().current_frame_no_receiver.borrow_and_update(), + status, + )?; - Ok((builder, state)) + Ok(builder) } fn execute_step( @@ -628,7 +675,7 @@ impl Connection { self.stats.inc_rows_written(rows_written as u64); } - fn describe(&self, sql: &str) -> DescribeResult { + fn describe(&self, sql: &str) -> crate::Result { let stmt = self.conn.prepare(sql)?; let params = (1..=stmt.parameter_count()) @@ -733,7 +780,7 @@ where auth: Authenticated, builder: B, _replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { check_program_auth(auth, &pgm)?; let conn = self.inner.clone(); tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder)) @@ -746,7 +793,7 @@ where sql: String, auth: Authenticated, _replication_index: Option, - ) -> Result { + ) -> Result> { check_describe_auth(auth)?; let conn = self.inner.clone(); let res = tokio::task::spawn_blocking(move || conn.lock().describe(&sql)) @@ -825,7 +872,7 @@ mod test { fn test_libsql_conn_builder_driver() { test_driver(1000, |b| { let conn = setup_test_conn(); - Connection::run(conn, Program::seq(&["select * from test"]), b).map(|x| x.0) + Connection::run(conn, Program::seq(&["select * from test"]), b) }) } @@ -849,23 +896,23 @@ mod test { tokio::time::pause(); let conn = make_conn.make_connection().await.unwrap(); - let (_builder, state) = Connection::run( + let _builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); tokio::time::advance(TXN_TIMEOUT * 2).await; - let (builder, state) = Connection::run( + let builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Init); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); } @@ -893,13 +940,13 @@ mod test { for _ in 0..10 { let conn = make_conn.make_connection().await.unwrap(); set.spawn_blocking(move || { - let (builder, state) = Connection::run( - conn.inner, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); }); } @@ -934,15 +981,15 @@ mod test { let conn1 = make_conn.make_connection().await.unwrap(); tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); } }) @@ -951,16 +998,16 @@ mod test { let conn2 = make_conn.make_connection().await.unwrap(); let handle = tokio::task::spawn_blocking({ - let conn = conn2.inner.clone(); + let conn = conn2.clone(); move || { let before = Instant::now(); - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); before.elapsed() } @@ -970,12 +1017,15 @@ mod test { tokio::time::sleep(wait_time).await; tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = - Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) - .unwrap(); - assert_eq!(state, State::Init); + let builder = Connection::run( + conn.inner.clone(), + Program::seq(&["COMMIT"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(builder.into_ret()[0].is_ok()); } }) diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index 2d539057..98e5ffc3 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -8,12 +8,12 @@ use tokio::{sync::Semaphore, time::timeout}; use crate::auth::Authenticated; use crate::error::Error; use crate::query::{Params, Query}; -use crate::query_analysis::{State, Statement}; +use crate::query_analysis::Statement; use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; use crate::replication::FrameNo; use crate::Result; -use self::program::{Cond, DescribeResult, Program, Step}; +use self::program::{Cond, DescribeResponse, Program, Step}; pub mod config; pub mod dump; @@ -32,7 +32,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, response_builder: B, replication_index: Option, - ) -> Result<(B, State)>; + ) -> Result; /// Execute all the queries in the batch sequentially. /// If an query in the batch fails, the remaining queries are ignores, and the batch current @@ -43,7 +43,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let batch_len = batch.len(); let mut steps = make_batch_program(batch); @@ -67,11 +67,11 @@ pub trait Connection: Send + Sync + 'static { // ignore the rollback result let builder = result_builder.take(batch_len); - let (builder, state) = self + let builder = self .execute_program(pgm, auth, builder, replication_index) .await?; - Ok((builder.into_inner(), state)) + Ok(builder.into_inner()) } /// Execute all the queries in the batch sequentially. @@ -82,7 +82,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let steps = make_batch_program(batch); let pgm = Program::new(steps); self.execute_program(pgm, auth, result_builder, replication_index) @@ -111,7 +111,7 @@ pub trait Connection: Send + Sync + 'static { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result; + ) -> Result>; /// Check whether the connection is in autocommit mode. async fn is_autocommit(&self) -> Result; @@ -312,7 +312,7 @@ impl Connection for TrackedConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result { self.atime.store(now_millis(), Ordering::Relaxed); self.inner .execute_program(pgm, auth, builder, replication_index) @@ -325,7 +325,7 @@ impl Connection for TrackedConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { self.atime.store(now_millis(), Ordering::Relaxed); self.inner.describe(sql, auth, replication_index).await } @@ -353,7 +353,7 @@ impl Connection for TrackedConnection { } #[cfg(test)] -mod test { +pub mod test { use super::*; #[derive(Debug)] @@ -367,7 +367,7 @@ mod test { _auth: Authenticated, _builder: B, _replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result { unreachable!() } @@ -376,7 +376,7 @@ mod test { _sql: String, _auth: Authenticated, _replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { unreachable!() } diff --git a/sqld/src/connection/program.rs b/sqld/src/connection/program.rs index fabfbd18..3017232a 100644 --- a/sqld/src/connection/program.rs +++ b/sqld/src/connection/program.rs @@ -60,8 +60,6 @@ pub enum Cond { IsAutocommit, } -pub type DescribeResult = crate::Result; - #[derive(Debug, Clone)] pub struct DescribeResponse { pub params: Vec, diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 26be6ef9..be030fef 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -1,34 +1,32 @@ use std::path::PathBuf; use std::sync::Arc; +use futures_core::future::BoxFuture; +use futures_core::Stream; use parking_lot::Mutex as PMutex; -use rusqlite::types::ValueRef; use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; -use tokio::sync::{watch, Mutex}; +use tokio::sync::{mpsc, watch, Mutex}; +use tokio_stream::StreamExt; use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; -use tonic::Request; -use uuid::Uuid; +use tonic::{Request, Streaming}; use crate::auth::Authenticated; +use crate::connection::program::{DescribeCol, DescribeParam}; use crate::error::Error; use crate::namespace::NamespaceName; -use crate::query::Value; -use crate::query_analysis::State; -use crate::query_result_builder::{ - Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, -}; +use crate::query_analysis::TxnStatus; +use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::proxy_client::ProxyClient; -use crate::rpc::proxy::rpc::query_result::RowResult; -use crate::rpc::proxy::rpc::{DisconnectMessage, ExecuteResults}; +use crate::rpc::proxy::rpc::{self, exec_req, exec_resp, ExecReq, ExecResp}; use crate::rpc::NAMESPACE_METADATA_KEY; use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; use super::config::DatabaseConfigStore; use super::libsql::{LibSqlConnection, MakeLibSqlConn}; -use super::program::DescribeResult; +use super::program::DescribeResponse; use super::Connection; use super::{MakeConnection, Program}; @@ -104,13 +102,11 @@ impl MakeConnection for MakeWriteProxyConn { } } -#[derive(Debug)] -pub struct WriteProxyConnection { +pub struct WriteProxyConnection> { /// Lazily initialized read connection read_conn: LibSqlConnection, write_proxy: ProxyClient, - state: Mutex, - client_id: Uuid, + state: Mutex, /// FrameNo of the last write performed by this connection on the primary. /// any subsequent read on this connection must wait for the replicator to catch up with this /// frame_no @@ -120,51 +116,8 @@ pub struct WriteProxyConnection { builder_config: QueryBuilderConfig, stats: Arc, namespace: NamespaceName, -} - -fn execute_results_to_builder( - execute_result: ExecuteResults, - mut builder: B, - config: &QueryBuilderConfig, -) -> Result { - builder.init(config)?; - for result in execute_result.results { - match result.row_result { - Some(RowResult::Row(rows)) => { - builder.begin_step()?; - builder.cols_description(rows.column_descriptions.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decltype.as_deref(), - }))?; - - builder.begin_rows()?; - for row in rows.rows { - builder.begin_row()?; - for value in row.values { - let value: Value = bincode::deserialize(&value.data) - // something is wrong, better stop right here - .map_err(QueryResultBuilderError::from_any)?; - builder.add_row_value(ValueRef::from(&value))?; - } - builder.finish_row()?; - } - - builder.finish_rows()?; - - builder.finish_step(rows.affected_row_count, rows.last_insert_rowid)?; - } - Some(RowResult::Error(err)) => { - builder.begin_step()?; - builder.step_error(Error::RpcQueryError(err))?; - builder.finish_step(0, None)?; - } - None => (), - } - } - - builder.finish(execute_result.current_frame_no)?; - Ok(builder) + remote_conn: Mutex>>, } impl WriteProxyConnection { @@ -180,56 +133,73 @@ impl WriteProxyConnection { Ok(Self { read_conn, write_proxy, - state: Mutex::new(State::Init), - client_id: Uuid::new_v4(), - last_write_frame_no: PMutex::new(None), + state: Mutex::new(TxnStatus::Init), + last_write_frame_no: Default::default(), applied_frame_no_receiver, builder_config, stats, namespace, + remote_conn: Default::default(), }) } + async fn with_remote_conn( + &self, + auth: Authenticated, + builder_config: QueryBuilderConfig, + cb: F, + ) -> crate::Result + where + F: FnOnce(&mut RemoteConnection) -> BoxFuture<'_, crate::Result>, + { + let mut remote_conn = self.remote_conn.lock().await; + if remote_conn.is_some() { + cb(remote_conn.as_mut().unwrap()).await + } else { + let conn = RemoteConnection::connect( + self.write_proxy.clone(), + self.namespace.clone(), + auth, + builder_config, + ) + .await?; + let conn = remote_conn.insert(conn); + cb(conn).await + } + } + async fn execute_remote( &self, pgm: Program, - state: &mut State, + status: &mut TxnStatus, auth: Authenticated, builder: B, - ) -> Result<(B, State)> { + ) -> Result { self.stats.inc_write_requests_delegated(); - let mut client = self.write_proxy.clone(); - - let mut req = Request::new(crate::rpc::proxy::rpc::ProgramReq { - client_id: self.client_id.to_string(), - pgm: Some(pgm.into()), - }); - - let namespace = BinaryMetadataValue::from_bytes(self.namespace.as_slice()); - req.metadata_mut() - .insert_bin(NAMESPACE_METADATA_KEY, namespace); - auth.upgrade_grpc_request(&mut req); - - match client.execute(req).await { - Ok(r) => { - let execute_result = r.into_inner(); - *state = execute_result.state().into(); - let current_frame_no = execute_result.current_frame_no; - let builder = - execute_results_to_builder(execute_result, builder, &self.builder_config)?; - if let Some(current_frame_no) = current_frame_no { - self.update_last_write_frame_no(current_frame_no); - } - - Ok((builder, *state)) - } - Err(e) => { - // Set state to invalid, so next call is sent to remote, and we have a chance - // to recover state. - *state = State::Invalid; - Err(Error::RpcQueryExecutionError(e)) + *status = TxnStatus::Invalid; + let res = self + .with_remote_conn(auth, self.builder_config, |conn| { + Box::pin(conn.execute(pgm, builder)) + }) + .await; + + let (builder, new_status, new_frame_no) = match res { + Ok(res) => res, + Err(e @ (Error::PrimaryStreamDisconnect | Error::PrimaryStreamMisuse)) => { + // drop the connection, and reset the state. + self.remote_conn.lock().await.take(); + *status = TxnStatus::Init; + return Err(e); } + Err(e) => return Err(e), + }; + + *status = new_status; + if let Some(current_frame_no) = new_frame_no { + self.update_last_write_frame_no(current_frame_no); } + + Ok(builder) } fn update_last_write_frame_no(&self, new_frame_no: FrameNo) { @@ -261,6 +231,164 @@ impl WriteProxyConnection { } } +struct RemoteConnection> { + response_stream: R, + request_sender: mpsc::Sender, + current_request_id: u32, + builder_config: QueryBuilderConfig, +} + +impl RemoteConnection { + async fn connect( + mut client: ProxyClient, + namespace: NamespaceName, + auth: Authenticated, + builder_config: QueryBuilderConfig, + ) -> crate::Result { + let (request_sender, receiver) = mpsc::channel(1); + + let stream = tokio_stream::wrappers::ReceiverStream::new(receiver); + let mut req = Request::new(stream); + let namespace = BinaryMetadataValue::from_bytes(namespace.as_slice()); + req.metadata_mut() + .insert_bin(NAMESPACE_METADATA_KEY, namespace); + auth.upgrade_grpc_request(&mut req); + let response_stream = client.stream_exec(req).await.unwrap().into_inner(); + + Ok(Self { + response_stream, + request_sender, + current_request_id: 0, + builder_config, + }) + } +} + +impl RemoteConnection +where + R: Stream> + Unpin, +{ + /// Perform a request on to the remote peer, and call message_cb for every message received for + /// that request. message cb should return whether to expect more message for that request. + async fn make_request( + &mut self, + req: exec_req::Request, + mut response_cb: impl FnMut(exec_resp::Response) -> crate::Result, + ) -> crate::Result<()> { + let request_id = self.current_request_id; + self.current_request_id += 1; + + let req = ExecReq { + request_id, + request: Some(req), + }; + + self.request_sender + .send(req) + .await + .map_err(|_| Error::PrimaryStreamDisconnect)?; + + while let Some(resp) = self.response_stream.next().await { + match resp { + Ok(resp) => { + // there was an interuption, and we moved to the next query + if resp.request_id > request_id { + return Err(Error::PrimaryStreamInterupted); + } + + // we can ignore response for previously interupted requests + if resp.request_id < request_id { + continue; + } + + if !response_cb(resp.response.ok_or(Error::PrimaryStreamMisuse)?)? { + break; + } + } + Err(e) => { + tracing::error!("received an error from connection stream: {e}"); + return Err(Error::PrimaryStreamDisconnect); + } + } + } + + Ok(()) + } + + async fn execute( + &mut self, + program: Program, + mut builder: B, + ) -> crate::Result<(B, TxnStatus, Option)> { + let mut txn_status = TxnStatus::Invalid; + let mut new_frame_no = None; + let builder_config = self.builder_config; + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::ProgramResp(resp) => { + crate::rpc::streaming_exec::apply_program_resp_to_builder( + &builder_config, + &mut builder, + resp, + |last_frame_no, status| { + txn_status = status; + new_frame_no = last_frame_no; + }, + ) + } + exec_resp::Response::DescribeResp(_) => Err(Error::PrimaryStreamMisuse), + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + }; + + self.make_request( + exec_req::Request::Execute(rpc::StreamProgramReq { + pgm: Some(program.into()), + }), + cb, + ) + .await?; + + Ok((builder, txn_status, new_frame_no)) + } + + #[allow(dead_code)] // reference implementation + async fn describe(&mut self, stmt: String) -> crate::Result { + let mut out = None; + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::DescribeResp(resp) => { + out = Some(DescribeResponse { + params: resp + .params + .into_iter() + .map(|p| DescribeParam { name: p.name }) + .collect(), + cols: resp + .cols + .into_iter() + .map(|c| DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: resp.is_explain, + is_readonly: resp.is_readonly, + }); + + Ok(false) + } + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + exec_resp::Response::ProgramResp(_) => Err(Error::PrimaryStreamMisuse), + }; + + self.make_request( + exec_req::Request::Describe(rpc::StreamDescribeReq { stmt }), + cb, + ) + .await?; + + out.ok_or(Error::PrimaryStreamMisuse) + } +} + #[async_trait::async_trait] impl Connection for WriteProxyConnection { async fn execute_program( @@ -269,26 +397,30 @@ impl Connection for WriteProxyConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let mut state = self.state.lock().await; // This is a fresh namespace, and it is not replicated yet, proxy the first request. if self.applied_frame_no_receiver.borrow().is_none() { self.execute_remote(pgm, &mut state, auth, builder).await - } else if *state == State::Init && pgm.is_read_only() { + } else if *state == TxnStatus::Init && pgm.is_read_only() { + // set the state to invalid before doing anything, and set it to a valid state after. + *state = TxnStatus::Invalid; self.wait_replication_sync(replication_index).await?; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - let (builder, new_state) = self + let builder = self .read_conn .execute_program(pgm.clone(), auth.clone(), builder, replication_index) .await?; - if new_state != State::Init { + let new_state = self.read_conn.txn_status()?; + if new_state != TxnStatus::Init { self.read_conn.rollback(auth.clone()).await?; self.execute_remote(pgm, &mut state, auth, builder).await } else { - Ok((builder, new_state)) + *state = new_state; + Ok(builder) } } else { self.execute_remote(pgm, &mut state, auth, builder).await @@ -300,7 +432,7 @@ impl Connection for WriteProxyConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result { + ) -> Result> { self.wait_replication_sync(replication_index).await?; self.read_conn.describe(sql, auth, replication_index).await } @@ -308,8 +440,8 @@ impl Connection for WriteProxyConnection { async fn is_autocommit(&self) -> Result { let state = self.state.lock().await; Ok(match *state { - State::Txn => false, - State::Init | State::Invalid => true, + TxnStatus::Txn => false, + TxnStatus::Init | TxnStatus::Invalid => true, }) } @@ -328,25 +460,21 @@ impl Connection for WriteProxyConnection { } } -impl Drop for WriteProxyConnection { - fn drop(&mut self) { - // best effort attempt to disconnect - let mut remote = self.write_proxy.clone(); - let client_id = self.client_id.to_string(); - tokio::spawn(async move { - let _ = remote.disconnect(DisconnectMessage { client_id }).await; - }); - } -} - #[cfg(test)] pub mod test { use arbitrary::{Arbitrary, Unstructured}; use bytes::Bytes; use rand::Fill; + use rusqlite::types::ValueRef; use super::*; - use crate::query_result_builder::test::test_driver; + use crate::{ + query_result_builder::{test::test_driver, Column, QueryResultBuilderError}, + rpc::{ + proxy::rpc::{query_result::RowResult, ExecuteResults}, + streaming_exec::test::random_valid_program_resp, + }, + }; /// generate an arbitraty rpc value. see build.rs for usage. pub fn arbitrary_rpc_value(u: &mut Unstructured) -> arbitrary::Result> { @@ -362,15 +490,85 @@ pub mod test { Ok(v.into()) } + fn execute_results_to_builder( + execute_result: ExecuteResults, + mut builder: B, + config: &QueryBuilderConfig, + ) -> Result { + builder.init(config)?; + for result in execute_result.results { + match result.row_result { + Some(RowResult::Row(rows)) => { + builder.begin_step()?; + builder.cols_description(rows.column_descriptions.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }))?; + + builder.begin_rows()?; + for row in rows.rows { + builder.begin_row()?; + for value in row.values { + let value: crate::query::Value = bincode::deserialize(&value.data) + // something is wrong, better stop right here + .map_err(QueryResultBuilderError::from_any)?; + builder.add_row_value(ValueRef::from(&value))?; + } + builder.finish_row()?; + } + + builder.finish_rows()?; + + builder.finish_step(rows.affected_row_count, rows.last_insert_rowid)?; + } + Some(RowResult::Error(err)) => { + builder.begin_step()?; + builder.step_error(Error::RpcQueryError(err))?; + builder.finish_step(0, None)?; + } + None => (), + } + } + + builder.finish(execute_result.current_frame_no, TxnStatus::Init)?; + + Ok(builder) + } + /// In this test, we generate random ExecuteResults, and ensures that the `execute_results_to_builder` drives the builder FSM correctly. #[test] fn test_execute_results_to_builder() { - test_driver(1000, |b| { - let mut data = [0; 10_000]; - data.try_fill(&mut rand::thread_rng()).unwrap(); - let mut un = Unstructured::new(&data); - let res = ExecuteResults::arbitrary(&mut un).unwrap(); - execute_results_to_builder(res, b, &QueryBuilderConfig::default()) - }); + test_driver( + 1000, + |b| -> std::result::Result { + let mut data = [0; 10_000]; + data.try_fill(&mut rand::thread_rng()).unwrap(); + let mut un = Unstructured::new(&data); + let res = ExecuteResults::arbitrary(&mut un).unwrap(); + execute_results_to_builder(res, b, &QueryBuilderConfig::default()) + }, + ); + } + + #[tokio::test] + // in this test we do a roundtrip: generate a random valid program, stream it to + // RemoteConnection, and make sure that the remote connection drives the builder with the same + // state transitions. + async fn validate_random_stream_response() { + let (response_stream, validator) = random_valid_program_resp(500, 150); + let (request_sender, _request_recver) = mpsc::channel(1); + let mut remote = RemoteConnection { + response_stream: response_stream.map(Ok), + request_sender, + current_request_id: 0, + builder_config: QueryBuilderConfig::default(), + }; + + remote + .execute(Program::seq(&[]), validator) + .await + .unwrap() + .0 + .into_ret(); } } diff --git a/sqld/src/error.rs b/sqld/src/error.rs index 66ca58dc..97df108b 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -79,6 +79,13 @@ pub enum Error { ConflictingRestoreParameters, #[error("failed to fork database: {0}")] Fork(#[from] ForkError), + + #[error("Connection with primary broken")] + PrimaryStreamDisconnect, + #[error("Proxy protocal misuse")] + PrimaryStreamMisuse, + #[error("Proxy request interupted")] + PrimaryStreamInterupted, } trait ResponseError: std::error::Error { @@ -129,6 +136,9 @@ impl IntoResponse for Error { LoadDumpExistingDb => self.format_err(StatusCode::BAD_REQUEST), ConflictingRestoreParameters => self.format_err(StatusCode::BAD_REQUEST), Fork(e) => e.into_response(), + PrimaryStreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamMisuse => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamInterupted => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), } } } diff --git a/sqld/src/hrana/batch.rs b/sqld/src/hrana/batch.rs index cb0deb63..a2ddd9e2 100644 --- a/sqld/src/hrana/batch.rs +++ b/sqld/src/hrana/batch.rs @@ -110,7 +110,7 @@ pub async fn execute_batch( replication_index: Option, ) -> Result { let batch_builder = HranaBatchProtoBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, batch_builder, replication_index) .await .map_err(catch_batch_error)?; @@ -151,7 +151,7 @@ pub async fn execute_sequence( replication_index: Option, ) -> Result<()> { let builder = StepResultsBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, builder, replication_index) .await .map_err(catch_batch_error)?; diff --git a/sqld/src/hrana/cursor.rs b/sqld/src/hrana/cursor.rs index 005799a7..c67079a9 100644 --- a/sqld/src/hrana/cursor.rs +++ b/sqld/src/hrana/cursor.rs @@ -8,6 +8,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::auth::Authenticated; use crate::connection::program::Program; use crate::connection::Connection; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -255,7 +256,11 @@ impl QueryResultBuilder for CursorResultBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.emit_entry(Ok(SizedEntry { entry: proto::CursorEntry::ReplicationIndex { replication_index: last_frame_no, diff --git a/sqld/src/hrana/result_builder.rs b/sqld/src/hrana/result_builder.rs index d2e19910..70d1890a 100644 --- a/sqld/src/hrana/result_builder.rs +++ b/sqld/src/hrana/result_builder.rs @@ -6,6 +6,7 @@ use bytes::Bytes; use rusqlite::types::ValueRef; use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, }; @@ -225,7 +226,11 @@ impl QueryResultBuilder for SingleStatementBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.last_frame_no = last_frame_no; Ok(()) } @@ -344,7 +349,11 @@ impl QueryResultBuilder for HranaBatchProtoBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } diff --git a/sqld/src/hrana/stmt.rs b/sqld/src/hrana/stmt.rs index 2021b384..46cd3684 100644 --- a/sqld/src/hrana/stmt.rs +++ b/sqld/src/hrana/stmt.rs @@ -58,7 +58,7 @@ pub async fn execute_stmt( replication_index: Option, ) -> Result { let builder = SingleStatementBuilder::default(); - let (stmt_res, _) = db + let stmt_res = db .execute_batch(vec![query], auth, builder, replication_index) .await .map_err(catch_stmt_error)?; diff --git a/sqld/src/http/user/mod.rs b/sqld/src/http/user/mod.rs index 367f7e77..503222c0 100644 --- a/sqld/src/http/user/mod.rs +++ b/sqld/src/http/user/mod.rs @@ -37,7 +37,7 @@ use crate::http::user::types::HttpQuery; use crate::namespace::{MakeNamespace, NamespaceStore}; use crate::net::Accept; use crate::query::{self, Query}; -use crate::query_analysis::{predict_final_state, State, Statement}; +use crate::query_analysis::{predict_final_state, Statement, TxnStatus}; use crate::query_result_builder::QueryResultBuilder; use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; @@ -97,15 +97,15 @@ fn parse_queries(queries: Vec) -> crate::Result> { out.push(query); } - match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { - State::Txn => { + match predict_final_state(TxnStatus::Init, out.iter().map(|q| &q.stmt)) { + TxnStatus::Txn => { return Err(Error::QueryError( "interactive transaction not allowed in HTTP queries".to_string(), )) } - State::Init => (), + TxnStatus::Init => (), // maybe we should err here, but let's sqlite deal with that. - State::Invalid => (), + TxnStatus::Invalid => (), } Ok(out) @@ -121,7 +121,7 @@ async fn handle_query( let db = connection_maker.create().await?; let builder = JsonHttpPayloadBuilder::new(); - let (builder, _) = db + let builder = db .execute_batch_or_rollback(batch, auth, builder, query.replication_index) .await?; diff --git a/sqld/src/http/user/result_builder.rs b/sqld/src/http/user/result_builder.rs index fa7c4710..4f56f7db 100644 --- a/sqld/src/http/user/result_builder.rs +++ b/sqld/src/http/user/result_builder.rs @@ -6,6 +6,7 @@ use serde::{Serialize, Serializer}; use serde_json::ser::{CompactFormatter, Formatter}; use std::sync::atomic::Ordering; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, JsonFormatter, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, @@ -293,7 +294,11 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { } // TODO: how do we return last_frame_no? - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.formatter.end_array(&mut self.buffer)?; Ok(()) @@ -306,7 +311,8 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { #[cfg(test)] mod test { - use crate::query_result_builder::test::random_builder_driver; + + use crate::query_result_builder::test::{fsm_builder_driver, random_transition}; use super::*; @@ -314,7 +320,8 @@ mod test { fn test_json_builder() { for _ in 0..1000 { let builder = JsonHttpPayloadBuilder::new(); - let ret = random_builder_driver(100, builder).into_ret(); + let trace = random_transition(100); + let ret = fsm_builder_driver(&trace, builder).into_ret(); println!("{}", std::str::from_utf8(&ret).unwrap()); // we produce valid json serde_json::from_slice::>(&ret).unwrap(); diff --git a/sqld/src/lib.rs b/sqld/src/lib.rs index a0461bd9..ba984f16 100644 --- a/sqld/src/lib.rs +++ b/sqld/src/lib.rs @@ -500,6 +500,8 @@ where let proxy_service = ProxyService::new(namespaces.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds + // TODO: this will no longer be necessary once client have adopted the streaming proxy + // protocol self.join_set.spawn({ let clients = proxy_service.clients(); async move { diff --git a/sqld/src/query_analysis.rs b/sqld/src/query_analysis.rs index 5f0e4f37..32352620 100644 --- a/sqld/src/query_analysis.rs +++ b/sqld/src/query_analysis.rs @@ -171,7 +171,7 @@ impl StmtKind { /// The state of a transaction for a series of statement #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum State { +pub enum TxnStatus { /// The txn in an opened state Txn, /// The txn in a closed state @@ -180,19 +180,21 @@ pub enum State { Invalid, } -impl State { +impl TxnStatus { pub fn step(&mut self, kind: StmtKind) { *self = match (*self, kind) { - (State::Txn, StmtKind::TxnBegin) | (State::Init, StmtKind::TxnEnd) => State::Invalid, - (State::Txn, StmtKind::TxnEnd) => State::Init, + (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => { + TxnStatus::Invalid + } + (TxnStatus::Txn, StmtKind::TxnEnd) => TxnStatus::Init, (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, - (State::Invalid, _) => State::Invalid, - (State::Init, StmtKind::TxnBegin) => State::Txn, + (TxnStatus::Invalid, _) => TxnStatus::Invalid, + (TxnStatus::Init, StmtKind::TxnBegin) => TxnStatus::Txn, }; } pub fn reset(&mut self) { - *self = State::Init + *self = TxnStatus::Init } } @@ -284,9 +286,9 @@ impl Statement { /// Given a an initial state and an array of queries, attempts to predict what the final state will /// be pub fn predict_final_state<'a>( - mut state: State, + mut state: TxnStatus, stmts: impl Iterator, -) -> State { +) -> TxnStatus { for stmt in stmts { state.step(stmt.kind); } diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index 914037ee..121d1a00 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -8,12 +8,14 @@ use serde::Serialize; use serde_json::ser::Formatter; use std::sync::atomic::AtomicUsize; +use crate::query_analysis::TxnStatus; use crate::replication::FrameNo; pub static TOTAL_RESPONSE_SIZE: AtomicUsize = AtomicUsize::new(0); #[derive(Debug)] pub enum QueryResultBuilderError { + /// The response payload is too large ResponseTooLarge(u64), Internal(anyhow::Error), } @@ -120,7 +122,11 @@ pub trait QueryResultBuilder: Send + 'static { /// end adding rows fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError>; /// finish serialization. - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError>; + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError>; /// returns the inner ret fn into_ret(self) -> Self::Ret; /// Returns a `QueryResultBuilder` that wraps Self and takes at most `n` steps @@ -311,7 +317,11 @@ impl QueryResultBuilder for StepResultsBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -372,7 +382,11 @@ impl QueryResultBuilder for IgnoreResult { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -481,8 +495,12 @@ impl QueryResultBuilder for Take { } } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { - self.inner.finish(last_frame_no) + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.inner.finish(last_frame_no, state) } fn into_ret(self) -> Self::Ret { @@ -599,6 +617,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -614,7 +633,7 @@ pub mod test { #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[repr(usize)] // do not reorder! - enum FsmState { + pub enum FsmState { Init = 0, Finish, BeginStep, @@ -683,11 +702,31 @@ pub mod test { } } - pub fn random_builder_driver(mut max_steps: usize, mut b: B) -> B { + pub fn random_transition(mut max_steps: usize) -> Vec { + let mut trace = Vec::with_capacity(max_steps); + let mut state = Init; + trace.push(state); + loop { + if max_steps > 0 { + state = state.rand_transition(false); + } else { + state = state.toward_finish() + } + + trace.push(state); + if state == FsmState::Finish { + break; + } + + max_steps = max_steps.saturating_sub(1); + } + trace + } + + pub fn fsm_builder_driver(trace: &[FsmState], mut b: B) -> B { let mut rand_data = [0; 10_000]; rand_data.try_fill(&mut rand::thread_rng()).unwrap(); let mut u = Unstructured::new(&rand_data); - let mut trace = Vec::new(); #[derive(Arbitrary)] pub enum ValueRef<'a> { @@ -710,9 +749,7 @@ pub mod test { } } - let mut state = Init; - trace.push(state); - loop { + for state in trace { match state { Init => b.init(&QueryBuilderConfig::default()).unwrap(), BeginStep => b.begin_step().unwrap(), @@ -734,27 +771,114 @@ pub mod test { FinishRow => b.finish_row().unwrap(), FinishRows => b.finish_rows().unwrap(), Finish => { - b.finish(Some(0)).unwrap(); + b.finish(Some(0), TxnStatus::Init).unwrap(); break; } BuilderError => return b, } + } - if max_steps > 0 { - state = state.rand_transition(false); - } else { - state = state.toward_finish() - } + b + } - trace.push(state); + /// A Builder that validates a given execution trace + pub struct ValidateTraceBuilder { + trace: Vec, + current: usize, + } - max_steps = max_steps.saturating_sub(1); + impl ValidateTraceBuilder { + pub fn new(trace: Vec) -> Self { + Self { trace, current: 0 } } + } - // this can be usefull to help debug the generated test case - dbg!(trace); + impl QueryResultBuilder for ValidateTraceBuilder { + type Ret = (); - b + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Init); + self.current += 1; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginStep); + self.current += 1; + Ok(()) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishStep); + self.current += 1; + Ok(()) + } + + fn step_error( + &mut self, + _error: crate::error::Error, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::StepError); + self.current += 1; + Ok(()) + } + + fn cols_description<'a>( + &mut self, + _cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::ColsDescription); + self.current += 1; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRows); + self.current += 1; + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRow); + self.current += 1; + Ok(()) + } + + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::AddRowValue); + self.current += 1; + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRow); + self.current += 1; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRows); + self.current += 1; + Ok(()) + } + + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Finish); + self.current += 1; + Ok(()) + } + + fn into_ret(self) -> Self::Ret { + assert_eq!(self.current, self.trace.len()); + } } pub struct FsmQueryBuilder { @@ -882,6 +1006,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { self.maybe_inject_error()?; self.transition(Finish) @@ -930,7 +1055,7 @@ pub mod test { builder.finish_rows().unwrap(); builder.finish_step(0, None).unwrap(); - builder.finish(Some(0)).unwrap(); + builder.finish(Some(0), TxnStatus::Init).unwrap(); } #[test] diff --git a/sqld/src/rpc/mod.rs b/sqld/src/rpc/mod.rs index 252d58e7..5dc51c33 100644 --- a/sqld/src/rpc/mod.rs +++ b/sqld/src/rpc/mod.rs @@ -19,6 +19,7 @@ pub mod proxy; pub mod replica_proxy; pub mod replication_log; pub mod replication_log_proxy; +pub mod streaming_exec; /// A tonic error code to signify that a namespace doesn't exist. pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST"; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 5ba004a7..490d3080 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -1,24 +1,29 @@ use std::collections::HashMap; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use futures_core::Stream; +use rusqlite::types::ValueRef; use uuid::Uuid; use crate::auth::{Auth, Authenticated}; use crate::connection::Connection; use crate::database::{Database, PrimaryConnection}; use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; +use crate::rpc::streaming_exec::make_proxy_stream; use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ - describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, - ExecuteResults, QueryResult, ResultRows, Row, + describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq, + ExecResp, ExecuteResults, QueryResult, ResultRows, Row, }; use super::NAMESPACE_DOESNT_EXIST; @@ -32,7 +37,7 @@ pub mod rpc { use crate::query_analysis::Statement; use crate::{connection, error::Error as SqldError}; - use self::{error::ErrorCode, execute_results::State}; + use self::error::ErrorCode; tonic::include_proto!("proxy"); impl From for Error { @@ -55,22 +60,22 @@ pub mod rpc { } } - impl From for State { - fn from(other: crate::query_analysis::State) -> Self { + impl From for State { + fn from(other: crate::query_analysis::TxnStatus) -> Self { match other { - crate::query_analysis::State::Txn => Self::Txn, - crate::query_analysis::State::Init => Self::Init, - crate::query_analysis::State::Invalid => Self::Invalid, + crate::query_analysis::TxnStatus::Txn => Self::Txn, + crate::query_analysis::TxnStatus::Init => Self::Init, + crate::query_analysis::TxnStatus::Invalid => Self::Invalid, } } } - impl From for crate::query_analysis::State { + impl From for crate::query_analysis::TxnStatus { fn from(other: State) -> Self { match other { - State::Txn => crate::query_analysis::State::Txn, - State::Init => crate::query_analysis::State::Init, - State::Invalid => crate::query_analysis::State::Invalid, + State::Txn => crate::query_analysis::TxnStatus::Txn, + State::Init => crate::query_analysis::TxnStatus::Init, + State::Invalid => crate::query_analysis::TxnStatus::Invalid, } } } @@ -290,7 +295,8 @@ impl ProxyService { } #[derive(Debug, Default)] -struct ExecuteResultBuilder { +struct ExecuteResultsBuilder { + output: Option, results: Vec, current_rows: Vec, current_row: rpc::Row, @@ -301,8 +307,8 @@ struct ExecuteResultBuilder { current_step_size: u64, } -impl QueryResultBuilder for ExecuteResultBuilder { - type Ret = Vec; +impl QueryResultBuilder for ExecuteResultsBuilder { + type Ret = ExecuteResults; fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Self { @@ -397,10 +403,7 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn add_row_value( - &mut self, - v: rusqlite::types::ValueRef, - ) -> Result<(), QueryResultBuilderError> { + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { let data = bincode::serialize( &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, ) @@ -435,12 +438,21 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + txn_status: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.output = Some(ExecuteResults { + results: std::mem::take(&mut self.results), + state: rpc::State::from(txn_status).into(), + current_frame_no: last_frame_no, + }); Ok(()) } fn into_ret(self) -> Self::Ret { - self.results + self.output.unwrap() } } @@ -457,6 +469,42 @@ pub async fn garbage_collect(clients: &mut HashMap> #[tonic::async_trait] impl Proxy for ProxyService { + type StreamExecStream = Pin> + Send>>; + + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + let auth = if let Some(auth) = &self.auth { + auth.authenticate_grpc(&req, self.disable_namespaces)? + } else { + Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? + }; + + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let (connection_maker, _new_frame_notifier) = self + .namespaces + .with(namespace, |ns| { + let connection_maker = ns.db.connection_maker(); + let notifier = ns.db.logger.new_frame_notifier.subscribe(); + (connection_maker, notifier) + }) + .await + .map_err(|e| { + if let crate::error::Error::NamespaceDoesntExist(_) = e { + tonic::Status::failed_precondition(NAMESPACE_DOESNT_EXIST) + } else { + tonic::Status::internal(e.to_string()) + } + })?; + + let conn = connection_maker.create().await.unwrap(); + + let stream = make_proxy_stream(conn, auth, req.into_inner()); + + Ok(tonic::Response::new(Box::pin(stream))) + } + async fn execute( &self, req: tonic::Request, @@ -472,13 +520,9 @@ impl Proxy for ProxyService { .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument, e.to_string()))?; let client_id = Uuid::from_str(&req.client_id).unwrap(); - let (connection_maker, new_frame_notifier) = self + let connection_maker = self .namespaces - .with(namespace, |ns| { - let connection_maker = ns.db.connection_maker(); - let notifier = ns.db.logger.new_frame_notifier.subscribe(); - (connection_maker, notifier) - }) + .with(namespace, |ns| ns.db.connection_maker()) .await .map_err(|e| { if let crate::error::Error::NamespaceDoesntExist(_) = e { @@ -507,19 +551,14 @@ impl Proxy for ProxyService { tracing::debug!("executing request for {client_id}"); - let builder = ExecuteResultBuilder::default(); - let (results, state) = db + let builder = ExecuteResultsBuilder::default(); + let builder = db .execute_program(pgm, auth, builder, None) .await // TODO: this is no necessarily a permission denied error! .map_err(|e| tonic::Status::new(tonic::Code::PermissionDenied, e.to_string()))?; - let current_frame_no = *new_frame_notifier.borrow(); - Ok(tonic::Response::new(ExecuteResults { - current_frame_no, - results: results.into_ret(), - state: rpc::execute_results::State::from(state).into(), - })) + Ok(tonic::Response::new(builder.into_ret())) } //TODO: also handle cleanup on peer disconnect diff --git a/sqld/src/rpc/replica_proxy.rs b/sqld/src/rpc/replica_proxy.rs index c4aa7179..1792caf0 100644 --- a/sqld/src/rpc/replica_proxy.rs +++ b/sqld/src/rpc/replica_proxy.rs @@ -1,13 +1,14 @@ use std::sync::Arc; use hyper::Uri; +use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; use crate::auth::Auth; use super::proxy::rpc::{ self, proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, - DisconnectMessage, ExecuteResults, + DisconnectMessage, ExecReq, ExecResp, ExecuteResults, }; pub struct ReplicaProxyService { @@ -32,6 +33,31 @@ impl ReplicaProxyService { #[tonic::async_trait] impl Proxy for ReplicaProxyService { + type StreamExecStream = tonic::codec::Streaming; + + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + let (meta, ext, mut stream) = req.into_parts(); + let stream = async_stream::stream! { + while let Some(it) = stream.next().await { + match it { + Ok(it) => yield it, + Err(e) => { + // close the stream on error + tracing::error!("error proxying stream request: {e}"); + break + }, + } + } + }; + let mut req = tonic::Request::from_parts(meta, ext, stream); + self.do_auth(&mut req)?; + let mut client = self.client.clone(); + client.stream_exec(req).await + } + async fn execute( &self, mut req: tonic::Request, diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap new file mode 100644 index 00000000..06794184 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap @@ -0,0 +1,28 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + DescribeResp( + DescribeResp { + params: [ + DescribeParam { + name: Some( + "$hello", + ), + }, + ], + cols: [ + DescribeCol { + name: "$hello", + decltype: None, + }, + ], + is_explain: false, + is_readonly: true, + }, + ), + ), +} diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap new file mode 100644 index 00000000..e957ad95 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap @@ -0,0 +1,255 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: builder.into_ret() +--- +[ + Ok( + [ + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + ], + ), +] diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap new file mode 100644 index 00000000..d3090106 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap @@ -0,0 +1,5 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap_err().to_string() +--- +status: InvalidArgument, message: "invalid request", details: [], metadata: MetadataMap { headers: {} } diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap new file mode 100644 index 00000000..25b566a0 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap @@ -0,0 +1,72 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + ProgramResp( + ProgramResp { + steps: [ + RespStep { + step: Some( + Init( + Init, + ), + ), + }, + RespStep { + step: Some( + BeginStep( + BeginStep, + ), + ), + }, + RespStep { + step: Some( + ColsDescription( + ColsDescription { + columns: [], + }, + ), + ), + }, + RespStep { + step: Some( + BeginRows( + BeginRows, + ), + ), + }, + RespStep { + step: Some( + FinishRows( + FinishRows, + ), + ), + }, + RespStep { + step: Some( + FinishStep( + FinishStep { + affected_row_count: 0, + last_insert_rowid: None, + }, + ), + ), + }, + RespStep { + step: Some( + Finish( + Finish { + last_frame_no: None, + state: Init, + }, + ), + ), + }, + ], + }, + ), + ), +} diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap new file mode 100644 index 00000000..e957ad95 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap @@ -0,0 +1,255 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: builder.into_ret() +--- +[ + Ok( + [ + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + ], + ), +] diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs new file mode 100644 index 00000000..b5826da3 --- /dev/null +++ b/sqld/src/rpc/streaming_exec.rs @@ -0,0 +1,570 @@ +use std::sync::Arc; + +use futures_core::future::BoxFuture; +use futures_core::Stream; +use futures_option::OptionExt; +use prost::Message; +use rusqlite::types::ValueRef; +use tokio::pin; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; +use tonic::{Code, Status}; + +use crate::auth::Authenticated; +use crate::connection::Connection; +use crate::error::Error; +use crate::query_analysis::TxnStatus; +use crate::query_result_builder::{ + Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, +}; +use crate::replication::FrameNo; +use crate::rpc::proxy::rpc::exec_req::Request; +use crate::rpc::proxy::rpc::exec_resp::{self, Response}; +use crate::rpc::proxy::rpc::{DescribeCol, DescribeParam, DescribeResp, StreamDescribeReq}; + +use super::proxy::rpc::resp_step::Step; +use super::proxy::rpc::row_value::Value; +use super::proxy::rpc::{ + self, AddRowValue, ColsDescription, ExecReq, ExecResp, Finish, FinishStep, ProgramResp, + RespStep, RowValue, StepError, +}; + +const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::mb(1).as_u64() as usize; + +pub fn make_proxy_stream( + conn: C, + auth: Authenticated, + request_stream: S, +) -> impl Stream> +where + S: Stream>, + C: Connection, +{ + make_proxy_stream_inner(conn, auth, request_stream, MAX_RESPONSE_SIZE) +} + +fn make_proxy_stream_inner( + conn: C, + auth: Authenticated, + request_stream: S, + max_program_resp_size: usize, +) -> impl Stream> +where + S: Stream>, + C: Connection, +{ + async_stream::stream! { + let mut current_request_fut: Option, u32)>> = None; + let (snd, mut recv) = mpsc::channel(1); + let conn = Arc::new(conn); + + pin!(request_stream); + + loop { + tokio::select! { + biased; + maybe_req = request_stream.next() => { + let Some(maybe_req) = maybe_req else { break }; + match maybe_req { + Err(e) => { + tracing::error!("stream error: {e}"); + break + } + Ok(req) => { + let request_id = req.request_id; + match req.request { + Some(Request::Execute(pgm)) => { + let Ok(pgm) = + crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { + yield Err(Status::new(Code::InvalidArgument, "invalid program")); + break + }; + let conn = conn.clone(); + let auth = auth.clone(); + let sender = snd.clone(); + + let fut = async move { + let builder = StreamResponseBuilder { + request_id, + sender, + current: None, + current_size: 0, + max_program_resp_size, + }; + + let ret = conn.execute_program(pgm, auth, builder, None).await.map(|_| ()); + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + } + Some(Request::Describe(StreamDescribeReq { stmt })) => { + let auth = auth.clone(); + let sender = snd.clone(); + let conn = conn.clone(); + let fut = async move { + let do_describe = || async move { + let ret = conn.describe(stmt, auth, None).await??; + Ok(DescribeResp { + cols: ret.cols.into_iter().map(|c| DescribeCol { name: c.name, decltype: c.decltype }).collect(), + params: ret.params.into_iter().map(|p| DescribeParam { name: p.name }).collect(), + is_explain: ret.is_explain, + is_readonly: ret.is_readonly + }) + }; + + let ret: crate::Result<()> = match do_describe().await { + Ok(resp) => { + let _ = sender.send(ExecResp { request_id, response: Some(Response::DescribeResp(resp)) }).await; + Ok(()) + } + Err(e) => Err(e), + }; + + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + + }, + None => { + yield Err(Status::new(Code::InvalidArgument, "invalid request")); + break + } + } + } + } + }, + Some(res) = recv.recv() => { + yield Ok(res); + }, + (ret, request_id) = current_request_fut.current(), if current_request_fut.is_some() => { + if let Err(e) = ret { + yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) }) + } + }, + else => break, + } + } + } +} + +struct StreamResponseBuilder { + request_id: u32, + sender: mpsc::Sender, + current: Option, + current_size: usize, + max_program_resp_size: usize, +} + +impl StreamResponseBuilder { + fn current(&mut self) -> &mut ProgramResp { + self.current + .get_or_insert_with(|| ProgramResp { steps: Vec::new() }) + } + + fn push(&mut self, step: Step) -> Result<(), QueryResultBuilderError> { + let current = self.current(); + let step = RespStep { step: Some(step) }; + let size = step.encoded_len(); + current.steps.push(step); + self.current_size += size; + + if self.current_size >= self.max_program_resp_size { + self.flush()?; + } + + Ok(()) + } + + fn flush(&mut self) -> Result<(), QueryResultBuilderError> { + if let Some(current) = self.current.take() { + let resp = ExecResp { + request_id: self.request_id, + response: Some(exec_resp::Response::ProgramResp(current)), + }; + self.current_size = 0; + self.sender + .blocking_send(resp) + .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; + } + + Ok(()) + } +} + +/// Apply the response to the the builder, and return whether the builder need more steps +pub fn apply_program_resp_to_builder( + config: &QueryBuilderConfig, + builder: &mut B, + resp: ProgramResp, + mut on_finish: impl FnMut(Option, TxnStatus), +) -> crate::Result { + for step in resp.steps { + let Some(step) = step.step else { + return Err(Error::PrimaryStreamMisuse); + }; + match step { + Step::Init(_) => builder.init(config)?, + Step::BeginStep(_) => builder.begin_step()?, + Step::FinishStep(FinishStep { + affected_row_count, + last_insert_rowid, + }) => builder.finish_step(affected_row_count, last_insert_rowid)?, + Step::StepError(StepError { error: Some(err) }) => { + builder.step_error(crate::error::Error::RpcQueryError(err))? + } + Step::ColsDescription(ColsDescription { columns }) => { + let cols = columns.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }); + builder.cols_description(cols)? + } + Step::BeginRows(_) => builder.begin_rows()?, + Step::BeginRow(_) => builder.begin_row()?, + Step::AddRowValue(AddRowValue { + val: Some(RowValue { value: Some(val) }), + }) => { + let val = match &val { + Value::Text(s) => ValueRef::Text(s.as_bytes()), + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Blob(b) => ValueRef::Blob(b.as_slice()), + Value::Null(_) => ValueRef::Null, + }; + builder.add_row_value(val)?; + } + Step::FinishRow(_) => builder.finish_row()?, + Step::FinishRows(_) => builder.finish_rows()?, + Step::Finish(f @ Finish { last_frame_no, .. }) => { + let txn_status = TxnStatus::from(f.state()); + on_finish(last_frame_no, txn_status); + builder.finish(last_frame_no, txn_status)?; + return Ok(false); + } + _ => return Err(Error::PrimaryStreamMisuse), + } + } + + Ok(true) +} + +impl QueryResultBuilder for StreamResponseBuilder { + type Ret = (); + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.push(Step::Init(rpc::Init {}))?; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginStep(rpc::BeginStep {}))?; + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishStep(rpc::FinishStep { + affected_row_count, + last_insert_rowid, + }))?; + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.push(Step::StepError(rpc::StepError { + error: Some(error.into()), + }))?; + Ok(()) + } + + fn cols_description<'a>( + &mut self, + cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::ColsDescription(rpc::ColsDescription { + columns: cols + .into_iter() + .map(Into::into) + .map(|c| rpc::Column { + name: c.name.into(), + decltype: c.decl_ty.map(Into::into), + }) + .collect::>(), + }))?; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginRows(rpc::BeginRows {}))?; + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginRow(rpc::BeginRow {}))?; + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.push(Step::AddRowValue(rpc::AddRowValue { + val: Some(v.into()), + }))?; + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishRow(rpc::FinishRow {}))?; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishRows(rpc::FinishRows {}))?; + Ok(()) + } + + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::Finish(rpc::Finish { + last_frame_no, + state: rpc::State::from(state).into(), + }))?; + self.flush()?; + Ok(()) + } + + fn into_ret(self) -> Self::Ret {} +} + +impl From> for RowValue { + fn from(value: ValueRef<'_>) -> Self { + let value = Some(match value { + ValueRef::Null => Value::Null(true), + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(x) => Value::Real(x), + ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + }); + + RowValue { value } + } +} + +#[cfg(test)] +pub mod test { + use insta::{assert_debug_snapshot, assert_snapshot}; + use tempfile::tempdir; + use tokio_stream::wrappers::ReceiverStream; + + use crate::auth::{Authorized, Permission}; + use crate::connection::libsql::LibSqlConnection; + use crate::connection::program::Program; + use crate::query_result_builder::test::{ + fsm_builder_driver, random_transition, TestBuilder, ValidateTraceBuilder, + }; + use crate::rpc::proxy::rpc::StreamProgramReq; + + use super::*; + + fn exec_req_stmt(s: &str, id: u32) -> ExecReq { + ExecReq { + request_id: id, + request: Some(Request::Execute(StreamProgramReq { + pgm: Some(Program::seq(&[s]).into()), + })), + } + } + + #[tokio::test] + async fn invalid_request() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let stream = make_proxy_stream(conn, Authenticated::Anonymous, ReceiverStream::new(rcv)); + pin!(stream); + + let req = ExecReq { + request_id: 0, + request: None, + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_snapshot!(stream.next().await.unwrap().unwrap_err().to_string()); + } + + #[tokio::test] + async fn request_stream_dropped() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + drop(snd); + + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn perform_query_simple() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } + + #[tokio::test] + async fn single_query_split_response() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + // limit the size of the response to force a split + let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 0); + for i in 1..50 { + let req = exec_req_stmt( + r#"insert into test values ("something moderately long")"#, + i, + ); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, i); + } + + let req = exec_req_stmt("select * from test", 100); + snd.send(Ok(req)).await.unwrap(); + + let mut num_resp = 0; + let mut builder = TestBuilder::default(); + loop { + let Response::ProgramResp(resp) = + stream.next().await.unwrap().unwrap().response.unwrap() + else { + panic!() + }; + if !apply_program_resp_to_builder( + &QueryBuilderConfig::default(), + &mut builder, + resp, + |_, _| (), + ) + .unwrap() + { + break; + } + num_resp += 1; + } + + assert_eq!(num_resp, 3); + assert_debug_snapshot!(builder.into_ret()); + } + + #[tokio::test] + async fn request_interupted() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(2); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req1 = exec_req_stmt("create table test (foo)", 0); + let req2 = exec_req_stmt("create table test (foo)", 1); + snd.send(Ok(req1)).await.unwrap(); + snd.send(Ok(req2)).await.unwrap(); + + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 1); + } + + #[tokio::test] + async fn describe() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req = ExecReq { + request_id: 0, + request: Some(Request::Describe(StreamDescribeReq { + stmt: "select $hello".into(), + })), + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } + + /// This fuction returns a random, valid, program resp for use in other tests + pub fn random_valid_program_resp( + size: usize, + max_resp_size: usize, + ) -> (impl Stream, ValidateTraceBuilder) { + let (sender, receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id: 0, + sender, + current: None, + current_size: 0, + max_program_resp_size: max_resp_size, + }; + + let trace = random_transition(size); + tokio::task::spawn_blocking({ + let trace = trace.clone(); + move || fsm_builder_driver(&trace, builder) + }); + + ( + ReceiverStream::new(receiver), + ValidateTraceBuilder::new(trace), + ) + } +} diff --git a/sqld/tests/cluster/mod.rs b/sqld/tests/cluster/mod.rs index 86cf8e44..aa227c36 100644 --- a/sqld/tests/cluster/mod.rs +++ b/sqld/tests/cluster/mod.rs @@ -2,9 +2,7 @@ use super::common; -use insta::assert_snapshot; use libsql::{Database, Value}; -use serde_json::json; use sqld::config::{AdminApiConfig, RpcClientConfig, RpcServerConfig, UserApiConfig}; use tempfile::tempdir; use tokio::{task::JoinSet, time::Duration}; @@ -205,40 +203,3 @@ fn sync_many_replica() { sim.run().unwrap(); } - -#[test] -fn create_namespace() { - let mut sim = Builder::new().build(); - make_cluster(&mut sim, 0, false); - - sim.client("client", async { - let db = - Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; - let conn = db.connect()?; - - let Err(e) = conn.execute("create table test (x)", ()).await else { - panic!() - }; - assert_snapshot!(e.to_string()); - - let client = Client::new(); - let resp = client - .post( - "http://foo.primary:9090/v1/namespaces/foo/create", - json!({}), - ) - .await?; - assert_eq!(resp.status(), 200); - - conn.execute("create table test (x)", ()).await.unwrap(); - let mut rows = conn.query("select count(*) from test", ()).await.unwrap(); - assert!(matches!( - rows.next().unwrap().unwrap().get_value(0).unwrap(), - Value::Integer(0) - )); - - Ok(()) - }); - - sim.run().unwrap(); -} diff --git a/sqld/tests/namespaces/mod.rs b/sqld/tests/namespaces/mod.rs index e6b2ee88..a0b90c96 100644 --- a/sqld/tests/namespaces/mod.rs +++ b/sqld/tests/namespaces/mod.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use crate::common::http::Client; use crate::common::net::{init_tracing, TestServer, TurmoilAcceptor, TurmoilConnector}; +use insta::assert_snapshot; use libsql::{Database, Value}; use serde_json::json; use sqld::config::{AdminApiConfig, RpcServerConfig, UserApiConfig}; @@ -41,6 +42,44 @@ fn make_primary(sim: &mut Sim, path: PathBuf) { }); } +#[test] +fn create_namespace() { + let mut sim = Builder::new().build(); + let tmp = tempdir().unwrap(); + make_primary(&mut sim, tmp.path().into()); + + sim.client("client", async { + let db = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; + let conn = db.connect()?; + + let Err(e) = conn.execute("create table test (x)", ()).await else { + panic!() + }; + assert_snapshot!(e.to_string()); + + let client = Client::new(); + let resp = client + .post( + "http://foo.primary:9090/v1/namespaces/foo/create", + json!({}), + ) + .await?; + assert_eq!(resp.status(), 200); + + conn.execute("create table test (x)", ()).await.unwrap(); + let mut rows = conn.query("select count(*) from test", ()).await.unwrap(); + assert!(matches!( + rows.next().unwrap().unwrap().get_value(0).unwrap(), + Value::Integer(0) + )); + + Ok(()) + }); + + sim.run().unwrap(); +} + #[test] fn fork_namespace() { let mut sim = Builder::new().build(); diff --git a/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap b/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap new file mode 100644 index 00000000..3f190961 --- /dev/null +++ b/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap @@ -0,0 +1,5 @@ +--- +source: sqld/tests/namespaces/mod.rs +expression: e.to_string() +--- +Hrana: `api error: `{"error":"Namespace `foo` doesn't exist"}``