Skip to content

[Feature Request] Expose common device utilities in triton.language (dtype helpers, int64 program_id) #177

@kiddyjinjin

Description

@kiddyjinjin

Describe the issue

Summary

This issue is a feature request to consider exposing some of these utilities as first-class APIs in FlagTree (either as builtins or official device-library functions).


Context / Motivation

In a production setting, we need Triton kernels that:

  • Have convenient helpers for:
    • promoting scalars to tensors
    • checking whether a dtype is floating
    • using program_id / num_programs in int64 form

Right now we implement these on our side as Triton @jit device functions, but they feel generic enough to be part of flagtree itself.
Below is a simplified version of our current "extension" module.


Current local extension (simplified)

import triton
from triton import language as tl

# 1. program_id / num_programs returning int64

@triton.jit
def program_id(axis: int) -> tl.tensor:
    return tl.program_id(axis).to(tl.int64)

@triton.jit
def num_programs(axis: int) -> tl.tensor:
    return tl.num_programs(axis).to(tl.int64)


# 2. scalar → tensor promotion + dtype predicate

@triton.jit
def promote_to_tensor(x):
    # Addition promotes to tensor for us
    return x + tl.zeros((1,), tl.int1)

@triton.jit
def is_floating(x):
    return promote_to_tensor(x).dtype.is_floating()

Environment details

Triton:3.2
env: NVIDIA-A100 & NVIDIA-H800

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions