Skip to content

Commit 975bdfa

Browse files
committed
use structural pattern matching
1 parent 57a99b0 commit 975bdfa

File tree

10 files changed

+589
-612
lines changed

10 files changed

+589
-612
lines changed

schema_salad/cpp_codegen.py

Lines changed: 51 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -612,106 +612,72 @@ def convertTypeToCpp(self, type_declaration: list[Any] | dict[str, Any] | str) -
612612
if not isinstance(type_declaration, list):
613613
return self.convertTypeToCpp([type_declaration])
614614

615-
if len(type_declaration) == 1:
616-
if type_declaration[0] in ("null", "https://w3id.org/cwl/salad#null"):
615+
if len(type_declaration) > 1:
616+
type_declaration = list(map(self.convertTypeToCpp, type_declaration))
617+
type_declaration = ", ".join(type_declaration)
618+
return f"std::variant<{type_declaration}>"
619+
620+
match type_declaration[0]:
621+
case "null" | "https://w3id.org/cwl/salad#null":
617622
return "std::monostate"
618-
elif type_declaration[0] in (
619-
"string",
620-
"http://www.w3.org/2001/XMLSchema#string",
621-
):
623+
case "string" | "http://www.w3.org/2001/XMLSchema#string":
622624
return "std::string"
623-
elif type_declaration[0] in ("int", "http://www.w3.org/2001/XMLSchema#int"):
625+
case "int" | "http://www.w3.org/2001/XMLSchema#int":
624626
return "int32_t"
625-
elif type_declaration[0] in (
626-
"long",
627-
"http://www.w3.org/2001/XMLSchema#long",
628-
):
627+
case "long" | "http://www.w3.org/2001/XMLSchema#long":
629628
return "int64_t"
630-
elif type_declaration[0] in (
631-
"float",
632-
"http://www.w3.org/2001/XMLSchema#float",
633-
):
629+
case "float" | "http://www.w3.org/2001/XMLSchema#float":
634630
return "float"
635-
elif type_declaration[0] in (
636-
"double",
637-
"http://www.w3.org/2001/XMLSchema#double",
638-
):
631+
case "double" | "http://www.w3.org/2001/XMLSchema#double":
639632
return "double"
640-
elif type_declaration[0] in (
641-
"boolean",
642-
"http://www.w3.org/2001/XMLSchema#boolean",
643-
):
633+
case "boolean" | "http://www.w3.org/2001/XMLSchema#boolean":
644634
return "bool"
645-
elif type_declaration[0] == "https://w3id.org/cwl/salad#Any":
635+
case "https://w3id.org/cwl/salad#Any":
646636
return "std::any"
647-
elif type_declaration[0] == "https://w3id.org/cwl/cwl#Expression":
637+
case "https://w3id.org/cwl/cwl#Expression":
648638
return "cwl_expression_string"
649-
elif type_declaration[0] in (
650-
"PrimitiveType",
651-
"https://w3id.org/cwl/salad#PrimitiveType",
652-
):
639+
case "PrimitiveType" | "https://w3id.org/cwl/salad#PrimitiveType":
653640
return "std::variant<bool, int32_t, int64_t, float, double, std::string>"
654-
elif isinstance(type_declaration[0], dict):
655-
if "type" in type_declaration[0] and type_declaration[0]["type"] in (
656-
"enum",
657-
"https://w3id.org/cwl/salad#enum",
658-
):
659-
name = type_declaration[0]["name"]
660-
if name not in self.enumDefinitions:
661-
self.enumDefinitions[name] = EnumDefinition(
662-
type_declaration[0]["name"],
663-
list(map(shortname, type_declaration[0]["symbols"])),
664-
)
665-
if len(name.split("#")) != 2:
666-
return safename(name)
667-
(namespace, classname) = name.split("#")
668-
return safenamespacename(namespace) + "::" + safename(classname)
669-
elif "type" in type_declaration[0] and type_declaration[0]["type"] in (
670-
"array",
671-
"https://w3id.org/cwl/salad#array",
672-
):
673-
items = type_declaration[0]["items"]
674-
if isinstance(items, list):
675-
ts = [self.convertTypeToCpp(i) for i in items]
676-
name = ", ".join(ts)
677-
return f"std::vector<std::variant<{name}>>"
678-
else:
679-
i = self.convertTypeToCpp(items)
680-
return f"std::vector<{i}>"
681-
elif "type" in type_declaration[0] and type_declaration[0]["type"] in (
682-
"map",
683-
"https://w3id.org/cwl/salad#map",
684-
):
685-
values = type_declaration[0]["values"]
686-
if isinstance(values, list):
687-
ts = [self.convertTypeToCpp(i) for i in values]
688-
name = ", ".join(ts)
689-
return f"std::map<std::string, std::variant<{name}>>"
690-
else:
691-
i = self.convertTypeToCpp(values)
692-
return f"std::map<std::string, {i}>"
693-
elif "type" in type_declaration[0] and type_declaration[0]["type"] in (
694-
"record",
695-
"https://w3id.org/cwl/salad#record",
696-
):
697-
n = type_declaration[0]["name"]
698-
(namespace, classname) = split_name(n)
699-
return safenamespacename(namespace) + "::" + safename(classname)
700-
641+
case {"type": "enum" | "https://w3id.org/cwl/salad#enum"}:
642+
name = type_declaration[0]["name"]
643+
if name not in self.enumDefinitions:
644+
self.enumDefinitions[name] = EnumDefinition(
645+
type_declaration[0]["name"],
646+
list(map(shortname, type_declaration[0]["symbols"])),
647+
)
648+
if len(name.split("#")) != 2:
649+
return safename(name)
650+
(namespace, classname) = name.split("#")
651+
return safenamespacename(namespace) + "::" + safename(classname)
652+
case {"type": "array" | "https://w3id.org/cwl/salad#array", "items": list(items)}:
653+
ts = [self.convertTypeToCpp(i) for i in items]
654+
name = ", ".join(ts)
655+
return f"std::vector<std::variant<{name}>>"
656+
case {"type": "array" | "https://w3id.org/cwl/salad#array", "items": items}:
657+
i = self.convertTypeToCpp(items)
658+
return f"std::vector<{i}>"
659+
case {"type": "map" | "https://w3id.org/cwl/salad#map", "values": list(values)}:
660+
ts = [self.convertTypeToCpp(i) for i in values]
661+
name = ", ".join(ts)
662+
return f"std::map<std::string, std::variant<{name}>>"
663+
case {"type": "map" | "https://w3id.org/cwl/salad#map", "values": values}:
664+
i = self.convertTypeToCpp(values)
665+
return f"std::map<std::string, {i}>"
666+
case {"type": "record" | "https://w3id.org/cwl/salad#record"}:
667+
n = type_declaration[0]["name"]
668+
(namespace, classname) = split_name(n)
669+
return safenamespacename(namespace) + "::" + safename(classname)
670+
case dict():
701671
n = type_declaration[0]["type"]
702672
(namespace, classname) = split_name(n)
703673
return safenamespacename(namespace) + "::" + safename(classname)
704674

705-
if len(type_declaration[0].split("#")) != 2:
706-
_logger.debug(f"// something weird2 about {type_declaration[0]}")
707-
return cast(str, type_declaration[0])
708-
709-
(namespace, classname) = split_name(type_declaration[0])
710-
return safenamespacename(namespace) + "::" + safename(classname)
675+
if len(type_declaration[0].split("#")) != 2:
676+
_logger.debug(f"// something weird2 about {type_declaration[0]}")
677+
return cast(str, type_declaration[0])
711678

712-
type_declaration = list(map(self.convertTypeToCpp, type_declaration))
713-
type_declaration = ", ".join(type_declaration)
714-
return f"std::variant<{type_declaration}>"
679+
(namespace, classname) = split_name(type_declaration[0])
680+
return safenamespacename(namespace) + "::" + safename(classname)
715681

716682
def epilogue(self, root_loader: TypeDef | None) -> None:
717683
"""Trigger to generate the epilouge code."""

schema_salad/dlang_codegen.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -205,41 +205,45 @@ def parse_record_field_type(
205205
else:
206206
annotate_str = ""
207207

208-
if isinstance(type_, str):
209-
stype = shortname(type_)
210-
if stype == "boolean":
211-
type_str = "bool"
212-
elif stype == "null":
213-
type_str = "None"
214-
else:
215-
type_str = stype
216-
elif isinstance(type_, list):
217-
t_str = [
218-
self.parse_record_field_type(t, None, parent_has_idmap=has_idmap)[1] for t in type_
219-
]
220-
if has_default:
221-
t_str = [t for t in t_str if t != "None"]
222-
if len(t_str) == 1:
223-
type_str = t_str[0]
224-
else:
225-
if are_dispatchable(type_, has_idmap):
226-
t_str += ["Any"]
227-
union_types = ", ".join(t_str)
228-
type_str = f"Union!({union_types})"
229-
elif shortname(type_["type"]) == "array":
230-
item_type = self.parse_record_field_type(
231-
type_["items"], None, parent_has_idmap=has_idmap
232-
)[1]
233-
type_str = f"{item_type}[]"
234-
elif shortname(type_["type"]) == "record":
235-
return annotate_str, shortname(type_.get("name", "record"))
236-
elif shortname(type_["type"]) == "enum":
237-
return annotate_str, "'not yet implemented'"
238-
elif shortname(type_["type"]) == "map":
239-
value_type = self.parse_record_field_type(
240-
type_["values"], None, parent_has_idmap=has_idmap, has_default=True
241-
)[1]
242-
type_str = f"{value_type}[string]"
208+
match type_:
209+
case str():
210+
match shortname(type_):
211+
case "boolean":
212+
type_str = "bool"
213+
case "null":
214+
type_str = "None"
215+
case str(stype):
216+
type_str = stype
217+
case list():
218+
t_str = [
219+
self.parse_record_field_type(t, None, parent_has_idmap=has_idmap)[1]
220+
for t in type_
221+
]
222+
if has_default:
223+
t_str = [t for t in t_str if t != "None"]
224+
if len(t_str) == 1:
225+
type_str = t_str[0]
226+
else:
227+
if are_dispatchable(type_, has_idmap):
228+
t_str += ["Any"]
229+
union_types = ", ".join(t_str)
230+
type_str = f"Union!({union_types})"
231+
case dict():
232+
match shortname(type_["type"]):
233+
case "array":
234+
item_type = self.parse_record_field_type(
235+
type_["items"], None, parent_has_idmap=has_idmap
236+
)[1]
237+
type_str = f"{item_type}[]"
238+
case "record":
239+
return annotate_str, shortname(type_.get("name", "record"))
240+
case "enum":
241+
return annotate_str, "'not yet implemented'"
242+
case "map":
243+
value_type = self.parse_record_field_type(
244+
type_["values"], None, parent_has_idmap=has_idmap, has_default=True
245+
)[1]
246+
type_str = f"{value_type}[string]"
243247
return annotate_str, type_str
244248

245249
def parse_record_field(self, field: dict[str, Any], parent_name: str | None = None) -> str:

0 commit comments

Comments
 (0)