Skip to content

Commit 6abaeff

Browse files
committed
perf: add parallel wave execution via rayon
- Add rayon as an optional dependency behind a new `parallel` feature flag - Add Graph::topological_levels() which groups nodes into independent execution waves using tensor availability levels - Refactor Runtime::execute() to use wave-based scheduling: gather inputs sequentially, run operators in parallel per wave via rayon par_iter, then store outputs sequentially; falls back to iter when feature is off - Remove the now-unused execute_node() helper
1 parent 82e1221 commit 6abaeff

File tree

3 files changed

+137
-89
lines changed

3 files changed

+137
-89
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ env_logger = "0.11"
3434

3535
# Optional features
3636
tokio = { version = "1.0", features = ["full"], optional = true }
37+
rayon = { version = "1.10", optional = true }
3738

3839
[build-dependencies]
3940
prost-build = "0.14.1"
@@ -53,6 +54,7 @@ imageproc = "0.25"
5354
[features]
5455
default = []
5556
async = ["tokio"]
57+
parallel = ["rayon"]
5658
formal-verification = [] # Enable formal verification checks and contracts
5759

5860
[[bench]]

src/graph.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,57 @@ impl Graph {
275275
Ok(result)
276276
}
277277

278+
/// Group nodes into parallel execution waves.
279+
///
280+
/// Returns a list of levels where every node in a level is independent of
281+
/// every other node in that level (no data edges between them). Nodes in
282+
/// the same level can be executed concurrently; levels must be executed in
283+
/// order.
284+
pub fn topological_levels(&self) -> Result<Vec<Vec<usize>>> {
285+
let n = self.nodes.len();
286+
if n == 0 {
287+
return Ok(vec![]);
288+
}
289+
290+
// tensor_level[t] = the wave after which tensor t is available.
291+
// Graph inputs and initializers are available before wave 0 → level 0.
292+
let mut tensor_level: HashMap<&str, usize> = HashMap::new();
293+
for input in &self.inputs {
294+
tensor_level.insert(input.name.as_str(), 0);
295+
}
296+
for name in self.initializers.keys() {
297+
tensor_level.insert(name.as_str(), 0);
298+
}
299+
300+
// Process nodes in topological order so dependencies are resolved first.
301+
let topo_order = self.topological_sort()?;
302+
let mut node_level = vec![0usize; n];
303+
304+
for &idx in &topo_order {
305+
let node = &self.nodes[idx];
306+
// A node's wave = max wave of all its input tensors.
307+
let level = node
308+
.inputs
309+
.iter()
310+
.filter_map(|name| tensor_level.get(name.as_str()).copied())
311+
.max()
312+
.unwrap_or(0);
313+
node_level[idx] = level;
314+
// Outputs produced by this node become available at level + 1.
315+
for output in &node.outputs {
316+
tensor_level.insert(output.as_str(), level + 1);
317+
}
318+
}
319+
320+
let max_level = node_level.iter().copied().max().unwrap_or(0);
321+
let mut levels: Vec<Vec<usize>> = vec![vec![]; max_level + 1];
322+
for (idx, &lvl) in node_level.iter().enumerate() {
323+
levels[lvl].push(idx);
324+
}
325+
326+
Ok(levels)
327+
}
328+
278329
/// Print the graph structure in a visual ASCII format
279330
pub fn print_graph(&self) {
280331
// Calculate the width needed for the graph name

src/runtime.rs

Lines changed: 84 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
66
use crate::{
77
error::{OnnxError, Result},
8-
graph::{Graph, Node},
8+
graph::Graph,
99
operators,
1010
tensor::Tensor,
1111
};
1212
use std::collections::HashMap;
1313

14+
#[cfg(feature = "parallel")]
15+
use rayon::prelude::*;
16+
1417
/// Runtime execution engine for ONNX models
1518
pub struct Runtime {
1619
/// Whether to enable debug logging
@@ -113,13 +116,86 @@ impl Runtime {
113116
context.add_tensor(name.clone(), tensor.clone());
114117
}
115118

116-
// Get execution order
117-
let execution_order = graph.topological_sort()?;
118-
119-
// Execute nodes in order
120-
for &node_idx in &execution_order {
121-
let node = &graph.nodes[node_idx];
122-
self.execute_node(node, &mut context)?;
119+
// Group nodes into independent waves; nodes within the same wave
120+
// have no data dependencies on each other and can run in parallel.
121+
let levels = graph.topological_levels()?;
122+
let debug = self.debug;
123+
124+
for level_nodes in &levels {
125+
// Phase 1: gather inputs for every node in this wave (sequential,
126+
// read-only access to context).
127+
let work: Vec<(usize, Vec<Tensor>)> = level_nodes
128+
.iter()
129+
.map(|&node_idx| {
130+
let node = &graph.nodes[node_idx];
131+
let inputs = node
132+
.inputs
133+
.iter()
134+
.map(|name| {
135+
context.get_tensor(name).cloned().ok_or_else(|| {
136+
OnnxError::runtime_error(format!(
137+
"Node '{}' references unknown tensor '{}'",
138+
node.name, name
139+
))
140+
})
141+
})
142+
.collect::<Result<Vec<_>>>()?;
143+
Ok((node_idx, inputs))
144+
})
145+
.collect::<Result<Vec<_>>>()?;
146+
147+
// Phase 2: run operators — parallel when the `parallel` feature is
148+
// enabled, sequential otherwise.
149+
let run = |(node_idx, inputs): (usize, Vec<Tensor>)| -> (usize, Result<Vec<Tensor>>) {
150+
let node = &graph.nodes[node_idx];
151+
if debug {
152+
log::debug!("Executing node '{}' ({})", node.name, node.op_type);
153+
for (i, t) in inputs.iter().enumerate() {
154+
log::debug!(" Input {}: shape {:?}", i, t.shape());
155+
}
156+
}
157+
let result = node.get_operator_type().and_then(|op_type| {
158+
operators::execute_operator(&op_type, &inputs, &node.attributes).map_err(|e| {
159+
OnnxError::runtime_error(format!(
160+
"Failed to execute {:?} ({}): {}",
161+
op_type, node.name, e
162+
))
163+
})
164+
});
165+
(node_idx, result)
166+
};
167+
168+
#[cfg(feature = "parallel")]
169+
let results: Vec<(usize, Result<Vec<Tensor>>)> =
170+
work.into_par_iter().map(run).collect();
171+
#[cfg(not(feature = "parallel"))]
172+
let results: Vec<(usize, Result<Vec<Tensor>>)> = work.into_iter().map(run).collect();
173+
174+
// Phase 3: store outputs sequentially and update stats.
175+
for (node_idx, outputs_result) in results {
176+
let node = &graph.nodes[node_idx];
177+
let output_tensors = outputs_result?;
178+
179+
if output_tensors.len() != node.outputs.len() {
180+
return Err(OnnxError::runtime_error(format!(
181+
"Node '{}' produced {} outputs but expected {}",
182+
node.name,
183+
output_tensors.len(),
184+
node.outputs.len()
185+
)));
186+
}
187+
188+
if debug {
189+
for (i, t) in output_tensors.iter().enumerate() {
190+
log::debug!(" Output {}: shape {:?}", i, t.shape());
191+
}
192+
}
193+
194+
for (name, tensor) in node.outputs.iter().zip(output_tensors) {
195+
context.add_tensor(name.clone(), tensor);
196+
}
197+
context.stats.ops_executed += 1;
198+
}
123199
}
124200

125201
// Extract outputs
@@ -175,87 +251,6 @@ impl Runtime {
175251
Ok(())
176252
}
177253

178-
/// Execute a single node
179-
fn execute_node(&self, node: &Node, context: &mut ExecutionContext) -> Result<()> {
180-
let node_start = std::time::Instant::now();
181-
182-
if self.debug {
183-
log::debug!("Executing node '{}' ({})", node.name, node.op_type);
184-
}
185-
186-
// Gather input tensors
187-
let input_tensors: Vec<Tensor> = node
188-
.inputs
189-
.iter()
190-
.map(|name| {
191-
context
192-
.get_tensor(name)
193-
.ok_or_else(|| {
194-
OnnxError::runtime_error(format!(
195-
"Node '{}' references unknown tensor '{}'",
196-
node.name, name
197-
))
198-
})
199-
.cloned()
200-
})
201-
.collect::<Result<Vec<_>>>()?;
202-
203-
// Log input shapes for debugging
204-
if self.debug {
205-
for (i, tensor) in input_tensors.iter().enumerate() {
206-
log::debug!(" Input {}: shape {:?}", i, tensor.shape());
207-
}
208-
}
209-
210-
// Execute the operator
211-
let op_type = node.get_operator_type()?;
212-
let output_tensors =
213-
operators::execute_operator(&op_type, &input_tensors, &node.attributes).map_err(
214-
|e| {
215-
OnnxError::runtime_error(format!(
216-
"Failed to execute {:?} ({}): {}",
217-
op_type, node.name, e
218-
))
219-
},
220-
)?;
221-
222-
// Log output shapes for debugging
223-
if self.debug {
224-
for (i, tensor) in output_tensors.iter().enumerate() {
225-
log::debug!(" Output {}: shape {:?}", i, tensor.shape());
226-
}
227-
}
228-
229-
// Store output tensors
230-
if output_tensors.len() != node.outputs.len() {
231-
return Err(OnnxError::runtime_error(format!(
232-
"Node '{}' produced {} outputs but expected {}",
233-
node.name,
234-
output_tensors.len(),
235-
node.outputs.len()
236-
)));
237-
}
238-
239-
for (output_name, output_tensor) in node.outputs.iter().zip(output_tensors.into_iter()) {
240-
context.add_tensor(output_name.clone(), output_tensor);
241-
}
242-
243-
// Update statistics
244-
let execution_time = node_start.elapsed().as_millis() as f64;
245-
context.stats.ops_executed += 1;
246-
*context
247-
.stats
248-
.op_times
249-
.entry(node.op_type.clone())
250-
.or_insert(0.0) += execution_time;
251-
252-
if self.debug {
253-
log::debug!("Node '{}' executed in {:.2}ms", node.name, execution_time);
254-
}
255-
256-
Ok(())
257-
}
258-
259254
/// Extract output tensors from the execution context
260255
fn extract_outputs(
261256
&self,

0 commit comments

Comments
 (0)