-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Describe the issue
Summary
This issue is a feature request to consider exposing some of these utilities as first-class APIs in FlagTree (either as builtins or official device-library functions).
Context / Motivation
In a production setting, we need Triton kernels that:
- Have convenient helpers for:
- promoting scalars to tensors
- checking whether a dtype is floating
- using
program_id/num_programsin int64 form
Right now we implement these on our side as Triton @jit device functions, but they feel generic enough to be part of flagtree itself.
Below is a simplified version of our current "extension" module.
Current local extension (simplified)
import triton
from triton import language as tl
# 1. program_id / num_programs returning int64
@triton.jit
def program_id(axis: int) -> tl.tensor:
return tl.program_id(axis).to(tl.int64)
@triton.jit
def num_programs(axis: int) -> tl.tensor:
return tl.num_programs(axis).to(tl.int64)
# 2. scalar → tensor promotion + dtype predicate
@triton.jit
def promote_to_tensor(x):
# Addition promotes to tensor for us
return x + tl.zeros((1,), tl.int1)
@triton.jit
def is_floating(x):
return promote_to_tensor(x).dtype.is_floating()Environment details
Triton:3.2
env: NVIDIA-A100 & NVIDIA-H800
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels