Skip to content

Commit 05716d2

Browse files
committed
changes
1 parent 7e3efc5 commit 05716d2

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

torchax/torchax/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ class Configuration:
1010

1111
use_int32_for_index: bool = False
1212

13+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
14+
# allowed. However, if that torch.Tensor happens to be scalar, then we
15+
# can use scalar * tensor math to handle it
16+
allow_mixed_tensor_for_scalar_tensor: bool = True
17+
1318
# If true, we will convert Views into torchax.Tensors eagerly
1419
force_materialize_views: bool = False
1520

torchax/torchax/ops/jtorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def _tensor(data, *, dtype=None, **kwargs):
5050
leaves = jax.tree_util.tree_leaves(data)
5151
if len(leaves) > 0:
5252
dtype = python_types_to_torch_types.get(type(leaves[0]))
53+
def to_scalar(x):
54+
if isinstance(x, torch.Tensor):
55+
return x.item()
56+
return x
57+
data = jax.tree.map(to_scalar, data)
5358

5459
return jnp.array(
5560
data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
@@ -566,3 +571,9 @@ def torch_Tensor_repeat_interleave(self,
566571
*,
567572
output_size=None):
568573
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
574+
575+
@register_function(torch.nn.functional.conv2d)
576+
def torch_conv2d(
577+
input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1
578+
):
579+
return jaten._aten_conv2d(input, weight, bias, stride, padding, dilation, groups)

torchax/torchax/tensor.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def target_device(self):
339339
@target_device.setter
340340
def target_device(self, device: str):
341341
self._target_device = device.lower()
342+
self.default_device_or_sharding = jax.local_devices()[0]
342343

343344
def manual_seed(self, key):
344345
self._prng_key = mutable_array(jax.random.key(key))
@@ -359,10 +360,10 @@ def get_as_jax_device(self, device: Any):
359360
return jax.devices("cpu")[0]
360361

361362
if self.config.treat_cuda_as_jax_device and device.startswith("cuda"):
362-
return jax.local_devices()[0]
363+
return self.default_device_or_sharding
363364

364365
if device.startswith("xla"):
365-
return jax.local_devices()[0]
366+
return self.default_device_or_sharding
366367

367368
# TODO (wen): jax is NOT a device type,
368369
# once we can register more than one backend, revisit
@@ -461,6 +462,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
461462
return the_tensor
462463

463464
jax_device = self.get_as_jax_device(new_device)
465+
464466
if jax_device:
465467
arr = self.t2j_copy(the_tensor)
466468
arr = jax.device_put(arr, jax_device)
@@ -488,15 +490,16 @@ def _handle_tensor_constructor(self, func, args, kwargs):
488490
# let torch handle it
489491
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
490492
return func(*args, **kwargs)
491-
with jax.default_device(jax_device):
492-
requires_grad = kwargs.get("requires_grad", False)
493-
op = self._get_op_or_decomp(func)
494-
res = op.func(*args, **kwargs)
495-
if isinstance(res, jax.Array):
496-
res = Tensor(res, self)
497-
if requires_grad:
498-
res.requires_grad = True
499-
return res
493+
494+
requires_grad = kwargs.get("requires_grad", False)
495+
op = self._get_op_or_decomp(func)
496+
res = op.func(*args, **kwargs)
497+
if isinstance(res, jax.Array):
498+
res = jax.device_put(res, jax_device)
499+
res = Tensor(res, self)
500+
if requires_grad:
501+
res.requires_grad = True
502+
return res
500503

501504
def _torch_Tensor_to(self, args, kwargs):
502505
the_tensor = args[0]
@@ -593,6 +596,10 @@ def is_not_torchax_tensor(x):
593596

594597
if self.config.debug_accuracy_for_each_op:
595598
debug_accuracy(func, old_args, old_kwargs, res)
599+
600+
for r in torch_pytree.tree_flatten(res)[0]:
601+
if isinstance(r, Tensor) and r.dtype != super(torch.Tensor, r).dtype:
602+
breakpoint()
596603
return res
597604

598605
def enable_torch_modes(self):
@@ -642,6 +649,9 @@ def to_jax(x):
642649
if isinstance(
643650
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
644651
x = x.wait()
652+
if self.config.allow_mixed_tensor_for_scalar_tensor and not isinstance(x, Tensor):
653+
if x.squeeze().ndim == 0:
654+
return x.item()
645655
assert isinstance(x, Tensor) or isinstance(x, View), (
646656
f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
647657
)

0 commit comments

Comments
 (0)