@@ -265,13 +265,35 @@ def __float__(self):
265265 DType (TensorProto .DOUBLE ),
266266 DType (TensorProto .FLOAT16 ),
267267 DType (TensorProto .BFLOAT16 ),
268+ DType (TensorProto .COMPLEX64 ),
269+ DType (TensorProto .COMPLEX128 ),
268270 }:
269271 raise TypeError (
270272 f"Conversion to float only works for float scalar, "
271273 f"not for dtype={ self .dtype } ."
272274 )
273275 return float (self ._tensor )
274276
277+ def __complex__ (self ):
278+ "Implicit conversion to complex."
279+ if self .shape :
280+ raise ValueError (
281+ f"Conversion to bool only works for scalar, not for { self !r} ."
282+ )
283+ if self .dtype not in {
284+ DType (TensorProto .FLOAT ),
285+ DType (TensorProto .DOUBLE ),
286+ DType (TensorProto .FLOAT16 ),
287+ DType (TensorProto .BFLOAT16 ),
288+ DType (TensorProto .COMPLEX64 ),
289+ DType (TensorProto .COMPLEX128 ),
290+ }:
291+ raise TypeError (
292+ f"Conversion to float only works for float scalar, "
293+ f"not for dtype={ self .dtype } ."
294+ )
295+ return complex (self ._tensor )
296+
275297 def __iter__ (self ):
276298 """
277299 The :epkg:`Array API` does not define this function (2022/12).
0 commit comments