diff --git a/Cargo.lock b/Cargo.lock index c55895ff204..93ae11d0591 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2986,7 +2986,7 @@ dependencies = [ "libc", "option-ext", "redox_users 0.5.2", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3225,7 +3225,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4198,7 +4198,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.1", + "socket2 0.5.10", "system-configuration", "tokio", "tower-service", @@ -4523,7 +4523,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4618,7 +4618,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -5246,10 +5246,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00a21b43fe2a373896727b97927adedd2683d2907683f294f62cf8815fbf6a01" +version = "0.4.5" dependencies = [ + "async-trait", "reqwest", "serde", "serde_json", @@ -6007,7 +6006,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -7112,7 +7111,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.35", - "socket2 0.6.1", + "socket2 0.5.10", "thiserror 2.0.17", "tokio", "tracing", @@ -7149,9 +7148,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.1", + "socket2 0.5.10", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.59.0", ] [[package]] @@ -7743,7 +7742,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -8778,7 +8777,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix 1.1.3", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -9691,7 +9690,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 66e5c0a99f0..8041769514b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ lance-io = { version = "=2.0.0-beta.5", path = "./rust/lance-io", default-featur lance-linalg = { version = "=2.0.0-beta.5", path = "./rust/lance-linalg" } lance-namespace = { version = "=2.0.0-beta.5", path = "./rust/lance-namespace" } lance-namespace-impls = { version = "=2.0.0-beta.5", path = "./rust/lance-namespace-impls" } -lance-namespace-reqwest-client = "0.3.1" +lance-namespace-reqwest-client = { path = "../lance-namespace/rust/lance-namespace-reqwest-client" } lance-table = { version = "=2.0.0-beta.5", path = "./rust/lance-table" } lance-test-macros = { version = "=2.0.0-beta.5", path = "./rust/lance-test-macros" } lance-testing = { version = "=2.0.0-beta.5", path = "./rust/lance-testing" } diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 1ee870160f5..30901005dae 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -4302,10 +4302,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00a21b43fe2a373896727b97927adedd2683d2907683f294f62cf8815fbf6a01" +version = "0.4.5" dependencies = [ + "async-trait", "reqwest", "serde", "serde_json", diff --git a/python/Cargo.lock b/python/Cargo.lock index d4e19a95a6d..0c793a05379 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -4640,10 +4640,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00a21b43fe2a373896727b97927adedd2683d2907683f294f62cf8815fbf6a01" +version = "0.4.5" dependencies = [ + "async-trait", "reqwest", "serde", "serde_json", diff --git a/rust/lance-io/src/object_store/storage_options.rs b/rust/lance-io/src/object_store/storage_options.rs index f809df8d1d3..d5d8dc97ee5 100644 --- a/rust/lance-io/src/object_store/storage_options.rs +++ b/rust/lance-io/src/object_store/storage_options.rs @@ -115,6 +115,7 @@ impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider { id: Some(self.table_id.clone()), version: None, with_table_uri: None, + ..Default::default() }; let response = self diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index 91714d73d90..ef391799ff2 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -934,6 +934,7 @@ impl LanceNamespace for DirectoryNamespace { schema: Some(Box::new(json_schema)), storage_options, stats: None, + ..Default::default() }) } Err(err) => { @@ -960,6 +961,7 @@ impl LanceNamespace for DirectoryNamespace { schema: None, storage_options, stats: None, + ..Default::default() }) } else { Err(Error::Namespace { @@ -1156,7 +1158,6 @@ impl LanceNamespace for DirectoryNamespace { Ok(CreateEmptyTableResponse { transaction_id: None, location: Some(table_uri), - properties: None, storage_options: self.storage_options.clone(), }) } @@ -1987,6 +1988,7 @@ mod tests { id: Some(vec![]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2020,6 +2022,7 @@ mod tests { id: Some(vec!["parent".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2033,6 +2036,7 @@ mod tests { id: Some(vec![]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2065,6 +2069,7 @@ mod tests { id: Some(vec!["test_ns".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2113,6 +2118,7 @@ mod tests { id: Some(vec!["test_ns".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2248,6 +2254,7 @@ mod tests { // Describe namespace and verify properties let describe_req = DescribeNamespaceRequest { id: Some(vec!["test_ns".to_string()]), + ..Default::default() }; let result = namespace.describe_namespace(describe_req).await; assert!(result.is_ok()); @@ -2326,6 +2333,7 @@ mod tests { id: Some(vec!["ns1".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await.unwrap(); assert_eq!(result.tables.len(), 1); @@ -2335,6 +2343,7 @@ mod tests { id: Some(vec!["ns2".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_tables(list_req).await.unwrap(); assert_eq!(result.tables.len(), 1); diff --git a/rust/lance-namespace-impls/src/dir/manifest.rs b/rust/lance-namespace-impls/src/dir/manifest.rs index 4791bbb9df5..df433160d74 100644 --- a/rust/lance-namespace-impls/src/dir/manifest.rs +++ b/rust/lance-namespace-impls/src/dir/manifest.rs @@ -1113,6 +1113,7 @@ impl LanceNamespace for ManifestNamespace { schema: Some(Box::new(json_schema)), storage_options: self.storage_options.clone(), stats: None, + ..Default::default() }) } Err(_) => { @@ -1126,6 +1127,7 @@ impl LanceNamespace for ManifestNamespace { schema: None, storage_options: self.storage_options.clone(), stats: None, + ..Default::default() }) } } @@ -1624,7 +1626,6 @@ impl LanceNamespace for ManifestNamespace { Ok(CreateEmptyTableResponse { transaction_id: None, location: Some(table_uri), - properties: None, storage_options: self.storage_options.clone(), }) } @@ -2172,6 +2173,7 @@ mod tests { // Verify namespace exists let exists_req = NamespaceExistsRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_ok(), "Namespace should exist"); @@ -2181,6 +2183,7 @@ mod tests { id: Some(vec![]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2225,6 +2228,7 @@ mod tests { // Verify nested namespace exists let exists_req = NamespaceExistsRequest { id: Some(vec!["parent".to_string(), "child".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_ok(), "Nested namespace should exist"); @@ -2234,6 +2238,7 @@ mod tests { id: Some(vec!["parent".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_namespaces(list_req).await; assert!(result.is_ok()); @@ -2301,6 +2306,7 @@ mod tests { // Verify namespace no longer exists let exists_req = NamespaceExistsRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.namespace_exists(exists_req).await; assert!(result.is_err(), "Namespace should not exist after drop"); @@ -2379,6 +2385,7 @@ mod tests { id: Some(vec!["ns1".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = dir_namespace.list_tables(list_req).await; assert!(result.is_ok()); @@ -2415,6 +2422,7 @@ mod tests { // Describe the namespace let describe_req = DescribeNamespaceRequest { id: Some(vec!["ns1".to_string()]), + ..Default::default() }; let result = dir_namespace.describe_namespace(describe_req).await; assert!( diff --git a/rust/lance-namespace-impls/src/lib.rs b/rust/lance-namespace-impls/src/lib.rs index 88248841bcb..df5c7026053 100644 --- a/rust/lance-namespace-impls/src/lib.rs +++ b/rust/lance-namespace-impls/src/lib.rs @@ -104,7 +104,7 @@ pub use credentials::azure::{AzureCredentialVendor, AzureCredentialVendorConfig} pub use credentials::azure_props; #[cfg(feature = "rest")] -pub use rest::{RestNamespace, RestNamespaceBuilder}; +pub use rest::{HeaderProvider, RestNamespace, RestNamespaceBuilder, StaticHeaderProvider}; #[cfg(feature = "rest-adapter")] pub use rest_adapter::{RestAdapter, RestAdapterConfig, RestAdapterHandle}; diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index 3b5d0650659..f1e5d012ed8 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -4,13 +4,13 @@ //! REST implementation of Lance Namespace use std::collections::HashMap; +use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; -use lance_namespace::apis::{ - configuration::Configuration, namespace_api, table_api, tag_api, transaction_api, -}; +use lance_namespace::apis::configuration::{Configuration, RequestMiddleware}; +use lance_namespace::apis::{namespace_api, table_api, tag_api, transaction_api}; use lance_namespace::models::{ AlterTableAddColumnsRequest, AlterTableAddColumnsResponse, AlterTableAlterColumnsRequest, AlterTableAlterColumnsResponse, AlterTableDropColumnsRequest, AlterTableDropColumnsResponse, @@ -41,6 +41,98 @@ use lance_core::{box_error, Error, Result}; use lance_namespace::LanceNamespace; +/// Provides headers for REST API requests. +/// +/// Implementations can cache tokens and refresh them as needed. +/// The `get_headers` method is called before each API request. +/// +/// # Examples +/// +/// ``` +/// # use lance_namespace_impls::rest::{HeaderProvider, StaticHeaderProvider}; +/// # use std::collections::HashMap; +/// // Static headers +/// let mut headers = HashMap::new(); +/// headers.insert("Authorization".to_string(), "Bearer token".to_string()); +/// let provider = StaticHeaderProvider::new(headers); +/// ``` +#[async_trait] +pub trait HeaderProvider: Send + Sync + std::fmt::Debug { + /// Get headers to send with requests. + /// + /// Called before API requests - implementation should handle caching/refresh. + async fn get_headers(&self) -> Result>; + + /// Return self as Any for downcasting. + fn as_any(&self) -> &dyn std::any::Any; +} + +/// Static header provider that returns the same headers for every request. +/// +/// Use this for headers that don't change during the lifetime of the namespace, +/// such as API keys or user agent strings. +#[derive(Debug, Clone)] +pub struct StaticHeaderProvider { + headers: HashMap, +} + +impl StaticHeaderProvider { + /// Create a new static header provider with the given headers. + pub fn new(headers: HashMap) -> Self { + Self { headers } + } + + /// Create a new static header provider from an iterator of key-value pairs. + pub fn from_iter(iter: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into, + { + Self { + headers: iter + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), + } + } +} + +#[async_trait] +impl HeaderProvider for StaticHeaderProvider { + async fn get_headers(&self) -> Result> { + Ok(self.headers.clone()) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Middleware that adds headers from a HeaderProvider to each request. +#[derive(Debug)] +struct HeaderMiddleware { + provider: Arc, +} + +#[async_trait] +impl RequestMiddleware for HeaderMiddleware { + async fn process(&self, mut request: reqwest::Request) -> reqwest::Request { + if let Ok(headers) = self.provider.get_headers().await { + let header_map = request.headers_mut(); + for (name, value) in headers { + if let (Ok(n), Ok(v)) = ( + reqwest::header::HeaderName::from_bytes(name.as_bytes()), + reqwest::header::HeaderValue::from_str(&value), + ) { + header_map.insert(n, v); + } + } + } + request + } +} + /// Builder for creating a RestNamespace. /// /// This builder provides a fluent API for configuring and establishing @@ -59,17 +151,47 @@ use lance_namespace::LanceNamespace; /// # Ok(()) /// # } /// ``` -#[derive(Debug, Clone)] pub struct RestNamespaceBuilder { uri: String, delimiter: String, - headers: HashMap, + header_provider: Option>, cert_file: Option, key_file: Option, ssl_ca_cert: Option, assert_hostname: bool, } +impl std::fmt::Debug for RestNamespaceBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RestNamespaceBuilder") + .field("uri", &self.uri) + .field("delimiter", &self.delimiter) + .field( + "header_provider", + &self.header_provider.as_ref().map(|_| "..."), + ) + .field("cert_file", &self.cert_file) + .field("key_file", &self.key_file) + .field("ssl_ca_cert", &self.ssl_ca_cert) + .field("assert_hostname", &self.assert_hostname) + .finish() + } +} + +impl Clone for RestNamespaceBuilder { + fn clone(&self) -> Self { + Self { + uri: self.uri.clone(), + delimiter: self.delimiter.clone(), + header_provider: self.header_provider.clone(), + cert_file: self.cert_file.clone(), + key_file: self.key_file.clone(), + ssl_ca_cert: self.ssl_ca_cert.clone(), + assert_hostname: self.assert_hostname, + } + } +} + impl RestNamespaceBuilder { /// Default delimiter for object identifiers const DEFAULT_DELIMITER: &'static str = "$"; @@ -83,7 +205,7 @@ impl RestNamespaceBuilder { Self { uri: uri.into(), delimiter: Self::DEFAULT_DELIMITER.to_string(), - headers: HashMap::new(), + header_provider: None, cert_file: None, key_file: None, ssl_ca_cert: None, @@ -155,6 +277,13 @@ impl RestNamespaceBuilder { } } + // Create header provider if any headers were specified + let header_provider: Option> = if headers.is_empty() { + None + } else { + Some(Arc::new(StaticHeaderProvider::new(headers))) + }; + // Extract TLS options let cert_file = properties.get("tls.cert_file").cloned(); let key_file = properties.get("tls.key_file").cloned(); @@ -167,7 +296,7 @@ impl RestNamespaceBuilder { Ok(Self { uri, delimiter, - headers, + header_provider, cert_file, key_file, ssl_ca_cert, @@ -185,27 +314,70 @@ impl RestNamespaceBuilder { self } + /// Set a dynamic header provider for HTTP requests. + /// + /// The header provider will be called before each API request to get + /// the headers to send. Use this for dynamic headers like auth tokens + /// that need to be refreshed periodically. + /// + /// # Arguments + /// + /// * `provider` - The header provider implementation + pub fn header_provider(mut self, provider: Arc) -> Self { + self.header_provider = Some(provider); + self + } + /// Add a custom header to the HTTP requests. /// + /// This is a convenience method that creates or updates a [`StaticHeaderProvider`]. + /// For dynamic headers (e.g., auth tokens that need refresh), use [`header_provider`] + /// instead. + /// /// # Arguments /// /// * `name` - Header name /// * `value` - Header value pub fn header(mut self, name: impl Into, value: impl Into) -> Self { - self.headers.insert(name.into(), value.into()); + let name = name.into(); + let value = value.into(); + + // Get existing headers or create empty map + let mut headers = self.extract_static_headers(); + headers.insert(name, value); + self.header_provider = Some(Arc::new(StaticHeaderProvider::new(headers))); self } /// Add multiple custom headers to the HTTP requests. /// + /// This is a convenience method that creates or updates a [`StaticHeaderProvider`]. + /// For dynamic headers (e.g., auth tokens that need refresh), use [`header_provider`] + /// instead. + /// /// # Arguments /// /// * `headers` - HashMap of headers to add pub fn headers(mut self, headers: HashMap) -> Self { - self.headers.extend(headers); + let mut existing = self.extract_static_headers(); + existing.extend(headers); + self.header_provider = Some(Arc::new(StaticHeaderProvider::new(existing))); self } + /// Extract headers from an existing StaticHeaderProvider, or return empty map. + fn extract_static_headers(&self) -> HashMap { + self.header_provider + .as_ref() + .and_then(|p| { + // Try to downcast to StaticHeaderProvider using as_any() + p.as_any() + .downcast_ref::() + .map(|sp| sp.headers.clone()) + }) + .unwrap_or_default() + } + /// Set the client certificate file for mTLS. /// /// # Arguments @@ -252,7 +424,46 @@ impl RestNamespaceBuilder { /// /// Returns a `RestNamespace` instance. pub fn build(self) -> RestNamespace { - RestNamespace::from_builder(self) + let mut client_builder = reqwest::Client::builder(); + + // Configure mTLS if certificate and key files are provided + if let (Some(cert_file), Some(key_file)) = (&self.cert_file, &self.key_file) { + if let (Ok(cert), Ok(key)) = (std::fs::read(cert_file), std::fs::read(key_file)) { + if let Ok(identity) = reqwest::Identity::from_pem(&[&cert[..], &key[..]].concat()) { + client_builder = client_builder.identity(identity); + } + } + } + + // Load CA certificate for server verification + if let Some(ca_cert_file) = &self.ssl_ca_cert { + if let Ok(ca_cert) = std::fs::read(ca_cert_file) { + if let Ok(ca_cert) = reqwest::Certificate::from_pem(&ca_cert) { + client_builder = client_builder.add_root_certificate(ca_cert); + } + } + } + + // Configure hostname verification + client_builder = client_builder.danger_accept_invalid_hostnames(!self.assert_hostname); + + let client = client_builder + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + + // Create configuration with middleware if header provider is set + let mut config = Configuration::new(); + config.client = client; + config.base_path = self.uri.clone(); + + if let Some(provider) = self.header_provider { + config.request_middleware = Some(Arc::new(HeaderMiddleware { provider })); + } + + RestNamespace { + config, + delimiter: self.delimiter, + } } } @@ -304,10 +515,18 @@ fn convert_api_error(err: lance_namespace::apis::Error) - /// # Ok(()) /// # } /// ``` -#[derive(Clone)] pub struct RestNamespace { + config: Configuration, delimiter: String, - reqwest_config: Configuration, +} + +impl Clone for RestNamespace { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + delimiter: self.delimiter.clone(), + } + } } impl std::fmt::Debug for RestNamespace { @@ -323,72 +542,15 @@ impl std::fmt::Display for RestNamespace { } impl RestNamespace { - /// Create a new REST namespace from builder - pub(crate) fn from_builder(builder: RestNamespaceBuilder) -> Self { - // Build reqwest client with custom headers if provided - let mut client_builder = reqwest::Client::builder(); - - // Add custom headers to the client - if !builder.headers.is_empty() { - let mut headers = reqwest::header::HeaderMap::new(); - for (key, value) in &builder.headers { - if let (Ok(header_name), Ok(header_value)) = ( - reqwest::header::HeaderName::from_bytes(key.as_bytes()), - reqwest::header::HeaderValue::from_str(value), - ) { - headers.insert(header_name, header_value); - } - } - client_builder = client_builder.default_headers(headers); - } - - // Configure mTLS if certificate and key files are provided - if let (Some(cert_file), Some(key_file)) = (&builder.cert_file, &builder.key_file) { - if let (Ok(cert), Ok(key)) = (std::fs::read(cert_file), std::fs::read(key_file)) { - if let Ok(identity) = reqwest::Identity::from_pem(&[&cert[..], &key[..]].concat()) { - client_builder = client_builder.identity(identity); - } - } - } - - // Load CA certificate for server verification - if let Some(ca_cert_file) = &builder.ssl_ca_cert { - if let Ok(ca_cert) = std::fs::read(ca_cert_file) { - if let Ok(ca_cert) = reqwest::Certificate::from_pem(&ca_cert) { - client_builder = client_builder.add_root_certificate(ca_cert); - } - } - } - - // Configure hostname verification - client_builder = client_builder.danger_accept_invalid_hostnames(!builder.assert_hostname); - - let client = client_builder - .build() - .unwrap_or_else(|_| reqwest::Client::new()); - - let mut reqwest_config = Configuration::new(); - reqwest_config.client = client; - reqwest_config.base_path = builder.uri; - - Self { - delimiter: builder.delimiter, - reqwest_config, - } - } - /// Create a new REST namespace with custom configuration (for testing) #[cfg(test)] - pub fn with_configuration(delimiter: String, reqwest_config: Configuration) -> Self { - Self { - delimiter, - reqwest_config, - } + pub fn with_configuration(delimiter: String, config: Configuration) -> Self { + Self { config, delimiter } } /// Get the base endpoint URL for this namespace pub fn endpoint(&self) -> &str { - &self.reqwest_config.base_path + &self.config.base_path } } @@ -401,7 +563,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; namespace_api::list_namespaces( - &self.reqwest_config, + &self.config, &id, Some(&self.delimiter), request.page_token.as_deref(), @@ -417,7 +579,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - namespace_api::describe_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + namespace_api::describe_namespace(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -428,7 +590,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - namespace_api::create_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + namespace_api::create_namespace(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -436,7 +598,7 @@ impl LanceNamespace for RestNamespace { async fn drop_namespace(&self, request: DropNamespaceRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - namespace_api::drop_namespace(&self.reqwest_config, &id, request, Some(&self.delimiter)) + namespace_api::drop_namespace(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -444,7 +606,7 @@ impl LanceNamespace for RestNamespace { async fn namespace_exists(&self, request: NamespaceExistsRequest) -> Result<()> { let id = object_id_str(&request.id, &self.delimiter)?; - namespace_api::namespace_exists(&self.reqwest_config, &id, request, Some(&self.delimiter)) + namespace_api::namespace_exists(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -453,7 +615,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; table_api::list_tables( - &self.reqwest_config, + &self.config, &id, Some(&self.delimiter), request.page_token.as_deref(), @@ -467,11 +629,12 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; table_api::describe_table( - &self.reqwest_config, + &self.config, &id, request.clone(), Some(&self.delimiter), request.with_table_uri, + request.load_detailed_metadata, ) .await .map_err(convert_api_error) @@ -480,7 +643,7 @@ impl LanceNamespace for RestNamespace { async fn register_table(&self, request: RegisterTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::register_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::register_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -488,7 +651,7 @@ impl LanceNamespace for RestNamespace { async fn table_exists(&self, request: TableExistsRequest) -> Result<()> { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::table_exists(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::table_exists(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -496,7 +659,7 @@ impl LanceNamespace for RestNamespace { async fn drop_table(&self, request: DropTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::drop_table(&self.reqwest_config, &id, Some(&self.delimiter)) + table_api::drop_table(&self.config, &id, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -507,7 +670,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::deregister_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::deregister_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -515,7 +678,7 @@ impl LanceNamespace for RestNamespace { async fn count_table_rows(&self, request: CountTableRowsRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::count_table_rows(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::count_table_rows(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -528,7 +691,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; table_api::create_table( - &self.reqwest_config, + &self.config, &id, request_data.to_vec(), Some(&self.delimiter), @@ -544,7 +707,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::create_empty_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::create_empty_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -557,7 +720,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; table_api::insert_into_table( - &self.reqwest_config, + &self.config, &id, request_data.to_vec(), Some(&self.delimiter), @@ -580,7 +743,7 @@ impl LanceNamespace for RestNamespace { })?; table_api::merge_insert_into_table( - &self.reqwest_config, + &self.config, &id, on, request_data.to_vec(), @@ -600,7 +763,7 @@ impl LanceNamespace for RestNamespace { async fn update_table(&self, request: UpdateTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::update_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::update_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -611,7 +774,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::delete_from_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::delete_from_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -619,10 +782,9 @@ impl LanceNamespace for RestNamespace { async fn query_table(&self, request: QueryTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - let response = - table_api::query_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) - .await - .map_err(convert_api_error)?; + let response = table_api::query_table(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error)?; // Convert response to bytes let bytes = response.bytes().await.map_err(|e| Error::IO { @@ -639,7 +801,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::create_table_index(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::create_table_index(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -650,7 +812,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::list_table_indices(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::list_table_indices(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -666,7 +828,7 @@ impl LanceNamespace for RestNamespace { let index_name = ""; // This should come from somewhere in the request table_api::describe_table_index_stats( - &self.reqwest_config, + &self.config, &id, index_name, request, @@ -682,14 +844,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - transaction_api::describe_transaction( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + transaction_api::describe_transaction(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn alter_transaction( @@ -698,14 +855,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - transaction_api::alter_transaction( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + transaction_api::alter_transaction(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn create_table_scalar_index( @@ -714,14 +866,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::create_table_scalar_index( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::create_table_scalar_index(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn drop_table_index( @@ -732,14 +879,14 @@ impl LanceNamespace for RestNamespace { let index_name = request.index_name.as_deref().unwrap_or(""); - table_api::drop_table_index(&self.reqwest_config, &id, index_name, Some(&self.delimiter)) + table_api::drop_table_index(&self.config, &id, index_name, Some(&self.delimiter)) .await .map_err(convert_api_error) } async fn list_all_tables(&self, request: ListTablesRequest) -> Result { table_api::list_all_tables( - &self.reqwest_config, + &self.config, Some(&self.delimiter), request.page_token.as_deref(), request.limit, @@ -751,7 +898,7 @@ impl LanceNamespace for RestNamespace { async fn restore_table(&self, request: RestoreTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::restore_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::restore_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -759,7 +906,7 @@ impl LanceNamespace for RestNamespace { async fn rename_table(&self, request: RenameTableRequest) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::rename_table(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::rename_table(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -771,7 +918,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; table_api::list_table_versions( - &self.reqwest_config, + &self.config, &id, Some(&self.delimiter), request.page_token.as_deref(), @@ -790,7 +937,7 @@ impl LanceNamespace for RestNamespace { let metadata = request.metadata.unwrap_or_default(); let result = table_api::update_table_schema_metadata( - &self.reqwest_config, + &self.config, &id, metadata, Some(&self.delimiter), @@ -810,7 +957,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::get_table_stats(&self.reqwest_config, &id, request, Some(&self.delimiter)) + table_api::get_table_stats(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -821,14 +968,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::explain_table_query_plan( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::explain_table_query_plan(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn analyze_table_query_plan( @@ -837,14 +979,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::analyze_table_query_plan( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::analyze_table_query_plan(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn alter_table_add_columns( @@ -853,14 +990,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::alter_table_add_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::alter_table_add_columns(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn alter_table_alter_columns( @@ -869,14 +1001,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::alter_table_alter_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::alter_table_alter_columns(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn alter_table_drop_columns( @@ -885,14 +1012,9 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - table_api::alter_table_drop_columns( - &self.reqwest_config, - &id, - request, - Some(&self.delimiter), - ) - .await - .map_err(convert_api_error) + table_api::alter_table_drop_columns(&self.config, &id, request, Some(&self.delimiter)) + .await + .map_err(convert_api_error) } async fn list_table_tags( @@ -902,7 +1024,7 @@ impl LanceNamespace for RestNamespace { let id = object_id_str(&request.id, &self.delimiter)?; tag_api::list_table_tags( - &self.reqwest_config, + &self.config, &id, Some(&self.delimiter), request.page_token.as_deref(), @@ -918,7 +1040,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - tag_api::get_table_tag_version(&self.reqwest_config, &id, request, Some(&self.delimiter)) + tag_api::get_table_tag_version(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -929,7 +1051,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - tag_api::create_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + tag_api::create_table_tag(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -940,7 +1062,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - tag_api::delete_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + tag_api::delete_table_tag(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -951,7 +1073,7 @@ impl LanceNamespace for RestNamespace { ) -> Result { let id = object_id_str(&request.id, &self.delimiter)?; - tag_api::update_table_tag(&self.reqwest_config, &id, request, Some(&self.delimiter)) + tag_api::update_table_tag(&self.config, &id, request, Some(&self.delimiter)) .await .map_err(convert_api_error) } @@ -959,7 +1081,7 @@ impl LanceNamespace for RestNamespace { fn namespace_id(&self) -> String { format!( "RestNamespace {{ endpoint: {:?}, delimiter: {:?} }}", - self.reqwest_config.base_path, self.delimiter + self.config.base_path, self.delimiter ) } } @@ -1031,6 +1153,7 @@ mod tests { id: Some(vec!["test".to_string()]), page_token: None, limit: None, + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1145,15 +1268,16 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); + let mut config = Configuration::new(); + config.base_path = mock_server.uri(); - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespace::with_configuration("$".to_string(), config); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), page_token: None, limit: Some(10), + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1184,15 +1308,16 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); + let mut config = Configuration::new(); + config.base_path = mock_server.uri(); - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespace::with_configuration("$".to_string(), config); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), page_token: None, limit: Some(10), + ..Default::default() }; let result = namespace.list_namespaces(request).await; @@ -1220,15 +1345,16 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); + let mut config = Configuration::new(); + config.base_path = mock_server.uri(); - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespace::with_configuration("$".to_string(), config); let request = CreateNamespaceRequest { id: Some(vec!["test".to_string(), "newnamespace".to_string()]), properties: None, mode: None, + ..Default::default() }; let result = namespace.create_namespace(request).await; @@ -1257,10 +1383,10 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); + let mut config = Configuration::new(); + config.base_path = mock_server.uri(); - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespace::with_configuration("$".to_string(), config); let request = CreateTableRequest { id: Some(vec![ @@ -1269,6 +1395,7 @@ mod tests { "table".to_string(), ]), mode: Some("Create".to_string()), + ..Default::default() }; let data = Bytes::from("arrow data here"); @@ -1294,10 +1421,10 @@ mod tests { .await; // Create namespace with mock server URL - let mut reqwest_config = Configuration::new(); - reqwest_config.base_path = mock_server.uri(); + let mut config = Configuration::new(); + config.base_path = mock_server.uri(); - let namespace = RestNamespace::with_configuration("$".to_string(), reqwest_config); + let namespace = RestNamespace::with_configuration("$".to_string(), config); let request = InsertIntoTableRequest { id: Some(vec![ @@ -1306,6 +1433,7 @@ mod tests { "table".to_string(), ]), mode: Some("Append".to_string()), + ..Default::default() }; let data = Bytes::from("arrow data here"); @@ -1316,4 +1444,241 @@ mod tests { let response = result.unwrap(); assert_eq!(response.transaction_id, Some("txn-123".to_string())); } + + // ==================== HeaderProvider Tests ==================== + + #[test] + fn test_static_header_provider_new() { + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), "Bearer token".to_string()); + headers.insert("X-Custom".to_string(), "value".to_string()); + + let provider = StaticHeaderProvider::new(headers.clone()); + assert_eq!(provider.headers, headers); + } + + #[test] + fn test_static_header_provider_from_iter() { + let provider = StaticHeaderProvider::from_iter([ + ("Authorization", "Bearer token"), + ("X-Custom", "value"), + ]); + + assert_eq!(provider.headers.len(), 2); + assert_eq!( + provider.headers.get("Authorization"), + Some(&"Bearer token".to_string()) + ); + assert_eq!(provider.headers.get("X-Custom"), Some(&"value".to_string())); + } + + #[tokio::test] + async fn test_static_header_provider_get_headers() { + let provider = StaticHeaderProvider::from_iter([("Authorization", "Bearer token")]); + + let headers = provider.get_headers().await.unwrap(); + assert_eq!(headers.len(), 1); + assert_eq!( + headers.get("Authorization"), + Some(&"Bearer token".to_string()) + ); + } + + #[test] + fn test_builder_header_provider_method() { + #[derive(Debug)] + struct TestProvider; + + #[async_trait] + impl HeaderProvider for TestProvider { + async fn get_headers(&self) -> Result> { + let mut headers = HashMap::new(); + headers.insert("X-Test".to_string(), "test-value".to_string()); + Ok(headers) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + + let provider = Arc::new(TestProvider); + let builder = RestNamespaceBuilder::new("http://example.com").header_provider(provider); + + assert!(builder.header_provider.is_some()); + } + + #[test] + fn test_builder_header_method_creates_static_provider() { + let builder = RestNamespaceBuilder::new("http://example.com") + .header("Authorization", "Bearer token") + .header("X-Custom", "value"); + + assert!(builder.header_provider.is_some()); + + // Verify headers were accumulated + let headers = builder.extract_static_headers(); + assert_eq!(headers.len(), 2); + } + + #[test] + fn test_builder_headers_method_creates_static_provider() { + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), "Bearer token".to_string()); + headers.insert("X-Custom".to_string(), "value".to_string()); + + let builder = RestNamespaceBuilder::new("http://example.com").headers(headers); + + assert!(builder.header_provider.is_some()); + } + + #[test] + fn test_from_properties_creates_static_provider() { + let mut properties = HashMap::new(); + properties.insert("uri".to_string(), "http://example.com".to_string()); + properties.insert( + "header.Authorization".to_string(), + "Bearer token".to_string(), + ); + properties.insert("header.X-Custom".to_string(), "value".to_string()); + + let builder = + RestNamespaceBuilder::from_properties(properties).expect("Failed to create builder"); + + assert!(builder.header_provider.is_some()); + + let headers = builder.extract_static_headers(); + assert_eq!(headers.len(), 2); + assert_eq!( + headers.get("Authorization"), + Some(&"Bearer token".to_string()) + ); + } + + #[test] + fn test_from_properties_no_headers() { + let mut properties = HashMap::new(); + properties.insert("uri".to_string(), "http://example.com".to_string()); + + let builder = + RestNamespaceBuilder::from_properties(properties).expect("Failed to create builder"); + + // No headers specified, so no provider + assert!(builder.header_provider.is_none()); + } + + #[tokio::test] + async fn test_dynamic_header_provider() { + use std::sync::atomic::{AtomicU32, Ordering}; + + // A provider that returns different headers each time + #[derive(Debug)] + struct CountingProvider { + counter: AtomicU32, + } + + #[async_trait] + impl HeaderProvider for CountingProvider { + async fn get_headers(&self) -> Result> { + let count = self.counter.fetch_add(1, Ordering::SeqCst); + let mut headers = HashMap::new(); + headers.insert("X-Request-Count".to_string(), count.to_string()); + Ok(headers) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + + let provider = Arc::new(CountingProvider { + counter: AtomicU32::new(0), + }); + + // Verify the provider returns different values + let headers1 = provider.get_headers().await.unwrap(); + let headers2 = provider.get_headers().await.unwrap(); + + assert_eq!(headers1.get("X-Request-Count"), Some(&"0".to_string())); + assert_eq!(headers2.get("X-Request-Count"), Some(&"1".to_string())); + } + + #[tokio::test] + async fn test_dynamic_header_provider_with_mock_server() { + use std::sync::atomic::{AtomicU32, Ordering}; + + // Start a mock server + let mock_server = MockServer::start().await; + + // Provider that increments a counter in the header + #[derive(Debug)] + struct CountingProvider { + counter: AtomicU32, + } + + #[async_trait] + impl HeaderProvider for CountingProvider { + async fn get_headers(&self) -> Result> { + let count = self.counter.fetch_add(1, Ordering::SeqCst); + let mut headers = HashMap::new(); + headers.insert("X-Request-Count".to_string(), count.to_string()); + Ok(headers) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + + // Create mock that expects the header + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header("X-Request-Count", "0")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": ["ns1"] + }))) + .mount(&mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .and(wiremock::matchers::header("X-Request-Count", "1")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "namespaces": ["ns2"] + }))) + .mount(&mock_server) + .await; + + let provider = Arc::new(CountingProvider { + counter: AtomicU32::new(0), + }); + + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .header_provider(provider) + .build(); + + let request1 = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + page_token: None, + limit: None, + ..Default::default() + }; + + let request2 = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + page_token: None, + limit: None, + ..Default::default() + }; + + // First request should have X-Request-Count: 0 + let result1 = namespace.list_namespaces(request1).await; + assert!(result1.is_ok()); + assert_eq!(result1.unwrap().namespaces, vec!["ns1"]); + + // Second request should have X-Request-Count: 1 + let result2 = namespace.list_namespaces(request2).await; + assert!(result2.is_ok()); + assert_eq!(result2.unwrap().namespaces, vec!["ns2"]); + } } diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 1030faafe14..7044e1c4d04 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -823,7 +823,7 @@ impl Dataset { let request = CreateEmptyTableRequest { id: Some(table_id.clone()), location: None, - properties: None, + ..Default::default() }; let response = namespace @@ -872,6 +872,7 @@ impl Dataset { id: Some(table_id.clone()), version: None, with_table_uri: None, + ..Default::default() }; let response = namespace diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 3d463ce6ca4..4233cf4b236 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -138,6 +138,7 @@ impl DatasetBuilder { id: Some(table_id.clone()), version: None, with_table_uri: None, + ..Default::default() }; let response = namespace