Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
([GH-1511](https://github.com/NVIDIA/warp/issues/1511)).
- Fix `@wp.overload` kernel stubs defined in nested scopes to register correctly instead of raising
`IndentationError` ([GH-1557](https://github.com/NVIDIA/warp/issues/1557)).
- Fix `wp.array[...]`-style subscript annotations to allow for PEP 604 unions (for example, `wp.array[float] | float`)
on Python methods ([GH-1548](https://github.com/NVIDIA/warp/issues/1548)).

### Documentation

Expand Down
10 changes: 10 additions & 0 deletions warp/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5023,6 +5023,12 @@ def __eq__(self, other):
def __hash__(self):
return hash((self._concrete_cls, self.dtype, self.ndim))

def __or__(self, other):
return Union[self, other] # noqa: UP007

def __ror__(self, other):
return Union[other, self] # noqa: UP007


class _ArrayAnnotation(_ArrayAnnotationBase):
"""Lightweight annotation for :class:`array` types."""
Expand Down Expand Up @@ -7201,6 +7207,10 @@ def get_type_code(arg_type) -> str:
# This must come before isinstance(arg_type, type) check
arg_types = arg_type.__args__
return f"tpl{len(arg_types)}{''.join(get_type_code(x) for x in arg_types)}"
elif get_origin(arg_type) is Union:
raise TypeError(
"Union type annotations are only supported at Python scope and are invalid in Warp kernels/functions"
)
elif isinstance(arg_type, type):
if hasattr(arg_type, "_wp_scalar_type_"):
# vector/matrix type
Expand Down
42 changes: 41 additions & 1 deletion warp/tests/test_subscript_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Tests for subscript-style type annotations on arrays and tiles, plus the internal Vector/Matrix/Quaternion/Transformation generics."""

import unittest
from typing import Any, Literal, TypeVar, get_origin
from typing import Any, Literal, TypeVar, Union, get_origin

import numpy as np

Expand Down Expand Up @@ -485,6 +485,46 @@ def test_annotation_vars_caching(self):
vars2 = arr_ann.vars
self.assertIs(vars1, vars2)

def test_annotation_union_operator(self):
"""Array annotations support runtime union expressions in both operand orders."""
arr_ann = wp.array[float]
ia_ann = wp.indexedarray[wp.float64]

u1 = arr_ann | float
self.assertIs(get_origin(u1), Union)
self.assertEqual(u1.__args__, (arr_ann, float))

u2 = float | arr_ann
self.assertIs(get_origin(u2), Union)
self.assertEqual(u2.__args__, (float, arr_ann))

u3 = arr_ann | None
self.assertIs(get_origin(u3), Union)
self.assertCountEqual(u3.__args__, (arr_ann, type(None)))

u4 = None | ia_ann
self.assertIs(get_origin(u4), Union)
self.assertCountEqual(u4.__args__, (ia_ann, type(None)))

# Chaining produces a regular Union annotation.
chained = arr_ann | float | ia_ann | float
self.assertIs(get_origin(chained), Union)
self.assertIn(arr_ann, chained.__args__)
self.assertIn(float, chained.__args__)
self.assertIn(ia_ann, chained.__args__)

def test_annotation_union_invalid_for_codegen(self):
"""Union annotations are intentionally invalid in Warp codegen."""

with self.assertRaisesRegex(
RuntimeError,
r"Union type annotations are only supported at Python scope",
):

@wp.func
def _invalid_union_arg(a: wp.array[float] | float) -> float:
return 0.0

def test_annotation_helpers(self):
"""Test matches_array_class and concrete_array_type helpers."""
arr_ann = wp.array[float]
Expand Down