@@ -339,6 +339,7 @@ def target_device(self):
339
339
@target_device .setter
340
340
def target_device (self , device : str ):
341
341
self ._target_device = device .lower ()
342
+ self .default_device_or_sharding = jax .local_devices ()[0 ]
342
343
343
344
def manual_seed (self , key ):
344
345
self ._prng_key = mutable_array (jax .random .key (key ))
@@ -359,10 +360,10 @@ def get_as_jax_device(self, device: Any):
359
360
return jax .devices ("cpu" )[0 ]
360
361
361
362
if self .config .treat_cuda_as_jax_device and device .startswith ("cuda" ):
362
- return jax . local_devices ()[ 0 ]
363
+ return self . default_device_or_sharding
363
364
364
365
if device .startswith ("xla" ):
365
- return jax . local_devices ()[ 0 ]
366
+ return self . default_device_or_sharding
366
367
367
368
# TODO (wen): jax is NOT a device type,
368
369
# once we can register more than one backend, revisit
@@ -461,6 +462,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
461
462
return the_tensor
462
463
463
464
jax_device = self .get_as_jax_device (new_device )
465
+
464
466
if jax_device :
465
467
arr = self .t2j_copy (the_tensor )
466
468
arr = jax .device_put (arr , jax_device )
@@ -488,15 +490,16 @@ def _handle_tensor_constructor(self, func, args, kwargs):
488
490
# let torch handle it
489
491
with mode_utils .no_dispatch (), torch ._C .DisableTorchFunction ():
490
492
return func (* args , ** kwargs )
491
- with jax .default_device (jax_device ):
492
- requires_grad = kwargs .get ("requires_grad" , False )
493
- op = self ._get_op_or_decomp (func )
494
- res = op .func (* args , ** kwargs )
495
- if isinstance (res , jax .Array ):
496
- res = Tensor (res , self )
497
- if requires_grad :
498
- res .requires_grad = True
499
- return res
493
+
494
+ requires_grad = kwargs .get ("requires_grad" , False )
495
+ op = self ._get_op_or_decomp (func )
496
+ res = op .func (* args , ** kwargs )
497
+ if isinstance (res , jax .Array ):
498
+ res = jax .device_put (res , jax_device )
499
+ res = Tensor (res , self )
500
+ if requires_grad :
501
+ res .requires_grad = True
502
+ return res
500
503
501
504
def _torch_Tensor_to (self , args , kwargs ):
502
505
the_tensor = args [0 ]
@@ -593,6 +596,10 @@ def is_not_torchax_tensor(x):
593
596
594
597
if self .config .debug_accuracy_for_each_op :
595
598
debug_accuracy (func , old_args , old_kwargs , res )
599
+
600
+ for r in torch_pytree .tree_flatten (res )[0 ]:
601
+ if isinstance (r , Tensor ) and r .dtype != super (torch .Tensor , r ).dtype :
602
+ breakpoint ()
596
603
return res
597
604
598
605
def enable_torch_modes (self ):
@@ -642,6 +649,9 @@ def to_jax(x):
642
649
if isinstance (
643
650
x , torch .distributed ._functional_collectives .AsyncCollectiveTensor ):
644
651
x = x .wait ()
652
+ if self .config .allow_mixed_tensor_for_scalar_tensor and not isinstance (x , Tensor ):
653
+ if x .squeeze ().ndim == 0 :
654
+ return x .item ()
645
655
assert isinstance (x , Tensor ) or isinstance (x , View ), (
646
656
f"Expect a Tensor or a View but got { type (x )} ; usually this means there is a mixed math between XLATensor and torch.Tensor"
647
657
)
0 commit comments