diff --git a/ext/rbs_extension/main.c b/ext/rbs_extension/main.c index 8ab04615d..3209ea587 100644 --- a/ext/rbs_extension/main.c +++ b/ext/rbs_extension/main.c @@ -272,6 +272,58 @@ static VALUE rbsparser_parse_signature(VALUE self, VALUE buffer, VALUE start_pos return result; } + +struct parse_type_params_arg { + VALUE buffer; + rb_encoding *encoding; + rbs_parser_t *parser; + VALUE module_type_params; +}; + +static VALUE parse_type_params_try(VALUE a) { + struct parse_type_params_arg *arg = (struct parse_type_params_arg *)a; + rbs_parser_t *parser = arg->parser; + + if (parser->next_token.type == pEOF) { + return Qnil; + } + + rbs_node_list_t *params = NULL; + rbs_parse_type_params(parser, arg->module_type_params, ¶ms); + + raise_error_if_any(parser, arg->buffer); + + rbs_translation_context_t ctx = rbs_translation_context_create( + &parser->constant_pool, + arg->buffer, + arg->encoding + ); + + return rbs_node_list_to_ruby_array(ctx, params); +} + + +static VALUE rbsparser_parse_type_params(VALUE self, VALUE buffer, VALUE start_pos, VALUE end_pos, VALUE module_type_params) { + VALUE string = rb_funcall(buffer, rb_intern("content"), 0); + StringValue(string); + rb_encoding *encoding = rb_enc_get(string); + + rbs_parser_t *parser = alloc_parser_from_buffer(buffer, FIX2INT(start_pos), FIX2INT(end_pos)); + struct parse_type_params_arg arg = { + .buffer = buffer, + .encoding = encoding, + .parser = parser, + .module_type_params = module_type_params + }; + + VALUE result = rb_ensure(parse_type_params_try, (VALUE)&arg, ensure_free_parser, (VALUE)parser); + + RB_GC_GUARD(string); + + return result; +} + + static VALUE parse_inline_leading_annotation_try(VALUE a) { struct parse_type_arg *arg = (struct parse_type_arg *) a; rbs_parser_t *parser = arg->parser; @@ -391,6 +443,7 @@ void rbs__init_parser(void) { rb_define_singleton_method(RBS_Parser, "_parse_type", rbsparser_parse_type, 5); rb_define_singleton_method(RBS_Parser, "_parse_method_type", rbsparser_parse_method_type, 5); rb_define_singleton_method(RBS_Parser, "_parse_signature", rbsparser_parse_signature, 3); + rb_define_singleton_method(RBS_Parser, "_parse_type_params", rbsparser_parse_type_params, 4); rb_define_singleton_method(RBS_Parser, "_parse_inline_leading_annotation", rbsparser_parse_inline_leading_annotation, 4); rb_define_singleton_method(RBS_Parser, "_parse_inline_trailing_annotation", rbsparser_parse_inline_trailing_annotation, 4); rb_define_singleton_method(RBS_Parser, "_lex", rbsparser_lex, 2); diff --git a/include/rbs/parser.h b/include/rbs/parser.h index 0b974e94c..dc3f82ca0 100644 --- a/include/rbs/parser.h +++ b/include/rbs/parser.h @@ -130,6 +130,8 @@ bool rbs_parse_type(rbs_parser_t *parser, rbs_node_t **type); bool rbs_parse_method_type(rbs_parser_t *parser, rbs_method_type_t **method_type); bool rbs_parse_signature(rbs_parser_t *parser, rbs_signature_t **signature); +bool rbs_parse_type_params(rbs_parser_t *parser, bool module_type_params, rbs_node_list_t **params); + /** * Parse an inline leading annotation from a string. * diff --git a/lib/rbs/parser_aux.rb b/lib/rbs/parser_aux.rb index cb81cd204..79171faa3 100644 --- a/lib/rbs/parser_aux.rb +++ b/lib/rbs/parser_aux.rb @@ -35,6 +35,11 @@ def self.parse_signature(source) [buf, dirs, decls] end + def self.parse_type_params(source, module_type_params: true) + buf = buffer(source) + _parse_type_params(buf, 0, buf.last_position, module_type_params) + end + def self.magic_comment(buf) start_pos = 0 diff --git a/sig/parser.rbs b/sig/parser.rbs index c19c09de3..c405212d9 100644 --- a/sig/parser.rbs +++ b/sig/parser.rbs @@ -68,6 +68,24 @@ module RBS # def self.parse_signature: (Buffer | String) -> [Buffer, Array[AST::Directives::t], Array[AST::Declarations::t]] + # Parse a list of type parameters and return it + # + # ```ruby + # RBS::Parser.parse_type_params("") # => nil + # RBS::Parser.parse_type_params("[U, V]") # => `[:U, :V]` + # RBS::Parser.parse_type_params("[in U, V < Integer]") # => `[:U, :V]` + # ``` + # + # When `module_type_params` is `false`, an error is raised if `unchecked`, `in` or `out` are used. + # + # ```ruby + # RBS::Parser.parse_type_params("[unchecked U]", module_type_params: false) # => Raises an error + # RBS::Parser.parse_type_params("[out U]", module_type_params: false) # => Raises an error + # RBS::Parser.parse_type_params("[in U]", module_type_params: false) # => Raises an error + # ``` + # + def self.parse_type_params: (Buffer | String, ?module_type_params: bool) -> Array[AST::TypeParam] + # Returns the magic comment from the buffer # def self.magic_comment: (Buffer) -> AST::Directives::ResolveTypeNames? @@ -104,6 +122,8 @@ module RBS def self._parse_signature: (Buffer, Integer start_pos, Integer end_pos) -> [Array[AST::Directives::t], Array[AST::Declarations::t]] + def self._parse_type_params: (Buffer, Integer start_pos, Integer end_pos, bool module_type_params) -> Array[AST::TypeParam] + def self._lex: (Buffer, Integer end_pos) -> Array[[Symbol, Location[untyped, untyped]]] def self._parse_inline_leading_annotation: (Buffer, Integer start_pos, Integer end_pos, Array[Symbol] variables) -> AST::Ruby::Annotations::leading_annotation diff --git a/src/parser.c b/src/parser.c index 5b559e40a..9f2187bc4 100644 --- a/src/parser.c +++ b/src/parser.c @@ -3258,6 +3258,26 @@ bool rbs_parse_signature(rbs_parser_t *parser, rbs_signature_t **signature) { return true; } +bool rbs_parse_type_params(rbs_parser_t *parser, bool module_type_params, rbs_node_list_t **params) { + if (parser->next_token.type != pLBRACKET) { + rbs_parser_set_error(parser, parser->next_token, true, "expected a token `pLBRACKET`"); + return false; + } + + rbs_range_t rg = NULL_RANGE; + rbs_parser_push_typevar_table(parser, true); + bool res = parse_type_params(parser, &rg, module_type_params, params); + rbs_parser_push_typevar_table(parser, false); + + rbs_parser_advance(parser); + if (parser->current_token.type != pEOF) { + rbs_parser_set_error(parser, parser->current_token, true, "expected a token `%s`", rbs_token_type_str(pEOF)); + return false; + } + + return res; +} + id_table *alloc_empty_table(rbs_allocator_t *allocator) { id_table *table = rbs_allocator_alloc(allocator, id_table); diff --git a/test/rbs/parser_test.rb b/test/rbs/parser_test.rb index dd5bc1848..5c12249f4 100644 --- a/test/rbs/parser_test.rb +++ b/test/rbs/parser_test.rb @@ -820,6 +820,60 @@ def test_proc__untyped_function end end + def test_parse_type_params + RBS::Parser.parse_type_params(buffer("[T]")).tap do |params| + assert_equal 1, params.size + assert_equal :T, params[0].name + assert_nil params[0].upper_bound + end + + RBS::Parser.parse_type_params(buffer("[T < Integer, U = String]")).tap do |params| + assert_equal 2, params.size + assert_equal :T, params[0].name + assert_equal "Integer", params[0].upper_bound.to_s + assert_equal :U, params[1].name + assert_equal "String", params[1].default_type.to_s + end + + RBS::Parser.parse_type_params(buffer("[T, in U, out V]")).tap do |params| + assert_equal 3, params.size + assert_equal :T, params[0].name + assert_equal "invariant", params[0].variance.to_s + assert_equal :U, params[1].name + assert_equal "contravariant", params[1].variance.to_s + assert_equal :V, params[2].name + assert_equal "covariant", params[2].variance.to_s + end + + RBS::Parser.parse_type_params(buffer("[T, unchecked U, unchecked out V = Integer]")).tap do |params| + assert_equal 3, params.size + assert_equal :T, params[0].name + refute params[0].unchecked? + assert_equal :U, params[1].name + assert params[1].unchecked? + assert_equal :V, params[2].name + assert params[2].unchecked? + assert_equal "covariant", params[2].variance.to_s + assert_equal "Integer", params[2].default_type.to_s + end + + assert_raises RBS::ParsingError do + RBS::Parser.parse_type_params(buffer("[]")) + end + + assert_raises RBS::ParsingError do + RBS::Parser.parse_type_params(buffer("[T]A")) + end + + assert_raises RBS::ParsingError do + RBS::Parser.parse_type_params(buffer("[in T]"), module_type_params: false) + end + + assert_raises RBS::ParsingError do + RBS::Parser.parse_type_params(buffer("[unchecked T]"), module_type_params: false) + end + end + def test__lex content = <<~RBS # LineComment