diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index c19c93b3..bbbeaea8 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -88,6 +88,22 @@ struct single_error_lut_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; +struct trt_decoder_config { + std::optional onnx_load_path; + std::optional engine_load_path; + std::optional engine_save_path; + std::optional precision; + std::optional memory_workspace; + + bool operator==(const trt_decoder_config &) const = default; + + __attribute__((visibility("default"))) cudaqx::heterogeneous_map + to_heterogeneous_map() const; + + __attribute__((visibility("default"))) static trt_decoder_config + from_heterogeneous_map(const cudaqx::heterogeneous_map &map); +}; + struct sliding_window_config { std::optional window_size; std::optional step_size; @@ -122,7 +138,8 @@ struct decoder_config { std::vector O_sparse; std::vector D_sparse; std::variant + nv_qldpc_decoder_config, sliding_window_config, + trt_decoder_config> decoder_custom_args; bool operator==(const decoder_config &) const = default; @@ -144,6 +161,10 @@ struct decoder_config { decoder_custom_args)) { return std::get(decoder_custom_args) .to_heterogeneous_map(); + } else if (std::holds_alternative( + decoder_custom_args)) { + return std::get(decoder_custom_args) + .to_heterogeneous_map(); } return cudaqx::heterogeneous_map(); } @@ -161,6 +182,8 @@ struct decoder_config { nv_qldpc_decoder_config::from_heterogeneous_map(map); } else if (type == "sliding_window") { decoder_custom_args = sliding_window_config::from_heterogeneous_map(map); + } else if (type == "trt_decoder") { + decoder_custom_args = trt_decoder_config::from_heterogeneous_map(map); } } diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index 87d4b54e..bd4a066a 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -188,6 +188,31 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map( return config; } +// ------ trt_decoder_config ------ +cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { + cudaqx::heterogeneous_map config_map; + + INSERT_ARG(onnx_load_path); + INSERT_ARG(engine_load_path); + INSERT_ARG(engine_save_path); + INSERT_ARG(precision); + INSERT_ARG(memory_workspace); + + return config_map; +} + +trt_decoder_config trt_decoder_config::from_heterogeneous_map( + const cudaqx::heterogeneous_map &map) { + trt_decoder_config config; + GET_ARG(onnx_load_path); + GET_ARG(engine_load_path); + GET_ARG(engine_save_path); + GET_ARG(precision); + GET_ARG(memory_workspace); + + return config; +} + // ------ sliding_window_config ------ cudaqx::heterogeneous_map sliding_window_config::to_heterogeneous_map() const { cudaqx::heterogeneous_map config_map; @@ -317,6 +342,18 @@ struct MappingTraits { cudaq::qec::decoding::config::single_error_lut_config &config) {} }; +template <> +struct MappingTraits { + static void + mapping(IO &io, cudaq::qec::decoding::config::trt_decoder_config &config) { + io.mapOptional("onnx_load_path", config.onnx_load_path); + io.mapOptional("engine_load_path", config.engine_load_path); + io.mapOptional("engine_save_path", config.engine_save_path); + io.mapOptional("precision", config.precision); + io.mapOptional("memory_workspace", config.memory_workspace); + } +}; + template <> struct MappingTraits { static void @@ -418,6 +455,9 @@ struct MappingTraits { } else if (config.type == "single_error_lut") { INIT_AND_MAP_DECODER_CUSTOM_ARGS( cudaq::qec::decoding::config::single_error_lut_config); + } else if (config.type == "trt_decoder") { + INIT_AND_MAP_DECODER_CUSTOM_ARGS( + cudaq::qec::decoding::config::trt_decoder_config); } else if (config.type == "sliding_window") { INIT_AND_MAP_DECODER_CUSTOM_ARGS( cudaq::qec::decoding::config::sliding_window_config); diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index e68493b9..a0f13d54 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -98,6 +98,24 @@ void bindDecodingConfig(py::module &mod) { &multi_error_lut_config::from_heterogeneous_map, py::arg("map")); + // trt_decoder_config + py::class_(mod_cfg, "trt_decoder_config", + "TensorRT decoder configuration.") + .def(py::init<>()) + .def(py::init([](const cudaqx::heterogeneous_map &map) { + return trt_decoder_config::from_heterogeneous_map(map); + }), + py::arg("map")) + .def_readwrite("onnx_load_path", &trt_decoder_config::onnx_load_path) + .def_readwrite("engine_load_path", &trt_decoder_config::engine_load_path) + .def_readwrite("engine_save_path", &trt_decoder_config::engine_save_path) + .def_readwrite("precision", &trt_decoder_config::precision) + .def_readwrite("memory_workspace", &trt_decoder_config::memory_workspace) + .def("to_heterogeneous_map", &trt_decoder_config::to_heterogeneous_map, + py::return_value_policy::move) + .def_static("from_heterogeneous_map", + &trt_decoder_config::from_heterogeneous_map, py::arg("map")); + // single_error_lut_config py::class_( mod_cfg, "single_error_lut_config", diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index d72b537d..19a2c9d3 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -48,6 +48,7 @@ decoder_config = qecrt.config.decoder_config nv_qldpc_decoder_config = qecrt.config.nv_qldpc_decoder_config multi_error_lut_config = qecrt.config.multi_error_lut_config +trt_decoder_config = qecrt.config.trt_decoder_config configure_decoders_from_file = qecrt.config.configure_decoders_from_file configure_decoders_from_str = qecrt.config.configure_decoders_from_str finalize_decoders = qecrt.config.finalize_decoders diff --git a/libs/qec/python/tests/test_decoding_config.py b/libs/qec/python/tests/test_decoding_config.py index c6c32c85..8e89a673 100644 --- a/libs/qec/python/tests/test_decoding_config.py +++ b/libs/qec/python/tests/test_decoding_config.py @@ -166,6 +166,16 @@ def test_nv_qldpc_decoder_config_toggle_multiple_fields_and_clear(): "lut_error_depth": (int, 1, 3), } +# trt_decoder_config tests + +FIELDS_TRT_DECODER = { + "onnx_load_path": (str, "/path/to/model.onnx", "/other/path/model.onnx"), + "engine_load_path": (str, "/path/to/engine.trt", "/other/engine.trt"), + "engine_save_path": (str, "/path/to/save.trt", "/other/save.trt"), + "precision": (str, "fp16", "fp32"), + "memory_workspace": (int, 1073741824, 2147483648), # 1GB, 2GB +} + def test_multi_error_lut_config_defaults_are_none(): m = qec.multi_error_lut_config() @@ -195,6 +205,74 @@ def test_configure_valid_multi_error_lut_decoders(): assert ret == 0 +# trt_decoder_config tests + + +def test_trt_decoder_config_defaults_are_none(): + trt = qec.trt_decoder_config() + for name in FIELDS_TRT_DECODER: + assert getattr(trt, name) is None, f"Expected {name} to default to None" + + +@pytest.mark.parametrize("name, meta", list(FIELDS_TRT_DECODER.items())) +def test_trt_decoder_config_set_and_get_each_optional(name, meta): + trt = qec.trt_decoder_config() + + py_type, sample_val, alt_val = meta + + # Initially None + assert getattr(trt, name) is None + + # Set to a valid value and get back + setattr(trt, name, sample_val) + got = getattr(trt, name) + assert isinstance(got, py_type) + assert got == sample_val + + # Change to an alternate valid value + setattr(trt, name, alt_val) + got2 = getattr(trt, name) + assert got2 == alt_val + + # Set value to None + setattr(trt, name, None) + assert getattr(trt, name) is None + + +def test_trt_decoder_config_yaml_roundtrip(): + trt = qec.trt_decoder_config() + trt.engine_load_path = "/path/to/engine.trt" + trt.precision = "fp16" + trt.memory_workspace = 1073741824 # 1GB + + dc = qec.decoder_config() + dc.id = 0 + dc.type = "trt_decoder" + dc.block_size = 10 + dc.syndrome_size = 3 + dc.H_sparse = [1, 2, 3, -1, 6, 7, 8, -1, -1] + dc.set_decoder_custom_args(trt) + + yaml_text = dc.to_yaml_str() + assert isinstance(yaml_text, str) and len(yaml_text) > 0 + + dc2 = qec.decoder_config.from_yaml_str(yaml_text) + + # Basic scalar fields + assert dc2 is not None + assert dc2.id == 0 + assert dc2.type == "trt_decoder" + assert dc2.block_size == 10 + assert dc2.syndrome_size == 3 + + # Recover TRT config from decoder_custom_args + trt2 = dc2.decoder_custom_args + assert trt2 is not None + assert trt2.engine_load_path == "/path/to/engine.trt" + assert trt2.precision == "fp16" + assert trt2.memory_workspace == 1073741824 + + # decoder_config tests