diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 39821cd..ee123a7 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1201,9 +1201,9 @@ class DictExpr : public Expression { class SliceExpr : public Expression { public: - std::shared_ptr start, end; - SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e) - : Expression(loc), start(std::move(s)), end(std::move(e)) {} + std::shared_ptr start, end, step; + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e, std::shared_ptr && st = nullptr) + : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1220,18 +1220,35 @@ class SubscriptExpr : public Expression { if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + auto len = target_value.size(); + auto wrap = [len](int64_t i) -> int64_t { + if (i < 0) { + return i + len; + } + return i; + }; + int64_t step = slice->step ? slice->step->evaluate(context).get() : 1; + if (!step) { + throw std::runtime_error("slice step cannot be zero"); + } + int64_t start = slice->start ? wrap(slice->start->evaluate(context).get()) : (step < 0 ? len - 1 : 0); + int64_t end = slice->end ? wrap(slice->end->evaluate(context).get()) : (step < 0 ? -1 : len); if (target_value.is_string()) { std::string s = target_value.get(); - if (start < 0) start = s.size() + start; - if (end < 0) end = s.size() + end; - return s.substr(start, end - start); - } else if (target_value.is_array()) { - if (start < 0) start = target_value.size() + start; - if (end < 0) end = target_value.size() + end; + + std::string result; + if (start < end && step == 1) { + result = s.substr(start, end - start); + } else { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { + result += s[i]; + } + } + return result; + + } else if (target_value.is_array()) { auto result = Value::array(); - for (auto i = start; i < end; ++i) { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { result.push_back(target_value.at(i)); } return result; @@ -1523,6 +1540,10 @@ class MethodCallExpr : public Expression { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "startswith") { + vargs.expectArgs("startswith method", {1, 1}, {0, 0}); + auto prefix = vargs.args[0].get(); + return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin()); } else if (method->get_name() == "title") { vargs.expectArgs("title method", {0, 0}, {0, 0}); auto res = str; @@ -2085,28 +2106,37 @@ class Parser { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { - std::shared_ptr index; + std::shared_ptr index; + auto slice_loc = get_location(); + std::shared_ptr start, end, step; + bool has_first_colon = false, has_second_colon = false; + + if (!peekSymbols({ ":" })) { + start = parseExpression(); + } + + if (!consumeToken(":").empty()) { + has_first_colon = true; + if (!peekSymbols({ ":", "]" })) { + end = parseExpression(); + } if (!consumeToken(":").empty()) { - auto slice_end = parseExpression(); - index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); - } else { - auto slice_start = parseExpression(); - if (!consumeToken(":").empty()) { - consumeSpaces(); - if (peekSymbols({ "]" })) { - index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); - } else { - auto slice_end = parseExpression(); - index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); - } - } else { - index = std::move(slice_start); + has_second_colon = true; + if (!peekSymbols({ "]" })) { + step = parseExpression(); } } - if (!index) throw std::runtime_error("Empty index in subscript"); - if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + } + + if ((has_first_colon || has_second_colon) && (start || end || step)) { + index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); + } else { + index = std::move(start); + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - value = std::make_shared(value->location, std::move(value), std::move(index)); + value = std::make_shared(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier in subscript"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index aa3a756..09323b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -318,6 +318,7 @@ set(MODEL_IDS ValiantLabs/Llama3.1-8B-Enigma xwen-team/Xwen-72B-Chat xwen-team/Xwen-7B-Chat + Qwen/Qwen3-4B # Broken, TODO: # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index a5a2707..a628aa2 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -184,6 +184,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "1", render(R"({{ 1 | safe }})", {}, {})); + EXPECT_EQ( + "True,False", + render(R"({{ 'abc'.startswith('ab') }},{{ ''.startswith('a') }})", {}, {})); EXPECT_EQ( "True,False", render(R"({{ 'abc'.endswith('bc') }},{{ ''.endswith('a') }})", {}, {})); @@ -477,6 +480,15 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "[1, 2, 3][0, 1][1, 2]", render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {})); + EXPECT_EQ( + "123;01;12", + render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }}", {}, {})); + EXPECT_EQ( + "[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1][0, 2][3, 1][2, 0]", + render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}{{ x[::2] }}{{ x[::-2] }}{{ x[-2::-2] }}", {}, {})); + EXPECT_EQ( + "3210;321;210;21;02;31;20", + render("{% set x = '0123' %}{{ x[::-1] }};{{ x[:0:-1] }};{{ x[2::-1] }};{{ x[2:0:-1] }};{{ x[::2] }};{{ x[::-2] }};{{ x[-2::-2] }}", {}, {})); EXPECT_EQ( "a", render("{{ ' a ' | trim }}", {}, {}));