@@ -64,9 +64,7 @@ impl IConnection for RestAPIConnection {
6464
6565    async  fn  exec ( & self ,  sql :  & str )  -> Result < i64 >  { 
6666        info ! ( "exec: {}" ,  sql) ; 
67-         let  page = self . client . query_all ( sql) . await ?; 
68-         let  affected_rows = parse_affected_rows_from_page ( & page) ?; 
69-         Ok ( affected_rows) 
67+         self . calculate_affected_rows_from_iter ( sql) . await 
7068    } 
7169
7270    async  fn  kill_query ( & self ,  query_id :  & str )  -> Result < ( ) >  { 
@@ -178,13 +176,13 @@ impl IConnection for RestAPIConnection {
178176    } 
179177} 
180178
181- impl < ' o >  RestAPIConnection  { 
179+ impl  RestAPIConnection  { 
182180    pub  async  fn  try_create ( dsn :  & str ,  name :  String )  -> Result < Self >  { 
183181        let  client = APIClient :: new ( dsn,  Some ( name) ) . await ?; 
184182        Ok ( Self  {  client } ) 
185183    } 
186184
187-     fn  default_file_format_options ( )  -> BTreeMap < & ' o  str ,  & ' o  str >  { 
185+     fn  default_file_format_options ( )  -> BTreeMap < & ' static  str ,  & ' static  str >  { 
188186        vec ! [ 
189187            ( "type" ,  "CSV" ) , 
190188            ( "field_delimiter" ,  "," ) , 
@@ -195,9 +193,68 @@ impl<'o> RestAPIConnection {
195193        . collect ( ) 
196194    } 
197195
198-     fn  default_copy_options ( )  -> BTreeMap < & ' o  str ,  & ' o  str >  { 
196+     fn  default_copy_options ( )  -> BTreeMap < & ' static  str ,  & ' static  str >  { 
199197        vec ! [ ( "purge" ,  "true" ) ] . into_iter ( ) . collect ( ) 
200198    } 
199+     fn  parse_row_count_string ( value_str :  & str )  -> Result < i64 ,  String >  { 
200+         let  trimmed = value_str. trim ( ) ; 
201+ 
202+         if  trimmed. is_empty ( )  { 
203+             return  Ok ( 0 ) ; 
204+         } 
205+ 
206+         if  let  Ok ( count)  = trimmed. parse :: < i64 > ( )  { 
207+             return  Ok ( count) ; 
208+         } 
209+ 
210+         if  let  Ok ( count)  = serde_json:: from_str :: < i64 > ( trimmed)  { 
211+             return  Ok ( count) ; 
212+         } 
213+ 
214+         let  unquoted = trimmed. trim_matches ( '"' ) ; 
215+         if  let  Ok ( count)  = unquoted. parse :: < i64 > ( )  { 
216+             return  Ok ( count) ; 
217+         } 
218+ 
219+         Err ( format ! ( 
220+             "failed to parse affected rows from: '{}'" , 
221+             value_str
222+         ) ) 
223+     } 
224+ 
225+     async  fn  calculate_affected_rows_from_iter ( & self ,  sql :  & str )  -> Result < i64 >  { 
226+         let  mut  rows = IConnection :: query_iter ( self ,  sql) . await ?; 
227+         let  mut  count = 0i64 ; 
228+ 
229+         use  tokio_stream:: StreamExt ; 
230+         // Get the first row to check if it has affected rows info 
231+         if  let  Some ( first_row)  = rows. next ( ) . await  { 
232+             let  row = first_row?; 
233+             let  schema = row. schema ( ) ; 
234+ 
235+             // Check if this is an affected rows response 
236+             if  !schema. fields ( ) . is_empty ( )  && schema. fields ( ) [ 0 ] . name . contains ( "number of rows" )  { 
237+                 let  values = row. values ( ) ; 
238+                 if  !values. is_empty ( )  { 
239+                     let  value = & values[ 0 ] ; 
240+                     let  s:  String  = value. clone ( ) . try_into ( ) . map_err ( |e| { 
241+                         Error :: InvalidResponse ( format ! ( "Failed to convert value to string: {}" ,  e) ) 
242+                     } ) ?; 
243+                     count = Self :: parse_row_count_string ( & s) . map_err ( Error :: InvalidResponse ) ?; 
244+                 } 
245+             }  else  { 
246+                 // If it's not affected rows info, count normally 
247+                 count = 1 ; 
248+                 // Continue counting the rest 
249+                 while  let  Some ( row_result)  = rows. next ( ) . await  { 
250+                     row_result?; 
251+                     count += 1 ; 
252+                 } 
253+             } 
254+         } 
255+ 
256+         Ok ( count) 
257+     } 
201258} 
202259
203260pub  struct  RestAPIRows < T >  { 
@@ -288,49 +345,3 @@ impl FromRowStats for RawRowWithStats {
288345        Ok ( RawRowWithStats :: Row ( RawRow :: new ( rows,  row) ) ) 
289346    } 
290347} 
291- 
292- fn  parse_affected_rows_from_page ( page :  & databend_client:: Page )  -> Result < i64 >  { 
293-     if  page. schema . is_empty ( )  { 
294-         return  Ok ( 0 ) ; 
295-     } 
296- 
297-     let  first_field = & page. schema [ 0 ] ; 
298-     if  !first_field. name . contains ( "number of rows" )  { 
299-         return  Ok ( 0 ) ; 
300-     } 
301- 
302-     if  page. data . is_empty ( )  || page. data [ 0 ] . is_empty ( )  { 
303-         return  Ok ( 0 ) ; 
304-     } 
305- 
306-     match  & page. data [ 0 ] [ 0 ]  { 
307-         Some ( value_str)  => parse_row_count_string ( value_str) . map_err ( Error :: InvalidResponse ) , 
308-         None  => Ok ( 0 ) , 
309-     } 
310- } 
311- 
312- fn  parse_row_count_string ( value_str :  & str )  -> Result < i64 ,  String >  { 
313-     let  trimmed = value_str. trim ( ) ; 
314- 
315-     if  trimmed. is_empty ( )  { 
316-         return  Ok ( 0 ) ; 
317-     } 
318- 
319-     if  let  Ok ( count)  = trimmed. parse :: < i64 > ( )  { 
320-         return  Ok ( count) ; 
321-     } 
322- 
323-     if  let  Ok ( count)  = serde_json:: from_str :: < i64 > ( trimmed)  { 
324-         return  Ok ( count) ; 
325-     } 
326- 
327-     let  unquoted = trimmed. trim_matches ( '"' ) ; 
328-     if  let  Ok ( count)  = unquoted. parse :: < i64 > ( )  { 
329-         return  Ok ( count) ; 
330-     } 
331- 
332-     Err ( format ! ( 
333-         "failed to parse affected rows from: '{}'" , 
334-         value_str
335-     ) ) 
336- } 
0 commit comments