diff --git a/lua/rest-nvim/parser/curl.lua b/lua/rest-nvim/parser/curl.lua new file mode 100644 index 00000000..9be2d941 --- /dev/null +++ b/lua/rest-nvim/parser/curl.lua @@ -0,0 +1,140 @@ +---@mod rest-nvim.parser.curl rest.nvim curl parsing module +--- +---@brief [[ +--- +--- rest.nvim curl command parsing module +--- rest.nvim uses `tree-sitter-bash` as a core parser to parse raw curl commands +--- +---@brief ]] + +local curl_parser = {} + +local utils = require("rest-nvim.utils") +local logger = require("rest-nvim.logger") + +---@param node TSNode Tree-sitter request node +---@param source Source +function curl_parser.parse_command(node, source) + assert(node:type() == "command") + assert(utils.ts_field_text(node, "name", source) == "curl") + local arg_nodes = node:field("argument") + if #arg_nodes < 1 then + logger.error("can't parse curl command with 0 arguments") + return + end + local args = {} + for _, arg_node in ipairs(arg_nodes) do + local arg_type = arg_node:type() + if arg_type == "word" then + table.insert(args, vim.treesitter.get_node_text(arg_node, source)) + elseif arg_type == "raw_string" then + -- FIXME: expand escaped sequences like `\n` + table.insert(args, vim.treesitter.get_node_text(arg_node, source):sub(2, -2)) + else + logger.error(("can't parse argument type: '%s'"):format(arg_type)) + return + end + end + return args +end + +-- -X, --request +-- The request method to use. +-- -H, --header +-- The request header to include in the request. +-- -u, --user | --basic | --digest +-- The user's credentials to be provided with the request, and the authorization method to use. +-- -d, --data, --data-ascii | --data-binary | --data-raw | --data-urlencode +-- The data to be sent in a POST request. +-- -F, --form +-- The multipart/form-data message to be sent in a POST request. +-- --url +-- The URL to fetch (mostly used when specifying URLs in a config file). +-- -i, --include +-- Defines whether the HTTP response headers are included in the output. +-- -v, --verbose +-- Enables the verbose operating mode. +-- -L, --location +-- Enables resending the request in case the requested page has moved to a different location. + +---@param args string[] +function curl_parser.parse_arguments(args) + local iter = vim.iter(args) + ---@type rest.Request + local req = { + -- TODO: add this to rest.Request type + meta = { + redirect = false, + }, + url = "", + method = "GET", + headers = {}, + cookies = {}, + handlers = {}, + } + local function any(value, list) + return vim.list_contains(list, value) + end + while true do + local arg = iter:next() + if not arg then + break + end + if any(arg, { "-X", "--request" }) then + req.method = iter:next() + elseif any(arg, { "-H", "--header" }) then + local pair = iter:next() + local key, value = pair:match("(%S+):%s*(.*)") + if not key then + logger.error("can't parse header:" .. pair) + else + key = key:lower() + req.headers[key] = req.headers[key] or {} + if value then + table.insert(req.headers[key], value) + end + end + -- TODO: handle more arguments + -- elseif any(arg, { "-u", "--user" }) then + -- elseif arg == "--basic" then + -- elseif arg == "--digest" then + elseif any(arg, { "-d", "--data", "--data-ascii", "--data-raw" }) then + -- handle external body with `@` syntax + local body = iter:next() + if arg ~= "--data-raw" and body:sub(1, 1) == "@" then + req.body = { + __TYPE = "external", + data = { + name = "", + path = body:sub(2), + }, + } + else + req.body = { + __TYPE = "raw", + data = body + } + end + -- elseif arg == "--data-binary" then + -- elseif any(arg, { "-F", "--form" }) then + elseif arg == "--url" then + req.url = iter:next() + elseif any(arg, { "-L", "--location" }) then + req.meta.redirect = true + elseif arg:match("^-%a+$") then + local flags_iter = vim.gsplit(arg:sub(2), "") + for flag in flags_iter do + if flag == "L" then + req.meta.redirect = true + end + end + elseif req.url == "" and not vim.startswith(arg, "-") then + req.url = arg + else + logger.warn("unknown argument: " .. arg) + end + end + return req +end + +return curl_parser diff --git a/lua/rest-nvim/parser/init.lua b/lua/rest-nvim/parser/init.lua index 5550c9be..69ab1b0a 100644 --- a/lua/rest-nvim/parser/init.lua +++ b/lua/rest-nvim/parser/init.lua @@ -32,15 +32,6 @@ local NAMED_REQUEST_QUERY = vim.treesitter.query.parse( ]] ) ----@param node TSNode ----@param field string ----@param source Source ----@return string|nil -local function get_node_field_text(node, field, source) - local n = node:field(field)[1] - return n and vim.treesitter.get_node_text(n, source) or nil -end - ---@param src string ---@param context rest.Context ---@return string @@ -63,8 +54,8 @@ local function parse_headers(req_node, source, context) end) local header_nodes = req_node:field("header") for _, node in ipairs(header_nodes) do - local key = assert(get_node_field_text(node, "name", source)) - local value = get_node_field_text(node, "value", source) + local key = assert(utils.ts_field_text(node, "name", source)) + local value = utils.ts_field_text(node, "value", source) key = expand_variables(key, context):lower() if value then value = expand_variables(value, context) @@ -106,6 +97,7 @@ local function parse_urlencoded_form(str) logger.error(("Error while parsing query '%s' from urlencoded form '%s'"):format(query_pairs, str)) return nil end + -- TODO: encode value here return vim.trim(key) .. "=" .. vim.trim(value) end) :join("&") @@ -122,7 +114,7 @@ function parser.parse_body(content_type, body_node, source, context) ---@cast body rest.Request.Body if node_type == "external_body" then body.__TYPE = "external" - local path = assert(get_node_field_text(body_node, "path", source)) + local path = assert(utils.ts_field_text(body_node, "path", source)) if type(source) ~= "number" then logger.error("can't parse external body on non-existing http file") return @@ -133,7 +125,7 @@ function parser.parse_body(content_type, body_node, source, context) basepath = basepath:gsub("^" .. vim.pesc(vim.uv.cwd() .. "/"), "") path = vim.fs.normalize(vim.fs.joinpath(basepath, path)) body.data = { - name = get_node_field_text(body_node, "name", source), + name = utils.ts_field_text(body_node, "name", source), path = path, } elseif node_type == "json_body" or content_type == "application/json" then @@ -217,7 +209,7 @@ end ---@param source Source ---@return TSNode[] function parser.get_all_request_nodes(source) - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local result = {} for node, _ in tree:root():iter_children() do if node:type() == "section" and #node:field("request") > 0 then @@ -230,7 +222,7 @@ end ---@return TSNode? function parser.get_request_node_by_name(name) local source = 0 - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local query = NAMED_REQUEST_QUERY for id, node, _metadata, _match in query:iter_captures(tree:root(), source) do local capture_name = query.captures[id] @@ -248,8 +240,8 @@ end ---@param ctx rest.Context function parser.parse_variable_declaration(vd_node, source, ctx) vim.validate({ node = utils.ts_node_spec(vd_node, "variable_declaration") }) - local name = assert(get_node_field_text(vd_node, "name", source)) - local value = vim.trim(assert(get_node_field_text(vd_node, "value", source))) + local name = assert(utils.ts_field_text(vd_node, "name", source)) + local value = vim.trim(assert(utils.ts_field_text(vd_node, "value", source))) value = expand_variables(value, ctx) ctx:set_global(name, value) end @@ -261,8 +253,8 @@ end local function parse_script(node, source) local lang = "javascript" local prev_node = utils.ts_upper_node(node) - if prev_node and prev_node:type() == "comment" and get_node_field_text(prev_node, "name", source) == "lang" then - local value = get_node_field_text(prev_node, "value", source) + if prev_node and prev_node:type() == "comment" and utils.ts_field_text(prev_node, "name", source) == "lang" then + local value = utils.ts_field_text(prev_node, "value", source) if value then lang = value end @@ -304,7 +296,7 @@ end ---@param source Source ---@return string[] function parser.get_request_names(source) - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local query = NAMED_REQUEST_QUERY local result = {} for id, node, _metadata, _match in query:iter_captures(tree:root(), source) do @@ -365,7 +357,7 @@ function parser.parse(node, source, ctx) local start_row = node:range() parser.eval_context(source, ctx, start_row) end - local method = get_node_field_text(req_node, "method", source) + local method = utils.ts_field_text(req_node, "method", source) if not method then logger.info("no method provided, falling back to 'GET'") method = "GET" @@ -379,7 +371,7 @@ function parser.parse(node, source, ctx) for child, _ in node:iter_children() do local child_type = child:type() if child_type == "request" then - url = expand_variables(assert(get_node_field_text(req_node, "url", source)), ctx) + url = expand_variables(assert(utils.ts_field_text(req_node, "url", source)), ctx) url = url:gsub("\n%s+", "") elseif child_type == "pre_request_script" then parser.parse_pre_request_script(child, source, ctx) @@ -390,9 +382,9 @@ function parser.parse(node, source, ctx) table.insert(handlers, handler) end elseif child_type == "request_separator" then - name = get_node_field_text(child, "value", source) - elseif child_type == "comment" and get_node_field_text(child, "name", source) == "name" then - name = get_node_field_text(child, "value", source) or name + name = utils.ts_field_text(child, "value", source) + elseif child_type == "comment" and utils.ts_field_text(child, "name", source) == "name" then + name = utils.ts_field_text(child, "value", source) or name elseif child_type == "variable_declaration" then parser.parse_variable_declaration(child, source, ctx) end @@ -455,7 +447,7 @@ function parser.parse(node, source, ctx) name = name, method = method, url = url, - http_version = get_node_field_text(req_node, "version", source), + http_version = utils.ts_field_text(req_node, "version", source), headers = headers, cookies = {}, body = body, diff --git a/lua/rest-nvim/utils.lua b/lua/rest-nvim/utils.lua index 0071c4aa..cefbbf5c 100644 --- a/lua/rest-nvim/utils.lua +++ b/lua/rest-nvim/utils.lua @@ -7,7 +7,6 @@ ---@brief ]] local logger = require("rest-nvim.logger") --- local config = require("rest-nvim.config") local utils = {} @@ -86,11 +85,11 @@ end function utils.parse_http_time(time_str) local pattern = "(%a+), (%d+)[%s-](%a+)[%s-](%d+) (%d+):(%d+):(%d+) GMT" local _, day, month_name, year, hour, min, sec = time_str:match(pattern) - -- stylua: ignore - local months = { - Jan = 1, Feb = 2, Mar = 3, Apr = 4, May = 5, Jun = 6, - Jul = 7, Aug = 8, Sep = 9, Oct = 10, Nov = 11, Dec = 12, - } + -- stylua: ignore + local months = { + Jan = 1, Feb = 2, Mar = 3, Apr = 4, May = 5, Jun = 6, + Jul = 7, Aug = 8, Sep = 9, Oct = 10, Nov = 11, Dec = 12, + } local time_table = { year = tonumber(year), month = months[month_name], @@ -186,20 +185,21 @@ function utils.ts_highlight_node(bufnr, node, ns, timeout) end ---@param source string|integer +---@param lang string ---@return vim.treesitter.LanguageTree -function utils.ts_get_parser(source) +function utils.ts_get_parser(source, lang) if type(source) == "string" then - return vim.treesitter.get_string_parser(source, "http") + return vim.treesitter.get_string_parser(source, lang) else - return vim.treesitter.get_parser(source, "http") + return vim.treesitter.get_parser(source, lang) end end ---@param source string|integer ---@return vim.treesitter.LanguageTree ---@return TSTree -function utils.ts_parse_source(source) - local ts_parser = utils.ts_get_parser(source) +function utils.ts_parse_source(source, lang) + local ts_parser = utils.ts_get_parser(source, lang) return ts_parser, assert(ts_parser:parse(false)[1]) end @@ -238,6 +238,15 @@ function utils.ts_upper_node(node) return min_node end +---@param node TSNode +---@param field string +---@param source Source +---@return string|nil +function utils.ts_field_text(node, field, source) + local n = node:field(field)[1] + return n and vim.treesitter.get_node_text(n, source) or nil +end + ---@param node TSNode ---@param expected_type string ---@return table diff --git a/spec/examples/examples_spec.lua b/spec/examples/examples_spec.lua index 46f84c2f..ed217b3a 100644 --- a/spec/examples/examples_spec.lua +++ b/spec/examples/examples_spec.lua @@ -14,7 +14,7 @@ end describe("multi-line-url", function() it("line breaks should be ignored", function() local source = open("spec/examples/multi_line_url.http") - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) local req = parser.parse(req_node, source) assert.not_nil(req) @@ -108,7 +108,7 @@ describe("builtin request hooks", function() describe("set_content_type", function() it("with external body", function() local source = open("spec/examples/post_with_external_body.http") - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) local req = assert(parser.parse(req_node, source)) _G.rest_request = req diff --git a/spec/parser/curl_parser_spec.lua b/spec/parser/curl_parser_spec.lua new file mode 100644 index 00000000..6172e287 --- /dev/null +++ b/spec/parser/curl_parser_spec.lua @@ -0,0 +1,36 @@ +---@module 'luassert' + +require("spec.minimal_init") + +local parser = require("rest-nvim.parser.curl") +local utils = require("rest-nvim.utils") + +describe("curl cli parser", function() + it("parse curl command", function() + local source = [[ + curl -sSL -X POST https://example.com \ + -H 'Content-Type: application/json' \ + -d '{ "foo": 123 }' + ]] + local _, tree = utils.ts_parse_source(source, "bash") + local curl_node = assert(tree:root():child(0)) + local args = parser.parse_command(curl_node, source) + assert(args) + assert.same({ + method = "POST", + url = "https://example.com", + headers = { + ["content-type"] = { "application/json" }, + }, + body = { + __TYPE = "raw", + data = '{ "foo": 123 }', + }, + meta = { + redirect = true, + }, + cookies = {}, + handlers = {}, + }, parser.parse_arguments(args)) + end) +end) diff --git a/spec/parser/http_parser_spec.lua b/spec/parser/http_parser_spec.lua index 4e4cf33d..25833b63 100644 --- a/spec/parser/http_parser_spec.lua +++ b/spec/parser/http_parser_spec.lua @@ -23,7 +23,7 @@ describe("http parser", function() end) it("parse from http string", function() local source = "GET https://github.com\n" - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) assert.same({ method = "GET", @@ -35,7 +35,7 @@ describe("http parser", function() end) it("parse from http file", function() local source = open("spec/examples/basic_get.http") - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) assert.same({ name = "basic get statement", @@ -66,7 +66,7 @@ GET http://localhost:80 GET /some/path HOST: localhost:8000 ]] - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) local req = assert(parser.parse(req_node, source)) assert.same("http://localhost:8000/some/path", req.url) @@ -78,7 +78,7 @@ X-Header1: value1 X-Header2: X-Header1: value2 ]] - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) assert.same({ url = "http://example.com/api", @@ -95,7 +95,7 @@ X-Header1: value2 describe("parse body", function() it("json body", function() local source = 'POST https://example.com\n\n{\n\t"blah": 1}\n' - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) assert.same({ method = "POST", @@ -111,7 +111,7 @@ X-Header1: value2 end) it("invalid json body", function() local source = 'POST https://example.com\n\n{\n\t"blah": 1\n' - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) local spy_log_warn = spy.on(logger, "warn") parser.parse(req_node, source) @@ -127,7 +127,7 @@ X-Header1: value2 password ]] - local _, tree = utils.ts_parse_source(source) + local _, tree = utils.ts_parse_source(source, "http") local req_node = assert(tree:root():child(0)) assert.same({ method = "POST", @@ -148,7 +148,7 @@ X-Header1: value2 it("parse invalid xml", function() logger.info("hi") local source = "POST https://example.com\n\n