Skip to content

Commit ba83202

Browse files
committed
add ts query
1 parent 9b6c7aa commit ba83202

File tree

9 files changed

+387
-197
lines changed

9 files changed

+387
-197
lines changed

crates/pgt_treesitter_queries/src/lib.rs

+43-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ impl<'a> Iterator for QueryResultIter<'a> {
6868
#[cfg(test)]
6969
mod tests {
7070

71-
use crate::{TreeSitterQueriesExecutor, queries::RelationMatch};
71+
use crate::{
72+
TreeSitterQueriesExecutor,
73+
queries::{ParameterMatch, RelationMatch},
74+
};
7275

7376
#[test]
7477
fn finds_all_relations_and_ignores_functions() {
@@ -137,11 +140,11 @@ where
137140
select
138141
*
139142
from (
140-
select *
143+
select *
141144
from (
142145
select *
143146
from private.something
144-
) as sq2
147+
) as sq2
145148
join private.tableau pt1
146149
on sq2.id = pt1.id
147150
) as sq1
@@ -185,4 +188,41 @@ on sq1.id = pt.id;
185188
assert_eq!(results[0].get_schema(sql), Some("private".into()));
186189
assert_eq!(results[0].get_table(sql), "something");
187190
}
191+
192+
#[test]
193+
fn extracts_parameters() {
194+
let sql = r#"select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3;"#;
195+
196+
let mut parser = tree_sitter::Parser::new();
197+
parser.set_language(tree_sitter_sql::language()).unwrap();
198+
199+
let tree = parser.parse(sql, None).unwrap();
200+
201+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
202+
203+
executor.add_query_results::<ParameterMatch>();
204+
205+
let results: Vec<&ParameterMatch> = executor
206+
.get_iter(None)
207+
.filter_map(|q| q.try_into().ok())
208+
.collect();
209+
210+
assert_eq!(results.len(), 4);
211+
212+
assert_eq!(results[0].get_root(sql), None);
213+
assert_eq!(results[0].get_path(sql), None);
214+
assert_eq!(results[0].get_field(sql), "v_test");
215+
216+
assert_eq!(results[1].get_root(sql), Some("fn_name".into()));
217+
assert_eq!(results[1].get_path(sql), Some("custom_type".into()));
218+
assert_eq!(results[1].get_field(sql), "v_test2");
219+
220+
assert_eq!(results[2].get_root(sql), None);
221+
assert_eq!(results[2].get_path(sql), None);
222+
assert_eq!(results[2].get_field(sql), "$3");
223+
224+
assert_eq!(results[3].get_root(sql), None);
225+
assert_eq!(results[3].get_path(sql), Some("custom_type".into()));
226+
assert_eq!(results[3].get_field(sql), "v_test3");
227+
}
188228
}

crates/pgt_treesitter_queries/src/queries/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
mod parameters;
12
mod relations;
23

4+
pub use parameters::*;
35
pub use relations::*;
46

57
#[derive(Debug)]
68
pub enum QueryResult<'a> {
79
Relation(RelationMatch<'a>),
10+
Parameter(ParameterMatch<'a>),
811
}
912

1013
impl QueryResult<'_> {
@@ -18,6 +21,16 @@ impl QueryResult<'_> {
1821

1922
let end = rm.table.end_position();
2023

24+
start >= range.start_point && end <= range.end_point
25+
}
26+
Self::Parameter(pm) => {
27+
let start = match pm.root {
28+
Some(s) => s.start_position(),
29+
None => pm.path.as_ref().unwrap().start_position(),
30+
};
31+
32+
let end = pm.field.end_position();
33+
2134
start >= range.start_point && end <= range.end_point
2235
}
2336
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
[
10+
(field
11+
(identifier)) @reference
12+
(field
13+
(object_reference)
14+
"." (identifier)) @reference
15+
(parameter) @parameter
16+
]
17+
"#;
18+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
19+
});
20+
21+
#[derive(Debug)]
22+
pub struct ParameterMatch<'a> {
23+
pub(crate) root: Option<tree_sitter::Node<'a>>,
24+
pub(crate) path: Option<tree_sitter::Node<'a>>,
25+
26+
pub(crate) field: tree_sitter::Node<'a>,
27+
}
28+
29+
impl ParameterMatch<'_> {
30+
pub fn get_root(&self, sql: &str) -> Option<String> {
31+
let str = self
32+
.root
33+
.as_ref()?
34+
.utf8_text(sql.as_bytes())
35+
.expect("Failed to get schema from RelationMatch");
36+
37+
Some(str.to_string())
38+
}
39+
40+
pub fn get_path(&self, sql: &str) -> Option<String> {
41+
let str = self
42+
.path
43+
.as_ref()?
44+
.utf8_text(sql.as_bytes())
45+
.expect("Failed to get table from RelationMatch");
46+
47+
Some(str.to_string())
48+
}
49+
50+
pub fn get_field(&self, sql: &str) -> String {
51+
self.field
52+
.utf8_text(sql.as_bytes())
53+
.expect("Failed to get table from RelationMatch")
54+
.to_string()
55+
}
56+
}
57+
58+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ParameterMatch<'a> {
59+
type Error = String;
60+
61+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
62+
match q {
63+
QueryResult::Parameter(r) => Ok(r),
64+
65+
#[allow(unreachable_patterns)]
66+
_ => Err("Invalid QueryResult type".into()),
67+
}
68+
}
69+
}
70+
71+
impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> {
72+
type Ref = &'a ParameterMatch<'a>;
73+
}
74+
75+
impl<'a> Query<'a> for ParameterMatch<'a> {
76+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
77+
let mut cursor = tree_sitter::QueryCursor::new();
78+
79+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
80+
81+
matches
82+
.filter_map(|m| {
83+
let captures = m.captures;
84+
85+
// We expect exactly one capture for a parameter
86+
if captures.len() != 1 {
87+
return None;
88+
}
89+
90+
let field = captures[0].node;
91+
let text = match field.utf8_text(stmt.as_bytes()) {
92+
Ok(t) => t,
93+
Err(_) => return None,
94+
};
95+
let parts: Vec<&str> = text.split('.').collect();
96+
97+
let param_match = match parts.len() {
98+
// Simple field: field_name
99+
1 => ParameterMatch {
100+
root: None,
101+
path: None,
102+
field,
103+
},
104+
// Table qualified: table.field_name
105+
2 => ParameterMatch {
106+
root: None,
107+
path: field.named_child(0),
108+
field: field.named_child(1)?,
109+
},
110+
// Fully qualified: schema.table.field_name
111+
3 => ParameterMatch {
112+
root: field.named_child(0).and_then(|n| n.named_child(0)),
113+
path: field.named_child(0).and_then(|n| n.named_child(1)),
114+
field: field.named_child(1)?,
115+
},
116+
// Unexpected number of parts
117+
_ => return None,
118+
};
119+
120+
Some(QueryResult::Parameter(param_match))
121+
})
122+
.collect()
123+
}
124+
}

crates/pgt_typecheck/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
mod diagnostics;
22
mod typed_identifier;
33

4-
use diagnostics::create_type_error;
54
pub use diagnostics::TypecheckDiagnostic;
6-
pub use typed_identifier::TypedIdentifier;
5+
use diagnostics::create_type_error;
76
use pgt_text_size::TextRange;
87
use sqlx::postgres::PgDatabaseError;
98
pub use sqlx::postgres::PgSeverity;
109
use sqlx::{Executor, PgPool};
10+
pub use typed_identifier::TypedIdentifier;
1111
use typed_identifier::apply_identifiers;
1212

1313
#[derive(Debug)]

crates/pgt_typecheck/src/typed_identifier.rs

+104-9
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
#[derive(Debug)]
2-
pub struct Type {
3-
pub schema: Option<String>,
4-
pub name: String,
5-
pub oid: i32,
6-
}
7-
81
#[derive(Debug)]
92
pub struct TypedIdentifier {
103
pub schema: Option<String>,
114
pub relation: Option<String>,
125
pub name: String,
13-
pub type_: Type,
6+
pub type_: (Option<String>, String),
147
}
158

169
impl TypedIdentifier {
1710
pub fn new(
1811
schema: Option<String>,
1912
relation: Option<String>,
2013
name: String,
21-
type_: Type,
14+
type_: (Option<String>, String),
2215
) -> Self {
2316
TypedIdentifier {
2417
schema,
@@ -44,5 +37,107 @@ pub fn apply_identifiers<'a>(
4437
println!("Applying identifiers to SQL: {}", sql);
4538
println!("Identifiers: {:?}", identifiers);
4639
println!("CST: {:#?}", cst);
40+
4741
sql
4842
}
43+
44+
#[cfg(test)]
45+
mod tests {
46+
use pgt_test_utils::test_database::get_new_test_db;
47+
use sqlx::Executor;
48+
49+
#[tokio::test]
50+
async fn test_apply_identifiers() {
51+
let input = "select v_test + fn_name.custom_type.v_test2 + $3 + test.field;";
52+
53+
let test_db = get_new_test_db().await;
54+
55+
let mut parser = tree_sitter::Parser::new();
56+
parser
57+
.set_language(tree_sitter_sql::language())
58+
.expect("Error loading sql language");
59+
60+
let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db)
61+
.await
62+
.expect("Failed to load Schema Cache");
63+
64+
let root = pgt_query_ext::parse(input).unwrap();
65+
let tree = parser.parse(input, None).unwrap();
66+
67+
println!("Parsed SQL: {:?}", root);
68+
println!("Parsed CST: {:?}", tree);
69+
70+
// let mut parameters = Vec::new();
71+
72+
enum Parameter {
73+
Identifier {
74+
range: (usize, usize),
75+
name: (Option<String>, String),
76+
},
77+
Parameter {
78+
range: (usize, usize),
79+
idx: usize,
80+
},
81+
}
82+
83+
let mut c = tree.walk();
84+
85+
'outer: loop {
86+
// 0. Add the current node to the map.
87+
println!("Current node: {:?}", c.node());
88+
match c.node().kind() {
89+
"identifier" => {
90+
println!(
91+
"Found identifier: {:?}",
92+
c.node().utf8_text(input.as_bytes()).unwrap()
93+
);
94+
}
95+
"parameter" => {
96+
println!(
97+
"Found parameter: {:?}",
98+
c.node().utf8_text(input.as_bytes()).unwrap()
99+
);
100+
}
101+
"object_reference" => {
102+
println!(
103+
"Found object reference: {:?}",
104+
c.node().utf8_text(input.as_bytes()).unwrap()
105+
);
106+
107+
// let source = self.text;
108+
// ts_node.utf8_text(source.as_bytes()).ok().map(|txt| {
109+
// if SanitizedCompletionParams::is_sanitized_token(txt) {
110+
// NodeText::Replaced
111+
// } else {
112+
// NodeText::Original(txt)
113+
// }
114+
// })
115+
}
116+
_ => {}
117+
}
118+
119+
// 1. Go to its child and continue.
120+
if c.goto_first_child() {
121+
continue 'outer;
122+
}
123+
124+
// 2. We've reached a leaf (node without a child). We will go to a sibling.
125+
if c.goto_next_sibling() {
126+
continue 'outer;
127+
}
128+
129+
// 3. If there are no more siblings, we need to go back up.
130+
'inner: loop {
131+
// 4. Check if we've reached the root node. If so, we're done.
132+
if !c.goto_parent() {
133+
break 'outer;
134+
}
135+
// 5. Go to the previous node's sibling.
136+
if c.goto_next_sibling() {
137+
// And break out of the inner loop.
138+
break 'inner;
139+
}
140+
}
141+
}
142+
}
143+
}

0 commit comments

Comments
 (0)