Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DLPack Support for np.uint16 seems broken #24680

Open
realquantumcookie opened this issue Nov 3, 2024 · 2 comments
Open

DLPack Support for np.uint16 seems broken #24680

realquantumcookie opened this issue Nov 3, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@realquantumcookie
Copy link

realquantumcookie commented Nov 3, 2024

Description

JAX_PLATFORMS=cpu CUDA_VISIBLE_DEVICES="" python
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jax.numpy as jnp
>>> import torch
>>> import numpy as np
>>> a = jnp.zeros((1,3,5,8), dtype=np.uint16)
>>> jnp.all(a)
Array(False, dtype=bool)
>>> b = torch.utils.dlpack.from_dlpack(a)
>>> b
tensor([[[[0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0]]]], dtype=torch.uint16)
>>> b.shape
torch.Size([1, 3, 5, 8])
>>> c = jax.dlpack.from_dlpack(b)
>>> a.layout
Layout(device_local_layout=DeviceLocalLayout(major_to_minor=(0, 1, 2, 3), _tiling=(), _sub_byte_element_size_in_bits=0), sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host))
>>> c.layout
Layout(device_local_layout=DeviceLocalLayout(major_to_minor=(1, 2, 0, 3), _tiling=(), _sub_byte_element_size_in_bits=0), sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host))
>>> jnp.all(c)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/quantumcookie/miniconda3/envs/UniEnvPy/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 596, in all
    return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
AssertionError: Unexpected XLA layout override: (XLA) DeviceLocalLayout(major_to_minor=(0, 1, 2, 3), _tiling=(), _sub_byte_element_size_in_bits=0) != DeviceLocalLayout(major_to_minor=(1, 2, 0, 3), _tiling=(), _sub_byte_element_size_in_bits=0) (User input layout)
>>> exit()

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='LabPC', release='6.8.0-47-generic', version='#47~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Oct 2 16:16:55 UTC 2', machine='x86_64')

$ nvidia-smi
Sun Nov 3 14:56:27 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02 Driver Version: 550.107.02 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
| 0% 43C P2 17W / 450W | 655MiB / 24564MiB | 6% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 242483 G /usr/lib/xorg/Xorg 132MiB |
| 0 N/A N/A 242640 G /usr/bin/gnome-shell 12MiB |
| 0 N/A N/A 245588 G ...erProcess --variations-seed-version 30MiB |
| 0 N/A N/A 268770 G ...irefox/5187/usr/lib/firefox/firefox 65MiB |
| 0 N/A N/A 2020781 C python 386MiB |
+-----------------------------------------------------------------------------------------+

@realquantumcookie realquantumcookie added the bug Something isn't working label Nov 3, 2024
@realquantumcookie
Copy link
Author

realquantumcookie commented Nov 3, 2024

Looks like this also affects np.float32.
I have torch 2.5.1 + cuda12.4 and jax[cuda12]==0.4.35

JAX_PLATFORMS=cpu python
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> a = torch.zeros((3,1,6,2), dtype=torch.float32)
>>> b = jax.dlpack.from_dlpack(a)
>>> c = jax.dlpack.from_dlpack(a.contiguous())
>>> b.layout
Layout(device_local_layout=DeviceLocalLayout(major_to_minor=(0, 2, 1, 3), _tiling=(), _sub_byte_element_size_in_bits=0), sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host))
>>> c.layout
Layout(device_local_layout=DeviceLocalLayout(major_to_minor=(0, 2, 1, 3), _tiling=(), _sub_byte_element_size_in_bits=0), sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host))
>>> jnp.all(c)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/quantumcookie/miniconda3/envs/UniEnvPy/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 596, in all
    return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
AssertionError: Unexpected XLA layout override: (XLA) DeviceLocalLayout(major_to_minor=(0, 1, 2, 3), _tiling=(), _sub_byte_element_size_in_bits=0) != DeviceLocalLayout(major_to_minor=(0, 2, 1, 3), _tiling=(), _sub_byte_element_size_in_bits=0) (User input layout)

@dfm
Copy link
Collaborator

dfm commented Nov 3, 2024

Thanks for the report! I can also reproduce the error, and there seem to be a few strange things going on.

  • Your intuition to check the layouts is a good one, and I don't understand why from_dlpack produces this (incorrect, I'd say) layout. I wonder if @superbobry might have any ideas about whats going on there?

  • But, even with this strange layout, I wouldn't expect jnp.all to crash with a layout error. @yashk2810, do you have an idea for why that would happen here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants