Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions rust/ruby-rbs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
)?;
writeln!(file, " }}")?;
}
"rbs_hash" => {
writeln!(file, " pub fn {}(&self) -> RBSHash {{", field.name)?;
writeln!(
file,
" RBSHash::new(self.parser, unsafe {{ (*self.pointer).{} }})",
field.c_name()
)?;
writeln!(file, " }}")?;
}
"rbs_location" => {
writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?;
writeln!(
Expand Down
89 changes: 89 additions & 0 deletions rust/ruby-rbs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,47 @@ impl NodeList {
}
}

pub struct RBSHash {
parser: *mut rbs_parser_t,
pointer: *mut rbs_hash,
}

impl RBSHash {
pub fn new(parser: *mut rbs_parser_t, pointer: *mut rbs_hash) -> Self {
Self { parser, pointer }
}

/// Returns an iterator over the key-value pairs.
#[must_use]
pub fn iter(&self) -> RBSHashIter {
RBSHashIter {
parser: self.parser,
current: unsafe { (*self.pointer).head },
}
}
}

pub struct RBSHashIter {
parser: *mut rbs_parser_t,
current: *mut rbs_hash_node_t,
}

impl Iterator for RBSHashIter {
type Item = (Node, Node);

fn next(&mut self) -> Option<Self::Item> {
if self.current.is_null() {
None
} else {
let pointer_data = unsafe { *self.current };
let key = unsafe { Node::new(self.parser, pointer_data.key) };
let value = unsafe { Node::new(self.parser, pointer_data.value) };
self.current = pointer_data.next;
Some((key, value))
}
}
}

pub struct RBSLocation {
pointer: *const rbs_location_t,
#[allow(dead_code)]
Expand Down Expand Up @@ -237,4 +278,52 @@ mod tests {
panic!("No literal type node found");
}
}

#[test]
fn test_rbs_hash_via_record_type() {
// RecordType stores its fields in an RBSHash via all_fields()
let rbs_code = r#"type foo = { name: String, age: Integer }"#;
let signature = parse(rbs_code.as_bytes());
assert!(signature.is_ok(), "Failed to parse RBS signature");

let signature_node = signature.unwrap();
if let Node::TypeAlias(type_alias) = signature_node.declarations().iter().next().unwrap()
&& let Node::RecordType(record) = type_alias.type_()
{
let hash = record.all_fields();
let fields: Vec<_> = hash.iter().collect();
assert_eq!(fields.len(), 2, "Expected 2 fields in record");

// Build a map of field names to type names
let mut field_types: Vec<(String, String)> = Vec::new();
for (key, value) in &fields {
let Node::Symbol(sym) = key else {
panic!("Expected Symbol key");
};
let Node::RecordFieldType(field_type) = value else {
panic!("Expected RecordFieldType value");
};
let Node::ClassInstanceType(class_type) = field_type.type_() else {
panic!("Expected ClassInstanceType");
};

let key_name = String::from_utf8(sym.name().to_vec()).unwrap();
let type_name_node = class_type.name();
let type_name_sym = type_name_node.name();
let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap();
field_types.push((key_name, type_name));
}

assert!(
field_types.contains(&("name".to_string(), "String".to_string())),
"Expected 'name: String'"
);
assert!(
field_types.contains(&("age".to_string(), "Integer".to_string())),
"Expected 'age: Integer'"
);
} else {
panic!("Expected TypeAlias with RecordType");
}
}
}
Loading