Skip to content

Commit

Permalink
refactor(cubesql): Extract CubeScanWrappedSqlNode from CubeScanWrappe…
Browse files Browse the repository at this point in the history
…rNode
  • Loading branch information
mcheshkov committed Oct 21, 2024
1 parent d5df5c0 commit 78f3fda
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 373 deletions.
31 changes: 14 additions & 17 deletions rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::{

use crate::{
compile::{
engine::df::wrapper::{CubeScanWrapperNode, SqlQuery},
engine::df::wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode, SqlQuery},
rewrite::WrappedSelectType,
test::find_cube_scans_deep_search,
},
Expand Down Expand Up @@ -386,35 +386,32 @@ impl ExtensionPlanner for CubeScanExtensionPlanner {
config_obj: self.config_obj.clone(),
}))
} else if let Some(wrapper_node) = node.as_any().downcast_ref::<CubeScanWrapperNode>() {
return Err(DataFusionError::Internal(format!(
"CubeScanWrapperNode is not executable, SQL should be generated first with QueryEngine::evaluate_wrapped_sql: {:?}",
wrapper_node
)));

Check warning on line 392 in rust/cubesql/cubesql/src/compile/engine/df/scan.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/scan.rs#L389-L392

Added lines #L389 - L392 were not covered by tests
} else if let Some(wrapped_sql_node) =
node.as_any().downcast_ref::<CubeScanWrappedSqlNode>()
{
// TODO
// assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs");
// assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs");
let scan_node =
find_cube_scans_deep_search(wrapper_node.wrapped_plan.clone(), false)
find_cube_scans_deep_search(wrapped_sql_node.wrapped_plan.clone(), false)
.into_iter()
.next()
.ok_or(DataFusionError::Internal(format!(
"No cube scans found in wrapper node: {:?}",
wrapper_node
wrapped_sql_node
)))?;

let schema = SchemaRef::new(wrapper_node.schema().as_ref().into());
let schema = SchemaRef::new(wrapped_sql_node.schema().as_ref().into());
Some(Arc::new(CubeScanExecutionPlan {
schema,
member_fields: wrapper_node.member_fields.as_ref().ok_or_else(|| {
DataFusionError::Internal(format!(
"Member fields are not set for wrapper node. Optimization wasn't performed: {:?}",
wrapper_node
))
})?.clone(),
member_fields: wrapped_sql_node.member_fields.clone(),
transport: self.transport.clone(),
request: wrapper_node.request.clone().unwrap_or(scan_node.request.clone()),
wrapped_sql: Some(wrapper_node.wrapped_sql.as_ref().ok_or_else(|| {
DataFusionError::Internal(format!(
"Wrapped SQL is not set for wrapper node. Optimization wasn't performed: {:?}",
wrapper_node
))
})?.clone()),
request: wrapped_sql_node.request.clone(),
wrapped_sql: Some(wrapped_sql_node.wrapped_sql.clone()),
auth_context: scan_node.auth_context.clone(),
options: scan_node.options.clone(),
meta: self.meta.clone(),
Expand Down
100 changes: 71 additions & 29 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,75 @@ impl SqlQuery {
}
}

#[derive(Clone, Debug)]
pub struct CubeScanWrappedSqlNode {
// TODO maybe replace wrapped plan with schema + scan_node
pub wrapped_plan: Arc<LogicalPlan>,
pub wrapped_sql: SqlQuery,
pub request: TransportLoadRequestQuery,
pub member_fields: Vec<MemberField>,
}

impl CubeScanWrappedSqlNode {
pub fn new(
wrapped_plan: Arc<LogicalPlan>,
sql: SqlQuery,
request: TransportLoadRequestQuery,
member_fields: Vec<MemberField>,
) -> Self {
Self {
wrapped_plan,
wrapped_sql: sql,
request,
member_fields,
}
}
}

impl UserDefinedLogicalNode for CubeScanWrappedSqlNode {
fn as_any(&self) -> &dyn Any {
self
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![]
}

fn schema(&self) -> &DFSchemaRef {
self.wrapped_plan.schema()
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
// TODO figure out nice plan for wrapped plan
write!(f, "CubeScanWrappedSql")
}

Check warning on line 249 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L246-L249

Added lines #L246 - L249 were not covered by tests

fn from_template(
&self,
exprs: &[datafusion::logical_plan::Expr],
inputs: &[datafusion::logical_plan::LogicalPlan],
) -> std::sync::Arc<dyn UserDefinedLogicalNode + Send + Sync> {
assert_eq!(inputs.len(), 0, "input size inconsistent");
assert_eq!(exprs.len(), 0, "expression size inconsistent");

Arc::new(CubeScanWrappedSqlNode {
wrapped_plan: self.wrapped_plan.clone(),
wrapped_sql: self.wrapped_sql.clone(),
request: self.request.clone(),
member_fields: self.member_fields.clone(),
})
}
}

#[derive(Debug, Clone)]
pub struct CubeScanWrapperNode {
pub wrapped_plan: Arc<LogicalPlan>,
pub meta: Arc<MetaContext>,
pub auth_context: AuthContextRef,
pub wrapped_sql: Option<SqlQuery>,
pub request: Option<TransportLoadRequestQuery>,
pub member_fields: Option<Vec<MemberField>>,
pub span_id: Option<Arc<SpanId>>,
pub config_obj: Arc<dyn ConfigObj>,
}
Expand All @@ -225,31 +286,10 @@ impl CubeScanWrapperNode {
wrapped_plan,
meta,
auth_context,
wrapped_sql: None,
request: None,
member_fields: None,
span_id,
config_obj,
}
}

pub fn with_sql_and_request(
&self,
sql: SqlQuery,
request: TransportLoadRequestQuery,
member_fields: Vec<MemberField>,
) -> Self {
Self {
wrapped_plan: self.wrapped_plan.clone(),
meta: self.meta.clone(),
auth_context: self.auth_context.clone(),
wrapped_sql: Some(sql),
request: Some(request),
member_fields: Some(member_fields),
span_id: self.span_id.clone(),
config_obj: self.config_obj.clone(),
}
}
}

fn expr_name(e: &Expr, schema: &Arc<DFSchema>) -> Result<String> {
Expand Down Expand Up @@ -317,7 +357,7 @@ impl CubeScanWrapperNode {
&self,
transport: Arc<dyn TransportService>,
load_request_meta: Arc<LoadRequestMeta>,
) -> result::Result<Self, CubeError> {
) -> result::Result<CubeScanWrappedSqlNode, CubeError> {
let schema = self.schema();
let wrapped_plan = self.wrapped_plan.clone();
let (sql, request, member_fields) = Self::generate_sql_for_node(
Expand Down Expand Up @@ -361,7 +401,12 @@ impl CubeScanWrapperNode {
sql.finalize_query(sql_templates).map_err(|e| CubeError::internal(e.to_string()))?;
Ok((sql, request, member_fields))
})?;
Ok(self.with_sql_and_request(sql, request, member_fields))
Ok(CubeScanWrappedSqlNode::new(
self.wrapped_plan.clone(),
sql,
request,
member_fields,
))
}

pub fn set_max_limit_for_node(self, node: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
Expand Down Expand Up @@ -2226,9 +2271,6 @@ impl UserDefinedLogicalNode for CubeScanWrapperNode {
wrapped_plan: self.wrapped_plan.clone(),
meta: self.meta.clone(),
auth_context: self.auth_context.clone(),
wrapped_sql: self.wrapped_sql.clone(),
request: self.request.clone(),
member_fields: self.member_fields.clone(),
span_id: self.span_id.clone(),
config_obj: self.config_obj.clone(),
})
Expand Down
Loading

0 comments on commit 78f3fda

Please sign in to comment.