Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 218 additions & 8 deletions src/discover/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use lazy_static::lazy_static;
use regex::{Regex, RegexSet};

use super::lexer::{split_on_operators, tokenize, TokenKind};
use super::lexer::{shell_split, split_on_operators, tokenize, TokenKind};
use super::rules::{IGNORED_EXACT, IGNORED_PREFIXES, RULES};

/// Result of classifying a command.
Expand Down Expand Up @@ -498,6 +498,10 @@ pub fn rewrite_command(
return None;
}

if has_input_redirect(trimmed) {
return None;
}

let compiled = compile_exclude_patterns(excluded);
let normalized_prefixes = normalize_transparent_prefixes(transparent_prefixes);

Expand Down Expand Up @@ -558,11 +562,19 @@ fn rewrite_compound(
}
TokenKind::Pipe => {
let seg = cmd[seg_start..tok.offset].trim();
let pipe_group_end = tokens.iter().find(|t| {
t.offset > tok.offset
&& (t.kind == TokenKind::Operator
|| (t.kind == TokenKind::Shellism && t.value == "&"))
});
let pipe_tail_end = pipe_group_end.map(|t| t.offset).unwrap_or(cmd.len());
let pipe_tail = cmd[tok.offset + tok.value.len()..pipe_tail_end].trim();
let has_stdin_consumer = pipe_tail_has_stdin_sensitive_consumer(pipe_tail);
let is_pipe_incompatible = seg.starts_with("find ")
|| seg == "find"
|| seg.starts_with("fd ")
|| seg == "fd";
let rewritten = if is_pipe_incompatible {
let rewritten = if is_pipe_incompatible || has_stdin_consumer {
seg.to_string()
} else {
rewrite_segment(seg, excluded, transparent_prefixes)
Expand All @@ -573,12 +585,6 @@ fn rewrite_compound(
}
result.push_str(&rewritten);

let pipe_group_end = tokens.iter().find(|t| {
t.offset > tok.offset
&& (t.kind == TokenKind::Operator
|| (t.kind == TokenKind::Shellism && t.value == "&"))
});

match pipe_group_end {
Some(next_op) => {
result.push(' ');
Expand Down Expand Up @@ -625,6 +631,129 @@ fn rewrite_compound(
}
}

fn has_input_redirect(cmd: &str) -> bool {
tokenize(cmd)
.iter()
.any(|t| t.kind == TokenKind::Redirect && t.value.starts_with('<'))
}

fn pipe_tail_has_stdin_sensitive_consumer(tail: &str) -> bool {
split_on_operators(tail, false)
.iter()
.any(|seg| is_stdin_sensitive_segment(seg))
}

fn is_stdin_sensitive_segment(seg: &str) -> bool {
is_stdin_sensitive_segment_inner(seg, 0)
}

fn is_stdin_sensitive_segment_inner(seg: &str, depth: usize) -> bool {
let trimmed = seg.trim();
if trimmed.is_empty() || depth >= MAX_PREFIX_DEPTH {
return false;
}

let (_env_prefix, rest_after_env) = strip_disabled_prefix(trimmed);
if rest_after_env != trimmed {
return is_stdin_sensitive_segment_inner(rest_after_env, depth + 1);
}

for &prefix in BUILTIN_TRANSPARENT_PREFIXES {
if let Some(rest) = strip_word_prefix(trimmed, prefix) {
return is_stdin_sensitive_segment_inner(rest, depth + 1);
}
}

let (cmd_part, _redirect_suffix) = strip_trailing_redirects(trimmed);
let args = shell_split(cmd_part);
if args.is_empty() {
return false;
}

let command = command_basename(&args[0]).to_ascii_lowercase();
match command.as_str() {
"kubectl" | "oc" => kubectl_args_read_stdin(&args),
"docker" | "podman" => docker_args_read_stdin(&args),
"git" | "yadm" => git_args_read_stdin(&args),
"wrangler" => wrangler_args_read_stdin(&args),
_ => false,
}
}

fn command_basename(command: &str) -> &str {
let base = command.rsplit(['/', '\\']).next().unwrap_or(command);
if base.len() >= 4 && base[base.len() - 4..].eq_ignore_ascii_case(".exe") {
&base[..base.len() - 4]
} else {
base
}
}

fn kubectl_args_read_stdin(args: &[String]) -> bool {
let Some(subcommand_idx) = args
.iter()
.position(|arg| matches!(arg.as_str(), "apply" | "create" | "replace" | "delete"))
else {
return false;
};
has_filename_stdin_arg(&args[subcommand_idx + 1..])
}

fn has_filename_stdin_arg(args: &[String]) -> bool {
let mut iter = args.iter().peekable();
while let Some(arg) = iter.next() {
match arg.as_str() {
"-f" | "--filename" => {
if iter.peek().is_some_and(|next| next.as_str() == "-") {
return true;
}
}
"-f-" | "--filename=-" => return true,
_ => {}
}
}
false
}

fn docker_args_read_stdin(args: &[String]) -> bool {
let Some(build_idx) = args.iter().position(|arg| arg == "build") else {
return false;
};
args[build_idx + 1..].iter().any(|arg| arg == "-")
}

fn git_args_read_stdin(args: &[String]) -> bool {
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"-C" | "-c" | "--git-dir" | "--work-tree" => i += 2,
"--no-pager" | "--no-optional-locks" | "--bare" | "--literal-pathspecs" => i += 1,
arg if arg.starts_with("--git-dir=") || arg.starts_with("--work-tree=") => i += 1,
arg if arg.starts_with('-') => i += 1,
_ => break,
}
}

if args.get(i).map(String::as_str) != Some("apply") {
return false;
}

let rest = &args[i + 1..];
rest.is_empty()
|| rest.iter().any(|arg| arg == "-")
|| rest.iter().all(|arg| arg.starts_with('-'))
}

fn wrangler_args_read_stdin(args: &[String]) -> bool {
matches!(
(
args.get(1).map(String::as_str),
args.get(2).map(String::as_str)
),
(Some("secret"), Some("put" | "bulk"))
)
}

fn rewrite_line_range(cmd: &str) -> Option<String> {
for re in [&*HEAD_N, &*HEAD_LINES] {
if let Some(caps) = re.captures(cmd) {
Expand Down Expand Up @@ -791,6 +920,10 @@ fn rewrite_segment_inner(
return Some(trimmed.to_string());
}

if is_stdin_sensitive_segment(cmd_part) {
return None;
}

if cmd_part.starts_with("head -") || cmd_part.starts_with("tail ") {
return rewrite_line_range(cmd_part).map(|r| format!("{}{}", r, redirect_suffix));
}
Expand Down Expand Up @@ -1523,6 +1656,83 @@ mod tests {
);
}

#[test]
fn test_rewrite_kubectl_apply_stdin_skipped() {
assert_eq!(rewrite_command_no_prefixes("kubectl apply -f -", &[]), None);
}

#[test]
fn test_rewrite_cat_pipe_kubectl_apply_stdin_skipped() {
assert_eq!(
rewrite_command_no_prefixes("cat manifest.yaml | kubectl apply -f -", &[]),
None
);
}

#[test]
fn test_rewrite_cat_pipe_wrangler_secret_bulk_skipped() {
assert_eq!(
rewrite_command_no_prefixes(
"cat secret.json | wrangler secret bulk --name worker",
&[]
),
None
);
}

#[test]
fn test_rewrite_wrangler_secret_put_stdin_skipped() {
assert_eq!(
rewrite_command_no_prefixes("wrangler secret put MY_KEY --name worker", &[]),
None
);
assert_eq!(
rewrite_command_no_prefixes(
"cat secret.txt | wrangler secret put MY_KEY --name worker",
&[]
),
None
);
}

#[test]
fn test_rewrite_input_redirect_skipped() {
assert_eq!(
rewrite_command_no_prefixes("psql postgres://db < dump.sql", &[]),
None
);
}

#[test]
fn test_rewrite_git_apply_stdin_skipped() {
assert_eq!(rewrite_command_no_prefixes("git apply", &[]), None);
assert_eq!(rewrite_command_no_prefixes("git apply -", &[]), None);
}

#[test]
fn test_rewrite_cat_pipe_docker_build_stdin_skipped() {
assert_eq!(
rewrite_command_no_prefixes("cat Dockerfile | docker build -", &[]),
None
);
}

#[test]
fn test_rewrite_kubectl_apply_file_still_rewritten() {
assert_eq!(
rewrite_command_no_prefixes("kubectl apply -f manifest.yaml", &[]),
Some("rtk kubectl apply -f manifest.yaml".into())
);
}

#[test]
fn test_rewrite_docker_build_context_still_rewritten() {
assert_eq!(
rewrite_command_no_prefixes("docker build .", &[]),
Some("rtk docker build .".into())
);
}

#[test]
fn test_rewrite_heredoc_returns_none() {
assert_eq!(
Expand Down
Loading