Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generate m2m connects and disconnects in the compiler #5153

Merged
merged 9 commits into from
Feb 7, 2025
37 changes: 36 additions & 1 deletion quaint/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ impl<'a> Comparable<'a> for Expression<'a> {
where
T: Into<Expression<'a>>,
{
Compare::In(Box::new(self), Box::new(selection.into()))
let expr = extract_single_var_row(selection.into());
Compare::In(Box::new(self), Box::new(expr))
}

fn not_in_selection<T>(self, selection: T) -> Compare<'a>
Expand Down Expand Up @@ -521,3 +522,37 @@ impl<'a> Comparable<'a> for Expression<'a> {
Compare::All(Box::new(self))
}
}

/// Converts a row consisting of a single var into the var itself.
/// Any other expression is returned as is.
fn extract_single_var_row(expr: Expression) -> Expression {
let Expression {
kind: ExpressionKind::Row(values),
..
} = &expr
else {
return expr;
};

let Some((
val @ Expression {
kind:
ExpressionKind::Parameterized(Value {
typed: ValueType::Var(_, _),
..
}),
..
},
[],
)) = values.values.split_first()
else {
return expr;
};

val.clone()
.decorate(
Some("prisma-comma-repeatable-start"),
Some("prisma-comma-repeatable-end"),
)
.into()
}
13 changes: 13 additions & 0 deletions quaint/src/ast/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ where
}
}

impl<'a, A> FromIterator<A> for Row<'a>
where
A: Into<Expression<'a>>,
{
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = A>,
{
let inner = iter.into_iter().map(Into::into).collect::<Vec<_>>();
Self { values: inner }
}
}

impl<'a> Comparable<'a> for Row<'a> {
fn equals<T>(self, comparison: T) -> Compare<'a>
where
Expand Down
68 changes: 41 additions & 27 deletions query-compiler/query-compiler/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use itertools::Itertools;
use query::translate_query;
use query_builder::QueryBuilder;
use query_core::{EdgeRef, Node, NodeRef, Query, QueryGraph, QueryGraphBuilderError, QueryGraphDependency};
use query_structure::{PlaceholderType, PrismaValue, SelectionResult};
use query_structure::{PlaceholderType, PrismaValue, SelectedField, SelectionResult};
use thiserror::Error;

use super::expression::{Binding, Expression};
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
(
field.clone(),
PrismaValue::Placeholder {
name: self.graph.edge_source(edge).id(),
name: generate_projected_dependency_name(self.graph.edge_source(edge), field),
r#type: PlaceholderType::Any,
},
)
Expand Down Expand Up @@ -152,7 +152,7 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
// doesn't belong into results, and is executed before all result scopes.
let mut expressions: Vec<Expression> = child_pairs
.into_iter()
.map(|(edge, node)| self.process_child_with_dependency(edge, node))
.map(|(_, node)| self.process_child_with_dependencies(node))
.collect::<Result<Vec<_>, _>>()?;

// Fold result scopes into one expression.
Expand All @@ -169,9 +169,9 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
// if not, we can separate them with a getfirstnonempty
let bindings = result_subgraphs
.into_iter()
.map(|(edge, node)| {
.map(|(_, node)| {
let name = node.id();
let expr = self.process_child_with_dependency(edge, node)?;
let expr = self.process_child_with_dependencies(node)?;
Ok(Binding { name, expr })
})
.collect::<TranslateResult<Vec<_>>>()?;
Expand Down Expand Up @@ -199,39 +199,53 @@ impl<'a, 'b> NodeTranslator<'a, 'b> {
}
}

fn process_child_with_dependency(&mut self, edge: EdgeRef, node: NodeRef) -> TranslateResult<Expression> {
let edge_content = self.graph.edge_content(&edge);
let field = if let Some(QueryGraphDependency::ProjectedDataDependency(selection, _)) = edge_content {
let mut fields = selection.selections();
if let Some(first) = fields.next().filter(|_| fields.len() == 0) {
Some(first.db_name().to_string())
} else {
// we need to handle MapField with multiple fields?
todo!()
}
} else {
None
};
fn process_child_with_dependencies(&mut self, node: NodeRef) -> TranslateResult<Expression> {
let bindings = self
.graph
.incoming_edges(&node)
.into_iter()
.filter_map(|edge| {
let field = if let Some(QueryGraphDependency::ProjectedDataDependency(selection, _)) =
self.graph.edge_content(&edge)
{
let mut fields = selection.selections();
if let Some(first) = fields.next().filter(|_| fields.len() == 0) {
first
} else {
// we need to handle MapField with multiple fields?
todo!()
}
} else {
return None;
};

let source = self.graph.edge_source(&edge);
Some(Binding::new(
generate_projected_dependency_name(source, field),
Expression::MapField {
field: field.prisma_name().into_owned(),
records: Box::new(Expression::Get { name: source.id() }),
},
))
})
.collect::<Vec<_>>();

// translate plucks the edges coming into node, we need to avoid accessing it afterwards
let edges = self.graph.incoming_edges(&node);
let source = self.graph.edge_source(&edge);
let expr = NodeTranslator::new(self.graph, node, &edges, self.query_builder).translate()?;

// we insert a MapField expression if the edge was a projected data dependency
if let Some(field) = field {
if !bindings.is_empty() {
Ok(Expression::Let {
bindings: vec![Binding::new(
source.id(),
Expression::MapField {
field,
records: Box::new(Expression::Get { name: source.id() }),
},
)],
bindings,
expr: Box::new(expr),
})
} else {
Ok(expr)
}
}
}

fn generate_projected_dependency_name(source: NodeRef, field: &SelectedField) -> String {
format!("{}${}", source.id(), field.prisma_name())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even when we have just one edge, doing this instead of shadowing makes the query plan much nicer to read 👍

}
140 changes: 80 additions & 60 deletions query-compiler/query-compiler/src/translate/query/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use itertools::Itertools;
use query_builder::{QueryArgumentsExt, QueryBuilder, RelationLink};
use query_core::{FilteredQuery, ReadQuery};
use query_core::{FilteredQuery, ReadQuery, RelatedRecordsQuery};
use query_structure::{
ConditionValue, FieldSelection, Filter, PrismaValue, QueryArguments, QueryMode, RelationField, ScalarCondition,
ScalarFilter, ScalarProjection,
Expand Down Expand Up @@ -61,7 +61,10 @@ pub(crate) fn translate_read_query(query: ReadQuery, builder: &dyn QueryBuilder)
}
}

ReadQuery::RelatedRecordsQuery(_) => unreachable!("related records query should not be at the top-level"),
ReadQuery::RelatedRecordsQuery(rrq) => {
let (expr, _) = build_read_related_records(rrq, None, builder)?;
expr
}

_ => todo!(),
})
Expand Down Expand Up @@ -99,7 +102,7 @@ fn add_inmemory_join(
})
.map(|rrq| -> TranslateResult<JoinExpression> {
let parent_field_name = rrq.parent_field.name().to_owned();

let left_scalars = rrq.parent_field.left_scalars();
let conditions = rrq
.parent_field
.left_scalars()
Expand All @@ -116,27 +119,15 @@ fn add_inmemory_join(
}
})
.collect();

let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last();
let needs_reversed_order = rrq.args.needs_reversed_order();

let (mut child_query, join_on) = if rrq.parent_field.relation().is_many_to_many() {
build_read_m2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
} else {
build_read_one2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
};

if needs_reversed_order {
child_query = Expression::Reverse(Box::new(child_query));
}

if !rrq.nested.is_empty() {
child_query = add_inmemory_join(child_query, rrq.nested, builder)?;
};
let (child, join_fields) = build_read_related_records(rrq, Some(conditions), builder)?;

Ok(JoinExpression {
child: child_query,
on: join_on,
child,
on: left_scalars
.into_iter()
.map(|sf| sf.name().to_owned())
.zip(join_fields)
.collect(),
parent_field: parent_field_name,
})
})
Expand All @@ -157,64 +148,82 @@ fn add_inmemory_join(
})
}

fn build_read_related_records(
rrq: RelatedRecordsQuery,
conditions: Option<Vec<ScalarCondition>>,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, JoinFields)> {
let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last();
let needs_reversed_order = rrq.args.needs_reversed_order();

let (mut child_query, join_on) = if rrq.parent_field.relation().is_many_to_many() {
build_read_m2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
} else {
build_read_one2m_query(rrq.parent_field, conditions, rrq.args, &selected_fields, builder)?
};

if needs_reversed_order {
child_query = Expression::Reverse(Box::new(child_query));
}

if !rrq.nested.is_empty() {
child_query = add_inmemory_join(child_query, rrq.nested, builder)?;
};
Ok((child_query, join_on))
}

fn build_read_m2m_query(
field: RelationField,
mut conditions: Vec<ScalarCondition>,
conditions: Option<Vec<ScalarCondition>>,
args: QueryArguments,
selected_fields: &FieldSelection,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, Vec<(String, String)>)> {
let condition = conditions
.pop()
.expect("should have at least one condition in m2m relation");
assert!(
conditions.is_empty(),
"should have at most one condition in m2m relation"
);
) -> TranslateResult<(Expression, JoinFields)> {
let condition = conditions.map(|mut conditions| {
let condition = conditions
.pop()
.expect("should have at least one condition in m2m relation");
assert!(
conditions.is_empty(),
"should have at most one condition in m2m relation"
);
condition
});

let link = RelationLink::new(field, condition);
let join_expr = link
.field()
.linking_fields()
.scalars()
.map(|left| (left.name().to_owned(), link.to_string()))
.collect_vec();
let link_name = link.to_string();

let query = builder
.build_get_related_records(link, args, selected_fields)
.map_err(TranslateError::QueryBuildFailure)?;

Ok((Expression::Query(query), join_expr))
Ok((Expression::Query(query), JoinFields(vec![link_name])))
}

fn build_read_one2m_query(
field: RelationField,
conditions: Vec<ScalarCondition>,
conditions: Option<Vec<ScalarCondition>>,
mut args: QueryArguments,
selected_fields: &FieldSelection,
builder: &dyn QueryBuilder,
) -> TranslateResult<(Expression, Vec<(String, String)>)> {
let join_expr = field
.linking_fields()
.scalars()
.zip(field.related_field().left_scalars())
.map(|(left, right)| (left.name().to_owned(), right.name().to_owned()))
.collect_vec();
) -> TranslateResult<(Expression, JoinFields)> {
let related_scalars = field.related_field().left_scalars();
let join_fields = related_scalars.iter().map(|sf| sf.name().to_owned()).collect();

// TODO: we ignore chunking for now
let linking_scalars = field.related_field().left_scalars();

assert_eq!(
linking_scalars.len(),
conditions.len(),
"linking fields should match conditions"
);
for (condition, child_field) in conditions.into_iter().zip(linking_scalars) {
args.add_filter(Filter::Scalar(ScalarFilter {
condition,
projection: ScalarProjection::Single(child_field.clone()),
mode: QueryMode::Default,
}));
if let Some(conditions) = conditions {
assert_eq!(
related_scalars.len(),
conditions.len(),
"linking fields should match conditions"
);
for (condition, child_field) in conditions.into_iter().zip(related_scalars) {
args.add_filter(Filter::Scalar(ScalarFilter {
condition,
projection: ScalarProjection::Single(child_field.clone()),
mode: QueryMode::Default,
}));
}
}

let to_one_relation = !field.arity().is_list();
Expand All @@ -227,5 +236,16 @@ fn build_read_one2m_query(
if to_one_relation {
expr = Expression::Unique(Box::new(expr));
}
Ok((expr, join_expr))
Ok((expr, JoinFields(join_fields)))
}

struct JoinFields(Vec<String>);

impl IntoIterator for JoinFields {
type Item = String;
type IntoIter = std::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
Loading
Loading