From 565f925fac9a365bad558f8c053232e3ad1a2ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabien=20P=C3=A9an?= Date: Mon, 15 Jun 2026 14:27:00 +0200 Subject: [PATCH] Implement union operator support for array annotations and add related tests (GH-1548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabien Péan --- CHANGELOG.md | 2 ++ warp/_src/types.py | 10 +++++++ warp/tests/test_subscript_types.py | 42 +++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df48b5640d..a7707c8f13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,8 @@ - Fix `wp.copy()` ignoring `src_offset`, `dest_offset`, and `count` for 1D non-contiguous arrays such as strided slices ([GH-1533](https://github.com/NVIDIA/warp/issues/1533)). - Fix the gradient of `wp.copy()` when `src_offset` and `dest_offset` differ. +- Fix `wp.array[...]`-style subscript annotations to allow for PEP 604 unions (for example, `wp.array[float] | float`) + on Python methods ([GH-1548](https://github.com/NVIDIA/warp/issues/1548)). ### Documentation diff --git a/warp/_src/types.py b/warp/_src/types.py index 74aff6176c..2e9d4c7987 100644 --- a/warp/_src/types.py +++ b/warp/_src/types.py @@ -5023,6 +5023,12 @@ def __eq__(self, other): def __hash__(self): return hash((self._concrete_cls, self.dtype, self.ndim)) + def __or__(self, other): + return Union[self, other] # noqa: UP007 + + def __ror__(self, other): + return Union[other, self] # noqa: UP007 + class _ArrayAnnotation(_ArrayAnnotationBase): """Lightweight annotation for :class:`array` types.""" @@ -7201,6 +7207,10 @@ def get_type_code(arg_type) -> str: # This must come before isinstance(arg_type, type) check arg_types = arg_type.__args__ return f"tpl{len(arg_types)}{''.join(get_type_code(x) for x in arg_types)}" + elif get_origin(arg_type) is Union: + raise TypeError( + "Union type annotations are only supported at Python scope and are invalid in Warp kernels/functions" + ) elif isinstance(arg_type, type): if hasattr(arg_type, "_wp_scalar_type_"): # vector/matrix type diff --git a/warp/tests/test_subscript_types.py b/warp/tests/test_subscript_types.py index 32286dc2b7..beb51aaefd 100644 --- a/warp/tests/test_subscript_types.py +++ b/warp/tests/test_subscript_types.py @@ -4,7 +4,7 @@ """Tests for subscript-style type annotations on arrays and tiles, plus the internal Vector/Matrix/Quaternion/Transformation generics.""" import unittest -from typing import Any, Literal, TypeVar, get_origin +from typing import Any, Literal, TypeVar, Union, get_origin import numpy as np @@ -485,6 +485,46 @@ def test_annotation_vars_caching(self): vars2 = arr_ann.vars self.assertIs(vars1, vars2) + def test_annotation_union_operator(self): + """Array annotations support runtime union expressions in both operand orders.""" + arr_ann = wp.array[float] + ia_ann = wp.indexedarray[wp.float64] + + u1 = arr_ann | float + self.assertIs(get_origin(u1), Union) + self.assertEqual(u1.__args__, (arr_ann, float)) + + u2 = float | arr_ann + self.assertIs(get_origin(u2), Union) + self.assertEqual(u2.__args__, (float, arr_ann)) + + u3 = arr_ann | None + self.assertIs(get_origin(u3), Union) + self.assertCountEqual(u3.__args__, (arr_ann, type(None))) + + u4 = None | ia_ann + self.assertIs(get_origin(u4), Union) + self.assertCountEqual(u4.__args__, (ia_ann, type(None))) + + # Chaining produces a regular Union annotation. + chained = arr_ann | float | ia_ann | float + self.assertIs(get_origin(chained), Union) + self.assertIn(arr_ann, chained.__args__) + self.assertIn(float, chained.__args__) + self.assertIn(ia_ann, chained.__args__) + + def test_annotation_union_invalid_for_codegen(self): + """Union annotations are intentionally invalid in Warp codegen.""" + + with self.assertRaisesRegex( + RuntimeError, + r"Union type annotations are only supported at Python scope", + ): + + @wp.func + def _invalid_union_arg(a: wp.array[float] | float) -> float: + return 0.0 + def test_annotation_helpers(self): """Test matches_array_class and concrete_array_type helpers.""" arr_ann = wp.array[float]