diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs index fb7d61661..ce05d2e62 100644 --- a/rust/ruby-rbs/build.rs +++ b/rust/ruby-rbs/build.rs @@ -172,6 +172,15 @@ fn generate(config: &Config) -> Result<(), Box> { )?; writeln!(file, " }}")?; } + "rbs_hash" => { + writeln!(file, " pub fn {}(&self) -> RBSHash {{", field.name)?; + writeln!( + file, + " RBSHash::new(self.parser, unsafe {{ (*self.pointer).{} }})", + field.c_name() + )?; + writeln!(file, " }}")?; + } "rbs_location" => { writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?; writeln!( diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs index 86e0ff5a3..d92e767ca 100644 --- a/rust/ruby-rbs/src/lib.rs +++ b/rust/ruby-rbs/src/lib.rs @@ -91,6 +91,47 @@ impl NodeList { } } +pub struct RBSHash { + parser: *mut rbs_parser_t, + pointer: *mut rbs_hash, +} + +impl RBSHash { + pub fn new(parser: *mut rbs_parser_t, pointer: *mut rbs_hash) -> Self { + Self { parser, pointer } + } + + /// Returns an iterator over the key-value pairs. + #[must_use] + pub fn iter(&self) -> RBSHashIter { + RBSHashIter { + parser: self.parser, + current: unsafe { (*self.pointer).head }, + } + } +} + +pub struct RBSHashIter { + parser: *mut rbs_parser_t, + current: *mut rbs_hash_node_t, +} + +impl Iterator for RBSHashIter { + type Item = (Node, Node); + + fn next(&mut self) -> Option { + if self.current.is_null() { + None + } else { + let pointer_data = unsafe { *self.current }; + let key = unsafe { Node::new(self.parser, pointer_data.key) }; + let value = unsafe { Node::new(self.parser, pointer_data.value) }; + self.current = pointer_data.next; + Some((key, value)) + } + } +} + pub struct RBSLocation { pointer: *const rbs_location_t, #[allow(dead_code)] @@ -237,4 +278,52 @@ mod tests { panic!("No literal type node found"); } } + + #[test] + fn test_rbs_hash_via_record_type() { + // RecordType stores its fields in an RBSHash via all_fields() + let rbs_code = r#"type foo = { name: String, age: Integer }"#; + let signature = parse(rbs_code.as_bytes()); + assert!(signature.is_ok(), "Failed to parse RBS signature"); + + let signature_node = signature.unwrap(); + if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap() + && let Node::RecordType(record) = type_alias.type_() + { + let hash = record.all_fields(); + let fields: Vec<_> = hash.iter().collect(); + assert_eq!(fields.len(), 2, "Expected 2 fields in record"); + + // Build a map of field names to type names + let mut field_types: Vec<(String, String)> = Vec::new(); + for (key, value) in &fields { + let Node::Symbol(sym) = key else { + panic!("Expected Symbol key"); + }; + let Node::RecordFieldType(field_type) = value else { + panic!("Expected RecordFieldType value"); + }; + let Node::ClassInstanceType(class_type) = field_type.type_() else { + panic!("Expected ClassInstanceType"); + }; + + let key_name = String::from_utf8(sym.name().to_vec()).unwrap(); + let type_name_node = class_type.name(); + let type_name_sym = type_name_node.name(); + let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap(); + field_types.push((key_name, type_name)); + } + + assert!( + field_types.contains(&("name".to_string(), "String".to_string())), + "Expected 'name: String'" + ); + assert!( + field_types.contains(&("age".to_string(), "Integer".to_string())), + "Expected 'age: Integer'" + ); + } else { + panic!("Expected TypeAlias with RecordType"); + } + } }