Skip to content

Commit db1c65d

Browse files
committed
Add yaml module
1 parent f3fdad6 commit db1c65d

File tree

5 files changed

+372
-0
lines changed

5 files changed

+372
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ vendored = ["mlua/vendored"]
1818

1919
json = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_json"]
2020
regex = ["dep:regex", "dep:ouroboros", "dep:quick_cache"]
21+
yaml = ["mlua/serde", "dep:ouroboros", "dep:serde", "dep:serde_yaml"]
2122

2223
[dependencies]
2324
mlua = { version = "0.11" }
2425
ouroboros = { version = "0.18", optional = true }
2526
serde = { version = "1.0", optional = true }
2627
serde_json = { version = "1.0", optional = true }
28+
serde_yaml = { version = "0.9", optional = true }
2729
owo-colors = "4"
2830
regex = { version = "1.0", optional = true }
2931
quick_cache = { version = "0.6", optional = true }

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ pub mod testing;
1717
pub mod json;
1818
#[cfg(feature = "regex")]
1919
pub mod regex;
20+
#[cfg(feature = "yaml")]
21+
pub mod yaml;

src/yaml.rs

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
use std::result::Result as StdResult;
2+
use std::sync::Arc;
3+
4+
use mlua::{
5+
AnyUserData, Error, Function, Integer as LuaInteger, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod,
6+
MultiValue, Result, SerializeOptions, Table, UserData, UserDataMethods, UserDataRefMut, Value,
7+
};
8+
use ouroboros::self_referencing;
9+
use serde::{Serialize, Serializer};
10+
11+
use crate::bytes::StringOrBytes;
12+
13+
/// Represents a native YAML object in Lua.
14+
#[derive(Clone)]
15+
pub(crate) struct YamlObject {
16+
root: Arc<serde_yaml::Value>,
17+
current: *const serde_yaml::Value,
18+
}
19+
20+
impl Serialize for YamlObject {
21+
fn serialize<S: Serializer>(&self, serializer: S) -> StdResult<S::Ok, S::Error> {
22+
self.current().serialize(serializer)
23+
}
24+
}
25+
26+
impl YamlObject {
27+
/// Creates a new `YamlObject` from the given YAML value.
28+
///
29+
/// SAFETY:
30+
/// The caller must ensure that `current` is a value inside `root`.
31+
unsafe fn new(root: &Arc<serde_yaml::Value>, current: &serde_yaml::Value) -> Self {
32+
let root = root.clone();
33+
YamlObject { root, current }
34+
}
35+
36+
/// Returns a reference to the current YAML value.
37+
#[inline(always)]
38+
fn current(&self) -> &serde_yaml::Value {
39+
unsafe { &*self.current }
40+
}
41+
42+
/// Returns a new `YamlObject` which points to the value at the given key.
43+
///
44+
/// This operation is cheap and does not clone the underlying data.
45+
fn get(&self, key: Value) -> Option<Self> {
46+
let value = match key {
47+
Value::Integer(index) if index > 0 => self.current().get(index as usize - 1),
48+
Value::String(key) => key.to_str().ok().and_then(|s| self.current().get(&*s)),
49+
_ => None,
50+
}?;
51+
unsafe { Some(Self::new(&self.root, value)) }
52+
}
53+
54+
/// Converts this `YamlObject` into a Lua `Value`.
55+
fn into_lua(self, lua: &Lua) -> Result<Value> {
56+
match self.current() {
57+
serde_yaml::Value::Null => Ok(Value::NULL),
58+
serde_yaml::Value::Bool(b) => Ok(Value::Boolean(*b)),
59+
serde_yaml::Value::Number(n) => {
60+
if let Some(n) = n.as_i64() {
61+
Ok(Value::Integer(n as _))
62+
} else if let Some(n) = n.as_f64() {
63+
Ok(Value::Number(n))
64+
} else {
65+
Err(Error::ToLuaConversionError {
66+
from: "number".to_string(),
67+
to: "integer or float",
68+
message: Some("number is too big to fit in a Lua integer".to_owned()),
69+
})
70+
}
71+
}
72+
serde_yaml::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
73+
value @ serde_yaml::Value::Sequence(_) | value @ serde_yaml::Value::Mapping(_) => {
74+
let obj_ud = lua.create_ser_userdata(unsafe { YamlObject::new(&self.root, value) })?;
75+
Ok(Value::UserData(obj_ud))
76+
}
77+
serde_yaml::Value::Tagged(tagged) => {
78+
// For tagged values, we'll return the value part and ignore the tag for simplicity
79+
let obj = unsafe { YamlObject::new(&self.root, &tagged.value) };
80+
obj.into_lua(lua)
81+
}
82+
}
83+
}
84+
85+
fn lua_iterator(&self, lua: &Lua) -> Result<MultiValue> {
86+
match self.current() {
87+
serde_yaml::Value::Sequence(_) => {
88+
let next = Self::lua_array_iterator(lua)?;
89+
let iter_ud = AnyUserData::wrap(LuaYamlArrayIter {
90+
value: self.clone(),
91+
next: 1, // index starts at 1
92+
});
93+
(next, iter_ud).into_lua_multi(lua)
94+
}
95+
serde_yaml::Value::Mapping(_) => {
96+
let next = Self::lua_map_iterator(lua)?;
97+
let iter_builder = LuaYamlMapIterBuilder {
98+
value: self.clone(),
99+
iter_builder: |value| value.current().as_mapping().unwrap().iter(),
100+
};
101+
let iter_ud = AnyUserData::wrap(iter_builder.build());
102+
(next, iter_ud).into_lua_multi(lua)
103+
}
104+
_ => ().into_lua_multi(lua),
105+
}
106+
}
107+
108+
/// Returns an iterator function for arrays.
109+
fn lua_array_iterator(lua: &Lua) -> Result<Function> {
110+
if let Ok(Some(f)) = lua.named_registry_value("__yaml_array_iterator") {
111+
return Ok(f);
112+
}
113+
114+
let f = lua.create_function(|lua, mut it: UserDataRefMut<LuaYamlArrayIter>| {
115+
it.next += 1;
116+
match it.value.get(Value::Integer(it.next - 1)) {
117+
Some(next_value) => (it.next - 1, next_value.into_lua(lua)?).into_lua_multi(lua),
118+
None => ().into_lua_multi(lua),
119+
}
120+
})?;
121+
lua.set_named_registry_value("__yaml_array_iterator", &f)?;
122+
Ok(f)
123+
}
124+
125+
/// Returns an iterator function for objects.
126+
fn lua_map_iterator(lua: &Lua) -> Result<Function> {
127+
if let Ok(Some(f)) = lua.named_registry_value("__yaml_map_iterator") {
128+
return Ok(f);
129+
}
130+
131+
let f = lua.create_function(|lua, mut it: UserDataRefMut<LuaYamlMapIter>| {
132+
let root = it.borrow_value().root.clone();
133+
it.with_iter_mut(move |iter| match iter.next() {
134+
Some((key, value)) => {
135+
// Convert YAML key to Lua value
136+
let key = match key {
137+
serde_yaml::Value::Null
138+
| serde_yaml::Value::Bool(..)
139+
| serde_yaml::Value::String(..)
140+
| serde_yaml::Value::Number(..) => unsafe {
141+
YamlObject::new(&root, key).into_lua(lua)?
142+
},
143+
_ => {
144+
let err =
145+
Error::runtime("only string/number/boolean keys are supported in YAML maps");
146+
return Err(err);
147+
}
148+
};
149+
let value = unsafe { YamlObject::new(&root, value) }.into_lua(lua)?;
150+
(key, value).into_lua_multi(lua)
151+
}
152+
None => ().into_lua_multi(lua),
153+
})
154+
})?;
155+
lua.set_named_registry_value("__yaml_map_iterator", &f)?;
156+
Ok(f)
157+
}
158+
}
159+
160+
impl From<serde_yaml::Value> for YamlObject {
161+
fn from(value: serde_yaml::Value) -> Self {
162+
let root = Arc::new(value);
163+
unsafe { Self::new(&root, &root) }
164+
}
165+
}
166+
167+
impl UserData for YamlObject {
168+
fn register(registry: &mut mlua::UserDataRegistry<Self>) {
169+
registry.add_method("dump", |lua, this, ()| lua.to_value(this));
170+
171+
registry.add_method("iter", |lua, this, ()| this.lua_iterator(lua));
172+
173+
registry.add_meta_method(MetaMethod::Index, |lua, this, key: Value| {
174+
this.get(key)
175+
.map(|obj| obj.into_lua(lua))
176+
.unwrap_or(Ok(Value::Nil))
177+
});
178+
179+
registry.add_meta_method(crate::METAMETHOD_ITER, |lua, this, ()| this.lua_iterator(lua));
180+
}
181+
}
182+
183+
struct LuaYamlArrayIter {
184+
value: YamlObject,
185+
next: LuaInteger,
186+
}
187+
188+
#[self_referencing]
189+
struct LuaYamlMapIter {
190+
value: YamlObject,
191+
192+
#[borrows(value)]
193+
#[covariant]
194+
iter: serde_yaml::mapping::Iter<'this>,
195+
}
196+
197+
fn decode(lua: &Lua, (data, opts): (StringOrBytes, Option<Table>)) -> Result<StdResult<Value, String>> {
198+
let opts = opts.as_ref();
199+
let mut options = SerializeOptions::new();
200+
if let Some(enabled) = opts.and_then(|t| t.get::<bool>("set_array_metatable").ok()) {
201+
options = options.set_array_metatable(enabled);
202+
}
203+
204+
let mut yaml: serde_yaml::Value = lua_try!(serde_yaml::from_slice(&data.as_bytes_deref()));
205+
lua_try!(yaml.apply_merge());
206+
Ok(Ok(lua.to_value_with(&yaml, options)?))
207+
}
208+
209+
fn decode_native(lua: &Lua, data: StringOrBytes) -> Result<StdResult<Value, String>> {
210+
let mut yaml: serde_yaml::Value = lua_try!(serde_yaml::from_slice(&data.as_bytes_deref()));
211+
lua_try!(yaml.apply_merge());
212+
Ok(Ok(lua_try!(YamlObject::from(yaml).into_lua(lua))))
213+
}
214+
215+
fn encode(value: Value, opts: Option<Table>) -> StdResult<String, String> {
216+
let opts = opts.as_ref();
217+
let mut value = value.to_serializable();
218+
219+
if opts.and_then(|t| t.get::<bool>("relaxed").ok()) == Some(true) {
220+
value = value.deny_recursive_tables(false).deny_unsupported_types(false);
221+
}
222+
223+
serde_yaml::to_string(&value).map_err(|e| e.to_string())
224+
}
225+
226+
/// A loader for the `yaml` module.
227+
fn loader(lua: &Lua) -> Result<Table> {
228+
let t = lua.create_table()?;
229+
t.set("decode", lua.create_function(decode)?)?;
230+
t.set("decode_native", lua.create_function(decode_native)?)?;
231+
t.set("encode", Function::wrap_raw(encode))?;
232+
Ok(t)
233+
}
234+
235+
/// Registers the `yaml` module in the given Lua state.
236+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
237+
let name = name.unwrap_or("@yaml");
238+
let value = loader(lua)?;
239+
lua.register_module(name, &value)?;
240+
Ok(value)
241+
}

tests/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ fn run_file(modname: &str) -> Result<()> {
1313

1414
#[cfg(feature = "json")]
1515
mlua_stdlib::json::register(&lua, None)?;
16+
#[cfg(feature = "yaml")]
17+
mlua_stdlib::yaml::register(&lua, None)?;
1618
#[cfg(feature = "regex")]
1719
mlua_stdlib::regex::register(&lua, None)?;
1820

@@ -54,3 +56,5 @@ include_tests! {
5456
include_tests!(json);
5557
#[cfg(feature = "regex")]
5658
include_tests!(regex);
59+
#[cfg(feature = "yaml")]
60+
include_tests!(yaml);

0 commit comments

Comments
 (0)