@@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None):
170170 "indexing" : default_integral ,
171171 }
172172
173-
174173 def _dtypes (self , kind ):
175- bool = torch .bool
176- int8 = torch .int8
177- int16 = torch .int16
178- int32 = torch .int32
179- int64 = torch .int64
180- uint8 = torch .uint8
181- # uint16, uint32, and uint64 are present in newer versions of pytorch,
182- # but they aren't generally supported by the array API functions, so
183- # we omit them from this function.
184- float32 = torch .float32
185- float64 = torch .float64
186- complex64 = torch .complex64
187- complex128 = torch .complex128
188-
189174 if kind is None :
190- return {
191- "bool" : bool ,
192- "int8" : int8 ,
193- "int16" : int16 ,
194- "int32" : int32 ,
195- "int64" : int64 ,
196- "uint8" : uint8 ,
197- "float32" : float32 ,
198- "float64" : float64 ,
199- "complex64" : complex64 ,
200- "complex128" : complex128 ,
201- }
175+ return self ._dtypes (
176+ (
177+ "bool" ,
178+ "signed integer" ,
179+ "unsigned integer" ,
180+ "real floating" ,
181+ "complex floating" ,
182+ )
183+ )
202184 if kind == "bool" :
203- return {"bool" : bool }
185+ return {"bool" : torch . bool }
204186 if kind == "signed integer" :
205187 return {
206- "int8" : int8 ,
207- "int16" : int16 ,
208- "int32" : int32 ,
209- "int64" : int64 ,
188+ "int8" : torch . int8 ,
189+ "int16" : torch . int16 ,
190+ "int32" : torch . int32 ,
191+ "int64" : torch . int64 ,
210192 }
211193 if kind == "unsigned integer" :
212- return {
213- "uint8" : uint8 ,
214- }
194+ try :
195+ # torch >=2.3
196+ return {
197+ "uint8" : torch .uint8 ,
198+ "uint16" : torch .uint16 ,
199+ "uint32" : torch .uint32 ,
200+ "uint64" : torch .uint32 ,
201+ }
202+ except AttributeError :
203+ return {"uint8" : torch .uint8 }
215204 if kind == "integral" :
216- return {
217- "int8" : int8 ,
218- "int16" : int16 ,
219- "int32" : int32 ,
220- "int64" : int64 ,
221- "uint8" : uint8 ,
222- }
205+ return self ._dtypes (("signed integer" , "unsigned integer" ))
223206 if kind == "real floating" :
224207 return {
225- "float32" : float32 ,
226- "float64" : float64 ,
208+ "float32" : torch . float32 ,
209+ "float64" : torch . float64 ,
227210 }
228211 if kind == "complex floating" :
229212 return {
230- "complex64" : complex64 ,
231- "complex128" : complex128 ,
213+ "complex64" : torch . complex64 ,
214+ "complex128" : torch . complex128 ,
232215 }
233216 if kind == "numeric" :
234- return {
235- "int8" : int8 ,
236- "int16" : int16 ,
237- "int32" : int32 ,
238- "int64" : int64 ,
239- "uint8" : uint8 ,
240- "float32" : float32 ,
241- "float64" : float64 ,
242- "complex64" : complex64 ,
243- "complex128" : complex128 ,
244- }
217+ return self ._dtypes (
218+ (
219+ "signed integer" ,
220+ "unsigned integer" ,
221+ "real floating" ,
222+ "complex floating" ,
223+ )
224+ )
245225 if isinstance (kind , tuple ):
246226 res = {}
247227 for k in kind :
@@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None):
261241 ----------
262242 device : Device, optional
263243 The device to get the data types for.
264- Unused for PyTorch, as all devices use the same dtypes.
265244 kind : str or tuple of str, optional
266245 The kind of data types to return. If ``None``, all data types are
267246 returned. If a string, only data types of that kind are returned.
0 commit comments