Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/FunSQL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ export
funsql_from,
funsql_fun,
funsql_group,
funsql_hide,
funsql_highlight,
funsql_in,
funsql_into,
funsql_iterate,
funsql_is_not_null,
funsql_is_null,
Expand All @@ -82,6 +84,7 @@ export
funsql_rank,
funsql_row_number,
funsql_select,
funsql_show,
funsql_sort,
funsql_sum,
funsql_with
Expand Down
168 changes: 88 additions & 80 deletions src/link.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@ struct LinkContext
knot_refs)
end

function link(q::SQLQuery)
@dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError())
ctx = LinkContext(catalog)
t = row_type(tail)
function _select(t::RowType)
refs = SQLQuery[]
t.visible || return refs
for (f, ft) in t.fields
if ft isa ScalarType
ft.visible || continue
push!(refs, Get(f))
else
nested_refs = _select(ft)
for nested_ref in nested_refs
push!(refs, Nested(name = f, tail = nested_ref))
end
end
end
refs
end

function link(q::SQLQuery)
@dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError())
ctx = LinkContext(catalog)
t = row_type(tail)
refs = _select(t)
tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs))
WithContext(tail = tail′, catalog = catalog, defs = ctx.defs)
end
Expand Down Expand Up @@ -123,19 +135,15 @@ function dismantle(n::GroupNode, ctx)
Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′)
end

function dismantle(n::IterateNode, ctx)
function dismantle(n::IntoNode, ctx)
tail′ = dismantle(ctx)
iterator′ = dismantle(n.iterator, ctx)
Iterate(iterator = iterator′, tail = tail′)
Into(name = n.name, tail = tail′)
end

function dismantle(n::JoinNode, ctx)
rt = row_type(n.joinee)
router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType))
function dismantle(n::IterateNode, ctx)
tail′ = dismantle(ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional, tail = tail′)
iterator′ = dismantle(n.iterator, ctx)
Iterate(iterator = iterator′, tail = tail′)
end

function dismantle(n::LimitNode, ctx)
Expand Down Expand Up @@ -181,6 +189,13 @@ function dismantle_scalar(n::ResolvedNode, ctx)
end
end

function dismantle(n::RoutedJoinNode, ctx)
tail′ = dismantle(ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional, tail = tail′)
end

function dismantle(n::SelectNode, ctx)
tail′ = dismantle(ctx)
args′ = dismantle_scalar(n.args, ctx)
Expand Down Expand Up @@ -232,16 +247,7 @@ function link(n::AppendNode, ctx)
end

function link(n::AsNode, ctx)
refs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, (local tail) |> Nested(name = (local name)))
@assert name == n.name
push!(refs, tail)
else
error()
end
end
tail′ = link(ctx.tail, ctx, refs)
tail′ = link(ctx)
As(name = n.name, tail = tail′)
end

Expand Down Expand Up @@ -289,10 +295,8 @@ function link(n::FromIterateNode, ctx)
end

function link(n::FromTableExpressionNode, ctx)
refs = ctx.cte_refs[(n.name, n.depth)]
for ref in ctx.refs
push!(refs, Nested(name = n.name, tail = ref))
end
cte_refs = ctx.cte_refs[(n.name, n.depth)]
append!(cte_refs, ctx.refs)
n
end

Expand Down Expand Up @@ -333,6 +337,20 @@ function link(n::GroupNode, ctx)
Group(by = n.by, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′)
end

function link(n::IntoNode, ctx)
refs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, (local tail) |> Nested(name = (local name)))
@assert name == n.name
push!(refs, tail)
else
error()
end
end
tail′ = link(ctx.tail, ctx, refs)
Into(name = n.name, tail = tail′)
end

function link(n::IterateNode, ctx)
iterator′ = n.iterator
defs = copy(ctx.defs)
Expand Down Expand Up @@ -364,53 +382,6 @@ function link(n::IterateNode, ctx)
Padding(tail = q′)
end

function route(r::JoinRouter, ref::SQLQuery)
if @dissect(ref, Nested(name = (local name))) && name in r.label_set
return 1
end
if @dissect(ref, Get(name = (local name))) && name in r.label_set
return 1
end
if @dissect(ref, Agg()) && r.group
return 1
end
return -1
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLQuery[]
rrefs = SQLQuery[]
for ref in ctx.refs
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
if n.optional && isempty(rrefs)
return link(ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLQuery[]
lateral_refs = SQLQuery[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, tail = link(n.joinee, ctx, rrefs))
RoutedJoin(
joinee = joinee′,
on = n.on,
router = n.router,
left = n.left,
right = n.right,
lateral = lateral,
tail = tail′)
end

function link(n::LimitNode, ctx)
tail′ = Linked(ctx.refs, tail = link(ctx))
Limit(offset = n.offset, limit = n.limit, tail = tail′)
Expand Down Expand Up @@ -459,6 +430,46 @@ function link(n::PartitionNode, ctx)
Partition(by = n.by, order_by = n.order_by, frame = n.frame, name = n.name, tail = tail′)
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLQuery[]
rrefs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, Nested(name = (local name))) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
if n.optional && isempty(rrefs)
return link(ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLQuery[]
lateral_refs = SQLQuery[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
if @dissect(ref, Nested(name = (local name))) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, tail = link(Into(name = n.name, tail = n.joinee), ctx, rrefs))
RoutedJoin(
joinee = joinee′,
on = n.on,
name = n.name,
left = n.left,
right = n.right,
lateral = lateral,
tail = tail′)
end

function link(n::SelectNode, ctx)
refs = SQLQuery[]
gather!(n.args, ctx, refs)
Expand Down Expand Up @@ -556,12 +567,9 @@ end
function gather!(n::IsolatedNode, ctx)
def = ctx.defs[n.idx]
!@dissect(def, Linked()) || return
refs = SQLQuery[]
for (f, ft) in n.type.fields
if ft isa ScalarType
push!(refs, Get(f))
break
end
refs = _select(n.type)
if !isempty(refs)
refs = refs[1:1]
end
def′ = Linked(refs, tail = link(def, ctx, refs))
ctx.defs[n.idx] = def′
Expand Down
16 changes: 10 additions & 6 deletions src/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ terminal(q::SQLQuery) =
Chain(q′, q) =
convert(SQLQuery, q)(q′)

label(q::SQLQuery) =
@something label(q.head) label(q.tail)
function label(q::SQLQuery; default = :_)
l = label(q.head)
l !== nothing ? l : label(q.tail; default)
end

label(n::AbstractSQLNode) =
nothing

label(::Nothing) =
:_
label(::Nothing; default = :_) =
default

label(q) =
label(convert(SQLQuery, q))
label(q; default = :_) =
label(convert(SQLQuery, q); default)


# A variant of SQLQuery for assembling a chain of identifiers.
Expand Down Expand Up @@ -913,6 +915,7 @@ include("nodes/get.jl")
include("nodes/group.jl")
include("nodes/highlight.jl")
include("nodes/internal.jl")
include("nodes/into.jl")
include("nodes/iterate.jl")
include("nodes/join.jl")
include("nodes/limit.jl")
Expand All @@ -921,6 +924,7 @@ include("nodes/order.jl")
include("nodes/over.jl")
include("nodes/partition.jl")
include("nodes/select.jl")
include("nodes/show.jl")
include("nodes/sort.jl")
include("nodes/variable.jl")
include("nodes/where.jl")
Expand Down
17 changes: 8 additions & 9 deletions src/nodes/as.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ AsNode(name) =
As(name; tail = nothing)
name => tail

In a scalar context, `As` specifies the name of the output column. When
applied to tabular data, `As` wraps the data in a nested record.
`As` specifies the name of the output column.

The arrow operator (`=>`) is a shorthand notation for `As`.

Expand All @@ -35,19 +34,19 @@ SELECT "person_1"."person_id" AS "id"
FROM "person" AS "person_1"
```

*Show all patients together with their state of residence.*
*Show all patients together with their primary care provider.*

```jldoctest
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]);
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]);

julia> location = SQLTable(:location, columns = [:location_id, :state]);
julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]);

julia> q = From(:person) |>
Join(From(:location) |> As(:location),
on = Get.location_id .== Get.location.location_id) |>
Select(Get.person_id, Get.location.state);
Join(From(:provider) |> As(:pcp),
on = Get.provider_id .== Get.pcp.provider_id) |>
Select(Get.person_id, Get.pcp.provider_name);

julia> print(render(q, tables = [person, location]))
julia> print(render(q, tables = [person, provider]))
SELECT
"person_1"."person_id",
"location_1"."state"
Expand Down
42 changes: 42 additions & 0 deletions src/nodes/hide.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Hide node

mutable struct HideNode <: TabularNode
over::Union{SQLNode, Nothing}
names::Vector{Symbol}
label_map::FunSQL.OrderedDict{Symbol, Int}

function HideNode(; over = nothing, names = [], label_map = nothing)
if label_map !== nothing
new(over, names, label_map)
else
n = new(over, names, FunSQL.OrderedDict{Symbol, Int}())
for (i, name) in enumerate(n.names)
if name in keys(n.label_map)
err = FunSQL.DuplicateLabelError(name, path = [n])
throw(err)
end
n.label_map[name] = i
end
n
end
end
end

HideNode(names...; over = nothing) =
HideNode(over = over, names = Symbol[names...])

Hide(args...; kws...) =
HideNode(args...; kws...) |> SQLNode

const funsql_hide = Hide

dissect(scr::Symbol, ::typeof(Hide), pats::Vector{Any}) =
dissect(scr, HideNode, pats)

function FunSQL.PrettyPrinting.quoteof(n::HideNode, ctx::FunSQL.QuoteContext)
ex = Expr(:call, nameof(Hide), quoteof(n.names, ctx)...)
if n.over !== nothing
ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex)
end
ex
end
Loading
Loading