|
48 | 48 | from ..speculative.mtp import SampleStateTensorsMTP
|
49 | 49 | from ..utils import (get_model_extra_attrs,
|
50 | 50 | set_per_request_piecewise_cuda_graph_flag,
|
51 |
| - set_torch_compiling, with_model_extra_attrs) |
| 51 | + set_prefer_mem_pool, set_torch_compiling, |
| 52 | + with_model_extra_attrs) |
52 | 53 | from .config import PyTorchConfig
|
53 | 54 | from .config_utils import is_mla
|
54 | 55 | from .cuda_graph_runner import CUDAGraphRunner
|
@@ -2186,35 +2187,35 @@ def forward(
|
2186 | 2187 | new_tensors_device, cache_indirection_buffer)
|
2187 | 2188 |
|
2188 | 2189 | self.iter_counter += 1
|
| 2190 | + with set_prefer_mem_pool(self.cuda_graph_runner.get_graph_pool()): |
| 2191 | + if not maybe_graph: |
| 2192 | + # Fallback to eager execution if graph was not used |
| 2193 | + with MoeLoadBalancerIterContext(moe_load_balancer): |
| 2194 | + outputs = self._forward_step(inputs, gather_ids, |
| 2195 | + gather_context_logits) |
| 2196 | + else: |
| 2197 | + if self.cuda_graph_runner.needs_capture(key): |
2189 | 2198 |
|
2190 |
| - if not maybe_graph: |
2191 |
| - # Fallback to eager execution if graph was not used |
2192 |
| - with MoeLoadBalancerIterContext(moe_load_balancer): |
2193 |
| - outputs = self._forward_step(inputs, gather_ids, |
2194 |
| - gather_context_logits) |
2195 |
| - else: |
2196 |
| - if self.cuda_graph_runner.needs_capture(key): |
2197 |
| - |
2198 |
| - def capture_forward_fn(inputs: Dict[str, Any]): |
2199 |
| - with MoeLoadBalancerIterContext(moe_load_balancer): |
2200 |
| - return self._forward_step( |
2201 |
| - inputs, |
2202 |
| - gather_ids=gather_ids, |
2203 |
| - gather_context_logits=gather_context_logits) |
| 2199 | + def capture_forward_fn(inputs: Dict[str, Any]): |
| 2200 | + with MoeLoadBalancerIterContext(moe_load_balancer): |
| 2201 | + return self._forward_step( |
| 2202 | + inputs, |
| 2203 | + gather_ids=gather_ids, |
| 2204 | + gather_context_logits=gather_context_logits) |
2204 | 2205 |
|
2205 |
| - def capture_postprocess_fn(inputs: Dict[str, Any]): |
2206 |
| - self._postprocess_inputs(inputs) |
| 2206 | + def capture_postprocess_fn(inputs: Dict[str, Any]): |
| 2207 | + self._postprocess_inputs(inputs) |
2207 | 2208 |
|
2208 |
| - self.cuda_graph_runner.capture(key, capture_forward_fn, |
2209 |
| - inputs, |
2210 |
| - capture_postprocess_fn) |
| 2209 | + self.cuda_graph_runner.capture(key, capture_forward_fn, |
| 2210 | + inputs, |
| 2211 | + capture_postprocess_fn) |
2211 | 2212 |
|
2212 |
| - # here we don't need to use context since cuda graph capture didn't run kernel. |
2213 |
| - # maybe we need a cleaner way to do this. |
2214 |
| - outputs = self.cuda_graph_runner.replay(key, inputs) |
2215 |
| - else: |
2216 |
| - with MoeLoadBalancerIterContext(moe_load_balancer): |
| 2213 | + # here we don't need to use context since cuda graph capture didn't run kernel. |
| 2214 | + # maybe we need a cleaner way to do this. |
2217 | 2215 | outputs = self.cuda_graph_runner.replay(key, inputs)
|
| 2216 | + else: |
| 2217 | + with MoeLoadBalancerIterContext(moe_load_balancer): |
| 2218 | + outputs = self.cuda_graph_runner.replay(key, inputs) |
2218 | 2219 |
|
2219 | 2220 | self._execute_logit_post_processors(scheduled_requests, outputs)
|
2220 | 2221 |
|
|
0 commit comments