From c426f280c4ccfa0dc637d2b3d174e1edbaf8deca Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 3 Oct 2023 14:04:38 +0200 Subject: [PATCH 01/17] extend proxy rpc with streaming proxy request --- sqld/proto/proxy.proto | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index 065c95a2..f060ff93 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -150,7 +150,69 @@ message ProgramReq { Program pgm = 2; } +message ExecMessage { + /// id of the request. The response will contain this id + uint32 request_id = 1; + oneof request { + ProgramReq execute = 2; + DescribeRequest describe = 3; + } +} + +/// streaming exec proto + +message Init { } +message BeginStep { } +message FinishStep { + uint64 affected_row_count = 1; + optional uint64 last_insert_row_id = 2; +} +message StepError { + string error = 1; +} +message ColsDescription { + repeated Column columns = 1; +} +message BeginRows { } +message BeginRow { } +message AddRowValue { + Value val = 1; +} +message FinishRow { } +message FinishRows { } +message Finish { + optional uint64 last_frame_no = 1; +} + + +message Message { + oneof payload { + Description describe_result = 1; + + Init init = 2; + BeginStep begin_step = 3; + FinishStep finish_step = 4; + StepError step_error = 5; + ColsDescription cols_description = 6; + BeginRows begin_rows = 7; + BeginRow begin_row = 8; + AddRowValue add_row_value = 9; + FinishRow finish_row = 10; + FinishRows finish_rows = 11; + Finish finish = 12; + + Error error = 13; + } +} + +message ExecResponse { + uint32 request_id = 1; + repeated Message messages = 2; +} + service Proxy { + rpc StreamExec(stream ExecMessage) returns (stream ExecResponse) {} + rpc Execute(ProgramReq) returns (ExecuteResults) {} rpc Describe(DescribeRequest) returns (DescribeResult) {} rpc Disconnect(DisconnectMessage) returns (Ack) {} From 565909fbb1b06b6657a38267230bbb7a13410803 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 3 Oct 2023 14:05:01 +0200 Subject: [PATCH 02/17] proxy streaming proxy request to primary --- sqld/src/rpc/replica_proxy.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sqld/src/rpc/replica_proxy.rs b/sqld/src/rpc/replica_proxy.rs index c4aa7179..02f58759 100644 --- a/sqld/src/rpc/replica_proxy.rs +++ b/sqld/src/rpc/replica_proxy.rs @@ -1,13 +1,14 @@ use std::sync::Arc; use hyper::Uri; +use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; use crate::auth::Auth; use super::proxy::rpc::{ self, proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, - DisconnectMessage, ExecuteResults, + DisconnectMessage, ExecuteResults, ExecResponse, ExecMessage, }; pub struct ReplicaProxyService { @@ -32,6 +33,16 @@ impl ReplicaProxyService { #[tonic::async_trait] impl Proxy for ReplicaProxyService { + type StreamExecStream = tonic::codec::Streaming; + + async fn stream_exec(&self,req: tonic::Request>) -> Result, tonic::Status> { + let (meta, ext, stream) = req.into_parts(); + let mut req = tonic::Request::from_parts(meta, ext, stream.map(|r| r.unwrap())); // TODO: handle mapping error + self.do_auth(&mut req)?; + let mut client = self.client.clone(); + client.stream_exec(req).await + } + async fn execute( &self, mut req: tonic::Request, From 636bf2eb3f68efbd98965d609982f7570efc9ff4 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 3 Oct 2023 15:20:32 +0200 Subject: [PATCH 03/17] wip --- sqld/proto/proxy.proto | 1 - sqld/src/rpc/proxy.rs | 190 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 2 deletions(-) diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index f060ff93..7b9556dc 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -110,7 +110,6 @@ message Step { optional Cond cond = 1; Query query = 2; } - message Cond { oneof cond { OkCond ok = 1; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 5ba004a7..b1949299 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -1,8 +1,13 @@ use std::collections::HashMap; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{ready, Context, Poll}; +use std::future::Future; use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use futures_core::Stream; +use tokio::sync::mpsc; use uuid::Uuid; use crate::auth::{Auth, Authenticated}; @@ -14,11 +19,12 @@ use crate::query_result_builder::{ }; use crate::replication::FrameNo; +use self::rpc::exec_message::Request; use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, - ExecuteResults, QueryResult, ResultRows, Row, + ExecuteResults, QueryResult, ResultRows, Row, ExecMessage, ExecResponse, }; use super::NAMESPACE_DOESNT_EXIST; @@ -455,8 +461,189 @@ pub async fn garbage_collect(clients: &mut HashMap> tracing::trace!("gc: remaining client handles count: {}", clients.len()); } +pin_project_lite::pin_project! { + pub struct StreamRequestHandler { + #[pin] + request_stream: S, + connection: Arc>, + state: HandlerState, + authenticated: Authenticated, + } +} + +struct StreamResponseBuilder { + request_id: u32, + sender: mpsc::Sender, +} + +impl QueryResultBuilder for StreamResponseBuilder { + type Ret = (); + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn step_error(&mut self, _error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn cols_description<'a>( + &mut self, + _cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn add_row_value(&mut self, _v: rusqlite::types::ValueRef) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + todo!() + } + + fn into_ret(self) -> Self::Ret { + todo!() + } +} + +enum HandlerState { + Execute(Pin>>), + Idle, + Fused, +} + +impl Stream for StreamRequestHandler + where S: Stream> +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match this.state { + HandlerState::Idle => { + match ready!(this.request_stream.poll_next(cx)) { + Some(Err(e)) => { + *this.state = HandlerState::Fused; + return Poll::Ready(Some(Err(e))); + } + Some(Ok(req)) => { + match req.request.unwrap() { + Request::Execute(exec) => { + let pgm = crate::connection::program::Program::try_from(exec.pgm.unwrap()).unwrap(); + let conn = this.connection.clone(); + let authenticated = this.authenticated.clone(); + + let s = async_stream::stream! { + let (sender, receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id: req.request_id, + sender, + }; + let fut = conn.execute_program(pgm, authenticated, builder, None); + loop { + tokio::select! { + res = fut => { + // todo check result? + break + } + msg = receiver.recv() => { + if let Some(msg) = msg { + yield msg; + } + } + } + } + }; + let fut = Box::pin(async move { + Ok(()) + }); + *this.state = HandlerState::Execute(Box::pin(s)); + }, + Request::Describe(_) => todo!(), + } + // we have placed the request, poll immediately + cx.waker().wake_by_ref(); + return Poll::Pending; + }, + None => Poll::Ready(None), + } + }, + HandlerState::Fused => Poll::Ready(None), + HandlerState::Execute(_) => todo!(), + } + } +} + #[tonic::async_trait] impl Proxy for ProxyService { + type StreamExecStream = StreamRequestHandler>; + + async fn stream_exec(&self, req: tonic::Request>) ->Result, tonic::Status> { + let authenticated = if let Some(auth) = &self.auth { + auth.authenticate_grpc(&req, self.disable_namespaces)? + } else { + Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? + }; + + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let (connection_maker, _new_frame_notifier) = self + .namespaces + .with(namespace, |ns| { + let connection_maker = ns.db.connection_maker(); + let notifier = ns.db.logger.new_frame_notifier.subscribe(); + (connection_maker, notifier) + }) + .await + .map_err(|e| { + if let crate::error::Error::NamespaceDoesntExist(_) = e { + tonic::Status::failed_precondition(NAMESPACE_DOESNT_EXIST) + } else { + tonic::Status::internal(e.to_string()) + } + })?; + + let connection = connection_maker.create().await.unwrap(); + + let handler = StreamRequestHandler { + authenticated, + request_stream: req.into_inner(), + connection: connection.into(), + state: HandlerState::Idle, + }; + + Ok(tonic::Response::new(handler)) + } + async fn execute( &self, req: tonic::Request, @@ -615,4 +802,5 @@ impl Proxy for ProxyService { })), })) } + } From 1efbd6d4c182156be3ee7e9886d58fff1087c401 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 3 Oct 2023 19:15:51 +0200 Subject: [PATCH 04/17] streaming proxy primary --- sqld/proto/proxy.proto | 2 +- sqld/src/rpc/proxy.rs | 168 ++++++++++++++++++++++++++-------- sqld/src/rpc/replica_proxy.rs | 7 +- 3 files changed, 134 insertions(+), 43 deletions(-) diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index 7b9556dc..be006ae4 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -164,7 +164,7 @@ message Init { } message BeginStep { } message FinishStep { uint64 affected_row_count = 1; - optional uint64 last_insert_row_id = 2; + optional int64 last_insert_rowid = 2; } message StepError { string error = 1; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index b1949299..884783e1 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -3,10 +3,11 @@ use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use std::future::Future; use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use futures::StreamExt; use futures_core::Stream; +use rusqlite::types::ValueRef; use tokio::sync::mpsc; use uuid::Uuid; @@ -20,11 +21,12 @@ use crate::query_result_builder::{ use crate::replication::FrameNo; use self::rpc::exec_message::Request; +use self::rpc::message::Payload; use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, - ExecuteResults, QueryResult, ResultRows, Row, ExecMessage, ExecResponse, + ExecMessage, ExecResponse, ExecuteResults, Message, QueryResult, ResultRows, Row, }; use super::NAMESPACE_DOESNT_EXIST; @@ -405,7 +407,7 @@ impl QueryResultBuilder for ExecuteResultBuilder { fn add_row_value( &mut self, - v: rusqlite::types::ValueRef, + v: ValueRef, ) -> Result<(), QueryResultBuilderError> { let data = bincode::serialize( &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, @@ -474,75 +476,141 @@ pin_project_lite::pin_project! { struct StreamResponseBuilder { request_id: u32, sender: mpsc::Sender, + current: Option, +} + +impl StreamResponseBuilder { + fn current(&mut self) -> &mut ExecResponse { + self.current.get_or_insert_with(|| ExecResponse { + messages: Vec::new(), + request_id: self.request_id, + }) + } + + fn push(&mut self, payload: Payload) { + const MAX_RESPONSE_MESSAGES: usize = 10; + + let current = self.current(); + current.messages.push(Message { + payload: Some(payload), + }); + + if current.messages.len() > MAX_RESPONSE_MESSAGES { + self.flush() + } + } + + fn flush(&mut self) { + if let Some(current) = self.current.take() { + self.sender.blocking_send(current).unwrap(); + } + } } impl QueryResultBuilder for StreamResponseBuilder { type Ret = (); fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::Init(rpc::Init {})); + Ok(()) } fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::BeginStep(rpc::BeginStep {})); + Ok(()) } fn finish_step( &mut self, - _affected_row_count: u64, - _last_insert_rowid: Option, + affected_row_count: u64, + last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::FinishStep(rpc::FinishStep { + affected_row_count, + last_insert_rowid, + })); + Ok(()) } - fn step_error(&mut self, _error: crate::error::Error) -> Result<(), QueryResultBuilderError> { - todo!() + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.push(Payload::StepError(rpc::StepError { + error: error.to_string(), + })); + Ok(()) } fn cols_description<'a>( &mut self, - _cols: impl IntoIterator>>, + cols: impl IntoIterator>>, ) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::ColsDescription(rpc::ColsDescription { + columns: cols + .into_iter() + .map(Into::into) + .map(|c| rpc::Column { + name: c.name.into(), + decltype: c.decl_ty.map(Into::into), + }) + .collect::>(), + })); + Ok(()) } fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::BeginRows(rpc::BeginRows {})); + Ok(()) } fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::BeginRow(rpc::BeginRow {})); + Ok(()) } - fn add_row_value(&mut self, _v: rusqlite::types::ValueRef) -> Result<(), QueryResultBuilderError> { - todo!() + fn add_row_value( + &mut self, + v: ValueRef, + ) -> Result<(), QueryResultBuilderError> { + let data = bincode::serialize( + &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, + ) + .map_err(QueryResultBuilderError::from_any)?; + + let val = Some(rpc::Value { data }); + + self.push(Payload::AddRowValue(rpc::AddRowValue { val })); + Ok(()) } fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::FinishRow(rpc::FinishRow {})); + Ok(()) } fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - todo!() + self.push(Payload::FinishRows(rpc::FinishRows {})); + Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { - todo!() + fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + self.push(Payload::Finish(rpc::Finish { last_frame_no })); + self.flush(); + Ok(()) } fn into_ret(self) -> Self::Ret { - todo!() + () } } enum HandlerState { - Execute(Pin>>), + Execute(Pin + Send>>), Idle, Fused, } -impl Stream for StreamRequestHandler - where S: Stream> +impl Stream for StreamRequestHandler +where + S: Stream>, { type Item = Result; @@ -557,22 +625,27 @@ impl Stream for StreamRequestHandler return Poll::Ready(Some(Err(e))); } Some(Ok(req)) => { + let request_id = req.request_id; match req.request.unwrap() { Request::Execute(exec) => { - let pgm = crate::connection::program::Program::try_from(exec.pgm.unwrap()).unwrap(); + let pgm = crate::connection::program::Program::try_from( + exec.pgm.unwrap(), + ) + .unwrap(); let conn = this.connection.clone(); let authenticated = this.authenticated.clone(); let s = async_stream::stream! { - let (sender, receiver) = mpsc::channel(1); + let (sender, mut receiver) = mpsc::channel(1); let builder = StreamResponseBuilder { - request_id: req.request_id, + request_id, sender, + current: None, }; - let fut = conn.execute_program(pgm, authenticated, builder, None); + let mut fut = conn.execute_program(pgm, authenticated, builder, None); loop { tokio::select! { - res = fut => { + _res = &mut fut => { // todo check result? break } @@ -584,22 +657,35 @@ impl Stream for StreamRequestHandler } } }; - let fut = Box::pin(async move { - Ok(()) - }); *this.state = HandlerState::Execute(Box::pin(s)); - }, + } Request::Describe(_) => todo!(), } // we have placed the request, poll immediately cx.waker().wake_by_ref(); return Poll::Pending; - }, - None => Poll::Ready(None), + } + None => { + // this would easier if tokio_stream re-exported combinators + *this.state = HandlerState::Fused; + Poll::Ready(None) + } } - }, + } HandlerState::Fused => Poll::Ready(None), - HandlerState::Execute(_) => todo!(), + HandlerState::Execute(stream) => { + let resp = ready!(stream.poll_next_unpin(cx)); + match resp { + Some(resp) => return Poll::Ready(Some(Ok(resp))), + None => { + // finished processing this query. Wake up immediately to prepare for the + // next + *this.state = HandlerState::Idle; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } } } } @@ -608,7 +694,10 @@ impl Stream for StreamRequestHandler impl Proxy for ProxyService { type StreamExecStream = StreamRequestHandler>; - async fn stream_exec(&self, req: tonic::Request>) ->Result, tonic::Status> { + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { let authenticated = if let Some(auth) = &self.auth { auth.authenticate_grpc(&req, self.disable_namespaces)? } else { @@ -623,7 +712,7 @@ impl Proxy for ProxyService { let notifier = ns.db.logger.new_frame_notifier.subscribe(); (connection_maker, notifier) }) - .await + .await .map_err(|e| { if let crate::error::Error::NamespaceDoesntExist(_) = e { tonic::Status::failed_precondition(NAMESPACE_DOESNT_EXIST) @@ -802,5 +891,4 @@ impl Proxy for ProxyService { })), })) } - } diff --git a/sqld/src/rpc/replica_proxy.rs b/sqld/src/rpc/replica_proxy.rs index 02f58759..08fbbaf5 100644 --- a/sqld/src/rpc/replica_proxy.rs +++ b/sqld/src/rpc/replica_proxy.rs @@ -8,7 +8,7 @@ use crate::auth::Auth; use super::proxy::rpc::{ self, proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, - DisconnectMessage, ExecuteResults, ExecResponse, ExecMessage, + DisconnectMessage, ExecMessage, ExecResponse, ExecuteResults, }; pub struct ReplicaProxyService { @@ -35,7 +35,10 @@ impl ReplicaProxyService { impl Proxy for ReplicaProxyService { type StreamExecStream = tonic::codec::Streaming; - async fn stream_exec(&self,req: tonic::Request>) -> Result, tonic::Status> { + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { let (meta, ext, stream) = req.into_parts(); let mut req = tonic::Request::from_parts(meta, ext, stream.map(|r| r.unwrap())); // TODO: handle mapping error self.do_auth(&mut req)?; From 065ee37557d12aac4a90e35236a5c139c402ea3f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 4 Oct 2023 12:03:19 +0200 Subject: [PATCH 05/17] write proxy streaming request --- sqld/proto/proxy.proto | 22 +- sqld/src/connection/libsql.rs | 21 +- sqld/src/connection/mod.rs | 12 +- sqld/src/connection/write_proxy.rs | 260 +++++++++++++++--------- sqld/src/hrana/cursor.rs | 7 +- sqld/src/hrana/result_builder.rs | 13 +- sqld/src/http/user/mod.rs | 10 +- sqld/src/http/user/result_builder.rs | 7 +- sqld/src/query_analysis.rs | 20 +- sqld/src/query_result_builder.rs | 27 ++- sqld/src/rpc/mod.rs | 1 + sqld/src/rpc/proxy.rs | 288 +++------------------------ sqld/src/rpc/replica_proxy.rs | 6 +- sqld/src/rpc/streaming_exec.rs | 283 ++++++++++++++++++++++++++ 14 files changed, 576 insertions(+), 401 deletions(-) create mode 100644 sqld/src/rpc/streaming_exec.rs diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index be006ae4..98586bc7 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -89,13 +89,14 @@ message DisconnectMessage { message Ack { } +enum State { + INIT = 0; + INVALID = 1; + TXN = 2; +} + message ExecuteResults { repeated QueryResult results = 1; - enum State { - Init = 0; - Invalid = 1; - Txn = 2; - } /// State after executing the queries State state = 2; /// Primary frame_no after executing the request. @@ -149,11 +150,11 @@ message ProgramReq { Program pgm = 2; } -message ExecMessage { +message ExecReq { /// id of the request. The response will contain this id uint32 request_id = 1; oneof request { - ProgramReq execute = 2; + Program execute = 2; DescribeRequest describe = 3; } } @@ -167,7 +168,7 @@ message FinishStep { optional int64 last_insert_rowid = 2; } message StepError { - string error = 1; + Error error = 1; } message ColsDescription { repeated Column columns = 1; @@ -181,6 +182,7 @@ message FinishRow { } message FinishRows { } message Finish { optional uint64 last_frame_no = 1; + State state = 2; } @@ -204,13 +206,13 @@ message Message { } } -message ExecResponse { +message ExecResp { uint32 request_id = 1; repeated Message messages = 2; } service Proxy { - rpc StreamExec(stream ExecMessage) returns (stream ExecResponse) {} + rpc StreamExec(stream ExecReq) returns (stream ExecResp) {} rpc Execute(ProgramReq) returns (ExecuteResults) {} rpc Describe(DescribeRequest) returns (DescribeResult) {} diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index e70e66d3..9f47bcae 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -13,7 +13,7 @@ use crate::auth::{Authenticated, Authorized, Permission}; use crate::error::Error; use crate::libsql_bindings::wal_hook::WalHook; use crate::query::Query; -use crate::query_analysis::{State, StmtKind}; +use crate::query_analysis::{StmtKind, TxnStatus}; use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::stats::Stats; @@ -405,7 +405,7 @@ impl Connection { this: Arc>, pgm: Program, mut builder: B, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { use rusqlite::TransactionState as Tx; let state = this.lock().state.clone(); @@ -469,20 +469,23 @@ impl Connection { results.push(res); } - builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?; - - let state = if matches!( + let status = if matches!( this.lock() .conn .transaction_state(Some(DatabaseName::Main))?, Tx::Read | Tx::Write ) { - State::Txn + TxnStatus::Txn } else { - State::Init + TxnStatus::Init }; - Ok((builder, state)) + builder.finish( + *this.lock().current_frame_no_receiver.borrow_and_update(), + status, + )?; + + Ok((builder, status)) } fn execute_step( @@ -733,7 +736,7 @@ where auth: Authenticated, builder: B, _replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { check_program_auth(auth, &pgm)?; let conn = self.inner.clone(); tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder)) diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index 2d539057..1c24a142 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -8,7 +8,7 @@ use tokio::{sync::Semaphore, time::timeout}; use crate::auth::Authenticated; use crate::error::Error; use crate::query::{Params, Query}; -use crate::query_analysis::{State, Statement}; +use crate::query_analysis::{Statement, TxnStatus}; use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; use crate::replication::FrameNo; use crate::Result; @@ -32,7 +32,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, response_builder: B, replication_index: Option, - ) -> Result<(B, State)>; + ) -> Result<(B, TxnStatus)>; /// Execute all the queries in the batch sequentially. /// If an query in the batch fails, the remaining queries are ignores, and the batch current @@ -43,7 +43,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { let batch_len = batch.len(); let mut steps = make_batch_program(batch); @@ -82,7 +82,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { let steps = make_batch_program(batch); let pgm = Program::new(steps); self.execute_program(pgm, auth, result_builder, replication_index) @@ -312,7 +312,7 @@ impl Connection for TrackedConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result<(B, TxnStatus)> { self.atime.store(now_millis(), Ordering::Relaxed); self.inner .execute_program(pgm, auth, builder, replication_index) @@ -367,7 +367,7 @@ mod test { _auth: Authenticated, _builder: B, _replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result<(B, TxnStatus)> { unreachable!() } diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 26be6ef9..8a1f678c 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -1,27 +1,31 @@ use std::path::PathBuf; use std::sync::Arc; +use futures_core::future::BoxFuture; use parking_lot::Mutex as PMutex; use rusqlite::types::ValueRef; use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; -use tokio::sync::{watch, Mutex}; +use tokio::sync::{mpsc, watch, Mutex}; +use tokio_stream::StreamExt; use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; -use tonic::Request; +use tonic::{Request, Streaming}; use uuid::Uuid; use crate::auth::Authenticated; use crate::error::Error; use crate::namespace::NamespaceName; use crate::query::Value; -use crate::query_analysis::State; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::proxy_client::ProxyClient; -use crate::rpc::proxy::rpc::query_result::RowResult; -use crate::rpc::proxy::rpc::{DisconnectMessage, ExecuteResults}; +use crate::rpc::proxy::rpc::{ + self, AddRowValue, ColsDescription, DisconnectMessage, ExecReq, ExecResp, Finish, FinishStep, + StepError, +}; use crate::rpc::NAMESPACE_METADATA_KEY; use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; @@ -109,7 +113,7 @@ pub struct WriteProxyConnection { /// Lazily initialized read connection read_conn: LibSqlConnection, write_proxy: ProxyClient, - state: Mutex, + state: Mutex, client_id: Uuid, /// FrameNo of the last write performed by this connection on the primary. /// any subsequent read on this connection must wait for the replicator to catch up with this @@ -120,51 +124,8 @@ pub struct WriteProxyConnection { builder_config: QueryBuilderConfig, stats: Arc, namespace: NamespaceName, -} - -fn execute_results_to_builder( - execute_result: ExecuteResults, - mut builder: B, - config: &QueryBuilderConfig, -) -> Result { - builder.init(config)?; - for result in execute_result.results { - match result.row_result { - Some(RowResult::Row(rows)) => { - builder.begin_step()?; - builder.cols_description(rows.column_descriptions.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decltype.as_deref(), - }))?; - - builder.begin_rows()?; - for row in rows.rows { - builder.begin_row()?; - for value in row.values { - let value: Value = bincode::deserialize(&value.data) - // something is wrong, better stop right here - .map_err(QueryResultBuilderError::from_any)?; - builder.add_row_value(ValueRef::from(&value))?; - } - builder.finish_row()?; - } - - builder.finish_rows()?; - builder.finish_step(rows.affected_row_count, rows.last_insert_rowid)?; - } - Some(RowResult::Error(err)) => { - builder.begin_step()?; - builder.step_error(Error::RpcQueryError(err))?; - builder.finish_step(0, None)?; - } - None => (), - } - } - - builder.finish(execute_result.current_frame_no)?; - - Ok(builder) + remote_conn: Mutex>, } impl WriteProxyConnection { @@ -180,56 +141,63 @@ impl WriteProxyConnection { Ok(Self { read_conn, write_proxy, - state: Mutex::new(State::Init), + state: Mutex::new(TxnStatus::Init), client_id: Uuid::new_v4(), - last_write_frame_no: PMutex::new(None), + last_write_frame_no: Default::default(), applied_frame_no_receiver, builder_config, stats, namespace, + remote_conn: Default::default(), }) } + async fn with_remote_conn( + &self, + auth: Authenticated, + builder_config: QueryBuilderConfig, + cb: F, + ) -> crate::Result + where + F: FnOnce(&mut RemoteConnection) -> BoxFuture<'_, crate::Result>, + { + let mut remote_conn = self.remote_conn.lock().await; + // TODO: catch broken connection, and reset it to None. + if remote_conn.is_some() { + cb(remote_conn.as_mut().unwrap()).await + } else { + let conn = RemoteConnection::connect( + self.write_proxy.clone(), + self.namespace.clone(), + auth, + builder_config, + ) + .await?; + let conn = remote_conn.insert(conn); + cb(conn).await + } + } + async fn execute_remote( &self, pgm: Program, - state: &mut State, + status: &mut TxnStatus, auth: Authenticated, builder: B, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { self.stats.inc_write_requests_delegated(); - let mut client = self.write_proxy.clone(); - - let mut req = Request::new(crate::rpc::proxy::rpc::ProgramReq { - client_id: self.client_id.to_string(), - pgm: Some(pgm.into()), - }); - - let namespace = BinaryMetadataValue::from_bytes(self.namespace.as_slice()); - req.metadata_mut() - .insert_bin(NAMESPACE_METADATA_KEY, namespace); - auth.upgrade_grpc_request(&mut req); - - match client.execute(req).await { - Ok(r) => { - let execute_result = r.into_inner(); - *state = execute_result.state().into(); - let current_frame_no = execute_result.current_frame_no; - let builder = - execute_results_to_builder(execute_result, builder, &self.builder_config)?; - if let Some(current_frame_no) = current_frame_no { - self.update_last_write_frame_no(current_frame_no); - } - - Ok((builder, *state)) - } - Err(e) => { - // Set state to invalid, so next call is sent to remote, and we have a chance - // to recover state. - *state = State::Invalid; - Err(Error::RpcQueryExecutionError(e)) - } + *status = TxnStatus::Invalid; + let (builder, new_status, new_frame_no) = self + .with_remote_conn(auth, self.builder_config, |conn| { + Box::pin(conn.execute(pgm, builder)) + }) + .await?; + *status = new_status; + if let Some(current_frame_no) = new_frame_no { + self.update_last_write_frame_no(current_frame_no); } + + Ok((builder, new_status)) } fn update_last_write_frame_no(&self, new_frame_no: FrameNo) { @@ -261,6 +229,118 @@ impl WriteProxyConnection { } } +struct RemoteConnection { + response_stream: Streaming, + request_sender: mpsc::Sender, + current_request_id: u32, + builder_config: QueryBuilderConfig, +} + +impl RemoteConnection { + async fn connect( + mut client: ProxyClient, + namespace: NamespaceName, + auth: Authenticated, + builder_config: QueryBuilderConfig, + ) -> crate::Result { + let (request_sender, receiver) = mpsc::channel(1); + + let stream = tokio_stream::wrappers::ReceiverStream::new(receiver); + let mut req = Request::new(stream); + let namespace = BinaryMetadataValue::from_bytes(namespace.as_slice()); + req.metadata_mut() + .insert_bin(NAMESPACE_METADATA_KEY, namespace); + auth.upgrade_grpc_request(&mut req); + dbg!(); + let response_stream = client.stream_exec(req).await.unwrap().into_inner(); + dbg!(); + + Ok(Self { + response_stream, + request_sender, + current_request_id: 0, + builder_config, + }) + } + + async fn execute( + &mut self, + program: Program, + mut builder: B, + ) -> crate::Result<(B, TxnStatus, Option)> { + let request_id = self.current_request_id; + self.current_request_id += 1; + + let req = ExecReq { + request_id, + request: Some(rpc::exec_req::Request::Execute(program.into())), + }; + + dbg!(); + self.request_sender.send(req).await.unwrap(); // TODO: the stream was close! + dbg!(); + let mut txn_status = TxnStatus::Invalid; + let mut new_frame_no = None; + + 'outer: while let Some(resp) = self.response_stream.next().await { + dbg!(&resp); + match resp { + Ok(resp) => { + if resp.request_id != request_id { + todo!("stream misuse: connection should be serialized"); + } + for message in resp.messages { + use rpc::message::Payload; + + match message.payload.unwrap() { + Payload::DescribeResult(_) => todo!("invalid response"), + + Payload::Init(_) => builder.init(&self.builder_config)?, + Payload::BeginStep(_) => builder.begin_step()?, + Payload::FinishStep(FinishStep { + affected_row_count, + last_insert_rowid, + }) => builder.finish_step(affected_row_count, last_insert_rowid)?, + Payload::StepError(StepError { error }) => builder + .step_error(crate::error::Error::RpcQueryError(error.unwrap()))?, + Payload::ColsDescription(ColsDescription { columns }) => { + let cols = columns.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }); + builder.cols_description(cols)? + } + Payload::BeginRows(_) => builder.begin_rows()?, + Payload::BeginRow(_) => builder.begin_row()?, + Payload::AddRowValue(AddRowValue { val }) => { + let value: Value = bincode::deserialize(&val.unwrap().data) + // something is wrong, better stop right here + .map_err(QueryResultBuilderError::from_any)?; + builder.add_row_value(ValueRef::from(&value))?; + } + Payload::FinishRow(_) => builder.finish_row()?, + Payload::FinishRows(_) => builder.finish_rows()?, + Payload::Finish(f @ Finish { last_frame_no, .. }) => { + txn_status = TxnStatus::from(f.state()); + new_frame_no = last_frame_no; + builder.finish(last_frame_no, txn_status)?; + dbg!(); + break 'outer; + } + Payload::Error(error) => { + return Err(crate::error::Error::RpcQueryError(error)) + } + } + } + } + Err(_e) => todo!("handle stream error"), + } + } + + Ok((builder, txn_status, new_frame_no)) + } +} + #[async_trait::async_trait] impl Connection for WriteProxyConnection { async fn execute_program( @@ -269,13 +349,13 @@ impl Connection for WriteProxyConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result<(B, TxnStatus)> { let mut state = self.state.lock().await; // This is a fresh namespace, and it is not replicated yet, proxy the first request. if self.applied_frame_no_receiver.borrow().is_none() { self.execute_remote(pgm, &mut state, auth, builder).await - } else if *state == State::Init && pgm.is_read_only() { + } else if *state == TxnStatus::Init && pgm.is_read_only() { self.wait_replication_sync(replication_index).await?; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive @@ -284,7 +364,7 @@ impl Connection for WriteProxyConnection { .read_conn .execute_program(pgm.clone(), auth.clone(), builder, replication_index) .await?; - if new_state != State::Init { + if new_state != TxnStatus::Init { self.read_conn.rollback(auth.clone()).await?; self.execute_remote(pgm, &mut state, auth, builder).await } else { @@ -308,8 +388,8 @@ impl Connection for WriteProxyConnection { async fn is_autocommit(&self) -> Result { let state = self.state.lock().await; Ok(match *state { - State::Txn => false, - State::Init | State::Invalid => true, + TxnStatus::Txn => false, + TxnStatus::Init | TxnStatus::Invalid => true, }) } diff --git a/sqld/src/hrana/cursor.rs b/sqld/src/hrana/cursor.rs index 005799a7..c67079a9 100644 --- a/sqld/src/hrana/cursor.rs +++ b/sqld/src/hrana/cursor.rs @@ -8,6 +8,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::auth::Authenticated; use crate::connection::program::Program; use crate::connection::Connection; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -255,7 +256,11 @@ impl QueryResultBuilder for CursorResultBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.emit_entry(Ok(SizedEntry { entry: proto::CursorEntry::ReplicationIndex { replication_index: last_frame_no, diff --git a/sqld/src/hrana/result_builder.rs b/sqld/src/hrana/result_builder.rs index d2e19910..70d1890a 100644 --- a/sqld/src/hrana/result_builder.rs +++ b/sqld/src/hrana/result_builder.rs @@ -6,6 +6,7 @@ use bytes::Bytes; use rusqlite::types::ValueRef; use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, }; @@ -225,7 +226,11 @@ impl QueryResultBuilder for SingleStatementBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.last_frame_no = last_frame_no; Ok(()) } @@ -344,7 +349,11 @@ impl QueryResultBuilder for HranaBatchProtoBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } diff --git a/sqld/src/http/user/mod.rs b/sqld/src/http/user/mod.rs index 367f7e77..c99bba95 100644 --- a/sqld/src/http/user/mod.rs +++ b/sqld/src/http/user/mod.rs @@ -37,7 +37,7 @@ use crate::http::user::types::HttpQuery; use crate::namespace::{MakeNamespace, NamespaceStore}; use crate::net::Accept; use crate::query::{self, Query}; -use crate::query_analysis::{predict_final_state, State, Statement}; +use crate::query_analysis::{predict_final_state, Statement, TxnStatus}; use crate::query_result_builder::QueryResultBuilder; use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; @@ -97,15 +97,15 @@ fn parse_queries(queries: Vec) -> crate::Result> { out.push(query); } - match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { - State::Txn => { + match predict_final_state(TxnStatus::Init, out.iter().map(|q| &q.stmt)) { + TxnStatus::Txn => { return Err(Error::QueryError( "interactive transaction not allowed in HTTP queries".to_string(), )) } - State::Init => (), + TxnStatus::Init => (), // maybe we should err here, but let's sqlite deal with that. - State::Invalid => (), + TxnStatus::Invalid => (), } Ok(out) diff --git a/sqld/src/http/user/result_builder.rs b/sqld/src/http/user/result_builder.rs index fa7c4710..c6b4d8a2 100644 --- a/sqld/src/http/user/result_builder.rs +++ b/sqld/src/http/user/result_builder.rs @@ -6,6 +6,7 @@ use serde::{Serialize, Serializer}; use serde_json::ser::{CompactFormatter, Formatter}; use std::sync::atomic::Ordering; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, JsonFormatter, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, @@ -293,7 +294,11 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { } // TODO: how do we return last_frame_no? - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.formatter.end_array(&mut self.buffer)?; Ok(()) diff --git a/sqld/src/query_analysis.rs b/sqld/src/query_analysis.rs index 5f0e4f37..32352620 100644 --- a/sqld/src/query_analysis.rs +++ b/sqld/src/query_analysis.rs @@ -171,7 +171,7 @@ impl StmtKind { /// The state of a transaction for a series of statement #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum State { +pub enum TxnStatus { /// The txn in an opened state Txn, /// The txn in a closed state @@ -180,19 +180,21 @@ pub enum State { Invalid, } -impl State { +impl TxnStatus { pub fn step(&mut self, kind: StmtKind) { *self = match (*self, kind) { - (State::Txn, StmtKind::TxnBegin) | (State::Init, StmtKind::TxnEnd) => State::Invalid, - (State::Txn, StmtKind::TxnEnd) => State::Init, + (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => { + TxnStatus::Invalid + } + (TxnStatus::Txn, StmtKind::TxnEnd) => TxnStatus::Init, (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, - (State::Invalid, _) => State::Invalid, - (State::Init, StmtKind::TxnBegin) => State::Txn, + (TxnStatus::Invalid, _) => TxnStatus::Invalid, + (TxnStatus::Init, StmtKind::TxnBegin) => TxnStatus::Txn, }; } pub fn reset(&mut self) { - *self = State::Init + *self = TxnStatus::Init } } @@ -284,9 +286,9 @@ impl Statement { /// Given a an initial state and an array of queries, attempts to predict what the final state will /// be pub fn predict_final_state<'a>( - mut state: State, + mut state: TxnStatus, stmts: impl Iterator, -) -> State { +) -> TxnStatus { for stmt in stmts { state.step(stmt.kind); } diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index 914037ee..03cc219b 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -8,6 +8,7 @@ use serde::Serialize; use serde_json::ser::Formatter; use std::sync::atomic::AtomicUsize; +use crate::query_analysis::TxnStatus; use crate::replication::FrameNo; pub static TOTAL_RESPONSE_SIZE: AtomicUsize = AtomicUsize::new(0); @@ -120,7 +121,11 @@ pub trait QueryResultBuilder: Send + 'static { /// end adding rows fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError>; /// finish serialization. - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError>; + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError>; /// returns the inner ret fn into_ret(self) -> Self::Ret; /// Returns a `QueryResultBuilder` that wraps Self and takes at most `n` steps @@ -311,7 +316,11 @@ impl QueryResultBuilder for StepResultsBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -372,7 +381,11 @@ impl QueryResultBuilder for IgnoreResult { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -481,8 +494,12 @@ impl QueryResultBuilder for Take { } } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { - self.inner.finish(last_frame_no) + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.inner.finish(last_frame_no, state) } fn into_ret(self) -> Self::Ret { diff --git a/sqld/src/rpc/mod.rs b/sqld/src/rpc/mod.rs index 252d58e7..3fda4f4d 100644 --- a/sqld/src/rpc/mod.rs +++ b/sqld/src/rpc/mod.rs @@ -19,6 +19,7 @@ pub mod proxy; pub mod replica_proxy; pub mod replication_log; pub mod replication_log_proxy; +mod streaming_exec; /// A tonic error code to signify that a namespace doesn't exist. pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST"; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 884783e1..e248be81 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -1,33 +1,28 @@ use std::collections::HashMap; -use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; -use std::task::{ready, Context, Poll}; use async_lock::{RwLock, RwLockUpgradableReadGuard}; -use futures::StreamExt; -use futures_core::Stream; use rusqlite::types::ValueRef; -use tokio::sync::mpsc; use uuid::Uuid; use crate::auth::{Auth, Authenticated}; use crate::connection::Connection; use crate::database::{Database, PrimaryConnection}; use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; -use self::rpc::exec_message::Request; -use self::rpc::message::Payload; use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ - describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, - ExecMessage, ExecResponse, ExecuteResults, Message, QueryResult, ResultRows, Row, + describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq, + ExecuteResults, QueryResult, ResultRows, Row, }; +use super::streaming_exec::StreamRequestHandler; use super::NAMESPACE_DOESNT_EXIST; pub mod rpc { @@ -40,7 +35,7 @@ pub mod rpc { use crate::query_analysis::Statement; use crate::{connection, error::Error as SqldError}; - use self::{error::ErrorCode, execute_results::State}; + use self::error::ErrorCode; tonic::include_proto!("proxy"); impl From for Error { @@ -63,22 +58,22 @@ pub mod rpc { } } - impl From for State { - fn from(other: crate::query_analysis::State) -> Self { + impl From for State { + fn from(other: crate::query_analysis::TxnStatus) -> Self { match other { - crate::query_analysis::State::Txn => Self::Txn, - crate::query_analysis::State::Init => Self::Init, - crate::query_analysis::State::Invalid => Self::Invalid, + crate::query_analysis::TxnStatus::Txn => Self::Txn, + crate::query_analysis::TxnStatus::Init => Self::Init, + crate::query_analysis::TxnStatus::Invalid => Self::Invalid, } } } - impl From for crate::query_analysis::State { + impl From for crate::query_analysis::TxnStatus { fn from(other: State) -> Self { match other { - State::Txn => crate::query_analysis::State::Txn, - State::Init => crate::query_analysis::State::Init, - State::Invalid => crate::query_analysis::State::Invalid, + State::Txn => crate::query_analysis::TxnStatus::Txn, + State::Init => crate::query_analysis::TxnStatus::Init, + State::Invalid => crate::query_analysis::TxnStatus::Invalid, } } } @@ -405,10 +400,7 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn add_row_value( - &mut self, - v: ValueRef, - ) -> Result<(), QueryResultBuilderError> { + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { let data = bincode::serialize( &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, ) @@ -443,7 +435,11 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -463,247 +459,22 @@ pub async fn garbage_collect(clients: &mut HashMap> tracing::trace!("gc: remaining client handles count: {}", clients.len()); } -pin_project_lite::pin_project! { - pub struct StreamRequestHandler { - #[pin] - request_stream: S, - connection: Arc>, - state: HandlerState, - authenticated: Authenticated, - } -} - -struct StreamResponseBuilder { - request_id: u32, - sender: mpsc::Sender, - current: Option, -} - -impl StreamResponseBuilder { - fn current(&mut self) -> &mut ExecResponse { - self.current.get_or_insert_with(|| ExecResponse { - messages: Vec::new(), - request_id: self.request_id, - }) - } - - fn push(&mut self, payload: Payload) { - const MAX_RESPONSE_MESSAGES: usize = 10; - - let current = self.current(); - current.messages.push(Message { - payload: Some(payload), - }); - - if current.messages.len() > MAX_RESPONSE_MESSAGES { - self.flush() - } - } - - fn flush(&mut self) { - if let Some(current) = self.current.take() { - self.sender.blocking_send(current).unwrap(); - } - } -} - -impl QueryResultBuilder for StreamResponseBuilder { - type Ret = (); - - fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - self.push(Payload::Init(rpc::Init {})); - Ok(()) - } - - fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginStep(rpc::BeginStep {})); - Ok(()) - } - - fn finish_step( - &mut self, - affected_row_count: u64, - last_insert_rowid: Option, - ) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishStep(rpc::FinishStep { - affected_row_count, - last_insert_rowid, - })); - Ok(()) - } - - fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { - self.push(Payload::StepError(rpc::StepError { - error: error.to_string(), - })); - Ok(()) - } - - fn cols_description<'a>( - &mut self, - cols: impl IntoIterator>>, - ) -> Result<(), QueryResultBuilderError> { - self.push(Payload::ColsDescription(rpc::ColsDescription { - columns: cols - .into_iter() - .map(Into::into) - .map(|c| rpc::Column { - name: c.name.into(), - decltype: c.decl_ty.map(Into::into), - }) - .collect::>(), - })); - Ok(()) - } - - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRows(rpc::BeginRows {})); - Ok(()) - } - - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRow(rpc::BeginRow {})); - Ok(()) - } - - fn add_row_value( - &mut self, - v: ValueRef, - ) -> Result<(), QueryResultBuilderError> { - let data = bincode::serialize( - &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, - ) - .map_err(QueryResultBuilderError::from_any)?; - - let val = Some(rpc::Value { data }); - - self.push(Payload::AddRowValue(rpc::AddRowValue { val })); - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRow(rpc::FinishRow {})); - Ok(()) - } - - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRows(rpc::FinishRows {})); - Ok(()) - } - - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { - self.push(Payload::Finish(rpc::Finish { last_frame_no })); - self.flush(); - Ok(()) - } - - fn into_ret(self) -> Self::Ret { - () - } -} - -enum HandlerState { - Execute(Pin + Send>>), - Idle, - Fused, -} - -impl Stream for StreamRequestHandler -where - S: Stream>, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match this.state { - HandlerState::Idle => { - match ready!(this.request_stream.poll_next(cx)) { - Some(Err(e)) => { - *this.state = HandlerState::Fused; - return Poll::Ready(Some(Err(e))); - } - Some(Ok(req)) => { - let request_id = req.request_id; - match req.request.unwrap() { - Request::Execute(exec) => { - let pgm = crate::connection::program::Program::try_from( - exec.pgm.unwrap(), - ) - .unwrap(); - let conn = this.connection.clone(); - let authenticated = this.authenticated.clone(); - - let s = async_stream::stream! { - let (sender, mut receiver) = mpsc::channel(1); - let builder = StreamResponseBuilder { - request_id, - sender, - current: None, - }; - let mut fut = conn.execute_program(pgm, authenticated, builder, None); - loop { - tokio::select! { - _res = &mut fut => { - // todo check result? - break - } - msg = receiver.recv() => { - if let Some(msg) = msg { - yield msg; - } - } - } - } - }; - *this.state = HandlerState::Execute(Box::pin(s)); - } - Request::Describe(_) => todo!(), - } - // we have placed the request, poll immediately - cx.waker().wake_by_ref(); - return Poll::Pending; - } - None => { - // this would easier if tokio_stream re-exported combinators - *this.state = HandlerState::Fused; - Poll::Ready(None) - } - } - } - HandlerState::Fused => Poll::Ready(None), - HandlerState::Execute(stream) => { - let resp = ready!(stream.poll_next_unpin(cx)); - match resp { - Some(resp) => return Poll::Ready(Some(Ok(resp))), - None => { - // finished processing this query. Wake up immediately to prepare for the - // next - *this.state = HandlerState::Idle; - cx.waker().wake_by_ref(); - return Poll::Pending; - } - } - } - } - } -} - #[tonic::async_trait] impl Proxy for ProxyService { - type StreamExecStream = StreamRequestHandler>; + type StreamExecStream = StreamRequestHandler>; async fn stream_exec( &self, - req: tonic::Request>, + req: tonic::Request>, ) -> Result, tonic::Status> { + dbg!(); let authenticated = if let Some(auth) = &self.auth { auth.authenticate_grpc(&req, self.disable_namespaces)? } else { Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? }; + dbg!(); let namespace = super::extract_namespace(self.disable_namespaces, &req)?; let (connection_maker, _new_frame_notifier) = self .namespaces @@ -721,14 +492,11 @@ impl Proxy for ProxyService { } })?; + dbg!(); let connection = connection_maker.create().await.unwrap(); - let handler = StreamRequestHandler { - authenticated, - request_stream: req.into_inner(), - connection: connection.into(), - state: HandlerState::Idle, - }; + dbg!(); + let handler = StreamRequestHandler::new(req.into_inner(), connection, authenticated); Ok(tonic::Response::new(handler)) } @@ -794,7 +562,7 @@ impl Proxy for ProxyService { Ok(tonic::Response::new(ExecuteResults { current_frame_no, results: results.into_ret(), - state: rpc::execute_results::State::from(state).into(), + state: rpc::State::from(state).into(), })) } diff --git a/sqld/src/rpc/replica_proxy.rs b/sqld/src/rpc/replica_proxy.rs index 08fbbaf5..a5ee1602 100644 --- a/sqld/src/rpc/replica_proxy.rs +++ b/sqld/src/rpc/replica_proxy.rs @@ -8,7 +8,7 @@ use crate::auth::Auth; use super::proxy::rpc::{ self, proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, - DisconnectMessage, ExecMessage, ExecResponse, ExecuteResults, + DisconnectMessage, ExecReq, ExecResp, ExecuteResults, }; pub struct ReplicaProxyService { @@ -33,11 +33,11 @@ impl ReplicaProxyService { #[tonic::async_trait] impl Proxy for ReplicaProxyService { - type StreamExecStream = tonic::codec::Streaming; + type StreamExecStream = tonic::codec::Streaming; async fn stream_exec( &self, - req: tonic::Request>, + req: tonic::Request>, ) -> Result, tonic::Status> { let (meta, ext, stream) = req.into_parts(); let mut req = tonic::Request::from_parts(meta, ext, stream.map(|r| r.unwrap())); // TODO: handle mapping error diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs new file mode 100644 index 00000000..3908ff8f --- /dev/null +++ b/sqld/src/rpc/streaming_exec.rs @@ -0,0 +1,283 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + +use futures_core::Stream; +use rusqlite::types::ValueRef; +use tokio::sync::mpsc; + +use crate::auth::Authenticated; +use crate::connection::Connection; +use crate::database::PrimaryConnection; +use crate::query_analysis::TxnStatus; +use crate::query_result_builder::{ + Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, +}; +use crate::replication::FrameNo; +use crate::rpc::proxy::rpc::{exec_req::Request, Message}; + +use super::proxy::rpc::{self, message::Payload, ExecReq, ExecResp}; + +pin_project_lite::pin_project! { + pub struct StreamRequestHandler { + #[pin] + request_stream: S, + connection: Arc, + state: HandlerState, + authenticated: Authenticated, + } +} + +impl StreamRequestHandler { + pub fn new( + request_stream: S, + connection: PrimaryConnection, + authenticated: Authenticated, + ) -> Self { + Self { + request_stream, + connection: connection.into(), + state: HandlerState::Idle, + authenticated, + } + } +} + +struct StreamResponseBuilder { + request_id: u32, + sender: mpsc::Sender, + current: Option, +} + +impl StreamResponseBuilder { + fn current(&mut self) -> &mut ExecResp { + self.current.get_or_insert_with(|| ExecResp { + messages: Vec::new(), + request_id: self.request_id, + }) + } + + fn push(&mut self, payload: Payload) { + const MAX_RESPONSE_MESSAGES: usize = 10; + + let current = self.current(); + current.messages.push(Message { + payload: Some(payload), + }); + + if current.messages.len() > MAX_RESPONSE_MESSAGES { + self.flush() + } + } + + fn flush(&mut self) { + if let Some(current) = self.current.take() { + self.sender.blocking_send(current).unwrap(); + } + } +} + +impl QueryResultBuilder for StreamResponseBuilder { + type Ret = (); + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.push(Payload::Init(rpc::Init {})); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Payload::BeginStep(rpc::BeginStep {})); + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.push(Payload::FinishStep(rpc::FinishStep { + affected_row_count, + last_insert_rowid, + })); + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.push(Payload::StepError(rpc::StepError { + error: Some(error.into()), + })); + Ok(()) + } + + fn cols_description<'a>( + &mut self, + cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + self.push(Payload::ColsDescription(rpc::ColsDescription { + columns: cols + .into_iter() + .map(Into::into) + .map(|c| rpc::Column { + name: c.name.into(), + decltype: c.decl_ty.map(Into::into), + }) + .collect::>(), + })); + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Payload::BeginRows(rpc::BeginRows {})); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Payload::BeginRow(rpc::BeginRow {})); + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + let data = bincode::serialize( + &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, + ) + .map_err(QueryResultBuilderError::from_any)?; + + let val = Some(rpc::Value { data }); + + self.push(Payload::AddRowValue(rpc::AddRowValue { val })); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Payload::FinishRow(rpc::FinishRow {})); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Payload::FinishRows(rpc::FinishRows {})); + Ok(()) + } + + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.push(Payload::Finish(rpc::Finish { + last_frame_no, + state: rpc::State::from(state).into(), + })); + self.flush(); + Ok(()) + } + + fn into_ret(self) -> Self::Ret { + () + } +} + +enum HandlerState { + Execute(Pin + Send>>), + Idle, + Fused, +} + +impl Stream for StreamRequestHandler +where + S: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + dbg!(); + match this.state { + HandlerState::Idle => { + dbg!(); + match ready!(this.request_stream.poll_next(cx)) { + Some(Err(e)) => { + dbg!(); + *this.state = HandlerState::Fused; + return Poll::Ready(Some(Err(e))); + } + Some(Ok(req)) => { + dbg!(&req); + let request_id = req.request_id; + match req.request.unwrap() { + Request::Execute(pgm) => { + dbg!(); + let pgm = + crate::connection::program::Program::try_from(pgm).unwrap(); + let conn = this.connection.clone(); + let authenticated = this.authenticated.clone(); + dbg!(); + + let s = async_stream::stream! { + let (sender, mut receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id, + sender, + current: None, + }; + let mut fut = conn.execute_program(pgm, authenticated, builder, None); + loop { + tokio::select! { + _res = &mut fut => { + dbg!(); + if let Err(e) = _res { + dbg!(e); + } + + // drain the receiver + while let Ok(msg) = receiver.try_recv() { + yield msg; + } + // todo check result? + break + } + msg = receiver.recv() => { + dbg!(&msg); + if let Some(msg) = msg { + dbg!(); + yield msg; + } + } + } + } + }; + dbg!(); + *this.state = HandlerState::Execute(Box::pin(s)); + } + Request::Describe(_) => todo!(), + } + // we have placed the request, poll immediately + cx.waker().wake_by_ref(); + return Poll::Pending; + } + None => { + // this would easier if tokio_stream re-exported combinators + *this.state = HandlerState::Fused; + Poll::Ready(None) + } + } + } + HandlerState::Fused => Poll::Ready(None), + HandlerState::Execute(stream) => { + dbg!(); + let resp = ready!(stream.as_mut().poll_next(cx)); + match resp { + Some(resp) => return Poll::Ready(Some(Ok(dbg!(resp)))), + None => { + dbg!(); + // finished processing this query. Wake up immediately to prepare for the + // next + *this.state = HandlerState::Idle; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + } + } +} From 1b44f2404323fce7a7cdf452a2f23dca3635827f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 4 Oct 2023 18:03:33 +0200 Subject: [PATCH 06/17] error handling --- sqld/src/connection/libsql.rs | 12 ++-- sqld/src/connection/write_proxy.rs | 76 ++++++++++++++++++---- sqld/src/error.rs | 2 + sqld/src/query_result_builder.rs | 7 +- sqld/src/rpc/streaming_exec.rs | 101 +++++++++++++++-------------- 5 files changed, 131 insertions(+), 67 deletions(-) diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 9f47bcae..17f7ed64 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -858,7 +858,7 @@ mod test { TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(state, TxnStatus::Txn); tokio::time::advance(TXN_TIMEOUT * 2).await; @@ -868,7 +868,7 @@ mod test { TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Init); + assert_eq!(state, TxnStatus::Init); assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); } @@ -902,7 +902,7 @@ mod test { TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(state, TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); }); } @@ -945,7 +945,7 @@ mod test { TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(state, TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); } }) @@ -963,7 +963,7 @@ mod test { TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(state, TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); before.elapsed() } @@ -978,7 +978,7 @@ mod test { let (builder, state) = Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) .unwrap(); - assert_eq!(state, State::Init); + assert_eq!(state, TxnStatus::Init); assert!(builder.into_ret()[0].is_ok()); } }) diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 8a1f678c..de6b8c92 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -187,11 +187,22 @@ impl WriteProxyConnection { ) -> Result<(B, TxnStatus)> { self.stats.inc_write_requests_delegated(); *status = TxnStatus::Invalid; - let (builder, new_status, new_frame_no) = self + let res = self .with_remote_conn(auth, self.builder_config, |conn| { Box::pin(conn.execute(pgm, builder)) }) - .await?; + .await; + + let (builder, new_status, new_frame_no) = match res { + Ok(res) => res, + Err(e @ Error::StreamDisconnect) => { + // drop the connection + self.remote_conn.lock().await.take(); + return Err(e); + } + Err(e) => return Err(e), + }; + *status = new_status; if let Some(current_frame_no) = new_frame_no { self.update_last_write_frame_no(current_frame_no); @@ -251,9 +262,7 @@ impl RemoteConnection { req.metadata_mut() .insert_bin(NAMESPACE_METADATA_KEY, namespace); auth.upgrade_grpc_request(&mut req); - dbg!(); let response_stream = client.stream_exec(req).await.unwrap().into_inner(); - dbg!(); Ok(Self { response_stream, @@ -276,14 +285,11 @@ impl RemoteConnection { request: Some(rpc::exec_req::Request::Execute(program.into())), }; - dbg!(); self.request_sender.send(req).await.unwrap(); // TODO: the stream was close! - dbg!(); let mut txn_status = TxnStatus::Invalid; let mut new_frame_no = None; 'outer: while let Some(resp) = self.response_stream.next().await { - dbg!(&resp); match resp { Ok(resp) => { if resp.request_id != request_id { @@ -314,7 +320,6 @@ impl RemoteConnection { Payload::BeginRow(_) => builder.begin_row()?, Payload::AddRowValue(AddRowValue { val }) => { let value: Value = bincode::deserialize(&val.unwrap().data) - // something is wrong, better stop right here .map_err(QueryResultBuilderError::from_any)?; builder.add_row_value(ValueRef::from(&value))?; } @@ -324,7 +329,6 @@ impl RemoteConnection { txn_status = TxnStatus::from(f.state()); new_frame_no = last_frame_no; builder.finish(last_frame_no, txn_status)?; - dbg!(); break 'outer; } Payload::Error(error) => { @@ -333,7 +337,10 @@ impl RemoteConnection { } } } - Err(_e) => todo!("handle stream error"), + Err(e) => { + tracing::error!("received error from connection stream: {e}"); + return Err(Error::StreamDisconnect) + }, } } @@ -426,7 +433,7 @@ pub mod test { use rand::Fill; use super::*; - use crate::query_result_builder::test::test_driver; + use crate::{query_result_builder::test::test_driver, rpc::proxy::rpc::{ExecuteResults, query_result::RowResult}}; /// generate an arbitraty rpc value. see build.rs for usage. pub fn arbitrary_rpc_value(u: &mut Unstructured) -> arbitrary::Result> { @@ -442,10 +449,55 @@ pub mod test { Ok(v.into()) } + fn execute_results_to_builder( + execute_result: ExecuteResults, + mut builder: B, + config: &QueryBuilderConfig, + ) -> Result { + builder.init(config)?; + for result in execute_result.results { + match result.row_result { + Some(RowResult::Row(rows)) => { + builder.begin_step()?; + builder.cols_description(rows.column_descriptions.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }))?; + + builder.begin_rows()?; + for row in rows.rows { + builder.begin_row()?; + for value in row.values { + let value: Value = bincode::deserialize(&value.data) + // something is wrong, better stop right here + .map_err(QueryResultBuilderError::from_any)?; + builder.add_row_value(ValueRef::from(&value))?; + } + builder.finish_row()?; + } + + builder.finish_rows()?; + + builder.finish_step(rows.affected_row_count, rows.last_insert_rowid)?; + } + Some(RowResult::Error(err)) => { + builder.begin_step()?; + builder.step_error(Error::RpcQueryError(err))?; + builder.finish_step(0, None)?; + } + None => (), + } + } + + builder.finish(execute_result.current_frame_no, TxnStatus::Init)?; + + Ok(builder) + } + /// In this test, we generate random ExecuteResults, and ensures that the `execute_results_to_builder` drives the builder FSM correctly. #[test] fn test_execute_results_to_builder() { - test_driver(1000, |b| { + test_driver(1000, |b| -> std::result::Result { let mut data = [0; 10_000]; data.try_fill(&mut rand::thread_rng()).unwrap(); let mut un = Unstructured::new(&data); diff --git a/sqld/src/error.rs b/sqld/src/error.rs index 66ca58dc..558b16e8 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -79,6 +79,8 @@ pub enum Error { ConflictingRestoreParameters, #[error("failed to fork database: {0}")] Fork(#[from] ForkError), + #[error("Connection with primary broken")] + StreamDisconnect, } trait ResponseError: std::error::Error { diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index 03cc219b..e7698407 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -15,6 +15,7 @@ pub static TOTAL_RESPONSE_SIZE: AtomicUsize = AtomicUsize::new(0); #[derive(Debug)] pub enum QueryResultBuilderError { + /// The response payload is too large ResponseTooLarge(u64), Internal(anyhow::Error), } @@ -616,6 +617,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -751,7 +753,7 @@ pub mod test { FinishRow => b.finish_row().unwrap(), FinishRows => b.finish_rows().unwrap(), Finish => { - b.finish(Some(0)).unwrap(); + b.finish(Some(0), TxnStatus::Init).unwrap(); break; } BuilderError => return b, @@ -899,6 +901,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { self.maybe_inject_error()?; self.transition(Finish) @@ -947,7 +950,7 @@ pub mod test { builder.finish_rows().unwrap(); builder.finish_step(0, None).unwrap(); - builder.finish(Some(0)).unwrap(); + builder.finish(Some(0), TxnStatus::Init).unwrap(); } #[test] diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 3908ff8f..ab718d46 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -5,6 +5,7 @@ use std::task::{ready, Context, Poll}; use futures_core::Stream; use rusqlite::types::ValueRef; use tokio::sync::mpsc; +use tonic::{Code, Status}; use crate::auth::Authenticated; use crate::connection::Connection; @@ -23,7 +24,7 @@ pin_project_lite::pin_project! { #[pin] request_stream: S, connection: Arc, - state: HandlerState, + state: State, authenticated: Authenticated, } } @@ -37,7 +38,7 @@ impl StreamRequestHandler { Self { request_stream, connection: connection.into(), - state: HandlerState::Idle, + state: State::Idle, authenticated, } } @@ -57,7 +58,7 @@ impl StreamResponseBuilder { }) } - fn push(&mut self, payload: Payload) { + fn push(&mut self, payload: Payload) -> Result<(), QueryResultBuilderError> { const MAX_RESPONSE_MESSAGES: usize = 10; let current = self.current(); @@ -66,14 +67,19 @@ impl StreamResponseBuilder { }); if current.messages.len() > MAX_RESPONSE_MESSAGES { - self.flush() + self.flush()?; } + + Ok(()) } - fn flush(&mut self) { + fn flush(&mut self) -> Result<(), QueryResultBuilderError> { if let Some(current) = self.current.take() { - self.sender.blocking_send(current).unwrap(); + self.sender.blocking_send(current) + .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; } + + Ok(()) } } @@ -81,12 +87,12 @@ impl QueryResultBuilder for StreamResponseBuilder { type Ret = (); fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - self.push(Payload::Init(rpc::Init {})); + self.push(Payload::Init(rpc::Init {}))?; Ok(()) } fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginStep(rpc::BeginStep {})); + self.push(Payload::BeginStep(rpc::BeginStep {}))?; Ok(()) } @@ -98,14 +104,14 @@ impl QueryResultBuilder for StreamResponseBuilder { self.push(Payload::FinishStep(rpc::FinishStep { affected_row_count, last_insert_rowid, - })); + }))?; Ok(()) } fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { self.push(Payload::StepError(rpc::StepError { error: Some(error.into()), - })); + }))?; Ok(()) } @@ -122,17 +128,17 @@ impl QueryResultBuilder for StreamResponseBuilder { decltype: c.decl_ty.map(Into::into), }) .collect::>(), - })); + }))?; Ok(()) } fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRows(rpc::BeginRows {})); + self.push(Payload::BeginRows(rpc::BeginRows {}))?; Ok(()) } fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRow(rpc::BeginRow {})); + self.push(Payload::BeginRow(rpc::BeginRow {}))?; Ok(()) } @@ -144,17 +150,17 @@ impl QueryResultBuilder for StreamResponseBuilder { let val = Some(rpc::Value { data }); - self.push(Payload::AddRowValue(rpc::AddRowValue { val })); + self.push(Payload::AddRowValue(rpc::AddRowValue { val }))?; Ok(()) } fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRow(rpc::FinishRow {})); + self.push(Payload::FinishRow(rpc::FinishRow {}))?; Ok(()) } fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRows(rpc::FinishRows {})); + self.push(Payload::FinishRows(rpc::FinishRows {}))?; Ok(()) } @@ -166,8 +172,8 @@ impl QueryResultBuilder for StreamResponseBuilder { self.push(Payload::Finish(rpc::Finish { last_frame_no, state: rpc::State::from(state).into(), - })); - self.flush(); + }))?; + self.flush()?; Ok(()) } @@ -176,7 +182,7 @@ impl QueryResultBuilder for StreamResponseBuilder { } } -enum HandlerState { +enum State { Execute(Pin + Send>>), Idle, Fused, @@ -184,34 +190,31 @@ enum HandlerState { impl Stream for StreamRequestHandler where - S: Stream>, + S: Stream>, { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - dbg!(); match this.state { - HandlerState::Idle => { - dbg!(); + State::Idle => { match ready!(this.request_stream.poll_next(cx)) { Some(Err(e)) => { - dbg!(); - *this.state = HandlerState::Fused; + *this.state = State::Fused; return Poll::Ready(Some(Err(e))); } Some(Ok(req)) => { - dbg!(&req); let request_id = req.request_id; - match req.request.unwrap() { - Request::Execute(pgm) => { - dbg!(); - let pgm = - crate::connection::program::Program::try_from(pgm).unwrap(); + match req.request { + Some(Request::Execute(pgm)) => { + let Ok(pgm) = + crate::connection::program::Program::try_from(pgm) else { + *this.state = State::Fused; + return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program")))); + }; let conn = this.connection.clone(); let authenticated = this.authenticated.clone(); - dbg!(); let s = async_stream::stream! { let (sender, mut receiver) = mpsc::channel(1); @@ -223,23 +226,23 @@ where let mut fut = conn.execute_program(pgm, authenticated, builder, None); loop { tokio::select! { - _res = &mut fut => { - dbg!(); - if let Err(e) = _res { - dbg!(e); - } - + res = &mut fut => { // drain the receiver while let Ok(msg) = receiver.try_recv() { yield msg; } + + if let Err(e) = res { + yield ExecResp { + request_id, + messages: vec![rpc::Message { payload: Some(Payload::Error(e.into()))}], + } + } // todo check result? break } msg = receiver.recv() => { - dbg!(&msg); if let Some(msg) = msg { - dbg!(); yield msg; } } @@ -247,9 +250,13 @@ where } }; dbg!(); - *this.state = HandlerState::Execute(Box::pin(s)); + *this.state = State::Execute(Box::pin(s)); + } + Some(Request::Describe(_)) => todo!(), + None => { + *this.state = State::Fused; + return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid ExecReq: missing request")))); } - Request::Describe(_) => todo!(), } // we have placed the request, poll immediately cx.waker().wake_by_ref(); @@ -257,13 +264,13 @@ where } None => { // this would easier if tokio_stream re-exported combinators - *this.state = HandlerState::Fused; + *this.state = State::Fused; Poll::Ready(None) } } } - HandlerState::Fused => Poll::Ready(None), - HandlerState::Execute(stream) => { + State::Fused => Poll::Ready(None), + State::Execute(stream) => { dbg!(); let resp = ready!(stream.as_mut().poll_next(cx)); match resp { @@ -272,7 +279,7 @@ where dbg!(); // finished processing this query. Wake up immediately to prepare for the // next - *this.state = HandlerState::Idle; + *this.state = State::Idle; cx.waker().wake_by_ref(); return Poll::Pending; } From 466c8ba86ddee80d2e82f2368bfcb2fa9a0ad20d Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 5 Oct 2023 17:18:44 +0200 Subject: [PATCH 07/17] remove txn state from execute return It is now passed to the result builder --- sqld/src/connection/libsql.rs | 89 ++++++++++++++++++------------ sqld/src/connection/mod.rs | 16 +++--- sqld/src/connection/write_proxy.rs | 40 +++++++++----- sqld/src/error.rs | 1 + sqld/src/hrana/batch.rs | 4 +- sqld/src/hrana/stmt.rs | 2 +- sqld/src/http/user/mod.rs | 2 +- sqld/src/rpc/proxy.rs | 37 ++++++------- sqld/src/rpc/streaming_exec.rs | 8 ++- 9 files changed, 115 insertions(+), 84 deletions(-) diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 17f7ed64..108a486d 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use parking_lot::{Mutex, RwLock}; -use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus}; +use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState}; use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; @@ -144,7 +144,6 @@ where } } -#[derive(Clone)] pub struct LibSqlConnection { inner: Arc>>, } @@ -160,6 +159,12 @@ impl std::fmt::Debug for LibSqlConnection { } } +impl Clone for LibSqlConnection { + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } + } +} + pub fn open_conn( path: &Path, wal_methods: &'static WalMethodsHook, @@ -219,6 +224,15 @@ where inner: Arc::new(Mutex::new(conn)), }) } + + pub fn txn_status(&self) -> crate::Result { + Ok(self + .inner + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into()) + } } struct Connection { @@ -351,6 +365,16 @@ unsafe extern "C" fn busy_handler(state: *mut c_void, _retries: c_in }) } +impl From for TxnStatus { + fn from(value: TransactionState) -> Self { + use TransactionState as Tx; + match value { + Tx::None => TxnStatus::Init, + Tx::Read | Tx::Write => TxnStatus::Txn, + _ => unreachable!(), + } + } +} impl Connection { fn new( path: &Path, @@ -405,7 +429,7 @@ impl Connection { this: Arc>, pgm: Program, mut builder: B, - ) -> Result<(B, TxnStatus)> { + ) -> Result { use rusqlite::TransactionState as Tx; let state = this.lock().state.clone(); @@ -469,23 +493,18 @@ impl Connection { results.push(res); } - let status = if matches!( - this.lock() - .conn - .transaction_state(Some(DatabaseName::Main))?, - Tx::Read | Tx::Write - ) { - TxnStatus::Txn - } else { - TxnStatus::Init - }; + let status = this + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into(); builder.finish( *this.lock().current_frame_no_receiver.borrow_and_update(), status, )?; - Ok((builder, status)) + Ok(builder) } fn execute_step( @@ -736,7 +755,7 @@ where auth: Authenticated, builder: B, _replication_index: Option, - ) -> Result<(B, TxnStatus)> { + ) -> Result { check_program_auth(auth, &pgm)?; let conn = self.inner.clone(); tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder)) @@ -828,7 +847,7 @@ mod test { fn test_libsql_conn_builder_driver() { test_driver(1000, |b| { let conn = setup_test_conn(); - Connection::run(conn, Program::seq(&["select * from test"]), b).map(|x| x.0) + Connection::run(conn, Program::seq(&["select * from test"]), b) }) } @@ -852,23 +871,23 @@ mod test { tokio::time::pause(); let conn = make_conn.make_connection().await.unwrap(); - let (_builder, state) = Connection::run( + let _builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, TxnStatus::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); tokio::time::advance(TXN_TIMEOUT * 2).await; - let (builder, state) = Connection::run( + let builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, TxnStatus::Init); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); } @@ -896,13 +915,13 @@ mod test { for _ in 0..10 { let conn = make_conn.make_connection().await.unwrap(); set.spawn_blocking(move || { - let (builder, state) = Connection::run( - conn.inner, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, TxnStatus::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); }); } @@ -937,15 +956,15 @@ mod test { let conn1 = make_conn.make_connection().await.unwrap(); tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, TxnStatus::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); } }) @@ -954,16 +973,16 @@ mod test { let conn2 = make_conn.make_connection().await.unwrap(); let handle = tokio::task::spawn_blocking({ - let conn = conn2.inner.clone(); + let conn = conn2.clone(); move || { let before = Instant::now(); - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, TxnStatus::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); before.elapsed() } @@ -973,12 +992,12 @@ mod test { tokio::time::sleep(wait_time).await; tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = - Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) + let builder = + Connection::run(conn.inner.clone(), Program::seq(&["COMMIT"]), TestBuilder::default()) .unwrap(); - assert_eq!(state, TxnStatus::Init); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(builder.into_ret()[0].is_ok()); } }) diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index 1c24a142..07c853dc 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -8,7 +8,7 @@ use tokio::{sync::Semaphore, time::timeout}; use crate::auth::Authenticated; use crate::error::Error; use crate::query::{Params, Query}; -use crate::query_analysis::{Statement, TxnStatus}; +use crate::query_analysis::Statement; use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; use crate::replication::FrameNo; use crate::Result; @@ -32,7 +32,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, response_builder: B, replication_index: Option, - ) -> Result<(B, TxnStatus)>; + ) -> Result; /// Execute all the queries in the batch sequentially. /// If an query in the batch fails, the remaining queries are ignores, and the batch current @@ -43,7 +43,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, TxnStatus)> { + ) -> Result { let batch_len = batch.len(); let mut steps = make_batch_program(batch); @@ -67,11 +67,11 @@ pub trait Connection: Send + Sync + 'static { // ignore the rollback result let builder = result_builder.take(batch_len); - let (builder, state) = self + let builder = self .execute_program(pgm, auth, builder, replication_index) .await?; - Ok((builder.into_inner(), state)) + Ok(builder.into_inner()) } /// Execute all the queries in the batch sequentially. @@ -82,7 +82,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, TxnStatus)> { + ) -> Result { let steps = make_batch_program(batch); let pgm = Program::new(steps); self.execute_program(pgm, auth, result_builder, replication_index) @@ -312,7 +312,7 @@ impl Connection for TrackedConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> crate::Result<(B, TxnStatus)> { + ) -> crate::Result { self.atime.store(now_millis(), Ordering::Relaxed); self.inner .execute_program(pgm, auth, builder, replication_index) @@ -367,7 +367,7 @@ mod test { _auth: Authenticated, _builder: B, _replication_index: Option, - ) -> crate::Result<(B, TxnStatus)> { + ) -> crate::Result { unreachable!() } diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index de6b8c92..fd9928f5 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -184,7 +184,7 @@ impl WriteProxyConnection { status: &mut TxnStatus, auth: Authenticated, builder: B, - ) -> Result<(B, TxnStatus)> { + ) -> Result { self.stats.inc_write_requests_delegated(); *status = TxnStatus::Invalid; let res = self @@ -208,7 +208,7 @@ impl WriteProxyConnection { self.update_last_write_frame_no(current_frame_no); } - Ok((builder, new_status)) + Ok(builder) } fn update_last_write_frame_no(&self, new_frame_no: FrameNo) { @@ -339,8 +339,8 @@ impl RemoteConnection { } Err(e) => { tracing::error!("received error from connection stream: {e}"); - return Err(Error::StreamDisconnect) - }, + return Err(Error::StreamDisconnect); + } } } @@ -356,26 +356,30 @@ impl Connection for WriteProxyConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> Result<(B, TxnStatus)> { + ) -> Result { let mut state = self.state.lock().await; // This is a fresh namespace, and it is not replicated yet, proxy the first request. if self.applied_frame_no_receiver.borrow().is_none() { self.execute_remote(pgm, &mut state, auth, builder).await } else if *state == TxnStatus::Init && pgm.is_read_only() { + // set the state to invalid before doing anything, and set it to a valid state after. + *state = TxnStatus::Invalid; self.wait_replication_sync(replication_index).await?; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - let (builder, new_state) = self + let builder = self .read_conn .execute_program(pgm.clone(), auth.clone(), builder, replication_index) .await?; + let new_state = self.read_conn.txn_status()?; if new_state != TxnStatus::Init { self.read_conn.rollback(auth.clone()).await?; self.execute_remote(pgm, &mut state, auth, builder).await } else { - Ok((builder, new_state)) + *state = new_state; + Ok(builder) } } else { self.execute_remote(pgm, &mut state, auth, builder).await @@ -433,7 +437,10 @@ pub mod test { use rand::Fill; use super::*; - use crate::{query_result_builder::test::test_driver, rpc::proxy::rpc::{ExecuteResults, query_result::RowResult}}; + use crate::{ + query_result_builder::test::test_driver, + rpc::proxy::rpc::{query_result::RowResult, ExecuteResults}, + }; /// generate an arbitraty rpc value. see build.rs for usage. pub fn arbitrary_rpc_value(u: &mut Unstructured) -> arbitrary::Result> { @@ -497,12 +504,15 @@ pub mod test { /// In this test, we generate random ExecuteResults, and ensures that the `execute_results_to_builder` drives the builder FSM correctly. #[test] fn test_execute_results_to_builder() { - test_driver(1000, |b| -> std::result::Result { - let mut data = [0; 10_000]; - data.try_fill(&mut rand::thread_rng()).unwrap(); - let mut un = Unstructured::new(&data); - let res = ExecuteResults::arbitrary(&mut un).unwrap(); - execute_results_to_builder(res, b, &QueryBuilderConfig::default()) - }); + test_driver( + 1000, + |b| -> std::result::Result { + let mut data = [0; 10_000]; + data.try_fill(&mut rand::thread_rng()).unwrap(); + let mut un = Unstructured::new(&data); + let res = ExecuteResults::arbitrary(&mut un).unwrap(); + execute_results_to_builder(res, b, &QueryBuilderConfig::default()) + }, + ); } } diff --git a/sqld/src/error.rs b/sqld/src/error.rs index 558b16e8..31b1fa12 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -131,6 +131,7 @@ impl IntoResponse for Error { LoadDumpExistingDb => self.format_err(StatusCode::BAD_REQUEST), ConflictingRestoreParameters => self.format_err(StatusCode::BAD_REQUEST), Fork(e) => e.into_response(), + StreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), } } } diff --git a/sqld/src/hrana/batch.rs b/sqld/src/hrana/batch.rs index cb0deb63..a2ddd9e2 100644 --- a/sqld/src/hrana/batch.rs +++ b/sqld/src/hrana/batch.rs @@ -110,7 +110,7 @@ pub async fn execute_batch( replication_index: Option, ) -> Result { let batch_builder = HranaBatchProtoBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, batch_builder, replication_index) .await .map_err(catch_batch_error)?; @@ -151,7 +151,7 @@ pub async fn execute_sequence( replication_index: Option, ) -> Result<()> { let builder = StepResultsBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, builder, replication_index) .await .map_err(catch_batch_error)?; diff --git a/sqld/src/hrana/stmt.rs b/sqld/src/hrana/stmt.rs index 2021b384..46cd3684 100644 --- a/sqld/src/hrana/stmt.rs +++ b/sqld/src/hrana/stmt.rs @@ -58,7 +58,7 @@ pub async fn execute_stmt( replication_index: Option, ) -> Result { let builder = SingleStatementBuilder::default(); - let (stmt_res, _) = db + let stmt_res = db .execute_batch(vec![query], auth, builder, replication_index) .await .map_err(catch_stmt_error)?; diff --git a/sqld/src/http/user/mod.rs b/sqld/src/http/user/mod.rs index c99bba95..503222c0 100644 --- a/sqld/src/http/user/mod.rs +++ b/sqld/src/http/user/mod.rs @@ -121,7 +121,7 @@ async fn handle_query( let db = connection_maker.create().await?; let builder = JsonHttpPayloadBuilder::new(); - let (builder, _) = db + let builder = db .execute_batch_or_rollback(batch, auth, builder, query.replication_index) .await?; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index e248be81..bd98661f 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -293,7 +293,8 @@ impl ProxyService { } #[derive(Debug, Default)] -struct ExecuteResultBuilder { +struct ExecuteResultsBuilder { + output: Option, results: Vec, current_rows: Vec, current_row: rpc::Row, @@ -304,8 +305,8 @@ struct ExecuteResultBuilder { current_step_size: u64, } -impl QueryResultBuilder for ExecuteResultBuilder { - type Ret = Vec; +impl QueryResultBuilder for ExecuteResultsBuilder { + type Ret = ExecuteResults; fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Self { @@ -437,14 +438,19 @@ impl QueryResultBuilder for ExecuteResultBuilder { fn finish( &mut self, - _last_frame_no: Option, - _state: TxnStatus, + last_frame_no: Option, + txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { + self.output = Some(ExecuteResults { + results: std::mem::take(&mut self.results), + state: rpc::State::from(txn_status).into(), + current_frame_no: last_frame_no, + }); Ok(()) } fn into_ret(self) -> Self::Ret { - self.results + self.output.unwrap() } } @@ -516,13 +522,9 @@ impl Proxy for ProxyService { .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument, e.to_string()))?; let client_id = Uuid::from_str(&req.client_id).unwrap(); - let (connection_maker, new_frame_notifier) = self + let connection_maker = self .namespaces - .with(namespace, |ns| { - let connection_maker = ns.db.connection_maker(); - let notifier = ns.db.logger.new_frame_notifier.subscribe(); - (connection_maker, notifier) - }) + .with(namespace, |ns| ns.db.connection_maker()) .await .map_err(|e| { if let crate::error::Error::NamespaceDoesntExist(_) = e { @@ -551,19 +553,14 @@ impl Proxy for ProxyService { tracing::debug!("executing request for {client_id}"); - let builder = ExecuteResultBuilder::default(); - let (results, state) = db + let builder = ExecuteResultsBuilder::default(); + let builder = db .execute_program(pgm, auth, builder, None) .await // TODO: this is no necessarily a permission denied error! .map_err(|e| tonic::Status::new(tonic::Code::PermissionDenied, e.to_string()))?; - let current_frame_no = *new_frame_notifier.borrow(); - Ok(tonic::Response::new(ExecuteResults { - current_frame_no, - results: results.into_ret(), - state: rpc::State::from(state).into(), - })) + Ok(tonic::Response::new(builder.into_ret())) } //TODO: also handle cleanup on peer disconnect diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index ab718d46..562790b1 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -75,7 +75,8 @@ impl StreamResponseBuilder { fn flush(&mut self) -> Result<(), QueryResultBuilderError> { if let Some(current) = self.current.take() { - self.sender.blocking_send(current) + self.sender + .blocking_send(current) .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; } @@ -255,7 +256,10 @@ where Some(Request::Describe(_)) => todo!(), None => { *this.state = State::Fused; - return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid ExecReq: missing request")))); + return Poll::Ready(Some(Err(Status::new( + Code::InvalidArgument, + "invalid ExecReq: missing request", + )))); } } // we have placed the request, poll immediately From 637f7f0dfea6f76576ba2612a2087a5c9b6a8059 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 5 Oct 2023 22:20:57 +0200 Subject: [PATCH 08/17] implement streaming describe rpc --- sqld/proto/proxy.proto | 99 ++++++++++++----- sqld/src/connection/libsql.rs | 19 ++-- sqld/src/connection/mod.rs | 8 +- sqld/src/connection/program.rs | 2 - sqld/src/connection/write_proxy.rs | 170 +++++++++++++++++++++-------- sqld/src/rpc/streaming_exec.rs | 97 ++++++++-------- 6 files changed, 264 insertions(+), 131 deletions(-) diff --git a/sqld/proto/proxy.proto b/sqld/proto/proxy.proto index 98586bc7..d85bbeb0 100644 --- a/sqld/proto/proxy.proto +++ b/sqld/proto/proxy.proto @@ -4,7 +4,7 @@ package proxy; message Queries { repeated Query queries = 1; // Uuid - string clientId = 2; + string client_id = 2; } message Query { @@ -34,10 +34,10 @@ message QueryResult { message Error { enum ErrorCode { - SQLError = 0; - TxBusy = 1; - TxTimeout = 2; - Internal = 3; + SQL_ERROR = 0; + TX_BUSY = 1; + TX_TIMEOUT = 2; + INTERNAL = 3; } ErrorCode code = 1; @@ -71,7 +71,7 @@ message Description { message Value { /// bincode encoded Value - bytes data = 1; + bytes data = 1; } message Row { @@ -84,7 +84,7 @@ message Column { } message DisconnectMessage { - string clientId = 1; + string client_id = 1; } message Ack { } @@ -150,17 +150,29 @@ message ProgramReq { Program pgm = 2; } +/// Streaming exec request message ExecReq { - /// id of the request. The response will contain this id + /// id of the request. The response will contain this id. uint32 request_id = 1; oneof request { - Program execute = 2; - DescribeRequest describe = 3; + StreamProgramReq execute = 2; + StreamDescribeReq describe = 3; } } -/// streaming exec proto +/// Describe request for the streaming protocol +message StreamProgramReq { + Program pgm = 1; +} +/// descibre request for the streaming protocol +message StreamDescribeReq { + string stmt = 1; +} + +/// Response message for the streaming proto + +/// Request response types message Init { } message BeginStep { } message FinishStep { @@ -173,10 +185,20 @@ message StepError { message ColsDescription { repeated Column columns = 1; } +message RowValue { + oneof value { + string text = 1; + int64 integer = 2; + double real = 3; + bytes blob = 4; + // null if present + bool null = 5; + } +} message BeginRows { } message BeginRow { } message AddRowValue { - Value val = 1; + RowValue val = 1; } message FinishRow { } message FinishRows { } @@ -185,35 +207,56 @@ message Finish { State state = 2; } +/// Stream execx dexcribe response messages +message DescribeParam { + optional string name = 1; +} -message Message { - oneof payload { - Description describe_result = 1; +message DescribeCol { + string name = 1; + optional string decltype = 2; +} - Init init = 2; - BeginStep begin_step = 3; - FinishStep finish_step = 4; - StepError step_error = 5; - ColsDescription cols_description = 6; - BeginRows begin_rows = 7; - BeginRow begin_row = 8; - AddRowValue add_row_value = 9; - FinishRow finish_row = 10; - FinishRows finish_rows = 11; - Finish finish = 12; +message DescribeResp { + repeated DescribeParam params = 1; + repeated DescribeCol cols = 2; + bool is_explain = 3; + bool is_readonly = 4; +} - Error error = 13; +message RespStep { + oneof step { + Init init = 1; + BeginStep begin_step = 2; + FinishStep finish_step = 3; + StepError step_error = 4; + ColsDescription cols_description = 5; + BeginRows begin_rows = 6; + BeginRow begin_row = 7; + AddRowValue add_row_value = 8; + FinishRow finish_row = 9; + FinishRows finish_rows = 10; + Finish finish = 11; } } +message ProgramResp { + repeated RespStep steps = 1; +} + message ExecResp { uint32 request_id = 1; - repeated Message messages = 2; + oneof response { + ProgramResp program_resp = 2; + DescribeResp describe_resp = 3; + Error error = 4; + } } service Proxy { rpc StreamExec(stream ExecReq) returns (stream ExecResp) {} + // Deprecated: rpc Execute(ProgramReq) returns (ExecuteResults) {} rpc Describe(DescribeRequest) returns (DescribeResult) {} rpc Disconnect(DisconnectMessage) returns (Ack) {} diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 108a486d..9ce14cd3 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -20,7 +20,7 @@ use crate::stats::Stats; use crate::Result; use super::config::DatabaseConfigStore; -use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse, DescribeResult}; +use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse}; use super::{MakeConnection, Program, Step, TXN_TIMEOUT}; pub struct MakeLibSqlConn { @@ -161,7 +161,9 @@ impl std::fmt::Debug for LibSqlConnection { impl Clone for LibSqlConnection { fn clone(&self) -> Self { - Self { inner: self.inner.clone() } + Self { + inner: self.inner.clone(), + } } } @@ -650,7 +652,7 @@ impl Connection { self.stats.inc_rows_written(rows_written as u64); } - fn describe(&self, sql: &str) -> DescribeResult { + fn describe(&self, sql: &str) -> crate::Result { let stmt = self.conn.prepare(sql)?; let params = (1..=stmt.parameter_count()) @@ -768,7 +770,7 @@ where sql: String, auth: Authenticated, _replication_index: Option, - ) -> Result { + ) -> Result> { check_describe_auth(auth)?; let conn = self.inner.clone(); let res = tokio::task::spawn_blocking(move || conn.lock().describe(&sql)) @@ -994,9 +996,12 @@ mod test { tokio::task::spawn_blocking({ let conn = conn1.clone(); move || { - let builder = - Connection::run(conn.inner.clone(), Program::seq(&["COMMIT"]), TestBuilder::default()) - .unwrap(); + let builder = Connection::run( + conn.inner.clone(), + Program::seq(&["COMMIT"]), + TestBuilder::default(), + ) + .unwrap(); assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(builder.into_ret()[0].is_ok()); } diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index 07c853dc..85b2374c 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -13,7 +13,7 @@ use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; use crate::replication::FrameNo; use crate::Result; -use self::program::{Cond, DescribeResult, Program, Step}; +use self::program::{Cond, DescribeResponse, Program, Step}; pub mod config; pub mod dump; @@ -111,7 +111,7 @@ pub trait Connection: Send + Sync + 'static { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result; + ) -> Result>; /// Check whether the connection is in autocommit mode. async fn is_autocommit(&self) -> Result; @@ -325,7 +325,7 @@ impl Connection for TrackedConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { self.atime.store(now_millis(), Ordering::Relaxed); self.inner.describe(sql, auth, replication_index).await } @@ -376,7 +376,7 @@ mod test { _sql: String, _auth: Authenticated, _replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { unreachable!() } diff --git a/sqld/src/connection/program.rs b/sqld/src/connection/program.rs index fabfbd18..3017232a 100644 --- a/sqld/src/connection/program.rs +++ b/sqld/src/connection/program.rs @@ -60,8 +60,6 @@ pub enum Cond { IsAutocommit, } -pub type DescribeResult = crate::Result; - #[derive(Debug, Clone)] pub struct DescribeResponse { pub params: Vec, diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index fd9928f5..c1eb9690 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -13,18 +13,18 @@ use tonic::{Request, Streaming}; use uuid::Uuid; use crate::auth::Authenticated; +use crate::connection::program::{DescribeCol, DescribeParam}; use crate::error::Error; use crate::namespace::NamespaceName; -use crate::query::Value; use crate::query_analysis::TxnStatus; -use crate::query_result_builder::{ - Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, -}; +use crate::query_result_builder::{Column, QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::proxy_client::ProxyClient; +use crate::rpc::proxy::rpc::resp_step::Step; +use crate::rpc::proxy::rpc::row_value::Value; use crate::rpc::proxy::rpc::{ - self, AddRowValue, ColsDescription, DisconnectMessage, ExecReq, ExecResp, Finish, FinishStep, - StepError, + self, exec_req, exec_resp, AddRowValue, ColsDescription, DisconnectMessage, ExecReq, ExecResp, + Finish, FinishStep, RowValue, StepError, }; use crate::rpc::NAMESPACE_METADATA_KEY; use crate::stats::Stats; @@ -32,7 +32,7 @@ use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; use super::config::DatabaseConfigStore; use super::libsql::{LibSqlConnection, MakeLibSqlConn}; -use super::program::DescribeResult; +use super::program::DescribeResponse; use super::Connection; use super::{MakeConnection, Program}; @@ -272,80 +272,160 @@ impl RemoteConnection { }) } - async fn execute( + /// Perform a request on to the remote peer, and call message_cb for every message received for + /// that request. message cb should return whether to expect more message for that request. + async fn make_request( &mut self, - program: Program, - mut builder: B, - ) -> crate::Result<(B, TxnStatus, Option)> { + req: exec_req::Request, + mut response_cb: impl FnMut(exec_resp::Response) -> crate::Result, + ) -> crate::Result<()> { let request_id = self.current_request_id; self.current_request_id += 1; let req = ExecReq { request_id, - request: Some(rpc::exec_req::Request::Execute(program.into())), + request: Some(req), }; - self.request_sender.send(req).await.unwrap(); // TODO: the stream was close! - let mut txn_status = TxnStatus::Invalid; - let mut new_frame_no = None; + self.request_sender + .send(req) + .await + .map_err(|_| Error::StreamDisconnect)?; 'outer: while let Some(resp) = self.response_stream.next().await { match resp { Ok(resp) => { + // todo: handle interuption if resp.request_id != request_id { todo!("stream misuse: connection should be serialized"); } - for message in resp.messages { - use rpc::message::Payload; - match message.payload.unwrap() { - Payload::DescribeResult(_) => todo!("invalid response"), + if !response_cb(resp.response.unwrap())? { + break 'outer; + } + } + Err(e) => { + tracing::error!("received error from connection stream: {e}"); + return Err(Error::StreamDisconnect); + } + } + } + + Ok(()) + } - Payload::Init(_) => builder.init(&self.builder_config)?, - Payload::BeginStep(_) => builder.begin_step()?, - Payload::FinishStep(FinishStep { + async fn execute( + &mut self, + program: Program, + mut builder: B, + ) -> crate::Result<(B, TxnStatus, Option)> { + let mut txn_status = TxnStatus::Invalid; + let mut new_frame_no = None; + let builder_config = self.builder_config; + let cb = |response: exec_resp::Response| { + match response { + exec_resp::Response::ProgramResp(resp) => { + for step in resp.steps { + let Some(step) = step.step else {panic!("invalid pgm")}; + match step { + Step::Init(_) => builder.init(&builder_config)?, + Step::BeginStep(_) => builder.begin_step()?, + Step::FinishStep(FinishStep { affected_row_count, last_insert_rowid, }) => builder.finish_step(affected_row_count, last_insert_rowid)?, - Payload::StepError(StepError { error }) => builder + Step::StepError(StepError { error }) => builder .step_error(crate::error::Error::RpcQueryError(error.unwrap()))?, - Payload::ColsDescription(ColsDescription { columns }) => { + Step::ColsDescription(ColsDescription { columns }) => { let cols = columns.iter().map(|c| Column { name: &c.name, decl_ty: c.decltype.as_deref(), }); builder.cols_description(cols)? } - Payload::BeginRows(_) => builder.begin_rows()?, - Payload::BeginRow(_) => builder.begin_row()?, - Payload::AddRowValue(AddRowValue { val }) => { - let value: Value = bincode::deserialize(&val.unwrap().data) - .map_err(QueryResultBuilderError::from_any)?; - builder.add_row_value(ValueRef::from(&value))?; + Step::BeginRows(_) => builder.begin_rows()?, + Step::BeginRow(_) => builder.begin_row()?, + Step::AddRowValue(AddRowValue { + val: Some(RowValue { value: Some(val) }), + }) => { + let val = match &val { + Value::Text(s) => ValueRef::Text(s.as_bytes()), + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Blob(b) => ValueRef::Blob(b.as_slice()), + Value::Null(_) => ValueRef::Null, + }; + builder.add_row_value(val)?; } - Payload::FinishRow(_) => builder.finish_row()?, - Payload::FinishRows(_) => builder.finish_rows()?, - Payload::Finish(f @ Finish { last_frame_no, .. }) => { + Step::FinishRow(_) => builder.finish_row()?, + Step::FinishRows(_) => builder.finish_rows()?, + Step::Finish(f @ Finish { last_frame_no, .. }) => { txn_status = TxnStatus::from(f.state()); new_frame_no = last_frame_no; builder.finish(last_frame_no, txn_status)?; - break 'outer; - } - Payload::Error(error) => { - return Err(crate::error::Error::RpcQueryError(error)) + return Ok(false); } + _ => todo!("invalid request"), } } } - Err(e) => { - tracing::error!("received error from connection stream: {e}"); - return Err(Error::StreamDisconnect); - } + exec_resp::Response::DescribeResp(_) => todo!("invalid resp"), + exec_resp::Response::Error(_) => todo!(), } - } + + Ok(true) + }; + + self.make_request( + exec_req::Request::Execute(rpc::StreamProgramReq { + pgm: Some(program.into()), + }), + cb, + ) + .await?; Ok((builder, txn_status, new_frame_no)) } + + #[allow(dead_code)] // reference implementation + async fn describe(&mut self, stmt: String) -> crate::Result { + let mut out = None; + let cb = |response: exec_resp::Response| { + match response { + exec_resp::Response::DescribeResp(resp) => { + out = Some(DescribeResponse { + params: resp + .params + .into_iter() + .map(|p| DescribeParam { name: p.name }) + .collect(), + cols: resp + .cols + .into_iter() + .map(|c| DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: resp.is_explain, + is_readonly: resp.is_readonly, + }); + } + exec_resp::Response::Error(_) => todo!(), + exec_resp::Response::ProgramResp(_) => todo!(), + } + + Ok(false) + }; + + self.make_request( + exec_req::Request::Describe(rpc::StreamDescribeReq { stmt }), + cb, + ) + .await?; + + Ok(out.unwrap()) + } } #[async_trait::async_trait] @@ -391,7 +471,7 @@ impl Connection for WriteProxyConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result { + ) -> Result> { self.wait_replication_sync(replication_index).await?; self.read_conn.describe(sql, auth, replication_index).await } @@ -438,7 +518,7 @@ pub mod test { use super::*; use crate::{ - query_result_builder::test::test_driver, + query_result_builder::{test::test_driver, QueryResultBuilderError}, rpc::proxy::rpc::{query_result::RowResult, ExecuteResults}, }; @@ -475,7 +555,7 @@ pub mod test { for row in rows.rows { builder.begin_row()?; for value in row.values { - let value: Value = bincode::deserialize(&value.data) + let value: crate::query::Value = bincode::deserialize(&value.data) // something is wrong, better stop right here .map_err(QueryResultBuilderError::from_any)?; builder.add_row_value(ValueRef::from(&value))?; diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 562790b1..6613a790 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -15,9 +15,11 @@ use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; -use crate::rpc::proxy::rpc::{exec_req::Request, Message}; +use crate::rpc::proxy::rpc::exec_req::Request; +use crate::rpc::proxy::rpc::exec_resp; -use super::proxy::rpc::{self, message::Payload, ExecReq, ExecResp}; +use super::proxy::rpc::resp_step::Step; +use super::proxy::rpc::{self, ExecReq, ExecResp, ProgramResp, RespStep, RowValue}; pin_project_lite::pin_project! { pub struct StreamRequestHandler { @@ -47,26 +49,22 @@ impl StreamRequestHandler { struct StreamResponseBuilder { request_id: u32, sender: mpsc::Sender, - current: Option, + current: Option, } impl StreamResponseBuilder { - fn current(&mut self) -> &mut ExecResp { - self.current.get_or_insert_with(|| ExecResp { - messages: Vec::new(), - request_id: self.request_id, - }) + fn current(&mut self) -> &mut ProgramResp { + self.current + .get_or_insert_with(|| ProgramResp { steps: Vec::new() }) } - fn push(&mut self, payload: Payload) -> Result<(), QueryResultBuilderError> { - const MAX_RESPONSE_MESSAGES: usize = 10; + fn push(&mut self, step: Step) -> Result<(), QueryResultBuilderError> { + const MAX_RESPONSE_STEPS: usize = 10; let current = self.current(); - current.messages.push(Message { - payload: Some(payload), - }); + current.steps.push(RespStep { step: Some(step) }); - if current.messages.len() > MAX_RESPONSE_MESSAGES { + if current.steps.len() > MAX_RESPONSE_STEPS { self.flush()?; } @@ -75,8 +73,12 @@ impl StreamResponseBuilder { fn flush(&mut self) -> Result<(), QueryResultBuilderError> { if let Some(current) = self.current.take() { + let resp = ExecResp { + request_id: self.request_id, + response: Some(exec_resp::Response::ProgramResp(current)), + }; self.sender - .blocking_send(current) + .blocking_send(resp) .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; } @@ -88,12 +90,12 @@ impl QueryResultBuilder for StreamResponseBuilder { type Ret = (); fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - self.push(Payload::Init(rpc::Init {}))?; + self.push(Step::Init(rpc::Init {}))?; Ok(()) } fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginStep(rpc::BeginStep {}))?; + self.push(Step::BeginStep(rpc::BeginStep {}))?; Ok(()) } @@ -102,7 +104,7 @@ impl QueryResultBuilder for StreamResponseBuilder { affected_row_count: u64, last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishStep(rpc::FinishStep { + self.push(Step::FinishStep(rpc::FinishStep { affected_row_count, last_insert_rowid, }))?; @@ -110,7 +112,7 @@ impl QueryResultBuilder for StreamResponseBuilder { } fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { - self.push(Payload::StepError(rpc::StepError { + self.push(Step::StepError(rpc::StepError { error: Some(error.into()), }))?; Ok(()) @@ -120,7 +122,7 @@ impl QueryResultBuilder for StreamResponseBuilder { &mut self, cols: impl IntoIterator>>, ) -> Result<(), QueryResultBuilderError> { - self.push(Payload::ColsDescription(rpc::ColsDescription { + self.push(Step::ColsDescription(rpc::ColsDescription { columns: cols .into_iter() .map(Into::into) @@ -134,34 +136,29 @@ impl QueryResultBuilder for StreamResponseBuilder { } fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRows(rpc::BeginRows {}))?; + self.push(Step::BeginRows(rpc::BeginRows {}))?; Ok(()) } fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::BeginRow(rpc::BeginRow {}))?; + self.push(Step::BeginRow(rpc::BeginRow {}))?; Ok(()) } fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { - let data = bincode::serialize( - &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, - ) - .map_err(QueryResultBuilderError::from_any)?; - - let val = Some(rpc::Value { data }); - - self.push(Payload::AddRowValue(rpc::AddRowValue { val }))?; + self.push(Step::AddRowValue(rpc::AddRowValue { + val: Some(v.into()), + }))?; Ok(()) } fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRow(rpc::FinishRow {}))?; + self.push(Step::FinishRow(rpc::FinishRow {}))?; Ok(()) } fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.push(Payload::FinishRows(rpc::FinishRows {}))?; + self.push(Step::FinishRows(rpc::FinishRows {}))?; Ok(()) } @@ -170,7 +167,7 @@ impl QueryResultBuilder for StreamResponseBuilder { last_frame_no: Option, state: TxnStatus, ) -> Result<(), QueryResultBuilderError> { - self.push(Payload::Finish(rpc::Finish { + self.push(Step::Finish(rpc::Finish { last_frame_no, state: rpc::State::from(state).into(), }))?; @@ -178,8 +175,22 @@ impl QueryResultBuilder for StreamResponseBuilder { Ok(()) } - fn into_ret(self) -> Self::Ret { - () + fn into_ret(self) -> Self::Ret { } +} + +impl From> for RowValue { + fn from(value: ValueRef<'_>) -> Self { + use rpc::row_value::Value; + + let value = Some(match value { + ValueRef::Null => Value::Null(true), + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(x) => Value::Real(x), + ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + }); + + RowValue { value } } } @@ -203,14 +214,14 @@ where match ready!(this.request_stream.poll_next(cx)) { Some(Err(e)) => { *this.state = State::Fused; - return Poll::Ready(Some(Err(e))); + Poll::Ready(Some(Err(e))) } Some(Ok(req)) => { let request_id = req.request_id; match req.request { Some(Request::Execute(pgm)) => { let Ok(pgm) = - crate::connection::program::Program::try_from(pgm) else { + crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { *this.state = State::Fused; return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program")))); }; @@ -236,10 +247,9 @@ where if let Err(e) = res { yield ExecResp { request_id, - messages: vec![rpc::Message { payload: Some(Payload::Error(e.into()))}], + response: Some(exec_resp::Response::Error(e.into())) } } - // todo check result? break } msg = receiver.recv() => { @@ -250,7 +260,6 @@ where } } }; - dbg!(); *this.state = State::Execute(Box::pin(s)); } Some(Request::Describe(_)) => todo!(), @@ -264,7 +273,7 @@ where } // we have placed the request, poll immediately cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Pending } None => { // this would easier if tokio_stream re-exported combinators @@ -275,17 +284,15 @@ where } State::Fused => Poll::Ready(None), State::Execute(stream) => { - dbg!(); let resp = ready!(stream.as_mut().poll_next(cx)); match resp { - Some(resp) => return Poll::Ready(Some(Ok(dbg!(resp)))), + Some(resp) => Poll::Ready(Some(Ok(resp))), None => { - dbg!(); // finished processing this query. Wake up immediately to prepare for the // next *this.state = State::Idle; cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Pending } } } From 2e95e757d009ee269881a43d3a49bbcbb00fcb55 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 5 Oct 2023 22:51:06 +0200 Subject: [PATCH 09/17] more error handling --- sqld/src/connection/write_proxy.rs | 47 +++++++++++++++++------------- sqld/src/error.rs | 9 ++++-- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index c1eb9690..9554189e 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -195,7 +195,7 @@ impl WriteProxyConnection { let (builder, new_status, new_frame_no) = match res { Ok(res) => res, - Err(e @ Error::StreamDisconnect) => { + Err(e @ (Error::PrimaryStreamDisconnect | Error::PrimaryStreamMisuse)) => { // drop the connection self.remote_conn.lock().await.take(); return Err(e); @@ -290,23 +290,28 @@ impl RemoteConnection { self.request_sender .send(req) .await - .map_err(|_| Error::StreamDisconnect)?; + .map_err(|_| Error::PrimaryStreamDisconnect)?; - 'outer: while let Some(resp) = self.response_stream.next().await { + while let Some(resp) = self.response_stream.next().await { match resp { Ok(resp) => { - // todo: handle interuption - if resp.request_id != request_id { - todo!("stream misuse: connection should be serialized"); + // there was an interuption, and we moved to the next query + if resp.request_id > request_id { + return Err(Error::PrimaryStreamInterupted) } - if !response_cb(resp.response.unwrap())? { - break 'outer; + // we can ignore response for previously interupted requests + if resp.request_id < request_id { + continue; + } + + if !response_cb(resp.response.ok_or(Error::PrimaryStreamMisuse)?)? { + break; } } Err(e) => { - tracing::error!("received error from connection stream: {e}"); - return Err(Error::StreamDisconnect); + tracing::error!("received an error from connection stream: {e}"); + return Err(Error::PrimaryStreamDisconnect); } } } @@ -326,7 +331,7 @@ impl RemoteConnection { match response { exec_resp::Response::ProgramResp(resp) => { for step in resp.steps { - let Some(step) = step.step else {panic!("invalid pgm")}; + let Some(step) = step.step else { return Err(Error::PrimaryStreamMisuse) }; match step { Step::Init(_) => builder.init(&builder_config)?, Step::BeginStep(_) => builder.begin_step()?, @@ -334,8 +339,8 @@ impl RemoteConnection { affected_row_count, last_insert_rowid, }) => builder.finish_step(affected_row_count, last_insert_rowid)?, - Step::StepError(StepError { error }) => builder - .step_error(crate::error::Error::RpcQueryError(error.unwrap()))?, + Step::StepError(StepError { error: Some(err) }) => builder + .step_error(crate::error::Error::RpcQueryError(err))?, Step::ColsDescription(ColsDescription { columns }) => { let cols = columns.iter().map(|c| Column { name: &c.name, @@ -365,12 +370,12 @@ impl RemoteConnection { builder.finish(last_frame_no, txn_status)?; return Ok(false); } - _ => todo!("invalid request"), + _ => return Err(Error::PrimaryStreamMisuse), } } } - exec_resp::Response::DescribeResp(_) => todo!("invalid resp"), - exec_resp::Response::Error(_) => todo!(), + exec_resp::Response::DescribeResp(_) => return Err(Error::PrimaryStreamMisuse), + exec_resp::Response::Error(e) => return Err(Error::RpcQueryError(e)), } Ok(true) @@ -410,12 +415,12 @@ impl RemoteConnection { is_explain: resp.is_explain, is_readonly: resp.is_readonly, }); + + Ok(false) } - exec_resp::Response::Error(_) => todo!(), - exec_resp::Response::ProgramResp(_) => todo!(), + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + exec_resp::Response::ProgramResp(_) => Err(Error::PrimaryStreamMisuse), } - - Ok(false) }; self.make_request( @@ -424,7 +429,7 @@ impl RemoteConnection { ) .await?; - Ok(out.unwrap()) + out.ok_or(Error::PrimaryStreamMisuse) } } diff --git a/sqld/src/error.rs b/sqld/src/error.rs index 31b1fa12..96b3ce29 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -79,8 +79,13 @@ pub enum Error { ConflictingRestoreParameters, #[error("failed to fork database: {0}")] Fork(#[from] ForkError), + #[error("Connection with primary broken")] - StreamDisconnect, + PrimaryStreamDisconnect, + #[error("Proxy protocal misuse")] + PrimaryStreamMisuse, + #[error("Proxy request interupted")] + PrimaryStreamInterupted, } trait ResponseError: std::error::Error { @@ -131,7 +136,7 @@ impl IntoResponse for Error { LoadDumpExistingDb => self.format_err(StatusCode::BAD_REQUEST), ConflictingRestoreParameters => self.format_err(StatusCode::BAD_REQUEST), Fork(e) => e.into_response(), - StreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), } } } From 1d0c264ec29a51b28477c5d24b7942f9fc8ea603 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 5 Oct 2023 22:59:09 +0200 Subject: [PATCH 10/17] preemtible request stream --- sqld/src/error.rs | 2 + sqld/src/rpc/streaming_exec.rs | 126 ++++++++++++++++----------------- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/sqld/src/error.rs b/sqld/src/error.rs index 96b3ce29..97df108b 100644 --- a/sqld/src/error.rs +++ b/sqld/src/error.rs @@ -137,6 +137,8 @@ impl IntoResponse for Error { ConflictingRestoreParameters => self.format_err(StatusCode::BAD_REQUEST), Fork(e) => e.into_response(), PrimaryStreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamMisuse => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamInterupted => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), } } } diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 6613a790..6d8c28c0 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -209,79 +209,79 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - match this.state { - State::Idle => { - match ready!(this.request_stream.poll_next(cx)) { - Some(Err(e)) => { - *this.state = State::Fused; - Poll::Ready(Some(Err(e))) - } - Some(Ok(req)) => { - let request_id = req.request_id; - match req.request { - Some(Request::Execute(pgm)) => { - let Ok(pgm) = - crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { - *this.state = State::Fused; - return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program")))); - }; - let conn = this.connection.clone(); - let authenticated = this.authenticated.clone(); - - let s = async_stream::stream! { - let (sender, mut receiver) = mpsc::channel(1); - let builder = StreamResponseBuilder { - request_id, - sender, - current: None, - }; - let mut fut = conn.execute_program(pgm, authenticated, builder, None); - loop { - tokio::select! { - res = &mut fut => { - // drain the receiver - while let Ok(msg) = receiver.try_recv() { - yield msg; - } + // we always poll from the request stream. If a new request arrive, we interupt the current + // one, and move to the next. + if let Poll::Ready(maybe_req) = this.request_stream.poll_next(cx) { + match maybe_req { + Some(Err(e)) => { + *this.state = State::Fused; + return Poll::Ready(Some(Err(e))) + } + Some(Ok(req)) => { + let request_id = req.request_id; + match req.request { + Some(Request::Execute(pgm)) => { + let Ok(pgm) = + crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { + *this.state = State::Fused; + return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program")))); + }; + let conn = this.connection.clone(); + let authenticated = this.authenticated.clone(); + + let s = async_stream::stream! { + let (sender, mut receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id, + sender, + current: None, + }; + let mut fut = conn.execute_program(pgm, authenticated, builder, None); + loop { + tokio::select! { + res = &mut fut => { + // drain the receiver + while let Ok(msg) = receiver.try_recv() { + yield msg; + } - if let Err(e) = res { - yield ExecResp { - request_id, - response: Some(exec_resp::Response::Error(e.into())) - } + if let Err(e) = res { + yield ExecResp { + request_id, + response: Some(exec_resp::Response::Error(e.into())) } - break } - msg = receiver.recv() => { - if let Some(msg) = msg { - yield msg; - } + break + } + msg = receiver.recv() => { + if let Some(msg) = msg { + yield msg; } } } - }; - *this.state = State::Execute(Box::pin(s)); - } - Some(Request::Describe(_)) => todo!(), - None => { - *this.state = State::Fused; - return Poll::Ready(Some(Err(Status::new( - Code::InvalidArgument, - "invalid ExecReq: missing request", - )))); - } + } + }; + *this.state = State::Execute(Box::pin(s)); + } + Some(Request::Describe(_)) => todo!(), + None => { + *this.state = State::Fused; + return Poll::Ready(Some(Err(Status::new( + Code::InvalidArgument, + "invalid ExecReq: missing request", + )))); } - // we have placed the request, poll immediately - cx.waker().wake_by_ref(); - Poll::Pending - } - None => { - // this would easier if tokio_stream re-exported combinators - *this.state = State::Fused; - Poll::Ready(None) } } + None => { + *this.state = State::Fused; + return Poll::Ready(None) + } } + } + + match this.state { + State::Idle => Poll::Pending, State::Fused => Poll::Ready(None), State::Execute(stream) => { let resp = ready!(stream.as_mut().poll_next(cx)); From b7f1ef7cb369a8f84082b6086184a434d02bee64 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 5 Oct 2023 23:17:25 +0200 Subject: [PATCH 11/17] flush when response exceed limit size --- sqld/src/connection/write_proxy.rs | 4 ++-- sqld/src/rpc/streaming_exec.rs | 13 ++++++++--- sqld/tests/cluster/mod.rs | 37 ------------------------------ sqld/tests/namespaces/mod.rs | 37 ++++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 9554189e..4cda299f 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -108,7 +108,6 @@ impl MakeConnection for MakeWriteProxyConn { } } -#[derive(Debug)] pub struct WriteProxyConnection { /// Lazily initialized read connection read_conn: LibSqlConnection, @@ -196,8 +195,9 @@ impl WriteProxyConnection { let (builder, new_status, new_frame_no) = match res { Ok(res) => res, Err(e @ (Error::PrimaryStreamDisconnect | Error::PrimaryStreamMisuse)) => { - // drop the connection + // drop the connection, and reset the state. self.remote_conn.lock().await.take(); + *status = TxnStatus::Init; return Err(e); } Err(e) => return Err(e), diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 6d8c28c0..1b72fc98 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::task::{ready, Context, Poll}; use futures_core::Stream; +use prost::Message; use rusqlite::types::ValueRef; use tokio::sync::mpsc; use tonic::{Code, Status}; @@ -50,6 +51,7 @@ struct StreamResponseBuilder { request_id: u32, sender: mpsc::Sender, current: Option, + current_size: usize, } impl StreamResponseBuilder { @@ -59,12 +61,15 @@ impl StreamResponseBuilder { } fn push(&mut self, step: Step) -> Result<(), QueryResultBuilderError> { - const MAX_RESPONSE_STEPS: usize = 10; + const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::mb(1).as_u64() as usize; let current = self.current(); - current.steps.push(RespStep { step: Some(step) }); + let step = RespStep { step: Some(step) }; + let size = step.encoded_len(); + current.steps.push(step); + self.current_size += size; - if current.steps.len() > MAX_RESPONSE_STEPS { + if self.current_size >= MAX_RESPONSE_SIZE { self.flush()?; } @@ -77,6 +82,7 @@ impl StreamResponseBuilder { request_id: self.request_id, response: Some(exec_resp::Response::ProgramResp(current)), }; + self.current_size = 0; self.sender .blocking_send(resp) .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; @@ -235,6 +241,7 @@ where request_id, sender, current: None, + current_size: 0, }; let mut fut = conn.execute_program(pgm, authenticated, builder, None); loop { diff --git a/sqld/tests/cluster/mod.rs b/sqld/tests/cluster/mod.rs index 86cf8e44..0b83c171 100644 --- a/sqld/tests/cluster/mod.rs +++ b/sqld/tests/cluster/mod.rs @@ -205,40 +205,3 @@ fn sync_many_replica() { sim.run().unwrap(); } - -#[test] -fn create_namespace() { - let mut sim = Builder::new().build(); - make_cluster(&mut sim, 0, false); - - sim.client("client", async { - let db = - Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; - let conn = db.connect()?; - - let Err(e) = conn.execute("create table test (x)", ()).await else { - panic!() - }; - assert_snapshot!(e.to_string()); - - let client = Client::new(); - let resp = client - .post( - "http://foo.primary:9090/v1/namespaces/foo/create", - json!({}), - ) - .await?; - assert_eq!(resp.status(), 200); - - conn.execute("create table test (x)", ()).await.unwrap(); - let mut rows = conn.query("select count(*) from test", ()).await.unwrap(); - assert!(matches!( - rows.next().unwrap().unwrap().get_value(0).unwrap(), - Value::Integer(0) - )); - - Ok(()) - }); - - sim.run().unwrap(); -} diff --git a/sqld/tests/namespaces/mod.rs b/sqld/tests/namespaces/mod.rs index e6b2ee88..22f0d785 100644 --- a/sqld/tests/namespaces/mod.rs +++ b/sqld/tests/namespaces/mod.rs @@ -41,6 +41,43 @@ fn make_primary(sim: &mut Sim, path: PathBuf) { }); } +#[test] +fn create_namespace() { + let mut sim = Builder::new().build(); + make_cluster(&mut sim, 0, false); + + sim.client("client", async { + let db = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; + let conn = db.connect()?; + + let Err(e) = conn.execute("create table test (x)", ()).await else { + panic!() + }; + assert_snapshot!(e.to_string()); + + let client = Client::new(); + let resp = client + .post( + "http://foo.primary:9090/v1/namespaces/foo/create", + json!({}), + ) + .await?; + assert_eq!(resp.status(), 200); + + conn.execute("create table test (x)", ()).await.unwrap(); + let mut rows = conn.query("select count(*) from test", ()).await.unwrap(); + assert!(matches!( + rows.next().unwrap().unwrap().get_value(0).unwrap(), + Value::Integer(0) + )); + + Ok(()) + }); + + sim.run().unwrap(); +} + #[test] fn fork_namespace() { let mut sim = Builder::new().build(); From d29081ab236d2e93ef0ce7f449ff583312f3050b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 7 Oct 2023 11:14:24 +0200 Subject: [PATCH 12/17] simplify stream proxy implementation --- Cargo.lock | 10 ++ sqld/Cargo.toml | 1 + sqld/src/rpc/proxy.rs | 20 ++-- sqld/src/rpc/streaming_exec.rs | 200 ++++++++++++--------------------- 4 files changed, 89 insertions(+), 142 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ac7c1ea5..85d0ad3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1689,6 +1689,15 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "futures-option" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01141bf1c1a803403a2e7ae6a7eb20506c6bd2ebd209fc2424d05572a78af5f5" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -3663,6 +3672,7 @@ dependencies = [ "fallible-iterator 0.3.0", "futures", "futures-core", + "futures-option", "hmac", "hyper", "hyper-rustls 0.24.1 (git+https://github.com/rustls/hyper-rustls.git?rev=163b3f5)", diff --git a/sqld/Cargo.toml b/sqld/Cargo.toml index 8cb7f8ea..a06b491e 100644 --- a/sqld/Cargo.toml +++ b/sqld/Cargo.toml @@ -71,6 +71,7 @@ rustls-pemfile = "1.0.3" rustls = "0.21.7" async-stream = "0.3.5" libsql = { git = "https://github.com/tursodatabase/libsql.git", rev = "bea8863", optional = true } +futures-option = "0.2.0" [dev-dependencies] proptest = "1.0.0" diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index bd98661f..0a51a0b7 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use futures_core::Stream; use rusqlite::types::ValueRef; use uuid::Uuid; @@ -15,14 +17,14 @@ use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; +use crate::rpc::streaming_exec::make_proxy_stream; use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq, - ExecuteResults, QueryResult, ResultRows, Row, + ExecuteResults, QueryResult, ResultRows, Row, ExecResp, }; -use super::streaming_exec::StreamRequestHandler; use super::NAMESPACE_DOESNT_EXIST; pub mod rpc { @@ -467,20 +469,18 @@ pub async fn garbage_collect(clients: &mut HashMap> #[tonic::async_trait] impl Proxy for ProxyService { - type StreamExecStream = StreamRequestHandler>; + type StreamExecStream = Pin> + Send>>; async fn stream_exec( &self, req: tonic::Request>, ) -> Result, tonic::Status> { - dbg!(); - let authenticated = if let Some(auth) = &self.auth { + let auth= if let Some(auth) = &self.auth { auth.authenticate_grpc(&req, self.disable_namespaces)? } else { Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? }; - dbg!(); let namespace = super::extract_namespace(self.disable_namespaces, &req)?; let (connection_maker, _new_frame_notifier) = self .namespaces @@ -498,13 +498,11 @@ impl Proxy for ProxyService { } })?; - dbg!(); - let connection = connection_maker.create().await.unwrap(); + let conn = connection_maker.create().await.unwrap(); - dbg!(); - let handler = StreamRequestHandler::new(req.into_inner(), connection, authenticated); + let stream = make_proxy_stream(conn, auth, req.into_inner()); - Ok(tonic::Response::new(handler)) + Ok(tonic::Response::new(Box::pin(stream))) } async fn execute( diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 1b72fc98..b36b683d 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -1,11 +1,12 @@ -use std::pin::Pin; use std::sync::Arc; -use std::task::{ready, Context, Poll}; use futures_core::Stream; +use futures_option::OptionExt; use prost::Message; use rusqlite::types::ValueRef; +use tokio::pin; use tokio::sync::mpsc; +use tokio_stream::StreamExt; use tonic::{Code, Status}; use crate::auth::Authenticated; @@ -17,32 +18,76 @@ use crate::query_result_builder::{ }; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::exec_req::Request; -use crate::rpc::proxy::rpc::exec_resp; +use crate::rpc::proxy::rpc::exec_resp::{self, Response}; use super::proxy::rpc::resp_step::Step; use super::proxy::rpc::{self, ExecReq, ExecResp, ProgramResp, RespStep, RowValue}; -pin_project_lite::pin_project! { - pub struct StreamRequestHandler { - #[pin] - request_stream: S, - connection: Arc, - state: State, - authenticated: Authenticated, - } -} - -impl StreamRequestHandler { - pub fn new( - request_stream: S, - connection: PrimaryConnection, - authenticated: Authenticated, - ) -> Self { - Self { - request_stream, - connection: connection.into(), - state: State::Idle, - authenticated, +pub fn make_proxy_stream(conn: PrimaryConnection, auth: Authenticated, request_stream: S) -> impl Stream> +where + S: Stream>, +{ + async_stream::stream! { + let mut current_request_fut = None; + let (snd, mut recv) = mpsc::channel(1); + let conn = Arc::new(conn); + pin!(request_stream); + + loop { + tokio::select! { + biased; + Some(maybe_req) = request_stream.next() => { + match maybe_req { + Err(e) => { + tracing::error!("stream error: {e}"); + break + } + Ok(req) => { + let request_id = req.request_id; + match req.request { + Some(Request::Execute(pgm)) => { + let Ok(pgm) = + crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { + yield Err(Status::new(Code::InvalidArgument, "invalid program")); + break + }; + let conn = conn.clone(); + let auth = auth.clone(); + let sender = snd.clone(); + + let fut = async move { + let builder = StreamResponseBuilder { + request_id, + sender, + current: None, + current_size: 0, + }; + + let ret = conn.execute_program(pgm, auth, builder, None).await; + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + } + Some(Request::Describe(_)) => todo!(), + None => { + yield Err(Status::new(Code::InvalidArgument, "invalid request")); + break + } + } + } + } + }, + Some(res) = recv.recv() => { + yield Ok(res); + }, + (ret, request_id) = current_request_fut.current(), if current_request_fut.is_some() => { + if let Err(e) = ret { + yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) }) + } + }, + else => break, + } } } } @@ -199,110 +244,3 @@ impl From> for RowValue { RowValue { value } } } - -enum State { - Execute(Pin + Send>>), - Idle, - Fused, -} - -impl Stream for StreamRequestHandler -where - S: Stream>, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - // we always poll from the request stream. If a new request arrive, we interupt the current - // one, and move to the next. - if let Poll::Ready(maybe_req) = this.request_stream.poll_next(cx) { - match maybe_req { - Some(Err(e)) => { - *this.state = State::Fused; - return Poll::Ready(Some(Err(e))) - } - Some(Ok(req)) => { - let request_id = req.request_id; - match req.request { - Some(Request::Execute(pgm)) => { - let Ok(pgm) = - crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { - *this.state = State::Fused; - return Poll::Ready(Some(Err(Status::new(Code::InvalidArgument, "invalid program")))); - }; - let conn = this.connection.clone(); - let authenticated = this.authenticated.clone(); - - let s = async_stream::stream! { - let (sender, mut receiver) = mpsc::channel(1); - let builder = StreamResponseBuilder { - request_id, - sender, - current: None, - current_size: 0, - }; - let mut fut = conn.execute_program(pgm, authenticated, builder, None); - loop { - tokio::select! { - res = &mut fut => { - // drain the receiver - while let Ok(msg) = receiver.try_recv() { - yield msg; - } - - if let Err(e) = res { - yield ExecResp { - request_id, - response: Some(exec_resp::Response::Error(e.into())) - } - } - break - } - msg = receiver.recv() => { - if let Some(msg) = msg { - yield msg; - } - } - } - } - }; - *this.state = State::Execute(Box::pin(s)); - } - Some(Request::Describe(_)) => todo!(), - None => { - *this.state = State::Fused; - return Poll::Ready(Some(Err(Status::new( - Code::InvalidArgument, - "invalid ExecReq: missing request", - )))); - } - } - } - None => { - *this.state = State::Fused; - return Poll::Ready(None) - } - } - } - - match this.state { - State::Idle => Poll::Pending, - State::Fused => Poll::Ready(None), - State::Execute(stream) => { - let resp = ready!(stream.as_mut().poll_next(cx)); - match resp { - Some(resp) => Poll::Ready(Some(Ok(resp))), - None => { - // finished processing this query. Wake up immediately to prepare for the - // next - *this.state = State::Idle; - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } - } - } -} From e0edb2169e061316da9aefcc30cc06554ed77008 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 7 Oct 2023 11:20:09 +0200 Subject: [PATCH 13/17] handle replica proxy stream error --- sqld/src/rpc/replica_proxy.rs | 16 ++++++++++++++-- sqld/tests/cluster/mod.rs | 2 -- sqld/tests/namespaces/mod.rs | 4 +++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sqld/src/rpc/replica_proxy.rs b/sqld/src/rpc/replica_proxy.rs index a5ee1602..1792caf0 100644 --- a/sqld/src/rpc/replica_proxy.rs +++ b/sqld/src/rpc/replica_proxy.rs @@ -39,8 +39,20 @@ impl Proxy for ReplicaProxyService { &self, req: tonic::Request>, ) -> Result, tonic::Status> { - let (meta, ext, stream) = req.into_parts(); - let mut req = tonic::Request::from_parts(meta, ext, stream.map(|r| r.unwrap())); // TODO: handle mapping error + let (meta, ext, mut stream) = req.into_parts(); + let stream = async_stream::stream! { + while let Some(it) = stream.next().await { + match it { + Ok(it) => yield it, + Err(e) => { + // close the stream on error + tracing::error!("error proxying stream request: {e}"); + break + }, + } + } + }; + let mut req = tonic::Request::from_parts(meta, ext, stream); self.do_auth(&mut req)?; let mut client = self.client.clone(); client.stream_exec(req).await diff --git a/sqld/tests/cluster/mod.rs b/sqld/tests/cluster/mod.rs index 0b83c171..aa227c36 100644 --- a/sqld/tests/cluster/mod.rs +++ b/sqld/tests/cluster/mod.rs @@ -2,9 +2,7 @@ use super::common; -use insta::assert_snapshot; use libsql::{Database, Value}; -use serde_json::json; use sqld::config::{AdminApiConfig, RpcClientConfig, RpcServerConfig, UserApiConfig}; use tempfile::tempdir; use tokio::{task::JoinSet, time::Duration}; diff --git a/sqld/tests/namespaces/mod.rs b/sqld/tests/namespaces/mod.rs index 22f0d785..a0b90c96 100644 --- a/sqld/tests/namespaces/mod.rs +++ b/sqld/tests/namespaces/mod.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use crate::common::http::Client; use crate::common::net::{init_tracing, TestServer, TurmoilAcceptor, TurmoilConnector}; +use insta::assert_snapshot; use libsql::{Database, Value}; use serde_json::json; use sqld::config::{AdminApiConfig, RpcServerConfig, UserApiConfig}; @@ -44,7 +45,8 @@ fn make_primary(sim: &mut Sim, path: PathBuf) { #[test] fn create_namespace() { let mut sim = Builder::new().build(); - make_cluster(&mut sim, 0, false); + let tmp = tempdir().unwrap(); + make_primary(&mut sim, tmp.path().into()); sim.client("client", async { let db = From df6efbfd8e4bb97b91dbe1e160dc45b5c45c5dc9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sun, 8 Oct 2023 10:15:17 +0200 Subject: [PATCH 14/17] add streaming_exec tests --- sqld/src/connection/libsql.rs | 25 +- sqld/src/connection/mod.rs | 2 +- sqld/src/connection/write_proxy.rs | 127 +++---- sqld/src/rpc/mod.rs | 2 +- sqld/src/rpc/proxy.rs | 4 +- ..._streaming_exec__test__interupt_query.snap | 255 ++++++++++++++ ...streaming_exec__test__invalid_request.snap | 5 + ...ming_exec__test__perform_query_simple.snap | 72 ++++ ...ec__test__single_query_split_response.snap | 255 ++++++++++++++ sqld/src/rpc/streaming_exec.rs | 317 +++++++++++++++++- .../tests__namespaces__create_namespace.snap | 5 + 11 files changed, 968 insertions(+), 101 deletions(-) create mode 100644 sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap create mode 100644 sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap create mode 100644 sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap create mode 100644 sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap create mode 100644 sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 9ce14cd3..6dd8bb34 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use parking_lot::{Mutex, RwLock}; use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState}; -use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; +use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook, }; use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; @@ -237,6 +237,29 @@ where } } +#[cfg(test)] +impl LibSqlConnection { + pub fn new_test(path: &Path) -> Self { + let (_snd, rcv) = watch::channel(None); + let conn = Connection::new( + path, + Arc::new([]), + &crate::libsql_bindings::wal_hook::TRANSPARENT_METHODS, + (), + Default::default(), + DatabaseConfigStore::new_test().into(), + QueryBuilderConfig::default(), + rcv, + Default::default(), + ) + .unwrap(); + + Self { + inner: Arc::new(Mutex::new(conn)), + } + } +} + struct Connection { conn: sqld_libsql_bindings::Connection, stats: Arc, diff --git a/sqld/src/connection/mod.rs b/sqld/src/connection/mod.rs index 85b2374c..98e5ffc3 100644 --- a/sqld/src/connection/mod.rs +++ b/sqld/src/connection/mod.rs @@ -353,7 +353,7 @@ impl Connection for TrackedConnection { } #[cfg(test)] -mod test { +pub mod test { use super::*; #[derive(Debug)] diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 4cda299f..c5b1d99d 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use futures_core::future::BoxFuture; use parking_lot::Mutex as PMutex; -use rusqlite::types::ValueRef; use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; use tokio::sync::{mpsc, watch, Mutex}; use tokio_stream::StreamExt; @@ -17,15 +16,10 @@ use crate::connection::program::{DescribeCol, DescribeParam}; use crate::error::Error; use crate::namespace::NamespaceName; use crate::query_analysis::TxnStatus; -use crate::query_result_builder::{Column, QueryBuilderConfig, QueryResultBuilder}; +use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::proxy_client::ProxyClient; -use crate::rpc::proxy::rpc::resp_step::Step; -use crate::rpc::proxy::rpc::row_value::Value; -use crate::rpc::proxy::rpc::{ - self, exec_req, exec_resp, AddRowValue, ColsDescription, DisconnectMessage, ExecReq, ExecResp, - Finish, FinishStep, RowValue, StepError, -}; +use crate::rpc::proxy::rpc::{self, exec_req, exec_resp, DisconnectMessage, ExecReq, ExecResp}; use crate::rpc::NAMESPACE_METADATA_KEY; use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; @@ -297,7 +291,7 @@ impl RemoteConnection { Ok(resp) => { // there was an interuption, and we moved to the next query if resp.request_id > request_id { - return Err(Error::PrimaryStreamInterupted) + return Err(Error::PrimaryStreamInterupted); } // we can ignore response for previously interupted requests @@ -327,58 +321,20 @@ impl RemoteConnection { let mut txn_status = TxnStatus::Invalid; let mut new_frame_no = None; let builder_config = self.builder_config; - let cb = |response: exec_resp::Response| { - match response { - exec_resp::Response::ProgramResp(resp) => { - for step in resp.steps { - let Some(step) = step.step else { return Err(Error::PrimaryStreamMisuse) }; - match step { - Step::Init(_) => builder.init(&builder_config)?, - Step::BeginStep(_) => builder.begin_step()?, - Step::FinishStep(FinishStep { - affected_row_count, - last_insert_rowid, - }) => builder.finish_step(affected_row_count, last_insert_rowid)?, - Step::StepError(StepError { error: Some(err) }) => builder - .step_error(crate::error::Error::RpcQueryError(err))?, - Step::ColsDescription(ColsDescription { columns }) => { - let cols = columns.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decltype.as_deref(), - }); - builder.cols_description(cols)? - } - Step::BeginRows(_) => builder.begin_rows()?, - Step::BeginRow(_) => builder.begin_row()?, - Step::AddRowValue(AddRowValue { - val: Some(RowValue { value: Some(val) }), - }) => { - let val = match &val { - Value::Text(s) => ValueRef::Text(s.as_bytes()), - Value::Integer(i) => ValueRef::Integer(*i), - Value::Real(x) => ValueRef::Real(*x), - Value::Blob(b) => ValueRef::Blob(b.as_slice()), - Value::Null(_) => ValueRef::Null, - }; - builder.add_row_value(val)?; - } - Step::FinishRow(_) => builder.finish_row()?, - Step::FinishRows(_) => builder.finish_rows()?, - Step::Finish(f @ Finish { last_frame_no, .. }) => { - txn_status = TxnStatus::from(f.state()); - new_frame_no = last_frame_no; - builder.finish(last_frame_no, txn_status)?; - return Ok(false); - } - _ => return Err(Error::PrimaryStreamMisuse), - } - } - } - exec_resp::Response::DescribeResp(_) => return Err(Error::PrimaryStreamMisuse), - exec_resp::Response::Error(e) => return Err(Error::RpcQueryError(e)), + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::ProgramResp(resp) => { + crate::rpc::streaming_exec::apply_program_resp_to_builder( + &builder_config, + &mut builder, + resp, + |last_frame_no, status| { + txn_status = status; + new_frame_no = last_frame_no; + }, + ) } - - Ok(true) + exec_resp::Response::DescribeResp(_) => Err(Error::PrimaryStreamMisuse), + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), }; self.make_request( @@ -395,32 +351,30 @@ impl RemoteConnection { #[allow(dead_code)] // reference implementation async fn describe(&mut self, stmt: String) -> crate::Result { let mut out = None; - let cb = |response: exec_resp::Response| { - match response { - exec_resp::Response::DescribeResp(resp) => { - out = Some(DescribeResponse { - params: resp - .params - .into_iter() - .map(|p| DescribeParam { name: p.name }) - .collect(), - cols: resp - .cols - .into_iter() - .map(|c| DescribeCol { - name: c.name, - decltype: c.decltype, - }) - .collect(), - is_explain: resp.is_explain, - is_readonly: resp.is_readonly, - }); - - Ok(false) - } - exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), - exec_resp::Response::ProgramResp(_) => Err(Error::PrimaryStreamMisuse), + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::DescribeResp(resp) => { + out = Some(DescribeResponse { + params: resp + .params + .into_iter() + .map(|p| DescribeParam { name: p.name }) + .collect(), + cols: resp + .cols + .into_iter() + .map(|c| DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: resp.is_explain, + is_readonly: resp.is_readonly, + }); + + Ok(false) } + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + exec_resp::Response::ProgramResp(_) => Err(Error::PrimaryStreamMisuse), }; self.make_request( @@ -520,10 +474,11 @@ pub mod test { use arbitrary::{Arbitrary, Unstructured}; use bytes::Bytes; use rand::Fill; + use rusqlite::types::ValueRef; use super::*; use crate::{ - query_result_builder::{test::test_driver, QueryResultBuilderError}, + query_result_builder::{test::test_driver, QueryResultBuilderError, Column}, rpc::proxy::rpc::{query_result::RowResult, ExecuteResults}, }; diff --git a/sqld/src/rpc/mod.rs b/sqld/src/rpc/mod.rs index 3fda4f4d..5dc51c33 100644 --- a/sqld/src/rpc/mod.rs +++ b/sqld/src/rpc/mod.rs @@ -19,7 +19,7 @@ pub mod proxy; pub mod replica_proxy; pub mod replication_log; pub mod replication_log_proxy; -mod streaming_exec; +pub mod streaming_exec; /// A tonic error code to signify that a namespace doesn't exist. pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST"; diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 0a51a0b7..490d3080 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -23,7 +23,7 @@ use self::rpc::proxy_server::Proxy; use self::rpc::query_result::RowResult; use self::rpc::{ describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq, - ExecuteResults, QueryResult, ResultRows, Row, ExecResp, + ExecResp, ExecuteResults, QueryResult, ResultRows, Row, }; use super::NAMESPACE_DOESNT_EXIST; @@ -475,7 +475,7 @@ impl Proxy for ProxyService { &self, req: tonic::Request>, ) -> Result, tonic::Status> { - let auth= if let Some(auth) = &self.auth { + let auth = if let Some(auth) = &self.auth { auth.authenticate_grpc(&req, self.disable_namespaces)? } else { Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap new file mode 100644 index 00000000..e957ad95 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__interupt_query.snap @@ -0,0 +1,255 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: builder.into_ret() +--- +[ + Ok( + [ + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + ], + ), +] diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap new file mode 100644 index 00000000..d3090106 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap @@ -0,0 +1,5 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap_err().to_string() +--- +status: InvalidArgument, message: "invalid request", details: [], metadata: MetadataMap { headers: {} } diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap new file mode 100644 index 00000000..25b566a0 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap @@ -0,0 +1,72 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + ProgramResp( + ProgramResp { + steps: [ + RespStep { + step: Some( + Init( + Init, + ), + ), + }, + RespStep { + step: Some( + BeginStep( + BeginStep, + ), + ), + }, + RespStep { + step: Some( + ColsDescription( + ColsDescription { + columns: [], + }, + ), + ), + }, + RespStep { + step: Some( + BeginRows( + BeginRows, + ), + ), + }, + RespStep { + step: Some( + FinishRows( + FinishRows, + ), + ), + }, + RespStep { + step: Some( + FinishStep( + FinishStep { + affected_row_count: 0, + last_insert_rowid: None, + }, + ), + ), + }, + RespStep { + step: Some( + Finish( + Finish { + last_frame_no: None, + state: Init, + }, + ), + ), + }, + ], + }, + ), + ), +} diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap new file mode 100644 index 00000000..e957ad95 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap @@ -0,0 +1,255 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: builder.into_ret() +--- +[ + Ok( + [ + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + ], + ), +] diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index b36b683d..30ecd9e3 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -11,7 +11,7 @@ use tonic::{Code, Status}; use crate::auth::Authenticated; use crate::connection::Connection; -use crate::database::PrimaryConnection; +use crate::error::Error; use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, @@ -21,22 +21,48 @@ use crate::rpc::proxy::rpc::exec_req::Request; use crate::rpc::proxy::rpc::exec_resp::{self, Response}; use super::proxy::rpc::resp_step::Step; -use super::proxy::rpc::{self, ExecReq, ExecResp, ProgramResp, RespStep, RowValue}; +use super::proxy::rpc::row_value::Value; +use super::proxy::rpc::{ + self, AddRowValue, ColsDescription, ExecReq, ExecResp, Finish, FinishStep, ProgramResp, + RespStep, RowValue, StepError, +}; + +const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::mb(1).as_u64() as usize; -pub fn make_proxy_stream(conn: PrimaryConnection, auth: Authenticated, request_stream: S) -> impl Stream> +pub fn make_proxy_stream( + conn: C, + auth: Authenticated, + request_stream: S, +) -> impl Stream> where S: Stream>, + C: Connection, +{ + make_proxy_stream_inner(conn, auth, request_stream, MAX_RESPONSE_SIZE) +} + +fn make_proxy_stream_inner( + conn: C, + auth: Authenticated, + request_stream: S, + max_program_resp_size: usize, +) -> impl Stream> +where + S: Stream>, + C: Connection, { async_stream::stream! { let mut current_request_fut = None; let (snd, mut recv) = mpsc::channel(1); let conn = Arc::new(conn); + pin!(request_stream); loop { tokio::select! { biased; - Some(maybe_req) = request_stream.next() => { + maybe_req = request_stream.next() => { + let Some(maybe_req) = maybe_req else { break }; match maybe_req { Err(e) => { tracing::error!("stream error: {e}"); @@ -61,6 +87,7 @@ where sender, current: None, current_size: 0, + max_program_resp_size, }; let ret = conn.execute_program(pgm, auth, builder, None).await; @@ -97,6 +124,7 @@ struct StreamResponseBuilder { sender: mpsc::Sender, current: Option, current_size: usize, + max_program_resp_size: usize, } impl StreamResponseBuilder { @@ -106,15 +134,13 @@ impl StreamResponseBuilder { } fn push(&mut self, step: Step) -> Result<(), QueryResultBuilderError> { - const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::mb(1).as_u64() as usize; - let current = self.current(); let step = RespStep { step: Some(step) }; let size = step.encoded_len(); current.steps.push(step); self.current_size += size; - if self.current_size >= MAX_RESPONSE_SIZE { + if self.current_size >= self.max_program_resp_size { self.flush()?; } @@ -137,6 +163,63 @@ impl StreamResponseBuilder { } } +/// Apply the response to the the builder, and return whether the builder need more steps +pub fn apply_program_resp_to_builder( + config: &QueryBuilderConfig, + builder: &mut B, + resp: ProgramResp, + mut on_finish: impl FnMut(Option, TxnStatus), +) -> crate::Result { + for step in resp.steps { + let Some(step) = step.step else { + return Err(Error::PrimaryStreamMisuse); + }; + match step { + Step::Init(_) => builder.init(config)?, + Step::BeginStep(_) => builder.begin_step()?, + Step::FinishStep(FinishStep { + affected_row_count, + last_insert_rowid, + }) => builder.finish_step(affected_row_count, last_insert_rowid)?, + Step::StepError(StepError { error: Some(err) }) => { + builder.step_error(crate::error::Error::RpcQueryError(err))? + } + Step::ColsDescription(ColsDescription { columns }) => { + let cols = columns.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }); + builder.cols_description(cols)? + } + Step::BeginRows(_) => builder.begin_rows()?, + Step::BeginRow(_) => builder.begin_row()?, + Step::AddRowValue(AddRowValue { + val: Some(RowValue { value: Some(val) }), + }) => { + let val = match &val { + Value::Text(s) => ValueRef::Text(s.as_bytes()), + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Blob(b) => ValueRef::Blob(b.as_slice()), + Value::Null(_) => ValueRef::Null, + }; + builder.add_row_value(val)?; + } + Step::FinishRow(_) => builder.finish_row()?, + Step::FinishRows(_) => builder.finish_rows()?, + Step::Finish(f @ Finish { last_frame_no, .. }) => { + let txn_status = TxnStatus::from(f.state()); + on_finish(last_frame_no, txn_status); + builder.finish(last_frame_no, txn_status)?; + return Ok(false); + } + _ => return Err(Error::PrimaryStreamMisuse), + } + } + + Ok(true) +} + impl QueryResultBuilder for StreamResponseBuilder { type Ret = (); @@ -226,13 +309,11 @@ impl QueryResultBuilder for StreamResponseBuilder { Ok(()) } - fn into_ret(self) -> Self::Ret { } + fn into_ret(self) -> Self::Ret {} } impl From> for RowValue { fn from(value: ValueRef<'_>) -> Self { - use rpc::row_value::Value; - let value = Some(match value { ValueRef::Null => Value::Null(true), ValueRef::Integer(i) => Value::Integer(i), @@ -244,3 +325,219 @@ impl From> for RowValue { RowValue { value } } } + +#[cfg(test)] +mod test { + use insta::{assert_debug_snapshot, assert_snapshot}; + use tempfile::tempdir; + use tokio_stream::wrappers::ReceiverStream; + + use crate::auth::{Authorized, Permission}; + use crate::connection::libsql::LibSqlConnection; + use crate::connection::program::Program; + use crate::query_result_builder::test::TestBuilder; + use crate::rpc::proxy::rpc::StreamProgramReq; + + use super::*; + + fn exec_req_stmt(s: &str, id: u32) -> ExecReq { + ExecReq { + request_id: id, + request: Some(Request::Execute(StreamProgramReq { + pgm: Some(Program::seq(&[s]).into()), + })), + } + } + + #[tokio::test] + async fn invalid_request() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let stream = make_proxy_stream(conn, Authenticated::Anonymous, ReceiverStream::new(rcv)); + pin!(stream); + + let req = ExecReq { + request_id: 0, + request: None, + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_snapshot!(stream.next().await.unwrap().unwrap_err().to_string()); + } + + #[tokio::test] + async fn request_stream_dropped() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + drop(snd); + + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn perform_query_simple() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } + + #[tokio::test] + async fn single_query_split_response() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + // limit the size of the response for force a split + let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 0); + for i in 1..50 { + let req = exec_req_stmt( + r#"insert into test values ("something moderately long")"#, + i, + ); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, i); + } + + let req = exec_req_stmt("select * from test", 100); + snd.send(Ok(req)).await.unwrap(); + + let mut num_resp = 0; + let mut builder = TestBuilder::default(); + loop { + let Response::ProgramResp(resp) = + stream.next().await.unwrap().unwrap().response.unwrap() + else { + panic!() + }; + if !apply_program_resp_to_builder( + &QueryBuilderConfig::default(), + &mut builder, + resp, + |_, _| (), + ) + .unwrap() + { + break; + } + num_resp += 1; + } + + assert_eq!(num_resp, 3); + assert_debug_snapshot!(builder.into_ret()); + } + + #[tokio::test] + async fn interupt_query() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + // limit the size of the response for force a split + let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 0); + for i in 1..50 { + let req = exec_req_stmt( + r#"insert into test values ("something moderately long")"#, + i, + ); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, i); + } + + let req = exec_req_stmt("select * from test", 100); + snd.send(Ok(req)).await.unwrap(); + + let mut num_resp = 0; + let mut builder = TestBuilder::default(); + loop { + let Response::ProgramResp(resp) = + stream.next().await.unwrap().unwrap().response.unwrap() + else { + panic!() + }; + if !apply_program_resp_to_builder( + &QueryBuilderConfig::default(), + &mut builder, + resp, + |_, _| (), + ) + .unwrap() + { + break; + } + num_resp += 1; + } + + assert_eq!(num_resp, 3); + assert_debug_snapshot!(builder.into_ret()); + } + + #[tokio::test] + async fn request_interupted() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(2); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + // limit the size of the response for force a split + let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req1 = exec_req_stmt("create table test (foo)", 0); + let req2 = exec_req_stmt("create table test (foo)", 1); + snd.send(Ok(req1)).await.unwrap(); + snd.send(Ok(req2)).await.unwrap(); + + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 1); + } +} diff --git a/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap b/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap new file mode 100644 index 00000000..3f190961 --- /dev/null +++ b/sqld/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap @@ -0,0 +1,5 @@ +--- +source: sqld/tests/namespaces/mod.rs +expression: e.to_string() +--- +Hrana: `api error: `{"error":"Namespace `foo` doesn't exist"}`` From 8f4bade697c43a37547cdee3fd6a4844e8ba203c Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sun, 8 Oct 2023 10:45:25 +0200 Subject: [PATCH 15/17] stream exec support describe --- sqld/src/connection/libsql.rs | 2 +- sqld/src/connection/write_proxy.rs | 2 +- ...__rpc__streaming_exec__test__describe.snap | 28 ++++++++ sqld/src/rpc/streaming_exec.rs | 69 ++++++++++++++++--- 4 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 6dd8bb34..392e9ee0 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use parking_lot::{Mutex, RwLock}; use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState}; -use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook, }; +use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index c5b1d99d..70fbb90e 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -478,7 +478,7 @@ pub mod test { use super::*; use crate::{ - query_result_builder::{test::test_driver, QueryResultBuilderError, Column}, + query_result_builder::{test::test_driver, Column, QueryResultBuilderError}, rpc::proxy::rpc::{query_result::RowResult, ExecuteResults}, }; diff --git a/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap new file mode 100644 index 00000000..06794184 --- /dev/null +++ b/sqld/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap @@ -0,0 +1,28 @@ +--- +source: sqld/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + DescribeResp( + DescribeResp { + params: [ + DescribeParam { + name: Some( + "$hello", + ), + }, + ], + cols: [ + DescribeCol { + name: "$hello", + decltype: None, + }, + ], + is_explain: false, + is_readonly: true, + }, + ), + ), +} diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 30ecd9e3..6c53582c 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use futures_core::future::BoxFuture; use futures_core::Stream; use futures_option::OptionExt; use prost::Message; @@ -11,6 +12,7 @@ use tonic::{Code, Status}; use crate::auth::Authenticated; use crate::connection::Connection; +use crate::connection::program::Program; use crate::error::Error; use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ @@ -19,6 +21,7 @@ use crate::query_result_builder::{ use crate::replication::FrameNo; use crate::rpc::proxy::rpc::exec_req::Request; use crate::rpc::proxy::rpc::exec_resp::{self, Response}; +use crate::rpc::proxy::rpc::{DescribeResp, StreamDescribeReq, DescribeCol, DescribeParam}; use super::proxy::rpc::resp_step::Step; use super::proxy::rpc::row_value::Value; @@ -52,7 +55,7 @@ where C: Connection, { async_stream::stream! { - let mut current_request_fut = None; + let mut current_request_fut: Option, u32)>> = None; let (snd, mut recv) = mpsc::channel(1); let conn = Arc::new(conn); @@ -90,13 +93,41 @@ where max_program_resp_size, }; - let ret = conn.execute_program(pgm, auth, builder, None).await; + let ret = conn.execute_program(pgm, auth, builder, None).await.map(|_| ()); (ret, request_id) }; current_request_fut.replace(Box::pin(fut)); } - Some(Request::Describe(_)) => todo!(), + Some(Request::Describe(StreamDescribeReq { stmt })) => { + let auth = auth.clone(); + let sender = snd.clone(); + let conn = conn.clone(); + let fut = async move { + let do_describe = || async move { + let ret = conn.describe(stmt, auth, None).await??; + Ok(DescribeResp { + cols: ret.cols.into_iter().map(|c| DescribeCol { name: c.name, decltype: c.decltype }).collect(), + params: ret.params.into_iter().map(|p| DescribeParam { name: p.name }).collect(), + is_explain: ret.is_explain, + is_readonly: ret.is_readonly + }) + }; + + let ret: crate::Result<()> = match do_describe().await { + Ok(resp) => { + let _ = sender.send(ExecResp { request_id, response: Some(Response::DescribeResp(resp)) }).await; + Ok(()) + } + Err(e) => Err(e), + }; + + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + + }, None => { yield Err(Status::new(Code::InvalidArgument, "invalid request")); break @@ -414,7 +445,7 @@ mod test { namespace: None, permission: Permission::FullAccess, }); - // limit the size of the response for force a split + // limit the size of the response to force a split let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); pin!(stream); @@ -470,8 +501,7 @@ mod test { namespace: None, permission: Permission::FullAccess, }); - // limit the size of the response for force a split - let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); pin!(stream); @@ -526,8 +556,7 @@ mod test { namespace: None, permission: Permission::FullAccess, }); - // limit the size of the response for force a split - let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); pin!(stream); @@ -540,4 +569,28 @@ mod test { let resp = stream.next().await.unwrap().unwrap(); assert_eq!(resp.request_id, 1); } + + #[tokio::test] + async fn describe() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req = ExecReq { + request_id: 0, + request: Some(Request::Describe(StreamDescribeReq { stmt: "select $hello".into() })), + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } } From 197cbdee8c8d4ea0b05ac2c339d4546f93460f4d Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sun, 8 Oct 2023 19:16:32 +0200 Subject: [PATCH 16/17] test RemoteConnection --- sqld/src/connection/write_proxy.rs | 49 ++++++---- sqld/src/http/user/result_builder.rs | 6 +- sqld/src/query_result_builder.rs | 134 +++++++++++++++++++++++---- sqld/src/rpc/streaming_exec.rs | 40 ++++++-- 4 files changed, 183 insertions(+), 46 deletions(-) diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 70fbb90e..9e3ff809 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::Arc; use futures_core::future::BoxFuture; +use futures_core::Stream; use parking_lot::Mutex as PMutex; use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; use tokio::sync::{mpsc, watch, Mutex}; @@ -9,7 +10,6 @@ use tokio_stream::StreamExt; use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; use tonic::{Request, Streaming}; -use uuid::Uuid; use crate::auth::Authenticated; use crate::connection::program::{DescribeCol, DescribeParam}; @@ -19,7 +19,7 @@ use crate::query_analysis::TxnStatus; use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::rpc::proxy::rpc::proxy_client::ProxyClient; -use crate::rpc::proxy::rpc::{self, exec_req, exec_resp, DisconnectMessage, ExecReq, ExecResp}; +use crate::rpc::proxy::rpc::{self, exec_req, exec_resp, ExecReq, ExecResp}; use crate::rpc::NAMESPACE_METADATA_KEY; use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; @@ -102,12 +102,11 @@ impl MakeConnection for MakeWriteProxyConn { } } -pub struct WriteProxyConnection { +pub struct WriteProxyConnection> { /// Lazily initialized read connection read_conn: LibSqlConnection, write_proxy: ProxyClient, state: Mutex, - client_id: Uuid, /// FrameNo of the last write performed by this connection on the primary. /// any subsequent read on this connection must wait for the replicator to catch up with this /// frame_no @@ -118,7 +117,7 @@ pub struct WriteProxyConnection { stats: Arc, namespace: NamespaceName, - remote_conn: Mutex>, + remote_conn: Mutex>>, } impl WriteProxyConnection { @@ -135,7 +134,6 @@ impl WriteProxyConnection { read_conn, write_proxy, state: Mutex::new(TxnStatus::Init), - client_id: Uuid::new_v4(), last_write_frame_no: Default::default(), applied_frame_no_receiver, builder_config, @@ -234,8 +232,8 @@ impl WriteProxyConnection { } } -struct RemoteConnection { - response_stream: Streaming, +struct RemoteConnection> { + response_stream: R, request_sender: mpsc::Sender, current_request_id: u32, builder_config: QueryBuilderConfig, @@ -265,7 +263,12 @@ impl RemoteConnection { builder_config, }) } +} +impl RemoteConnection +where + R: Stream> + Unpin, +{ /// Perform a request on to the remote peer, and call message_cb for every message received for /// that request. message cb should return whether to expect more message for that request. async fn make_request( @@ -458,17 +461,6 @@ impl Connection for WriteProxyConnection { } } -impl Drop for WriteProxyConnection { - fn drop(&mut self) { - // best effort attempt to disconnect - let mut remote = self.write_proxy.clone(); - let client_id = self.client_id.to_string(); - tokio::spawn(async move { - let _ = remote.disconnect(DisconnectMessage { client_id }).await; - }); - } -} - #[cfg(test)] pub mod test { use arbitrary::{Arbitrary, Unstructured}; @@ -479,7 +471,7 @@ pub mod test { use super::*; use crate::{ query_result_builder::{test::test_driver, Column, QueryResultBuilderError}, - rpc::proxy::rpc::{query_result::RowResult, ExecuteResults}, + rpc::{proxy::rpc::{query_result::RowResult, ExecuteResults}, streaming_exec::test::random_valid_program_resp}, }; /// generate an arbitraty rpc value. see build.rs for usage. @@ -555,4 +547,21 @@ pub mod test { }, ); } + + #[tokio::test] + // in this test we do a roundtrip: generate a random valid program, stream it to + // RemoteConnection, and make sure that the remote connection drives the builder with the same + // state transitions. + async fn validate_random_stream_response() { + let (response_stream, validator) = random_valid_program_resp(500, 150); + let (request_sender, _request_recver) = mpsc::channel(1); + let mut remote = RemoteConnection { + response_stream: response_stream.map(Ok), + request_sender, + current_request_id: 0, + builder_config: QueryBuilderConfig::default(), + }; + + remote.execute(Program::seq(&[]), validator).await.unwrap().0.into_ret(); + } } diff --git a/sqld/src/http/user/result_builder.rs b/sqld/src/http/user/result_builder.rs index c6b4d8a2..92aa2763 100644 --- a/sqld/src/http/user/result_builder.rs +++ b/sqld/src/http/user/result_builder.rs @@ -311,7 +311,8 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { #[cfg(test)] mod test { - use crate::query_result_builder::test::random_builder_driver; + + use crate::query_result_builder::test::{random_transition, fsm_builder_driver}; use super::*; @@ -319,7 +320,8 @@ mod test { fn test_json_builder() { for _ in 0..1000 { let builder = JsonHttpPayloadBuilder::new(); - let ret = random_builder_driver(100, builder).into_ret(); + let trace = random_transition(100); + let ret = fsm_builder_driver(&trace, builder).into_ret(); println!("{}", std::str::from_utf8(&ret).unwrap()); // we produce valid json serde_json::from_slice::>(&ret).unwrap(); diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index e7698407..506c9785 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -633,7 +633,7 @@ pub mod test { #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[repr(usize)] // do not reorder! - enum FsmState { + pub enum FsmState { Init = 0, Finish, BeginStep, @@ -702,11 +702,31 @@ pub mod test { } } - pub fn random_builder_driver(mut max_steps: usize, mut b: B) -> B { + pub fn random_transition(mut max_steps: usize) -> Vec { + let mut trace = Vec::with_capacity(max_steps); + let mut state = Init; + trace.push(state); + loop { + if max_steps > 0 { + state = state.rand_transition(false); + } else { + state = state.toward_finish() + } + + trace.push(state); + if state == FsmState::Finish { + break + } + + max_steps = max_steps.saturating_sub(1); + } + trace + } + + pub fn fsm_builder_driver(trace: &[FsmState], mut b: B) -> B { let mut rand_data = [0; 10_000]; rand_data.try_fill(&mut rand::thread_rng()).unwrap(); let mut u = Unstructured::new(&rand_data); - let mut trace = Vec::new(); #[derive(Arbitrary)] pub enum ValueRef<'a> { @@ -729,9 +749,7 @@ pub mod test { } } - let mut state = Init; - trace.push(state); - loop { + for state in trace { match state { Init => b.init(&QueryBuilderConfig::default()).unwrap(), BeginStep => b.begin_step().unwrap(), @@ -758,22 +776,106 @@ pub mod test { } BuilderError => return b, } + } - if max_steps > 0 { - state = state.rand_transition(false); - } else { - state = state.toward_finish() - } + b + } - trace.push(state); + /// A Builder that validates a given execution trace + pub struct ValidateTraceBuilder { + trace: Vec, + current: usize, + } - max_steps = max_steps.saturating_sub(1); + impl ValidateTraceBuilder { + pub fn new(trace: Vec) -> Self { + Self { trace, current: 0 } } + } - // this can be usefull to help debug the generated test case - dbg!(trace); + impl QueryResultBuilder for ValidateTraceBuilder { + type Ret = (); - b + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Init); + self.current += 1; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginStep); + self.current += 1; + Ok(()) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishStep); + self.current += 1; + Ok(()) + } + + fn step_error(&mut self, _error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::StepError); + self.current += 1; + Ok(()) + } + + fn cols_description<'a>( + &mut self, + _cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::ColsDescription); + self.current += 1; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRows); + self.current += 1; + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRow); + self.current += 1; + Ok(()) + } + + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::AddRowValue); + self.current += 1; + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRow); + self.current += 1; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRows); + self.current += 1; + Ok(()) + } + + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Finish); + self.current += 1; + Ok(()) + } + + fn into_ret(self) -> Self::Ret { + assert_eq!(self.current, self.trace.len()); + } } pub struct FsmQueryBuilder { diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 6c53582c..0e07a8d8 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -12,7 +12,6 @@ use tonic::{Code, Status}; use crate::auth::Authenticated; use crate::connection::Connection; -use crate::connection::program::Program; use crate::error::Error; use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ @@ -21,7 +20,7 @@ use crate::query_result_builder::{ use crate::replication::FrameNo; use crate::rpc::proxy::rpc::exec_req::Request; use crate::rpc::proxy::rpc::exec_resp::{self, Response}; -use crate::rpc::proxy::rpc::{DescribeResp, StreamDescribeReq, DescribeCol, DescribeParam}; +use crate::rpc::proxy::rpc::{DescribeCol, DescribeParam, DescribeResp, StreamDescribeReq}; use super::proxy::rpc::resp_step::Step; use super::proxy::rpc::row_value::Value; @@ -106,10 +105,10 @@ where let fut = async move { let do_describe = || async move { let ret = conn.describe(stmt, auth, None).await??; - Ok(DescribeResp { + Ok(DescribeResp { cols: ret.cols.into_iter().map(|c| DescribeCol { name: c.name, decltype: c.decltype }).collect(), params: ret.params.into_iter().map(|p| DescribeParam { name: p.name }).collect(), - is_explain: ret.is_explain, + is_explain: ret.is_explain, is_readonly: ret.is_readonly }) }; @@ -358,7 +357,7 @@ impl From> for RowValue { } #[cfg(test)] -mod test { +pub mod test { use insta::{assert_debug_snapshot, assert_snapshot}; use tempfile::tempdir; use tokio_stream::wrappers::ReceiverStream; @@ -366,7 +365,7 @@ mod test { use crate::auth::{Authorized, Permission}; use crate::connection::libsql::LibSqlConnection; use crate::connection::program::Program; - use crate::query_result_builder::test::TestBuilder; + use crate::query_result_builder::test::{TestBuilder, ValidateTraceBuilder, random_transition, fsm_builder_driver}; use crate::rpc::proxy::rpc::StreamProgramReq; use super::*; @@ -584,13 +583,38 @@ mod test { pin!(stream); // request 0 should be dropped, and request 1 should be processed instead - let req = ExecReq { + let req = ExecReq { request_id: 0, - request: Some(Request::Describe(StreamDescribeReq { stmt: "select $hello".into() })), + request: Some(Request::Describe(StreamDescribeReq { + stmt: "select $hello".into(), + })), }; snd.send(Ok(req)).await.unwrap(); assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); } + + /// This fuction returns a random, valid, program resp for use in other tests + pub fn random_valid_program_resp( + size: usize, + max_resp_size: usize, + ) -> (impl Stream, ValidateTraceBuilder) { + let (sender, receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id: 0, + sender, + current: None, + current_size: 0, + max_program_resp_size: max_resp_size, + }; + + let trace = random_transition(size); + tokio::task::spawn_blocking({ + let trace = trace.clone(); + move || fsm_builder_driver(&trace, builder) + }); + + (ReceiverStream::new(receiver), ValidateTraceBuilder::new(trace)) + } } From 6429ba3206ac577c7f91c226d543712cd8e9acd5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sun, 8 Oct 2023 19:25:02 +0200 Subject: [PATCH 17/17] fmt --- sqld/src/connection/write_proxy.rs | 13 ++++-- sqld/src/http/user/result_builder.rs | 2 +- sqld/src/lib.rs | 2 + sqld/src/query_result_builder.rs | 7 ++- sqld/src/rpc/streaming_exec.rs | 64 +++------------------------- 5 files changed, 25 insertions(+), 63 deletions(-) diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 9e3ff809..be030fef 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -153,7 +153,6 @@ impl WriteProxyConnection { F: FnOnce(&mut RemoteConnection) -> BoxFuture<'_, crate::Result>, { let mut remote_conn = self.remote_conn.lock().await; - // TODO: catch broken connection, and reset it to None. if remote_conn.is_some() { cb(remote_conn.as_mut().unwrap()).await } else { @@ -471,7 +470,10 @@ pub mod test { use super::*; use crate::{ query_result_builder::{test::test_driver, Column, QueryResultBuilderError}, - rpc::{proxy::rpc::{query_result::RowResult, ExecuteResults}, streaming_exec::test::random_valid_program_resp}, + rpc::{ + proxy::rpc::{query_result::RowResult, ExecuteResults}, + streaming_exec::test::random_valid_program_resp, + }, }; /// generate an arbitraty rpc value. see build.rs for usage. @@ -562,6 +564,11 @@ pub mod test { builder_config: QueryBuilderConfig::default(), }; - remote.execute(Program::seq(&[]), validator).await.unwrap().0.into_ret(); + remote + .execute(Program::seq(&[]), validator) + .await + .unwrap() + .0 + .into_ret(); } } diff --git a/sqld/src/http/user/result_builder.rs b/sqld/src/http/user/result_builder.rs index 92aa2763..4f56f7db 100644 --- a/sqld/src/http/user/result_builder.rs +++ b/sqld/src/http/user/result_builder.rs @@ -312,7 +312,7 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { #[cfg(test)] mod test { - use crate::query_result_builder::test::{random_transition, fsm_builder_driver}; + use crate::query_result_builder::test::{fsm_builder_driver, random_transition}; use super::*; diff --git a/sqld/src/lib.rs b/sqld/src/lib.rs index a0461bd9..ba984f16 100644 --- a/sqld/src/lib.rs +++ b/sqld/src/lib.rs @@ -500,6 +500,8 @@ where let proxy_service = ProxyService::new(namespaces.clone(), None, self.disable_namespaces); // Garbage collect proxy clients every 30 seconds + // TODO: this will no longer be necessary once client have adopted the streaming proxy + // protocol self.join_set.spawn({ let clients = proxy_service.clients(); async move { diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index 506c9785..121d1a00 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -715,7 +715,7 @@ pub mod test { trace.push(state); if state == FsmState::Finish { - break + break; } max_steps = max_steps.saturating_sub(1); @@ -818,7 +818,10 @@ pub mod test { Ok(()) } - fn step_error(&mut self, _error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + fn step_error( + &mut self, + _error: crate::error::Error, + ) -> Result<(), QueryResultBuilderError> { assert_eq!(self.trace[self.current], FsmState::StepError); self.current += 1; Ok(()) diff --git a/sqld/src/rpc/streaming_exec.rs b/sqld/src/rpc/streaming_exec.rs index 0e07a8d8..b5826da3 100644 --- a/sqld/src/rpc/streaming_exec.rs +++ b/sqld/src/rpc/streaming_exec.rs @@ -365,7 +365,9 @@ pub mod test { use crate::auth::{Authorized, Permission}; use crate::connection::libsql::LibSqlConnection; use crate::connection::program::Program; - use crate::query_result_builder::test::{TestBuilder, ValidateTraceBuilder, random_transition, fsm_builder_driver}; + use crate::query_result_builder::test::{ + fsm_builder_driver, random_transition, TestBuilder, ValidateTraceBuilder, + }; use crate::rpc::proxy::rpc::StreamProgramReq; use super::*; @@ -491,61 +493,6 @@ pub mod test { assert_debug_snapshot!(builder.into_ret()); } - #[tokio::test] - async fn interupt_query() { - let tmp = tempdir().unwrap(); - let conn = LibSqlConnection::new_test(tmp.path()); - let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); - - pin!(stream); - - let req = exec_req_stmt("create table test (foo)", 0); - snd.send(Ok(req)).await.unwrap(); - let resp = stream.next().await.unwrap().unwrap(); - assert_eq!(resp.request_id, 0); - for i in 1..50 { - let req = exec_req_stmt( - r#"insert into test values ("something moderately long")"#, - i, - ); - snd.send(Ok(req)).await.unwrap(); - let resp = stream.next().await.unwrap().unwrap(); - assert_eq!(resp.request_id, i); - } - - let req = exec_req_stmt("select * from test", 100); - snd.send(Ok(req)).await.unwrap(); - - let mut num_resp = 0; - let mut builder = TestBuilder::default(); - loop { - let Response::ProgramResp(resp) = - stream.next().await.unwrap().unwrap().response.unwrap() - else { - panic!() - }; - if !apply_program_resp_to_builder( - &QueryBuilderConfig::default(), - &mut builder, - resp, - |_, _| (), - ) - .unwrap() - { - break; - } - num_resp += 1; - } - - assert_eq!(num_resp, 3); - assert_debug_snapshot!(builder.into_ret()); - } - #[tokio::test] async fn request_interupted() { let tmp = tempdir().unwrap(); @@ -615,6 +562,9 @@ pub mod test { move || fsm_builder_driver(&trace, builder) }); - (ReceiverStream::new(receiver), ValidateTraceBuilder::new(trace)) + ( + ReceiverStream::new(receiver), + ValidateTraceBuilder::new(trace), + ) } }