@@ -8,7 +8,9 @@ use std::sync::Arc;
88
99use async_trait:: async_trait;
1010use chrono:: { DateTime , Utc } ;
11- use tracing:: info;
11+ use devolutions_gateway_task:: { ShutdownSignal , Task } ;
12+ use devolutions_pedm_shared:: policy:: ElevationResult ;
13+ use tracing:: { info, warn} ;
1214
1315mod err;
1416mod util;
@@ -209,4 +211,110 @@ pub(crate) trait Database: Send + Sync {
209211 async fn update_accounts ( & self , diff : & AccountsDiff ) -> Result < ( ) , DbError > ;
210212
211213 async fn insert_elevate_tmp_request ( & self , req_id : i32 , seconds : i32 ) -> Result < ( ) , DbError > ;
214+
215+ async fn insert_jit_elevation_result ( & self , result : & ElevationResult ) -> Result < ( ) , DbError > ;
216+ }
217+
218+ // Bridge for DB operations from synchronous functions.
219+ // This may or may not be a temporary workaround.
220+
221+ pub ( crate ) struct DbHandleError < T > {
222+ pub ( crate ) db_error : Option < DbError > ,
223+ pub ( crate ) value : T ,
224+ }
225+
226+ #[ derive( Clone ) ]
227+ pub ( crate ) struct DbHandle {
228+ tx : tokio:: sync:: mpsc:: Sender < DbRequest > ,
229+ }
230+
231+ impl DbHandle {
232+ pub ( crate ) fn insert_jit_elevation_result (
233+ & self ,
234+ result : ElevationResult ,
235+ ) -> Result < ( ) , DbHandleError < ElevationResult > > {
236+ let ( tx, rx) = tokio:: sync:: oneshot:: channel ( ) ;
237+
238+ match self
239+ . tx
240+ . blocking_send ( DbRequest :: InsertJitElevationResult { result, tx } )
241+ {
242+ Ok ( ( ) ) => match rx. blocking_recv ( ) {
243+ Ok ( db_result) => db_result,
244+ Err ( _) => {
245+ warn ! ( "Did not receive the response from the async bridge task" ) ;
246+ Ok ( ( ) )
247+ }
248+ } ,
249+ Err ( error) => {
250+ let DbRequest :: InsertJitElevationResult { result, .. } = error. 0 else {
251+ unreachable ! ( )
252+ } ;
253+
254+ Err ( DbHandleError {
255+ db_error : None ,
256+ value : result,
257+ } )
258+ }
259+ }
260+ }
261+ }
262+
263+ pub ( crate ) enum DbRequest {
264+ InsertJitElevationResult {
265+ result : ElevationResult ,
266+ tx : tokio:: sync:: oneshot:: Sender < Result < ( ) , DbHandleError < ElevationResult > > > ,
267+ } ,
268+ }
269+
270+ pub ( crate ) struct DbAsyncBridgeTask {
271+ db : Db ,
272+ rx : tokio:: sync:: mpsc:: Receiver < DbRequest > ,
273+ }
274+
275+ impl DbAsyncBridgeTask {
276+ pub fn new ( db : Db ) -> ( DbHandle , Self ) {
277+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( 8 ) ;
278+ ( DbHandle { tx } , Self { db, rx } )
279+ }
280+ }
281+
282+ #[ async_trait]
283+ impl Task for DbAsyncBridgeTask {
284+ type Output = anyhow:: Result < ( ) > ;
285+
286+ const NAME : & ' static str = "db-async-bridge" ;
287+
288+ async fn run ( mut self , mut shutdown_signal : ShutdownSignal ) -> anyhow:: Result < ( ) > {
289+ loop {
290+ tokio:: select! {
291+ req = self . rx. recv( ) => {
292+ let Some ( req) = req else {
293+ break ;
294+ } ;
295+
296+ match req {
297+ DbRequest :: InsertJitElevationResult { result, tx } => {
298+ match self . db. insert_jit_elevation_result( & result) . await {
299+ Ok ( ( ) ) => {
300+ let _ = tx. send( Ok ( ( ) ) ) ;
301+ }
302+ Err ( error) => {
303+ let _ = tx. send( Err ( DbHandleError {
304+ db_error: Some ( error) ,
305+ value: result,
306+ } ) ) ;
307+ }
308+ }
309+ }
310+ }
311+ }
312+ _ = shutdown_signal. wait( ) => {
313+ break ;
314+ }
315+ }
316+ }
317+
318+ Ok ( ( ) )
319+ }
212320}
0 commit comments