Skip to content

Commit

Permalink
Merge branch 'feature/customize-class-bases-option' (thanks @jnastarot)
Browse files Browse the repository at this point in the history
  • Loading branch information
pthom committed Dec 2, 2024
2 parents b04f247 + 6a08822 commit 2e56f48
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/litgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litgen.litgen_generator import (
LitgenGenerator,
GeneratedCodes,
GeneratedCodeType,
write_generated_code_for_files,
write_generated_code_for_file,
generate_code,
Expand All @@ -37,6 +38,7 @@
# When it is needed to have different options per c++ header file
"LitgenGenerator",
"GeneratedCodes",
"GeneratedCodeType",
"generate_code_for_file",
# Configure replacements
"standard_type_replacements",
Expand Down
35 changes: 28 additions & 7 deletions src/litgen/internal/adapted_types/adapted_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,24 @@ def stub_lines(self) -> list[str]:

def str_parent_classes_python() -> str:
parents: list[str] = []
if not self.cpp_element().has_base_classes():
custom_derived = (
[]
if not self.options.class_custom_inheritance_callback
else self.options.class_custom_inheritance_callback(self, litgen.GeneratedCodeType.stub)
)

if not custom_derived and not self.cpp_element().has_base_classes():
return ""
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)

if custom_derived:
for custom_base in custom_derived:
parents.append(custom_base)
else:
for _access_type, base_class in self.cpp_element().base_classes():
class_python_scope = cpp_to_python.cpp_scope_to_pybind_scope_str(
self.options, base_class, include_self=True
)
parents.append(class_python_scope)
if len(parents) == 0:
return ""
else:
Expand Down Expand Up @@ -604,11 +615,21 @@ def make_pyclass_creation_code() -> str:

# fill py::class_ additional template params (base classes, nodelete, etc)
other_template_params_list = []
if self.cpp_element().has_base_classes():
custom_derived = (
[]
if not self.options.class_custom_inheritance_callback
else self.options.class_custom_inheritance_callback(self, litgen.GeneratedCodeType.pydef)
)

if custom_derived:
for custom_base in custom_derived:
other_template_params_list.append(custom_base)
elif self.cpp_element().has_base_classes():
base_classes = self.cpp_element().base_classes()
for access_type, base_class in base_classes:
if access_type == CppAccessType.public or access_type == CppAccessType.protected:
other_template_params_list.append(base_class.cpp_scope_str(include_self=True))

if self.cpp_element().has_private_destructor() and options.bind_library == BindLibraryType.pybind11:
# nanobind does not support nodelete
other_template_params_list.append(f"std::unique_ptr<{qualified_struct_name}, py::nodelete>")
Expand Down
7 changes: 7 additions & 0 deletions src/litgen/litgen_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import subprocess
from dataclasses import dataclass
from enum import Enum

from codemanip import code_utils

Expand Down Expand Up @@ -61,6 +62,12 @@ def add_python_exe_folder_to_env_path() -> None:
return _apply_black_formatter_pyi_via_subprocess(options, file)


class GeneratedCodeType(Enum):
pydef = 1
stub = 2
glue = 3


@dataclass
class _GeneratedCode:
source_filename: CppFilename
Expand Down
10 changes: 9 additions & 1 deletion src/litgen/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from litgen.internal.class_iterable_info import ClassIterablesInfos

if TYPE_CHECKING:
from litgen.internal.adapted_types import AdaptedFunction
from litgen.internal.adapted_types import AdaptedFunction, AdaptedClass
from litgen.litgen_generator import GeneratedCodeType


class BindLibraryType(Enum):
Expand Down Expand Up @@ -505,6 +506,13 @@ class LitgenOptions:
# - [Understanding Holder Types in pybind11](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#custom-smart-pointers)
class_held_as_shared__regex: str = ""

# class_custom_inheritance_callback:
# (advanced) A callback to customize the base classes used in generated bindings.
# The first parameter is the AdaptedClass, representing the C++ class being adapted.
# The second parameter is the GeneratedCodeType, indicating whether stub or pydef code is being generated.
# An example usage can be found in: src/litgen/tests/option_class_custom_derivation__callback_test.py
class_custom_inheritance_callback: Callable[[AdaptedClass, GeneratedCodeType], list[str]] | None = None

# ------------------------------------------------------------------------------
# Templated class options
# ------------------------------------------------------------------------------
Expand Down
85 changes: 85 additions & 0 deletions src/litgen/tests/option_class_custom_derivation__callback_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations
from codemanip import code_utils
import litgen
from litgen.internal.adapted_types import AdaptedClass
from litgen import GeneratedCodeType


def test_class_custom_derivation__callback():
"""Example of how the callback mechanism `options.class_custom_inheritance_callback`
can be used in practice to add a base class
"""

# Let's suppose that we have the following C++ code in another file,
# which was not processed by litgen (or not yet)
"""
namespace CustomNS {
class FirstClass {
public:
FirstClass();
private:
int _value1;
};
}
"""

# And we are processing the following code:
code = """
class SecondClass : CustomNS::FirstClass {
public:
SecondClass();
private:
int _value2;
};
class ThirdClass : SecondClass {
public:
ThirdClass();
private:
int _value3;
};
"""

# This will be our callback to add the base class: it returns the base class which we should add
# (with a syntax that depends slightly on the generated code type)
def handle_classes_base(cls: AdaptedClass, generated_code_type: GeneratedCodeType) -> list[str]:
bases = []
elem = cls.cpp_element()
if elem.class_name == "SecondClass":
bases.append("FirstClass" if generated_code_type == GeneratedCodeType.stub else "CustomNS::FirstClass")
return bases

options = litgen.LitgenOptions()
options.class_custom_inheritance_callback = handle_classes_base
generated_code = litgen.generate_code(options, code)

code_utils.assert_are_codes_equal(
generated_code.pydef_code,
"""
auto pyClassSecondClass =
py::class_<SecondClass, CustomNS::FirstClass>
(m, "SecondClass", "")
.def(py::init<>())
;
auto pyClassThirdClass =
py::class_<ThirdClass, SecondClass>
(m, "ThirdClass", "")
.def(py::init<>())
;
""",
)

code_utils.assert_are_codes_equal(
generated_code.stub_code,
"""
class SecondClass(FirstClass):
def __init__(self) -> None:
pass
class ThirdClass(SecondClass):
def __init__(self) -> None:
pass
""",
)

0 comments on commit 2e56f48

Please sign in to comment.