Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion libs/qec/include/cudaq/qec/realtime/decoding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ struct single_error_lut_config {
from_heterogeneous_map(const cudaqx::heterogeneous_map &map);
};

struct trt_decoder_config {
std::optional<std::string> onnx_load_path;
std::optional<std::string> engine_load_path;
std::optional<std::string> engine_save_path;
std::optional<std::string> precision;
std::optional<std::size_t> memory_workspace;

bool operator==(const trt_decoder_config &) const = default;

__attribute__((visibility("default"))) cudaqx::heterogeneous_map
to_heterogeneous_map() const;

__attribute__((visibility("default"))) static trt_decoder_config
from_heterogeneous_map(const cudaqx::heterogeneous_map &map);
};

struct sliding_window_config {
std::optional<std::size_t> window_size;
std::optional<std::size_t> step_size;
Expand Down Expand Up @@ -122,7 +138,8 @@ struct decoder_config {
std::vector<std::int64_t> O_sparse;
std::vector<std::int64_t> D_sparse;
std::variant<single_error_lut_config, multi_error_lut_config,
nv_qldpc_decoder_config, sliding_window_config>
nv_qldpc_decoder_config, sliding_window_config,
trt_decoder_config>
decoder_custom_args;

bool operator==(const decoder_config &) const = default;
Expand All @@ -144,6 +161,10 @@ struct decoder_config {
decoder_custom_args)) {
return std::get<sliding_window_config>(decoder_custom_args)
.to_heterogeneous_map();
} else if (std::holds_alternative<trt_decoder_config>(
decoder_custom_args)) {
return std::get<trt_decoder_config>(decoder_custom_args)
.to_heterogeneous_map();
}
return cudaqx::heterogeneous_map();
}
Expand All @@ -161,6 +182,8 @@ struct decoder_config {
nv_qldpc_decoder_config::from_heterogeneous_map(map);
} else if (type == "sliding_window") {
decoder_custom_args = sliding_window_config::from_heterogeneous_map(map);
} else if (type == "trt_decoder") {
decoder_custom_args = trt_decoder_config::from_heterogeneous_map(map);
}
}

Expand Down
40 changes: 40 additions & 0 deletions libs/qec/lib/realtime/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,31 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map(
return config;
}

// ------ trt_decoder_config ------
cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const {
cudaqx::heterogeneous_map config_map;

INSERT_ARG(onnx_load_path);
INSERT_ARG(engine_load_path);
INSERT_ARG(engine_save_path);
INSERT_ARG(precision);
INSERT_ARG(memory_workspace);

return config_map;
}

trt_decoder_config trt_decoder_config::from_heterogeneous_map(
const cudaqx::heterogeneous_map &map) {
trt_decoder_config config;
GET_ARG(onnx_load_path);
GET_ARG(engine_load_path);
GET_ARG(engine_save_path);
GET_ARG(precision);
GET_ARG(memory_workspace);

return config;
}

// ------ sliding_window_config ------
cudaqx::heterogeneous_map sliding_window_config::to_heterogeneous_map() const {
cudaqx::heterogeneous_map config_map;
Expand Down Expand Up @@ -317,6 +342,18 @@ struct MappingTraits<cudaq::qec::decoding::config::single_error_lut_config> {
cudaq::qec::decoding::config::single_error_lut_config &config) {}
};

template <>
struct MappingTraits<cudaq::qec::decoding::config::trt_decoder_config> {
static void
mapping(IO &io, cudaq::qec::decoding::config::trt_decoder_config &config) {
io.mapOptional("onnx_load_path", config.onnx_load_path);
io.mapOptional("engine_load_path", config.engine_load_path);
io.mapOptional("engine_save_path", config.engine_save_path);
io.mapOptional("precision", config.precision);
io.mapOptional("memory_workspace", config.memory_workspace);
}
};

template <>
struct MappingTraits<cudaq::qec::decoding::config::sliding_window_config> {
static void
Expand Down Expand Up @@ -418,6 +455,9 @@ struct MappingTraits<cudaq::qec::decoding::config::decoder_config> {
} else if (config.type == "single_error_lut") {
INIT_AND_MAP_DECODER_CUSTOM_ARGS(
cudaq::qec::decoding::config::single_error_lut_config);
} else if (config.type == "trt_decoder") {
INIT_AND_MAP_DECODER_CUSTOM_ARGS(
cudaq::qec::decoding::config::trt_decoder_config);
} else if (config.type == "sliding_window") {
INIT_AND_MAP_DECODER_CUSTOM_ARGS(
cudaq::qec::decoding::config::sliding_window_config);
Expand Down
18 changes: 18 additions & 0 deletions libs/qec/python/bindings/py_decoding_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ void bindDecodingConfig(py::module &mod) {
&multi_error_lut_config::from_heterogeneous_map,
py::arg("map"));

// trt_decoder_config
py::class_<config::trt_decoder_config>(mod_cfg, "trt_decoder_config",
"TensorRT decoder configuration.")
.def(py::init<>())
.def(py::init([](const cudaqx::heterogeneous_map &map) {
return trt_decoder_config::from_heterogeneous_map(map);
}),
py::arg("map"))
.def_readwrite("onnx_load_path", &trt_decoder_config::onnx_load_path)
.def_readwrite("engine_load_path", &trt_decoder_config::engine_load_path)
.def_readwrite("engine_save_path", &trt_decoder_config::engine_save_path)
.def_readwrite("precision", &trt_decoder_config::precision)
.def_readwrite("memory_workspace", &trt_decoder_config::memory_workspace)
.def("to_heterogeneous_map", &trt_decoder_config::to_heterogeneous_map,
py::return_value_policy::move)
.def_static("from_heterogeneous_map",
&trt_decoder_config::from_heterogeneous_map, py::arg("map"));

// single_error_lut_config
py::class_<config::single_error_lut_config>(
mod_cfg, "single_error_lut_config",
Expand Down
1 change: 1 addition & 0 deletions libs/qec/python/cudaq_qec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
decoder_config = qecrt.config.decoder_config
nv_qldpc_decoder_config = qecrt.config.nv_qldpc_decoder_config
multi_error_lut_config = qecrt.config.multi_error_lut_config
trt_decoder_config = qecrt.config.trt_decoder_config
configure_decoders_from_file = qecrt.config.configure_decoders_from_file
configure_decoders_from_str = qecrt.config.configure_decoders_from_str
finalize_decoders = qecrt.config.finalize_decoders
Expand Down
78 changes: 78 additions & 0 deletions libs/qec/python/tests/test_decoding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def test_nv_qldpc_decoder_config_toggle_multiple_fields_and_clear():
"lut_error_depth": (int, 1, 3),
}

# trt_decoder_config tests

FIELDS_TRT_DECODER = {
"onnx_load_path": (str, "/path/to/model.onnx", "/other/path/model.onnx"),
"engine_load_path": (str, "/path/to/engine.trt", "/other/engine.trt"),
"engine_save_path": (str, "/path/to/save.trt", "/other/save.trt"),
"precision": (str, "fp16", "fp32"),
"memory_workspace": (int, 1073741824, 2147483648), # 1GB, 2GB
}


def test_multi_error_lut_config_defaults_are_none():
m = qec.multi_error_lut_config()
Expand Down Expand Up @@ -195,6 +205,74 @@ def test_configure_valid_multi_error_lut_decoders():
assert ret == 0


# trt_decoder_config tests


def test_trt_decoder_config_defaults_are_none():
trt = qec.trt_decoder_config()
for name in FIELDS_TRT_DECODER:
assert getattr(trt, name) is None, f"Expected {name} to default to None"


@pytest.mark.parametrize("name, meta", list(FIELDS_TRT_DECODER.items()))
def test_trt_decoder_config_set_and_get_each_optional(name, meta):
trt = qec.trt_decoder_config()

py_type, sample_val, alt_val = meta

# Initially None
assert getattr(trt, name) is None

# Set to a valid value and get back
setattr(trt, name, sample_val)
got = getattr(trt, name)
assert isinstance(got, py_type)
assert got == sample_val

# Change to an alternate valid value
setattr(trt, name, alt_val)
got2 = getattr(trt, name)
assert got2 == alt_val

# Set value to None
setattr(trt, name, None)
assert getattr(trt, name) is None


def test_trt_decoder_config_yaml_roundtrip():
trt = qec.trt_decoder_config()
trt.engine_load_path = "/path/to/engine.trt"
trt.precision = "fp16"
trt.memory_workspace = 1073741824 # 1GB

dc = qec.decoder_config()
dc.id = 0
dc.type = "trt_decoder"
dc.block_size = 10
dc.syndrome_size = 3
dc.H_sparse = [1, 2, 3, -1, 6, 7, 8, -1, -1]
dc.set_decoder_custom_args(trt)

yaml_text = dc.to_yaml_str()
assert isinstance(yaml_text, str) and len(yaml_text) > 0

dc2 = qec.decoder_config.from_yaml_str(yaml_text)

# Basic scalar fields
assert dc2 is not None
assert dc2.id == 0
assert dc2.type == "trt_decoder"
assert dc2.block_size == 10
assert dc2.syndrome_size == 3

# Recover TRT config from decoder_custom_args
trt2 = dc2.decoder_custom_args
assert trt2 is not None
assert trt2.engine_load_path == "/path/to/engine.trt"
assert trt2.precision == "fp16"
assert trt2.memory_workspace == 1073741824


# decoder_config tests


Expand Down
Loading