@@ -151,7 +151,7 @@ def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover
151
151
def is_tensorflow_array (obj : Any ) -> bool : # pragma: no cover
152
152
if not has_tensorflow :
153
153
return False
154
- elif isinstance (obj , tf .Tensor ):
154
+ elif isinstance (obj , tf .Tensor ): # type: ignore
155
155
return True
156
156
else :
157
157
return False
@@ -164,7 +164,7 @@ def is_tensorflow_gpu_array(obj: Any) -> bool: # pragma: no cover
164
164
def is_mxnet_array (obj : Any ) -> bool : # pragma: no cover
165
165
if not has_mxnet :
166
166
return False
167
- elif isinstance (obj , mx .nd .NDArray ):
167
+ elif isinstance (obj , mx .nd .NDArray ): # type: ignore
168
168
return True
169
169
else :
170
170
return False
@@ -316,15 +316,17 @@ def get_width(
316
316
317
317
def assert_tensorflow_installed () -> None : # pragma: no cover
318
318
"""Raise an ImportError if TensorFlow is not installed."""
319
- template = "TensorFlow support requires {pkg}: pip install thinc[tensorflow]"
319
+ template = "TensorFlow support requires {pkg}: pip install thinc[tensorflow]\n \n Enable TensorFlow support with thinc.api.enable_tensorflow() "
320
320
if not has_tensorflow :
321
- raise ImportError (template .format (pkg = "tensorflow>=2.0.0" ))
321
+ raise ImportError (template .format (pkg = "tensorflow>=2.0.0,<2.6.0 " ))
322
322
323
323
324
324
def assert_mxnet_installed () -> None : # pragma: no cover
325
325
"""Raise an ImportError if MXNet is not installed."""
326
326
if not has_mxnet :
327
- raise ImportError ("MXNet support requires mxnet: pip install thinc[mxnet]" )
327
+ raise ImportError (
328
+ "MXNet support requires mxnet: pip install thinc[mxnet]\n \n Enable MXNet support with thinc.api.enable_mxnet()"
329
+ )
328
330
329
331
330
332
def assert_pytorch_installed () -> None : # pragma: no cover
@@ -429,32 +431,32 @@ def torch2xp(
429
431
430
432
def xp2tensorflow (
431
433
xp_tensor : ArrayXd , requires_grad : bool = False , as_variable : bool = False
432
- ) -> "tf.Tensor" : # pragma: no cover
434
+ ) -> "tf.Tensor" : # type: ignore # pragma: no cover
433
435
"""Convert a numpy or cupy tensor to a TensorFlow Tensor or Variable"""
434
436
assert_tensorflow_installed ()
435
437
if hasattr (xp_tensor , "toDlpack" ):
436
438
dlpack_tensor = xp_tensor .toDlpack () # type: ignore
437
- tf_tensor = tf .experimental .dlpack .from_dlpack (dlpack_tensor )
439
+ tf_tensor = tf .experimental .dlpack .from_dlpack (dlpack_tensor ) # type: ignore
438
440
elif hasattr (xp_tensor , "__dlpack__" ):
439
441
dlpack_tensor = xp_tensor .__dlpack__ () # type: ignore
440
- tf_tensor = tf .experimental .dlpack .from_dlpack (dlpack_tensor )
442
+ tf_tensor = tf .experimental .dlpack .from_dlpack (dlpack_tensor ) # type: ignore
441
443
else :
442
- tf_tensor = tf .convert_to_tensor (xp_tensor )
444
+ tf_tensor = tf .convert_to_tensor (xp_tensor ) # type: ignore
443
445
if as_variable :
444
446
# tf.Variable() automatically puts in GPU if available.
445
447
# So we need to control it using the context manager
446
- with tf .device (tf_tensor .device ):
447
- tf_tensor = tf .Variable (tf_tensor , trainable = requires_grad )
448
+ with tf .device (tf_tensor .device ): # type: ignore
449
+ tf_tensor = tf .Variable (tf_tensor , trainable = requires_grad ) # type: ignore
448
450
if requires_grad is False and as_variable is False :
449
451
# tf.stop_gradient() automatically puts in GPU if available.
450
452
# So we need to control it using the context manager
451
- with tf .device (tf_tensor .device ):
452
- tf_tensor = tf .stop_gradient (tf_tensor )
453
+ with tf .device (tf_tensor .device ): # type: ignore
454
+ tf_tensor = tf .stop_gradient (tf_tensor ) # type: ignore
453
455
return tf_tensor
454
456
455
457
456
458
def tensorflow2xp (
457
- tf_tensor : "tf.Tensor" , * , ops : Optional ["Ops" ] = None
459
+ tf_tensor : "tf.Tensor" , * , ops : Optional ["Ops" ] = None # type: ignore
458
460
) -> ArrayXd : # pragma: no cover
459
461
"""Convert a Tensorflow tensor to numpy or cupy tensor depending on the `ops` parameter.
460
462
If `ops` is `None`, the type of the resultant tensor will be determined by the source tensor's device.
@@ -466,7 +468,7 @@ def tensorflow2xp(
466
468
if isinstance (ops , NumpyOps ):
467
469
return tf_tensor .numpy ()
468
470
else :
469
- dlpack_tensor = tf .experimental .dlpack .to_dlpack (tf_tensor )
471
+ dlpack_tensor = tf .experimental .dlpack .to_dlpack (tf_tensor ) # type: ignore
470
472
return cupy_from_dlpack (dlpack_tensor )
471
473
else :
472
474
if isinstance (ops , NumpyOps ) or ops is None :
@@ -477,21 +479,21 @@ def tensorflow2xp(
477
479
478
480
def xp2mxnet (
479
481
xp_tensor : ArrayXd , requires_grad : bool = False
480
- ) -> "mx.nd.NDArray" : # pragma: no cover
482
+ ) -> "mx.nd.NDArray" : # type: ignore # pragma: no cover
481
483
"""Convert a numpy or cupy tensor to a MXNet tensor."""
482
484
assert_mxnet_installed ()
483
485
if hasattr (xp_tensor , "toDlpack" ):
484
486
dlpack_tensor = xp_tensor .toDlpack () # type: ignore
485
- mx_tensor = mx .nd .from_dlpack (dlpack_tensor )
487
+ mx_tensor = mx .nd .from_dlpack (dlpack_tensor ) # type: ignore
486
488
else :
487
- mx_tensor = mx .nd .from_numpy (xp_tensor )
489
+ mx_tensor = mx .nd .from_numpy (xp_tensor ) # type: ignore
488
490
if requires_grad :
489
491
mx_tensor .attach_grad ()
490
492
return mx_tensor
491
493
492
494
493
495
def mxnet2xp (
494
- mx_tensor : "mx.nd.NDArray" , * , ops : Optional ["Ops" ] = None
496
+ mx_tensor : "mx.nd.NDArray" , * , ops : Optional ["Ops" ] = None # type: ignore
495
497
) -> ArrayXd : # pragma: no cover
496
498
"""Convert a MXNet tensor to a numpy or cupy tensor."""
497
499
from .api import NumpyOps
0 commit comments