@@ -17,6 +17,7 @@ use std::sync::Arc;
1717use std:: time:: Duration ;
1818use std:: time:: Instant ;
1919
20+ use regex:: Regex ;
2021use reqwest:: header:: HeaderMap ;
2122use reqwest:: header:: HeaderValue ;
2223use reqwest:: Client ;
@@ -41,12 +42,43 @@ pub struct HttpClient {
4142#[ derive( serde:: Deserialize , Debug ) ]
4243struct QueryResponse {
4344 session : Option < HttpSessionConf > ,
45+ schema : Vec < SchemaItem > ,
4446 data : Option < serde_json:: Value > ,
4547 next_uri : Option < String > ,
4648
4749 error : Option < serde_json:: Value > ,
4850}
4951
52+ #[ derive( serde:: Deserialize , Debug ) ]
53+ struct SchemaItem {
54+ #[ allow( dead_code) ]
55+ pub name : String ,
56+ pub r#type : String ,
57+ }
58+
59+ impl SchemaItem {
60+ fn parse_type ( & self ) -> Result < DefaultColumnType > {
61+ let nullable = Regex :: new ( r"^Nullable\((.+)\)$" ) . unwrap ( ) ;
62+ let value = match nullable. captures ( & self . r#type ) {
63+ Some ( captures) => {
64+ let ( _, [ value] ) = captures. extract ( ) ;
65+ value
66+ }
67+ None => & self . r#type ,
68+ } ;
69+ let typ = match value {
70+ "String" => DefaultColumnType :: Text ,
71+ "Int8" | "Int16" | "Int32" | "Int64" | "UInt8" | "UInt16" | "UInt32" | "UInt64" => {
72+ DefaultColumnType :: Integer
73+ }
74+ "Float32" | "Float64" => DefaultColumnType :: FloatingPoint ,
75+ decimal if decimal. starts_with ( "Decimal" ) => DefaultColumnType :: FloatingPoint ,
76+ _ => DefaultColumnType :: Any ,
77+ } ;
78+ Ok ( typ)
79+ }
80+ }
81+
5082// make error message the same with ErrorCode::display
5183fn format_error ( value : serde_json:: Value ) -> String {
5284 let value = value. as_object ( ) . unwrap ( ) ;
@@ -125,14 +157,20 @@ impl HttpClient {
125157
126158 pub async fn query ( & mut self , sql : & str ) -> Result < DBOutput < DefaultColumnType > > {
127159 let start = Instant :: now ( ) ;
160+ let port = self . port ;
161+ let mut response = self
162+ . post_query ( sql, & format ! ( "http://127.0.0.1:{port}/v1/query" ) )
163+ . await ?;
128164
129- let url = format ! ( "http://127.0.0.1:{}/v1/query" , self . port ) ;
165+ let mut schema = std :: mem :: take ( & mut response . schema ) ;
130166 let mut parsed_rows = vec ! [ ] ;
131- let mut response = self . post_query ( sql, & url) . await ?;
132167 self . handle_response ( & response, & mut parsed_rows) ?;
133168 while let Some ( next_uri) = & response. next_uri {
134- let url = format ! ( "http://127.0.0.1:{}{next_uri}" , self . port) ;
135- let new_response = self . poll_query_result ( & url) . await ?;
169+ let url = format ! ( "http://127.0.0.1:{port}{next_uri}" ) ;
170+ let mut new_response = self . poll_query_result ( & url) . await ?;
171+ if schema. is_empty ( ) && !new_response. schema . is_empty ( ) {
172+ schema = std:: mem:: take ( & mut new_response. schema ) ;
173+ }
136174 if new_response. next_uri . is_some ( ) {
137175 self . handle_response ( & new_response, & mut parsed_rows) ?;
138176 response = new_response;
@@ -143,11 +181,6 @@ impl HttpClient {
143181 if let Some ( error) = response. error {
144182 return Err ( format_error ( error) . into ( ) ) ;
145183 }
146- // Todo: add types to compare
147- let mut types = vec ! [ ] ;
148- if !parsed_rows. is_empty ( ) {
149- types = vec ! [ DefaultColumnType :: Any ; parsed_rows[ 0 ] . len( ) ] ;
150- }
151184
152185 if self . debug {
153186 println ! (
@@ -156,6 +189,11 @@ impl HttpClient {
156189 ) ;
157190 }
158191
192+ let types = schema
193+ . iter ( )
194+ . map ( |item| item. parse_type ( ) . unwrap_or ( DefaultColumnType :: Any ) )
195+ . collect ( ) ;
196+
159197 Ok ( DBOutput :: Rows {
160198 types,
161199 rows : parsed_rows,
0 commit comments