Skip to content

Commit 51c301a

Browse files
chore(completions): add tree sitter query for table aliases
1 parent b8a0986 commit 51c301a

File tree

4 files changed

+197
-2
lines changed

4 files changed

+197
-2
lines changed

crates/pgt_completions/src/context.rs

+11
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ pub(crate) struct CompletionContext<'a> {
115115
pub wrapping_statement_range: Option<tree_sitter::Range>,
116116

117117
pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
118+
119+
pub mentioned_table_aliases: HashMap<String, String>,
118120
}
119121

120122
impl<'a> CompletionContext<'a> {
@@ -131,6 +133,7 @@ impl<'a> CompletionContext<'a> {
131133
wrapping_statement_range: None,
132134
is_invocation: false,
133135
mentioned_relations: HashMap::new(),
136+
mentioned_table_aliases: HashMap::new(),
134137
};
135138

136139
ctx.gather_tree_context();
@@ -146,6 +149,7 @@ impl<'a> CompletionContext<'a> {
146149
let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql);
147150

148151
executor.add_query_results::<queries::RelationMatch>();
152+
executor.add_query_results::<queries::TableAliasMatch>();
149153

150154
for relation_match in executor.get_iter(stmt_range) {
151155
match relation_match {
@@ -166,6 +170,13 @@ impl<'a> CompletionContext<'a> {
166170
}
167171
};
168172
}
173+
174+
QueryResult::TableAliases(table_alias_match) => {
175+
self.mentioned_table_aliases.insert(
176+
table_alias_match.get_alias(sql),
177+
table_alias_match.get_table(sql),
178+
);
179+
}
169180
};
170181
}
171182
}

crates/pgt_treesitter_queries/src/lib.rs

+71-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,77 @@ impl<'a> Iterator for QueryResultIter<'a> {
6868
#[cfg(test)]
6969
mod tests {
7070

71-
use crate::{TreeSitterQueriesExecutor, queries::RelationMatch};
71+
use crate::{
72+
TreeSitterQueriesExecutor,
73+
queries::{RelationMatch, TableAliasMatch},
74+
};
75+
76+
#[test]
77+
fn finds_all_table_aliases() {
78+
let sql = r#"
79+
select
80+
*
81+
from
82+
(
83+
select
84+
something
85+
from
86+
public.cool_table pu
87+
join private.cool_tableau pr on pu.id = pr.id
88+
where
89+
x = '123'
90+
union
91+
select
92+
something_else
93+
from
94+
another_table puat
95+
inner join private.another_tableau prat on puat.id = prat.id
96+
union
97+
select
98+
x,
99+
y
100+
from
101+
public.get_something_cool ()
102+
) as cool
103+
join users u on u.id = cool.something
104+
where
105+
col = 17;
106+
"#;
107+
108+
let mut parser = tree_sitter::Parser::new();
109+
parser.set_language(tree_sitter_sql::language()).unwrap();
110+
111+
let tree = parser.parse(sql, None).unwrap();
112+
113+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
114+
115+
executor.add_query_results::<TableAliasMatch>();
116+
117+
let results: Vec<&TableAliasMatch> = executor
118+
.get_iter(None)
119+
.filter_map(|q| q.try_into().ok())
120+
.collect();
121+
122+
assert_eq!(results[0].get_schema(sql), Some("public".into()));
123+
assert_eq!(results[0].get_table(sql), "cool_table");
124+
assert_eq!(results[0].get_alias(sql), "pu");
125+
126+
assert_eq!(results[1].get_schema(sql), Some("private".into()));
127+
assert_eq!(results[1].get_table(sql), "cool_tableau");
128+
assert_eq!(results[1].get_alias(sql), "pr");
129+
130+
assert_eq!(results[2].get_schema(sql), None);
131+
assert_eq!(results[2].get_table(sql), "another_table");
132+
assert_eq!(results[2].get_alias(sql), "puat");
133+
134+
assert_eq!(results[3].get_schema(sql), Some("private".into()));
135+
assert_eq!(results[3].get_table(sql), "another_tableau");
136+
assert_eq!(results[3].get_alias(sql), "prat");
137+
138+
assert_eq!(results[4].get_schema(sql), None);
139+
assert_eq!(results[4].get_table(sql), "users");
140+
assert_eq!(results[4].get_alias(sql), "u");
141+
}
72142

73143
#[test]
74144
fn finds_all_relations_and_ignores_functions() {

crates/pgt_treesitter_queries/src/queries/mod.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
mod relations;
2+
mod table_aliases;
23

34
pub use relations::*;
5+
pub use table_aliases::*;
46

57
#[derive(Debug)]
68
pub enum QueryResult<'a> {
79
Relation(RelationMatch<'a>),
10+
TableAliases(TableAliasMatch<'a>),
811
}
912

1013
impl QueryResult<'_> {
1114
pub fn within_range(&self, range: &tree_sitter::Range) -> bool {
1215
match self {
13-
Self::Relation(rm) => {
16+
QueryResult::Relation(rm) => {
1417
let start = match rm.schema {
1518
Some(s) => s.start_position(),
1619
None => rm.table.start_position(),
@@ -20,6 +23,11 @@ impl QueryResult<'_> {
2023

2124
start >= range.start_point && end <= range.end_point
2225
}
26+
QueryResult::TableAliases(m) => {
27+
let start = m.table.start_position();
28+
let end = m.alias.end_position();
29+
start >= range.start_point && end <= range.end_point
30+
}
2331
}
2432
}
2533
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
(relation
10+
(object_reference
11+
.
12+
(identifier) @schema_or_table
13+
"."?
14+
(identifier)? @table
15+
)
16+
(keyword_as)?
17+
(identifier) @alias
18+
)
19+
"#;
20+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
21+
});
22+
23+
#[derive(Debug)]
24+
pub struct TableAliasMatch<'a> {
25+
pub(crate) table: tree_sitter::Node<'a>,
26+
pub(crate) alias: tree_sitter::Node<'a>,
27+
pub(crate) schema: Option<tree_sitter::Node<'a>>,
28+
}
29+
30+
impl TableAliasMatch<'_> {
31+
pub fn get_alias(&self, sql: &str) -> String {
32+
self.alias
33+
.utf8_text(sql.as_bytes())
34+
.expect("Failed to get alias from TableAliasMatch")
35+
.to_string()
36+
}
37+
38+
pub fn get_table(&self, sql: &str) -> String {
39+
self.table
40+
.utf8_text(sql.as_bytes())
41+
.expect("Failed to get table from TableAliasMatch")
42+
.to_string()
43+
}
44+
45+
pub fn get_schema(&self, sql: &str) -> Option<String> {
46+
self.schema.as_ref().map(|n| {
47+
n.utf8_text(sql.as_bytes())
48+
.expect("Failed to get table from TableAliasMatch")
49+
.to_string()
50+
})
51+
}
52+
}
53+
54+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a TableAliasMatch<'a> {
55+
type Error = String;
56+
57+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
58+
match q {
59+
QueryResult::TableAliases(t) => Ok(t),
60+
61+
#[allow(unreachable_patterns)]
62+
_ => Err("Invalid QueryResult type".into()),
63+
}
64+
}
65+
}
66+
67+
impl<'a> QueryTryFrom<'a> for TableAliasMatch<'a> {
68+
type Ref = &'a TableAliasMatch<'a>;
69+
}
70+
71+
impl<'a> Query<'a> for TableAliasMatch<'a> {
72+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
73+
let mut cursor = tree_sitter::QueryCursor::new();
74+
75+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
76+
77+
let mut to_return = vec![];
78+
79+
for m in matches {
80+
if m.captures.len() == 3 {
81+
let schema = m.captures[0].node;
82+
let table = m.captures[1].node;
83+
let alias = m.captures[2].node;
84+
85+
to_return.push(QueryResult::TableAliases(TableAliasMatch {
86+
table,
87+
alias,
88+
schema: Some(schema),
89+
}));
90+
}
91+
92+
if m.captures.len() == 2 {
93+
let table = m.captures[0].node;
94+
let alias = m.captures[1].node;
95+
96+
to_return.push(QueryResult::TableAliases(TableAliasMatch {
97+
table,
98+
alias,
99+
schema: None,
100+
}));
101+
}
102+
}
103+
104+
to_return
105+
}
106+
}

0 commit comments

Comments
 (0)