|
5 | 5 |
|
6 | 6 | use crate::{ |
7 | 7 | error::{OnnxError, Result}, |
8 | | - graph::{Graph, Node}, |
| 8 | + graph::Graph, |
9 | 9 | operators, |
10 | 10 | tensor::Tensor, |
11 | 11 | }; |
12 | 12 | use std::collections::HashMap; |
13 | 13 |
|
| 14 | +#[cfg(feature = "parallel")] |
| 15 | +use rayon::prelude::*; |
| 16 | + |
14 | 17 | /// Runtime execution engine for ONNX models |
15 | 18 | pub struct Runtime { |
16 | 19 | /// Whether to enable debug logging |
@@ -113,13 +116,86 @@ impl Runtime { |
113 | 116 | context.add_tensor(name.clone(), tensor.clone()); |
114 | 117 | } |
115 | 118 |
|
116 | | - // Get execution order |
117 | | - let execution_order = graph.topological_sort()?; |
118 | | - |
119 | | - // Execute nodes in order |
120 | | - for &node_idx in &execution_order { |
121 | | - let node = &graph.nodes[node_idx]; |
122 | | - self.execute_node(node, &mut context)?; |
| 119 | + // Group nodes into independent waves; nodes within the same wave |
| 120 | + // have no data dependencies on each other and can run in parallel. |
| 121 | + let levels = graph.topological_levels()?; |
| 122 | + let debug = self.debug; |
| 123 | + |
| 124 | + for level_nodes in &levels { |
| 125 | + // Phase 1: gather inputs for every node in this wave (sequential, |
| 126 | + // read-only access to context). |
| 127 | + let work: Vec<(usize, Vec<Tensor>)> = level_nodes |
| 128 | + .iter() |
| 129 | + .map(|&node_idx| { |
| 130 | + let node = &graph.nodes[node_idx]; |
| 131 | + let inputs = node |
| 132 | + .inputs |
| 133 | + .iter() |
| 134 | + .map(|name| { |
| 135 | + context.get_tensor(name).cloned().ok_or_else(|| { |
| 136 | + OnnxError::runtime_error(format!( |
| 137 | + "Node '{}' references unknown tensor '{}'", |
| 138 | + node.name, name |
| 139 | + )) |
| 140 | + }) |
| 141 | + }) |
| 142 | + .collect::<Result<Vec<_>>>()?; |
| 143 | + Ok((node_idx, inputs)) |
| 144 | + }) |
| 145 | + .collect::<Result<Vec<_>>>()?; |
| 146 | + |
| 147 | + // Phase 2: run operators — parallel when the `parallel` feature is |
| 148 | + // enabled, sequential otherwise. |
| 149 | + let run = |(node_idx, inputs): (usize, Vec<Tensor>)| -> (usize, Result<Vec<Tensor>>) { |
| 150 | + let node = &graph.nodes[node_idx]; |
| 151 | + if debug { |
| 152 | + log::debug!("Executing node '{}' ({})", node.name, node.op_type); |
| 153 | + for (i, t) in inputs.iter().enumerate() { |
| 154 | + log::debug!(" Input {}: shape {:?}", i, t.shape()); |
| 155 | + } |
| 156 | + } |
| 157 | + let result = node.get_operator_type().and_then(|op_type| { |
| 158 | + operators::execute_operator(&op_type, &inputs, &node.attributes).map_err(|e| { |
| 159 | + OnnxError::runtime_error(format!( |
| 160 | + "Failed to execute {:?} ({}): {}", |
| 161 | + op_type, node.name, e |
| 162 | + )) |
| 163 | + }) |
| 164 | + }); |
| 165 | + (node_idx, result) |
| 166 | + }; |
| 167 | + |
| 168 | + #[cfg(feature = "parallel")] |
| 169 | + let results: Vec<(usize, Result<Vec<Tensor>>)> = |
| 170 | + work.into_par_iter().map(run).collect(); |
| 171 | + #[cfg(not(feature = "parallel"))] |
| 172 | + let results: Vec<(usize, Result<Vec<Tensor>>)> = work.into_iter().map(run).collect(); |
| 173 | + |
| 174 | + // Phase 3: store outputs sequentially and update stats. |
| 175 | + for (node_idx, outputs_result) in results { |
| 176 | + let node = &graph.nodes[node_idx]; |
| 177 | + let output_tensors = outputs_result?; |
| 178 | + |
| 179 | + if output_tensors.len() != node.outputs.len() { |
| 180 | + return Err(OnnxError::runtime_error(format!( |
| 181 | + "Node '{}' produced {} outputs but expected {}", |
| 182 | + node.name, |
| 183 | + output_tensors.len(), |
| 184 | + node.outputs.len() |
| 185 | + ))); |
| 186 | + } |
| 187 | + |
| 188 | + if debug { |
| 189 | + for (i, t) in output_tensors.iter().enumerate() { |
| 190 | + log::debug!(" Output {}: shape {:?}", i, t.shape()); |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + for (name, tensor) in node.outputs.iter().zip(output_tensors) { |
| 195 | + context.add_tensor(name.clone(), tensor); |
| 196 | + } |
| 197 | + context.stats.ops_executed += 1; |
| 198 | + } |
123 | 199 | } |
124 | 200 |
|
125 | 201 | // Extract outputs |
@@ -175,87 +251,6 @@ impl Runtime { |
175 | 251 | Ok(()) |
176 | 252 | } |
177 | 253 |
|
178 | | - /// Execute a single node |
179 | | - fn execute_node(&self, node: &Node, context: &mut ExecutionContext) -> Result<()> { |
180 | | - let node_start = std::time::Instant::now(); |
181 | | - |
182 | | - if self.debug { |
183 | | - log::debug!("Executing node '{}' ({})", node.name, node.op_type); |
184 | | - } |
185 | | - |
186 | | - // Gather input tensors |
187 | | - let input_tensors: Vec<Tensor> = node |
188 | | - .inputs |
189 | | - .iter() |
190 | | - .map(|name| { |
191 | | - context |
192 | | - .get_tensor(name) |
193 | | - .ok_or_else(|| { |
194 | | - OnnxError::runtime_error(format!( |
195 | | - "Node '{}' references unknown tensor '{}'", |
196 | | - node.name, name |
197 | | - )) |
198 | | - }) |
199 | | - .cloned() |
200 | | - }) |
201 | | - .collect::<Result<Vec<_>>>()?; |
202 | | - |
203 | | - // Log input shapes for debugging |
204 | | - if self.debug { |
205 | | - for (i, tensor) in input_tensors.iter().enumerate() { |
206 | | - log::debug!(" Input {}: shape {:?}", i, tensor.shape()); |
207 | | - } |
208 | | - } |
209 | | - |
210 | | - // Execute the operator |
211 | | - let op_type = node.get_operator_type()?; |
212 | | - let output_tensors = |
213 | | - operators::execute_operator(&op_type, &input_tensors, &node.attributes).map_err( |
214 | | - |e| { |
215 | | - OnnxError::runtime_error(format!( |
216 | | - "Failed to execute {:?} ({}): {}", |
217 | | - op_type, node.name, e |
218 | | - )) |
219 | | - }, |
220 | | - )?; |
221 | | - |
222 | | - // Log output shapes for debugging |
223 | | - if self.debug { |
224 | | - for (i, tensor) in output_tensors.iter().enumerate() { |
225 | | - log::debug!(" Output {}: shape {:?}", i, tensor.shape()); |
226 | | - } |
227 | | - } |
228 | | - |
229 | | - // Store output tensors |
230 | | - if output_tensors.len() != node.outputs.len() { |
231 | | - return Err(OnnxError::runtime_error(format!( |
232 | | - "Node '{}' produced {} outputs but expected {}", |
233 | | - node.name, |
234 | | - output_tensors.len(), |
235 | | - node.outputs.len() |
236 | | - ))); |
237 | | - } |
238 | | - |
239 | | - for (output_name, output_tensor) in node.outputs.iter().zip(output_tensors.into_iter()) { |
240 | | - context.add_tensor(output_name.clone(), output_tensor); |
241 | | - } |
242 | | - |
243 | | - // Update statistics |
244 | | - let execution_time = node_start.elapsed().as_millis() as f64; |
245 | | - context.stats.ops_executed += 1; |
246 | | - *context |
247 | | - .stats |
248 | | - .op_times |
249 | | - .entry(node.op_type.clone()) |
250 | | - .or_insert(0.0) += execution_time; |
251 | | - |
252 | | - if self.debug { |
253 | | - log::debug!("Node '{}' executed in {:.2}ms", node.name, execution_time); |
254 | | - } |
255 | | - |
256 | | - Ok(()) |
257 | | - } |
258 | | - |
259 | 254 | /// Extract output tensors from the execution context |
260 | 255 | fn extract_outputs( |
261 | 256 | &self, |
|
0 commit comments