Skip to content

Commit 05fff03

Browse files
committed
a
Signed-off-by: yizhang-nv <[email protected]>
1 parent 4ab7d6e commit 05fff03

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def __call__(self, *args):
172172
return self.default_callable(*args)
173173

174174
if self.is_first_runner or self.is_last_runner:
175-
set_piecewise_running(self.is_first_runner)
175+
if self.is_first_runner == self.is_last_runner:
176+
set_piecewise_running(False)
177+
else:
178+
set_piecewise_running(self.is_first_runner)
176179

177180
entry = self.entries[runtime_num_of_token]
178181

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,14 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
7777

7878

7979
def maybe_compile(func):
80-
if is_piecewise_running():
81-
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
82-
return func
83-
return torch.compile(func)
80+
81+
def wrapper(*args, **kwargs):
82+
if is_piecewise_running():
83+
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
84+
return func(*args, **kwargs)
85+
return torch.compile(func)(*args, **kwargs)
86+
87+
return wrapper
8488

8589

8690
@maybe_compile

0 commit comments

Comments
 (0)