diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index ab94e31b5..03470b06a 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -54,6 +54,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" + [[package]] name = "arrayref" version = "0.3.6" @@ -288,6 +294,17 @@ dependencies = [ "xz2", ] +[[package]] +name = "async-recursion" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b015a331cc64ebd1774ba119538573603427eaace0a1950c423ab971f903796" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.64" @@ -626,6 +643,7 @@ dependencies = [ "datafusion-expr", "datafusion-optimizer", "datafusion-sql", + "datafusion-substrait", "env_logger", "log", "mimalloc", @@ -783,6 +801,22 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datafusion-substrait" +version = "17.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e5af8bc23708f6d9d1721947c8486c96153ce671269522d7d917bb428d2fa73" +dependencies = [ + "async-recursion", + "datafusion", + "itertools", + "prost 0.11.6", + "prost-build 0.9.0", + "prost-types 0.11.6", + "substrait", + "tokio", +] + [[package]] name = "digest" version = "0.10.6" @@ -800,6 +834,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dyn-clone" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9b0705efd4599c15a38151f4721f7bc388306f61084d3bfd50bd07fbca5cb60" + [[package]] name = "either" version = "1.8.1" @@ -849,6 +889,12 @@ dependencies = [ "instant", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "22.9.29" @@ -1019,6 +1065,15 @@ dependencies = [ "ahash", ] +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "heck" version = "0.4.1" @@ -1040,6 +1095,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "home" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "747309b4b440c06d57b0b25f2aee03ee9b5e5397d288c60e21fc709bb98a7408" +dependencies = [ + "winapi", +] + [[package]] name = "humantime" version = "2.1.0" @@ -1371,6 +1435,12 @@ dependencies = [ "adler", ] +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + [[package]] name = "num" version = "0.4.0" @@ -1561,6 +1631,26 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "pest" +version = "2.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028accff104c4e513bad663bbcd2ad7cfd5304144404c31ed0a77ac103d00660" +dependencies = [ + "thiserror", + "ucd-trie", +] + +[[package]] +name = "petgraph" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1600,6 +1690,112 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001" +dependencies = [ + "bytes", + "prost-derive 0.9.0", +] + +[[package]] +name = "prost" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698" +dependencies = [ + "bytes", + "prost-derive 0.11.6", +] + +[[package]] +name = "prost-build" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5" +dependencies = [ + "bytes", + "heck 0.3.3", + "itertools", + "lazy_static", + "log", + "multimap", + "petgraph", + "prost 0.9.0", + "prost-types 0.9.0", + "regex", + "tempfile", + "which", +] + +[[package]] +name = "prost-build" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e" +dependencies = [ + "bytes", + "heck 0.4.1", + "itertools", + "lazy_static", + "log", + "multimap", + "petgraph", + "prost 0.11.6", + "prost-types 0.11.6", + "regex", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-derive" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534b7a0e836e3c482d2693070f982e39e7611da9695d4d1f5a4b186b51faef0a" +dependencies = [ + "bytes", + "prost 0.9.0", +] + +[[package]] +name = "prost-types" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788" +dependencies = [ + "bytes", + "prost 0.11.6", +] + [[package]] name = "pyo3" version = "0.18.1" @@ -1731,6 +1927,15 @@ version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +[[package]] +name = "regress" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a92ff21fe8026ce3f2627faaf43606f0b67b014dbc9ccf027181a804f75d92e" +dependencies = [ + "memchr", +] + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1740,6 +1945,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustfmt-wrapper" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed729e3bee08ec2befd593c27e90ca9fdd25efdc83c94c3b82eaef16e4f7406e" +dependencies = [ + "serde", + "tempfile", + "thiserror", + "toml", + "toolchain_find", +] + [[package]] name = "rustix" version = "0.36.8" @@ -1775,6 +1993,30 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schemars" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a5fb6c61f29e723026dc8e923d94c694313212abbecbbe5f55a7748eec5b307" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f188d036977451159430f3b8dc82ec76364a42b7e289c2b18a9a18f4470058e9" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -1787,6 +2029,24 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver-parser" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" +dependencies = [ + "pest", +] + [[package]] name = "seq-macro" version = "0.3.2" @@ -1798,6 +2058,31 @@ name = "serde" version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "serde_json" @@ -1810,6 +2095,30 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_tokenstream" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "274f512d6748a01e67cbcde5b4307ab2c9d52a98a2b870a980ef0793a351deff" +dependencies = [ + "proc-macro2", + "serde", + "syn", +] + +[[package]] +name = "serde_yaml" +version = "0.9.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fb06d4b6cdaef0e0c51fa881acb721bed3c924cfaa71d9c94a3b771dfdf6567" +dependencies = [ + "indexmap", + "itoa 1.0.5", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha2" version = "0.10.6" @@ -1852,7 +2161,7 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "475b3bbe5245c26f2d8a6f62d67c1f30eb9fffeccee721c45d162c3ebbdf81b2" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "syn", @@ -1903,13 +2212,31 @@ version = "0.24.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "rustversion", "syn", ] +[[package]] +name = "substrait" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2feb96a6a106e21161551af32dc4e0fdab3aceb926b940d7e92a086b640fc7c" +dependencies = [ + "heck 0.4.1", + "prost 0.11.6", + "prost-build 0.11.6", + "prost-types 0.11.6", + "schemars", + "serde", + "serde_json", + "serde_yaml", + "typify", + "walkdir", +] + [[package]] name = "subtle" version = "2.4.1" @@ -2062,6 +2389,28 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + +[[package]] +name = "toolchain_find" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e85654a10e7a07a47c6f19d93818f3f343e22927f2fa280c84f7c8042743413" +dependencies = [ + "home", + "lazy_static", + "regex", + "semver", + "walkdir", +] + [[package]] name = "tracing" version = "0.1.37" @@ -2110,6 +2459,57 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "typify" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e8486352f3c946e69f983558cfc09b295250b01e01b381ec67a05a812d01d63" +dependencies = [ + "typify-impl", + "typify-macro", +] + +[[package]] +name = "typify-impl" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7624d0b911df6e2bbf34a236f76281f93b294cdde1d4df1dbdb748e5a7fefa5" +dependencies = [ + "heck 0.4.1", + "log", + "proc-macro2", + "quote", + "regress", + "rustfmt-wrapper", + "schemars", + "serde_json", + "syn", + "thiserror", + "unicode-ident", +] + +[[package]] +name = "typify-macro" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c42802aa033cee7650a4e1509ba7d5848a56f84be7c4b31e4385ee12445e942" +dependencies = [ + "proc-macro2", + "quote", + "schemars", + "serde", + "serde_json", + "serde_tokenstream", + "syn", + "typify-impl", +] + +[[package]] +name = "ucd-trie" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" + [[package]] name = "unicode-bidi" version = "0.3.10" @@ -2149,6 +2549,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +[[package]] +name = "unsafe-libyaml" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc7ed8ba44ca06be78ea1ad2c3682a43349126c8818054231ee6f4748012aed2" + [[package]] name = "url" version = "2.3.1" @@ -2246,6 +2652,17 @@ version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" +[[package]] +name = "which" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" +dependencies = [ + "either", + "libc", + "once_cell", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index e53b7d837..0d4ff4c7b 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -15,6 +15,7 @@ datafusion-common = "17.0.0" datafusion-expr = "17.0.0" datafusion-optimizer = "17.0.0" datafusion-sql = "17.0.0" +datafusion-substrait = "17.0.0" env_logger = "0.10" log = "^0.4" mimalloc = { version = "*", default-features = false } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 43a037dba..b1ba4af70 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -9,9 +9,17 @@ pub mod statement; pub mod table; pub mod types; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, future::Future, sync::Arc}; -use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema, TimeUnit}, + catalog::{ + catalog::{CatalogProvider, MemoryCatalogProvider}, + schema::MemorySchemaProvider, + }, + datasource::TableProvider, + prelude::SessionContext, +}; use datafusion_common::{config::ConfigOptions, DFSchema, DataFusionError}; use datafusion_expr::{ logical_plan::Extension, @@ -34,7 +42,9 @@ use datafusion_sql::{ ResolvedTableReference, TableReference, }; +use datafusion_substrait::{consumer, serializer}; use pyo3::prelude::*; +use tokio::runtime::Runtime; use self::logical::{ create_catalog_schema::CreateCatalogSchemaPlanNode, @@ -63,6 +73,7 @@ use crate::{ show_tables::ShowTablesPlanNode, PyLogicalPlan, }, + table::DaskTableSource, }, }; @@ -86,12 +97,13 @@ use crate::{ /// # } /// ``` #[pyclass(name = "DaskSQLContext", module = "dask_planner", subclass)] -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct DaskSQLContext { current_catalog: String, current_schema: String, schemas: HashMap, options: ConfigOptions, + session_ctx: SessionContext, } impl ContextProvider for DaskSQLContext { @@ -108,6 +120,7 @@ impl ContextProvider for DaskSQLContext { reference.catalog ))); } + match self.schemas.get(reference.schema) { Some(schema) => { let mut resp = None; @@ -411,6 +424,7 @@ impl DaskSQLContext { current_schema: default_schema_name.to_owned(), schemas: HashMap::new(), options: ConfigOptions::new(), + session_ctx: SessionContext::new(), } } @@ -432,7 +446,27 @@ impl DaskSQLContext { schema_name: String, schema: schema::DaskSchema, ) -> PyResult { - self.schemas.insert(schema_name, schema); + self.schemas.insert(schema_name.clone(), schema); + + match self.session_ctx.catalog(&self.current_catalog) { + Some(catalog) => { + let schema_provider = MemorySchemaProvider::new(); + let _result = catalog.register_schema(&schema_name, Arc::new(schema_provider)); + + self.session_ctx + .register_catalog(self.current_catalog.clone(), catalog); + } + None => { + let mem_catalog = MemoryCatalogProvider::new(); + let schema_provider = MemorySchemaProvider::new(); + let _result = mem_catalog.register_schema(&schema_name, Arc::new(schema_provider)); + + // Insert the new schema into this newly created catalog + self.session_ctx + .register_catalog(self.current_catalog.clone(), Arc::new(mem_catalog)); + } + } + Ok(true) } @@ -444,7 +478,30 @@ impl DaskSQLContext { ) -> PyResult { match self.schemas.get_mut(&schema_name) { Some(schema) => { - schema.add_table(table); + schema.add_table(table.clone()); + + let tbl_ref = TableReference::Partial { + schema: &self.current_schema, + table: table.table_name.as_str(), + }; + let tbl_src = self.get_table_provider(tbl_ref).unwrap(); + let provider = tbl_src + .as_any() + .downcast_ref::() + .expect("Invalid DefaulTableSource instance"); + let tbl_provider = provider.provider.clone() as Arc; + + let catalog = self.session_ctx.catalog(&self.current_catalog).unwrap(); + let schema = catalog.schema(&table.schema_name.unwrap()).unwrap(); + let _result = schema.register_table(table.table_name.clone(), tbl_provider.clone()); + + let bare_tbl_ref = TableReference::Bare { + table: table.table_name.as_str(), + }; + let _result = self + .session_ctx + .register_table(bare_tbl_ref, tbl_provider.clone()); + Ok(true) } None => Err(py_runtime_err(format!( @@ -509,10 +566,39 @@ impl DaskSQLContext { Err(e) => Err(py_optimization_exp(e)), } } + + /// Loads a `LogicalPlan` from a local Substrait protobuf file. + pub fn plan_from_substrait( + &self, + plan_path: String, + py: Python, + ) -> PyResult { + let result = serializer::deserialize(plan_path.as_str()); + let plan = Self::wait_for_future(py, result).unwrap(); + + let result = Self::wait_for_future( + py, + consumer::from_substrait_plan(&mut self.session_ctx.clone(), &plan), + ) + .map_err(DataFusionError::from) + .unwrap(); + + Ok(PyLogicalPlan::from(result)) + } } /// non-Python methods impl DaskSQLContext { + /// Utility to collect rust futures with GIL released + pub fn wait_for_future(py: Python, f: F) -> F::Output + where + F: Send, + F::Output: Send, + { + let rt = Runtime::new().unwrap(); + py.allow_threads(|| rt.block_on(f)) + } + /// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement pub fn _logical_relational_algebra( &self, diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 246390829..449abf080 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -1,7 +1,10 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; +use datafusion::{ + arrow::datatypes::{DataType, Field, SchemaRef}, + datasource::empty::EmptyTable, +}; use datafusion_common::DFField; use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource}; use datafusion_optimizer::utils::split_conjunction; @@ -25,12 +28,16 @@ use crate::{ /// DaskTable wrapper that is compatible with DataFusion logical query plans pub struct DaskTableSource { schema: SchemaRef, + pub provider: Arc, } impl DaskTableSource { /// Initialize a new `EmptyTable` from a schema. pub fn new(schema: SchemaRef) -> Self { - Self { schema } + Self { + schema: schema.clone(), + provider: Arc::new(EmptyTable::new(schema)), + } } } diff --git a/dask_sql/context.py b/dask_sql/context.py index 101200b9c..f6909568c 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -457,6 +457,7 @@ def sql( return_futures: bool = True, dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None, gpu: bool = False, + substrait: bool = False, config_options: Dict[str, Any] = None, ) -> Union[dd.DataFrame, pd.DataFrame]: """ @@ -483,6 +484,9 @@ def sql( to register before executing this query gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU; requires cuDF / dask-cuDF if enabled. Defaults to False. + substrait (:obj:`str`): If True the `sql` argument specifies a path to a Substrait plan file which is loaded + and ran as is without any optimizations. Otherwise it is treated as a standard SQL string and parsed by + the parsing engine. config_options (:obj:`Dict[str,Any]`): Specific configuration options to pass during query execution Returns: @@ -493,14 +497,19 @@ def sql( for df_name, df in dataframes.items(): self.create_table(df_name, df, gpu=gpu) - if isinstance(sql, str): - rel, _ = self._get_ral(sql) - elif isinstance(sql, LogicalPlan): - rel = sql + if substrait: + logger.debug(f"Executing query using substrait plan: '{sql}'") + plan = self.context.plan_from_substrait(sql) + print(f"LogicalPlan from substrait: \n{plan}") else: - raise RuntimeError( - f"Encountered unsupported `LogicalPlan` sql type: {type(sql)}" - ) + if isinstance(sql, str): + rel, _ = self._get_ral(sql) + elif isinstance(sql, LogicalPlan): + rel = sql + else: + raise RuntimeError( + f"Encountered unsupported `LogicalPlan` sql type: {type(sql)}" + ) return self._compute_table_from_rel(rel, return_futures) diff --git a/df_simple.json b/df_simple.json new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/proto/df_simple.proto b/tests/integration/proto/df_simple.proto new file mode 100644 index 000000000..71259977a Binary files /dev/null and b/tests/integration/proto/df_simple.proto differ diff --git a/tests/integration/test_substrait.py b/tests/integration/test_substrait.py new file mode 100644 index 000000000..52bc1605b --- /dev/null +++ b/tests/integration/test_substrait.py @@ -0,0 +1,12 @@ +import pandas as pd + +from tests.utils import assert_eq + + +def test_usertable_substrait_join(c): + return_df = c.sql("./tests/integration/proto/df_simple.proto", substrait=True) + expected_df = pd.DataFrame( + {"user_id": [1, 1, 2, 2], "b": [3, 3, 1, 3], "c": [1, 2, 3, 3]} + ) + + assert_eq(return_df, expected_df, check_index=False)