Skip to content

feat(completions): complete in WITH CHECK and USING clauses #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Jun 12, 2025
Merged
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
daf8efd
sure
juleswritescode May 27, 2025
dcf2439
so far!
juleswritescode May 27, 2025
9abf459
so far
juleswritescode May 28, 2025
8d4836e
yayyy
juleswritescode May 28, 2025
363c2ea
setup roles
juleswritescode May 28, 2025
0e1f0a8
use distinct method
juleswritescode May 28, 2025
8e4d17c
better…
juleswritescode May 28, 2025
d1d8453
sqlx prepare
juleswritescode May 29, 2025
9021bc0
ok
juleswritescode May 29, 2025
0753b50
better
juleswritescode May 29, 2025
1be61b4
ok
juleswritescode May 29, 2025
e55271f
ok
juleswritescode May 29, 2025
9cd04dd
ok
juleswritescode May 29, 2025
7e1c565
ok
juleswritescode May 29, 2025
b37adda
ok
juleswritescode May 29, 2025
cc38757
adjust test
juleswritescode May 29, 2025
a1e1a9c
ok
juleswritescode May 29, 2025
403cf82
resolve conflicts
juleswritescode Jun 2, 2025
9c3184e
ok
juleswritescode Jun 2, 2025
299e469
ok
juleswritescode Jun 3, 2025
2d37803
quicksave
juleswritescode Jun 3, 2025
027324f
reading the card…
juleswritescode Jun 3, 2025
0dd285f
wowa wiwa
juleswritescode Jun 3, 2025
f72297a
ok
juleswritescode Jun 5, 2025
578741e
lowercase…
juleswritescode Jun 5, 2025
ef5cb98
wowa wiwa
juleswritescode Jun 5, 2025
9214c91
add tests
juleswritescode Jun 5, 2025
f2b4b44
linty
juleswritescode Jun 5, 2025
b5e82ed
format
juleswritescode Jun 5, 2025
e736bd0
Merge branch 'main' of https://github.com/supabase-community/postgres…
juleswritescode Jun 5, 2025
cab8ead
Merge branch 'main' into feat/to-role
juleswritescode Jun 5, 2025
f109ece
Merge branch 'main' into feat/to-role
juleswritescode Jun 6, 2025
4beb1f2
merged
juleswritescode Jun 6, 2025
3117d07
simplify word parser
juleswritescode Jun 6, 2025
6640417
infer position
juleswritescode Jun 6, 2025
4c04ba9
hm
juleswritescode Jun 6, 2025
9a4e9fc
ok
juleswritescode Jun 6, 2025
eb54d62
cool
juleswritescode Jun 6, 2025
e9205c4
ok
juleswritescode Jun 6, 2025
f0ffd1d
ok
juleswritescode Jun 6, 2025
c1cbed8
fix tests
juleswritescode Jun 6, 2025
e8dbfdf
Update crates/pgt_completions/src/relevance/filtering.rs
juleswritescode Jun 6, 2025
11a4f57
fix sad bug
juleswritescode Jun 6, 2025
142c8e8
Merge branch 'feat/check-using' of https://github.com/supabase-commun…
juleswritescode Jun 6, 2025
ef6df7d
Update crates/pgt_completions/src/sanitization.rs
juleswritescode Jun 6, 2025
db31640
lintci
juleswritescode Jun 9, 2025
4da85cf
Merge branch 'feat/check-using' of https://github.com/supabase-commun…
juleswritescode Jun 9, 2025
6f12fcc
fmt
juleswritescode Jun 9, 2025
673b8d6
merge with main
juleswritescode Jun 10, 2025
2379f0f
Merge branch 'main' into feat/check-using
juleswritescode Jun 11, 2025
c1050c8
Merge branch 'main' of https://github.com/supabase-community/postgres…
juleswritescode Jun 12, 2025
d9497dd
ok
juleswritescode Jun 12, 2025
f03af23
ack
juleswritescode Jun 12, 2025
5606cf3
Merge branch 'feat/check-using' of https://github.com/supabase-commun…
juleswritescode Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 121 additions & 54 deletions crates/pgt_completions/src/context/base_parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::iter::Peekable;

use pgt_text_size::{TextRange, TextSize};
use std::iter::Peekable;

pub(crate) struct TokenNavigator {
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
@@ -101,73 +100,139 @@ impl WordWithIndex {
}
}

/// Note: A policy name within quotation marks will be considered a single word.
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
let mut words = vec![];

let mut start_of_word: Option<usize> = None;
let mut current_word = String::new();
let mut in_quotation_marks = false;

for (current_position, current_char) in sql.char_indices() {
if (current_char.is_ascii_whitespace() || current_char == ';')
&& !current_word.is_empty()
&& start_of_word.is_some()
&& !in_quotation_marks
{
words.push(WordWithIndex {
word: current_word,
start: start_of_word.unwrap(),
end: current_position,
});

current_word = String::new();
start_of_word = None;
} else if (current_char.is_ascii_whitespace() || current_char == ';')
&& current_word.is_empty()
{
// do nothing
} else if current_char == '"' && start_of_word.is_none() {
in_quotation_marks = true;
current_word.push(current_char);
start_of_word = Some(current_position);
} else if current_char == '"' && start_of_word.is_some() {
current_word.push(current_char);
in_quotation_marks = false;
} else if start_of_word.is_some() {
current_word.push(current_char)
pub(crate) struct SubStatementParser {
start_of_word: Option<usize>,
current_word: String,
in_quotation_marks: bool,
is_fn_call: bool,
words: Vec<WordWithIndex>,
}

impl SubStatementParser {
pub(crate) fn parse(sql: &str) -> Result<Vec<WordWithIndex>, String> {
let mut parser = SubStatementParser {
start_of_word: None,
current_word: String::new(),
in_quotation_marks: false,
is_fn_call: false,
words: vec![],
};

parser.collect_words(sql);

if parser.in_quotation_marks {
Err("String was not closed properly.".into())
} else {
start_of_word = Some(current_position);
current_word.push(current_char);
Ok(parser.words)
}
}

if let Some(start_of_word) = start_of_word {
if !current_word.is_empty() {
words.push(WordWithIndex {
word: current_word,
start: start_of_word,
end: sql.len(),
});
pub fn collect_words(&mut self, sql: &str) {
for (pos, c) in sql.char_indices() {
match c {
'"' => {
if !self.has_started_word() {
self.in_quotation_marks = true;
self.add_char(c);
self.start_word(pos);
} else {
self.in_quotation_marks = false;
self.add_char(c);
}
}

'(' => {
if !self.has_started_word() {
self.push_char_as_word(c, pos);
} else {
self.add_char(c);
self.is_fn_call = true;
}
}

')' => {
if self.is_fn_call {
self.add_char(c);
self.is_fn_call = false;
} else {
if self.has_started_word() {
self.push_word(pos);
}
self.push_char_as_word(c, pos);
}
}

_ => {
if c.is_ascii_whitespace() || c == ';' {
if self.in_quotation_marks {
self.add_char(c);
} else if !self.is_empty() && self.has_started_word() {
self.push_word(pos);
}
} else if self.has_started_word() {
self.add_char(c);
} else {
self.start_word(pos);
self.add_char(c)
}
}
}
}

if self.has_started_word() && !self.is_empty() {
self.push_word(sql.len())
}
}

if in_quotation_marks {
Err("String was not closed properly.".into())
} else {
Ok(words)
fn is_empty(&self) -> bool {
self.current_word.is_empty()
}

fn add_char(&mut self, c: char) {
self.current_word.push(c)
}

fn start_word(&mut self, pos: usize) {
self.start_of_word = Some(pos);
}

fn has_started_word(&self) -> bool {
self.start_of_word.is_some()
}

fn push_char_as_word(&mut self, c: char, pos: usize) {
self.words.push(WordWithIndex {
word: String::from(c),
start: pos,
end: pos + 1,
});
}

fn push_word(&mut self, current_position: usize) {
self.words.push(WordWithIndex {
word: self.current_word.clone(),
start: self.start_of_word.unwrap(),
end: current_position,
});
self.current_word = String::new();
self.start_of_word = None;
}
}

/// Note: A policy name within quotation marks will be considered a single word.
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
SubStatementParser::parse(sql)
}

#[cfg(test)]
mod tests {
use crate::context::base_parser::{WordWithIndex, sql_to_words};
use crate::context::base_parser::{SubStatementParser, WordWithIndex, sql_to_words};

#[test]
fn determines_positions_correctly() {
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string();
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (auth.uid());".to_string();

let words = sql_to_words(query.as_str()).unwrap();
let words = SubStatementParser::parse(query.as_str()).unwrap();

assert_eq!(words[0], to_word("create", 1, 7));
assert_eq!(words[1], to_word("policy", 8, 14));
@@ -181,7 +246,9 @@ mod tests {
assert_eq!(words[9], to_word("to", 73, 75));
assert_eq!(words[10], to_word("public", 78, 84));
assert_eq!(words[11], to_word("using", 87, 92));
assert_eq!(words[12], to_word("(true)", 93, 99));
assert_eq!(words[12], to_word("(", 93, 94));
assert_eq!(words[13], to_word("auth.uid()", 94, 104));
assert_eq!(words[14], to_word(")", 104, 105));
}

#[test]
27 changes: 25 additions & 2 deletions crates/pgt_completions/src/context/mod.rs
Original file line number Diff line number Diff line change
@@ -47,6 +47,15 @@ pub enum WrappingClause<'a> {
SetStatement,
AlterRole,
DropRole,

/// `PolicyCheck` refers to either the `WITH CHECK` or the `USING` clause
/// in a policy statement.
/// ```sql
/// CREATE POLICY "my pol" ON PUBLIC.USERS
/// FOR SELECT
/// USING (...) -- this one!
/// ```
PolicyCheck,
}

#[derive(PartialEq, Eq, Hash, Debug, Clone)]
@@ -78,6 +87,7 @@ pub(crate) enum NodeUnderCursor<'a> {
text: NodeText,
range: TextRange,
kind: String,
previous_node_kind: Option<String>,
},
}

@@ -222,6 +232,7 @@ impl<'a> CompletionContext<'a> {
text: revoke_context.node_text.into(),
range: revoke_context.node_range,
kind: revoke_context.node_kind.clone(),
previous_node_kind: None,
});

if revoke_context.node_kind == "revoke_table" {
@@ -249,6 +260,7 @@ impl<'a> CompletionContext<'a> {
text: grant_context.node_text.into(),
range: grant_context.node_range,
kind: grant_context.node_kind.clone(),
previous_node_kind: None,
});

if grant_context.node_kind == "grant_table" {
@@ -276,6 +288,7 @@ impl<'a> CompletionContext<'a> {
text: policy_context.node_text.into(),
range: policy_context.node_range,
kind: policy_context.node_kind.clone(),
previous_node_kind: Some(policy_context.previous_node_kind),
});

if policy_context.node_kind == "policy_table" {
@@ -295,7 +308,13 @@ impl<'a> CompletionContext<'a> {
}
"policy_role" => Some(WrappingClause::ToRoleAssignment),
"policy_table" => Some(WrappingClause::From),
_ => None,
_ => {
if policy_context.in_check_or_using_clause {
Some(WrappingClause::PolicyCheck)
} else {
None
}
}
};
}

@@ -785,7 +804,11 @@ impl<'a> CompletionContext<'a> {
.is_some_and(|sib| kinds.contains(&sib.kind()))
}

NodeUnderCursor::CustomNode { .. } => false,
NodeUnderCursor::CustomNode {
previous_node_kind, ..
} => previous_node_kind
.as_ref()
.is_some_and(|k| kinds.contains(&k.as_str())),
}
})
}
187 changes: 181 additions & 6 deletions crates/pgt_completions/src/context/policy_parser.rs
Original file line number Diff line number Diff line change
@@ -22,6 +22,10 @@ pub(crate) struct PolicyContext {
pub node_text: String,
pub node_range: TextRange,
pub node_kind: String,
pub previous_node_text: String,
pub previous_node_range: TextRange,
pub previous_node_kind: String,
pub in_check_or_using_clause: bool,
}

/// Simple parser that'll turn a policy-related statement into a context object required for
@@ -32,6 +36,7 @@ pub(crate) struct PolicyParser {
navigator: TokenNavigator,
context: PolicyContext,
cursor_position: usize,
in_check_or_using_clause: bool,
}

impl CompletionStatementParser for PolicyParser {
@@ -63,6 +68,7 @@ impl CompletionStatementParser for PolicyParser {
navigator: tokens.into(),
context: PolicyContext::default(),
cursor_position,
in_check_or_using_clause: false,
}
}
}
@@ -73,6 +79,8 @@ impl PolicyParser {
return;
}

self.context.in_check_or_using_clause = self.in_check_or_using_clause;

let previous = self.navigator.previous_token.take().unwrap();

match previous
@@ -84,6 +92,8 @@ impl PolicyParser {
self.context.node_range = token.get_range();
self.context.node_kind = "policy_name".into();
self.context.node_text = token.get_word();

self.context.previous_node_kind = "keyword_policy".into();
}
"on" => {
if token.get_word_without_quotes().contains('.') {
@@ -112,17 +122,35 @@ impl PolicyParser {
self.context.node_text = token.get_word();
self.context.node_kind = "policy_table".into();
}

self.context.previous_node_kind = "keyword_on".into();
}
"to" => {
self.context.node_range = token.get_range();
self.context.node_kind = "policy_role".into();
self.context.node_text = token.get_word();

self.context.previous_node_kind = "keyword_to".into();
}
_ => {

other => {
self.context.node_range = token.get_range();
self.context.node_text = token.get_word();

self.context.previous_node_range = previous.get_range();
self.context.previous_node_text = previous.get_word();

match other {
"(" | "=" => self.context.previous_node_kind = other.into(),
"and" => self.context.previous_node_kind = "keyword_and".into(),

_ => self.context.previous_node_kind = "".into(),
}
}
}

self.context.previous_node_range = previous.get_range();
self.context.previous_node_text = previous.get_word();
}

fn handle_token(&mut self, token: WordWithIndex) {
@@ -142,6 +170,13 @@ impl PolicyParser {
}
"on" => self.table_with_schema(),

"(" if self.navigator.prev_matches(&["using", "check"]) => {
self.in_check_or_using_clause = true;
}
")" => {
self.in_check_or_using_clause = false;
}

// skip the "to" so we don't parse it as the TO rolename when it's under the cursor
"rename" if self.navigator.next_matches(&["to"]) => {
self.navigator.advance();
@@ -218,7 +253,11 @@ mod tests {
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(25), TextSize::new(39)),
node_kind: "policy_name".into()
node_kind: "policy_name".into(),
in_check_or_using_clause: false,
previous_node_kind: "keyword_policy".into(),
previous_node_range: TextRange::new(18.into(), 24.into()),
previous_node_text: "policy".into(),
}
);

@@ -241,6 +280,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "".into(),
node_range: TextRange::new(TextSize::new(42), TextSize::new(56)),
in_check_or_using_clause: false,
previous_node_kind: "".into(),
previous_node_range: TextRange::new(25.into(), 41.into()),
previous_node_text: "\"my cool policy\"".into(),
}
);

@@ -263,6 +306,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "policy_table".into(),
node_range: TextRange::new(TextSize::new(45), TextSize::new(59)),
in_check_or_using_clause: false,
previous_node_kind: "keyword_on".into(),
previous_node_range: TextRange::new(42.into(), 44.into()),
previous_node_text: "on".into(),
}
);

@@ -285,6 +332,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "policy_table".into(),
node_range: TextRange::new(TextSize::new(50), TextSize::new(64)),
in_check_or_using_clause: false,
previous_node_kind: "keyword_on".into(),
previous_node_range: TextRange::new(42.into(), 44.into()),
previous_node_text: "on".into(),
}
);

@@ -308,6 +359,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "".into(),
node_range: TextRange::new(TextSize::new(72), TextSize::new(86)),
in_check_or_using_clause: false,
previous_node_kind: "".into(),
previous_node_range: TextRange::new(69.into(), 71.into()),
previous_node_text: "as".into(),
}
);

@@ -332,6 +387,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "".into(),
node_range: TextRange::new(TextSize::new(95), TextSize::new(109)),
in_check_or_using_clause: false,
previous_node_kind: "".into(),
previous_node_range: TextRange::new(72.into(), 82.into()),
previous_node_text: "permissive".into(),
}
);

@@ -356,6 +415,10 @@ mod tests {
node_text: "REPLACED_TOKEN".into(),
node_kind: "policy_role".into(),
node_range: TextRange::new(TextSize::new(98), TextSize::new(112)),
in_check_or_using_clause: false,
previous_node_kind: "keyword_to".into(),
previous_node_range: TextRange::new(95.into(), 97.into()),
previous_node_text: "to".into(),
}
);
}
@@ -383,7 +446,11 @@ mod tests {
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(57), TextSize::new(71)),
node_kind: "policy_table".into()
node_kind: "policy_table".into(),
in_check_or_using_clause: false,
previous_node_kind: "keyword_on".into(),
previous_node_range: TextRange::new(54.into(), 56.into()),
previous_node_text: "on".into(),
}
)
}
@@ -411,7 +478,11 @@ mod tests {
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(62), TextSize::new(76)),
node_kind: "policy_table".into()
node_kind: "policy_table".into(),
in_check_or_using_clause: false,
previous_node_kind: "keyword_on".into(),
previous_node_range: TextRange::new(54.into(), 56.into()),
previous_node_text: "on".into(),
}
)
}
@@ -436,7 +507,11 @@ mod tests {
statement_kind: PolicyStmtKind::Drop,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(23), TextSize::new(37)),
node_kind: "policy_name".into()
node_kind: "policy_name".into(),
in_check_or_using_clause: false,
previous_node_kind: "keyword_policy".into(),
previous_node_range: TextRange::new(16.into(), 22.into()),
previous_node_text: "policy".into(),
}
);

@@ -459,7 +534,11 @@ mod tests {
statement_kind: PolicyStmtKind::Drop,
node_text: "\"REPLACED_TOKEN\"".into(),
node_range: TextRange::new(TextSize::new(23), TextSize::new(39)),
node_kind: "policy_name".into()
node_kind: "policy_name".into(),
in_check_or_using_clause: false,
previous_node_kind: "keyword_policy".into(),
previous_node_range: TextRange::new(16.into(), 22.into()),
previous_node_text: "policy".into(),
}
);
}
@@ -477,4 +556,100 @@ mod tests {

assert_eq!(context, PolicyContext::default());
}

#[test]
fn correctly_determines_we_are_inside_checks() {
{
let (pos, query) = with_pos(format!(
r#"
create policy "my cool policy"
on auth.users
to all
using (id = {})
"#,
CURSOR_POS
));

let context = PolicyParser::get_context(query.as_str(), pos);

assert_eq!(
context,
PolicyContext {
policy_name: Some(r#""my cool policy""#.into()),
table_name: Some("users".into()),
schema_name: Some("auth".into()),
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(112), TextSize::new(126)),
node_kind: "".into(),
in_check_or_using_clause: true,
previous_node_kind: "=".into(),
previous_node_range: TextRange::new(110.into(), 111.into()),
previous_node_text: "=".into(),
}
);
}

{
let (pos, query) = with_pos(format!(
r#"
create policy "my cool policy"
on auth.users
to all
using ({}
"#,
CURSOR_POS
));

let context = PolicyParser::get_context(query.as_str(), pos);

assert_eq!(
context,
PolicyContext {
policy_name: Some(r#""my cool policy""#.into()),
table_name: Some("users".into()),
schema_name: Some("auth".into()),
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(106), TextSize::new(120)),
node_kind: "".into(),
in_check_or_using_clause: true,
previous_node_kind: "(".into(),
previous_node_range: TextRange::new(105.into(), 106.into()),
previous_node_text: "(".into(),
}
)
}

{
let (pos, query) = with_pos(format!(
r#"
create policy "my cool policy"
on auth.users
to all
with check ({}
"#,
CURSOR_POS
));

let context = PolicyParser::get_context(query.as_str(), pos);

assert_eq!(
context,
PolicyContext {
policy_name: Some(r#""my cool policy""#.into()),
table_name: Some("users".into()),
schema_name: Some("auth".into()),
statement_kind: PolicyStmtKind::Create,
node_text: "REPLACED_TOKEN".into(),
node_range: TextRange::new(TextSize::new(111), TextSize::new(125)),
node_kind: "".into(),
in_check_or_using_clause: true,
previous_node_kind: "(".into(),
previous_node_range: TextRange::new(110.into(), 111.into()),
previous_node_text: "(".into(),
}
)
}
}
}
48 changes: 48 additions & 0 deletions crates/pgt_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
@@ -817,4 +817,52 @@ mod tests {
.await;
}
}

#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
async fn suggests_columns_policy_using_clause(pool: PgPool) {
let setup = r#"
create table instruments (
id bigint primary key generated always as identity,
name text not null,
z text,
created_at timestamp with time zone default now()
);
"#;

pool.execute(setup).await.unwrap();

let col_queries = vec![
format!(
r#"create policy "my_pol" on public.instruments for select using ({})"#,
CURSOR_POS
),
format!(
r#"create policy "my_pol" on public.instruments for insert with check ({})"#,
CURSOR_POS
),
format!(
r#"create policy "my_pol" on public.instruments for update using (id = 1 and {})"#,
CURSOR_POS
),
format!(
r#"create policy "my_pol" on public.instruments for insert with check (id = 1 and {})"#,
CURSOR_POS
),
];

for query in col_queries {
assert_complete_results(
query.as_str(),
vec![
CompletionAssertion::Label("created_at".into()),
CompletionAssertion::Label("id".into()),
CompletionAssertion::Label("name".into()),
CompletionAssertion::Label("z".into()),
],
None,
&pool,
)
.await;
}
}
}
87 changes: 85 additions & 2 deletions crates/pgt_completions/src/providers/functions.rs
Original file line number Diff line number Diff line change
@@ -65,11 +65,14 @@ fn get_completion_text(ctx: &CompletionContext, func: &Function) -> CompletionTe

#[cfg(test)]
mod tests {
use sqlx::PgPool;
use sqlx::{Executor, PgPool};

use crate::{
CompletionItem, CompletionItemKind, complete,
test_helper::{CURSOR_POS, get_test_deps, get_test_params},
test_helper::{
CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_deps,
get_test_params,
},
};

#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
@@ -201,4 +204,84 @@ mod tests {
assert_eq!(label, "cool");
assert_eq!(kind, CompletionItemKind::Function);
}

#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
async fn only_allows_functions_and_procedures_in_policy_checks(pool: PgPool) {
let setup = r#"
create table coos (
id serial primary key,
name text
);
create or replace function my_cool_foo()
returns trigger
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
create or replace procedure my_cool_proc()
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
create or replace function string_concat_state(
state text,
value text,
separator text)
returns text
language plpgsql
as $$
begin
if state is null then
return value;
else
return state || separator || value;
end if;
end;
$$;
create aggregate string_concat(text, text) (
sfunc = string_concat_state,
stype = text,
initcond = ''
);
"#;

pool.execute(setup).await.unwrap();

let query = format!(
r#"create policy "my_pol" on public.instruments for insert with check (id = {})"#,
CURSOR_POS
);

assert_complete_results(
query.as_str(),
vec![
CompletionAssertion::LabelNotExists("string_concat".into()),
CompletionAssertion::LabelAndKind(
"my_cool_foo".into(),
CompletionItemKind::Function,
),
CompletionAssertion::LabelAndKind(
"my_cool_proc".into(),
CompletionItemKind::Function,
),
CompletionAssertion::LabelAndKind(
"string_concat_state".into(),
CompletionItemKind::Function,
),
],
None,
&pool,
)
.await;
}
}
24 changes: 18 additions & 6 deletions crates/pgt_completions/src/relevance/filtering.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use pgt_schema_cache::ProcKind;

use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause, WrappingNode};

use super::CompletionRelevanceData;
@@ -137,17 +139,27 @@ impl CompletionFilter<'_> {
&& ctx.parent_matches_one_of_kind(&["field"]))
}

WrappingClause::PolicyCheck => {
ctx.before_cursor_matches_kind(&["keyword_and", "("])
}

_ => false,
}
}

CompletionRelevanceData::Function(_) => matches!(
clause,
CompletionRelevanceData::Function(f) => match clause {
WrappingClause::From
| WrappingClause::Select
| WrappingClause::Where
| WrappingClause::Join { .. }
),
| WrappingClause::Select
| WrappingClause::Where
| WrappingClause::Join { .. } => true,

WrappingClause::PolicyCheck => {
ctx.before_cursor_matches_kind(&["="])
&& matches!(f.kind, ProcKind::Function | ProcKind::Procedure)
}

_ => false,
},

CompletionRelevanceData::Schema(_) => match clause {
WrappingClause::Select
28 changes: 25 additions & 3 deletions crates/pgt_completions/src/sanitization.rs
Original file line number Diff line number Diff line change
@@ -257,10 +257,15 @@ fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool {
.find(|c| !c.is_whitespace())
.unwrap_or_default();

let before_matches = before == ',' || before == '(';
let after_matches = after == ',' || after == ')';
// (.. and |)
let after_and_keyword = &sql[position.saturating_sub(4)..position] == "and " && after == ')';
let after_eq_sign = before == '=' && after == ')';

before_matches && after_matches
let head_of_list = before == '(' && after == ',';
let end_of_list = before == ',' && after == ')';
let between_list_items = before == ',' && after == ',';

head_of_list || end_of_list || between_list_items || after_and_keyword || after_eq_sign
}

#[cfg(test)]
@@ -444,5 +449,22 @@ mod tests {
"insert into instruments (name) values (a_function(name, ))",
TextSize::new(56)
));

// will sanitize after =
assert!(cursor_between_parentheses(
// create policy my_pol on users using (id = |),
"create policy my_pol on users using (id = )",
TextSize::new(42)
));

// will sanitize after and
assert!(cursor_between_parentheses(
// create policy my_pol on users using (id = 1 and |),
"create policy my_pol on users using (id = 1 and )",
TextSize::new(48)
));

// does not break if sql is really short
assert!(!cursor_between_parentheses("(a)", TextSize::new(2)));
}
}
2 changes: 1 addition & 1 deletion crates/pgt_schema_cache/src/lib.rs
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ mod types;
mod versions;

pub use columns::*;
pub use functions::{Behavior, Function, FunctionArg, FunctionArgs};
pub use functions::{Behavior, Function, FunctionArg, FunctionArgs, ProcKind};
pub use policies::{Policy, PolicyCommand};
pub use roles::*;
pub use schema_cache::SchemaCache;