Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with jax #98

Open
WilliamLi0623 opened this issue Sep 3, 2022 · 9 comments
Open

Problems with jax #98

WilliamLi0623 opened this issue Sep 3, 2022 · 9 comments

Comments

@WilliamLi0623
Copy link

I installed the required libraries by pip install -r requirement.txt. The CUDA works well and the GPU can be found by tensorflow. However, when I try to run the code, an error occurs.

"Unable to initialize backend 'cuda'": module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

I searched on the web and found it might be caused by the version of the libraries.

Could you please share the versions of those packages you used with me? Thank you very much.
微信图片_20220903050002

@prateekiiest
Copy link

prateekiiest commented Jan 21, 2023

Yes, same for me @WilliamLi0623.
Current solution : use pip install -U jaxlib=={version no}+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

cc @PetarV-

@prateekiiest
Copy link

it still doesnt work

@prateekiiest
Copy link

log dump

2023-01-21 20:01:53.374343: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:53.537534: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:53.539034: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:54.230774: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316097: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316175: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-01-21 20:01:55.316187: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-01-21 20:01:59.913842: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.913949: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914012: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914076: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914133: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcurand.so.10'; dlerror: libcurand.so.10: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914191: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusolver.so.11'; dlerror: libcusolver.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914254: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914312: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2023-01-21 20:01:59.914338: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I0121 20:01:59.939431 139832542189376 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0121 20:02:00.399612 139832542189376 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter
I0121 20:02:00.400197 139832542189376 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0121 20:02:00.400516 139832542189376 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2023-01-21 20:02:01.082623: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:109] Couldn't get ptxas version : FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2023-01-21 20:02:01.083637: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:451] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to launch ptxas'  If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
Fatal Python error: Aborted

Thread 0x00007f2d4d015740 (most recent call first):
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1014 in backend_compile
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/profiler.py", line 314 in wrapper
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1079 in compile_or_get_cached
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3439 in from_hlo
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3170 in _compile_unloaded
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3202 in compile
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359 in _xla_callable_uncached
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 202 in xla_primitive_callable
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/util.py", line 247 in cached
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/util.py", line 254 in wrapper
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/dispatch.py", line 118 in apply_primitive
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 712 in process_primitive
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 332 in bind_with_trace
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 329 in bind
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 509 in shift_right_logical
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 827 in threefry_seed
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 592 in random_seed_impl_base
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 587 in random_seed_impl
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 712 in process_primitive
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 332 in bind_with_trace
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/core.py", line 329 in bind
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 575 in random_seed
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/prng.py", line 267 in seed_with_impl
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/jax/_src/random.py", line 133 in PRNGKey
  File "/mnt/infonas/data/prateekch/clrs/clrs/examples/run.py", line 379 in main
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/absl/app.py", line 254 in _run_main
  File "/mnt/infonas/data/prateekch/penv/lib/python3.8/site-packages/absl/app.py", line 308 in run
  File "/mnt/infonas/data/prateekch/clrs/clrs/examples/run.py", line 537 in <module>
  File "/usr/lib/python3.8/runpy.py", line 87 in _run_code
  File "/usr/lib/python3.8/runpy.py", line 194 in _run_module_as_main
Aborted (core dumped)

@bibarzgoogle
Copy link
Collaborator

Sorry that the package sometimes doesn't work on GPU out-of-the-box. You need to make sure that JAX is installed in cuda-compatible version. Maybe try the following steps:

python3 -m venv clrs_env
source clrs_env/bin/activate
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install git+https://github.com/deepmind/clrs.git

If the jax[cuda] install is successful, you should be able to run

import jax
print(jax.local_devices())

in a python interpreter and see your GPU listed among the devices.

The JAX installation guide has good pointers to potential problems with the JAX installation on GPU. One thing to keep in mind is that JAX expects the CUDA installation to be at /usr/local/cuda-X.X . If the CUDA libraries are somewhere else, try creating a symlink

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

@mcleish7
Copy link
Contributor

Good Afternoon,

I am still experiencing issues, when I run python3 -m clrs.examples.run, I get this error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure. The full error is shown below, I have had a look but couldn't find anything to fix this. Any help would be appreciated.

Thanks,
Sean

I am running Python3.9 with JAX:
jax 0.4.4
jaxlib 0.4.4+cuda11.cudnn86

	_PyObject_MakeTpCall
	
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	
	Py_RunMain
	Py_BytesMain
	__libc_start_main
	_start
*** End stack trace ***

Traceback (most recent call last):
  File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
    app.run(main)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 380, in main
    rng_key = jax.random.PRNGKey(rng.randint(2**32))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/random.py", line 136, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 570, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 343, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 346, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 789, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 582, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 587, in random_seed_impl_base
    return seed(seeds)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/prng.py", line 822, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/lax/lax.py", line 511, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packagesserialized_computation/jax/_src/core.py", line 343, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 346, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/core.py", line 789, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 123, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 202, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 355, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3254, in compile
    executable = self._compile_unloaded(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3225, in _compile_unloaded
    return UnloadedMeshExecutable.from_hlo(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3512, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 1095, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 1040, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:627) dnn != nullptr 

@bibarzgoogle
Copy link
Collaborator

According to this, jax/jaxlib >= 0.4.3 seems to be incompatible with CuDNN 8.6. It would seem you have to use jax/jaxlib 0.4.2 or CuDNN 8.8, give it a try.

@mcleish7
Copy link
Contributor

Good Evening,

Thank you for the help, that seems to have got the benchmark code started with GPU support but I am now seeing this error: jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks.
From google I think this is to do with Python 3.9 and above. I don't see anything related in the JAX GitHub Issues. Is there a recommended version of Python for the CLRS code or do you think there is another cause?

Thanks,
Sean

2023-02-23 19:07:41.985101: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /local/java/cuda-11.6.0/lib64/:/local/java/cudnn-linux-x86_64-8.5.0.96_cuda11-archive/lib/
2023-02-23 19:07:41.985869: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /local/java/cuda-11.6.0/lib64/:/local/java/cudnn-linux-x86_64-8.5.0.96_cuda11-archive/lib/
2023-02-23 19:07:41.985881: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0223 19:07:49.846865 139926348851008 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0223 19:07:49.906153 139926348851008 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host Interpreter CUDA
I0223 19:07:49.906481 139926348851008 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0223 19:07:49.906545 139926348851008 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0223 19:07:55.778950 139926348851008 run.py:299] Creating samplers for algo bfs
W0223 19:07:55.779258 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:55.779516 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.005018 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.005211 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.241070 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.241253 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.525957 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.526148 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:56.848360 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
W0223 19:07:56.848542 139926348851008 samplers.py:100] Sampling dataset on-the-fly, unlimited samples.
W0223 19:07:57.231205 139926348851008 samplers.py:277] Ignoring kwargs {'length_needle'} when building sampler class <class 'clrs._src.samplers.BfsSampler'>
I0223 19:07:57.231393 139926348851008 samplers.py:112] Creating a dataset with 64 samples.
I0223 19:07:57.256827 139926348851008 run.py:158] Dataset not found in /tmp/CLRS30/CLRS30_v1.0.0. Downloading...
I0223 19:08:11.790027 139926348851008 dataset_info.py:482] Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
I0223 19:08:11.791674 139926348851008 dataset_info.py:482] Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
I0223 19:08:11.792154 139926348851008 dataset_builder.py:366] Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0)
I0223 19:08:11.792221 139926348851008 logging_logger.py:44] Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/bfs_test/1.0.0
WARNING:tensorflow:From /dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0223 19:08:11.990011 139926348851008 deprecation.py:350] From /dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Traceback (most recent call last):
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
    app.run(main)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 464, in main
    cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 370, in feedback
    loss, self._device_params, self._device_opt_state = self.jitted_feedback(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/api.py", line 564, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/pxla.py", line 2790, in lower_sharding_computation
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2073, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2006, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 318, in _feedback
    params, opt_state = self._update_params(params, grads, opt_state,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 428, in _update_params
    updates, opt_state = filter_null_grads(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 769, in filter_null_grads
    flat_opt_state = jax.tree_util.tree_map(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/tree_util.py", line 207, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 771, in <lambda>
    if not isinstance(x, _Array) else x, opt_state_skeleton, opt_state)
  File "/usr/lib64/python3.9/typing.py", line 720, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/usr/lib64/python3.9/typing.py", line 723, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 537, in <module>
    app.run(main)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/examples/run.py", line 464, in main
    cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 370, in feedback
    loss, self._device_params, self._device_opt_state = self.jitted_feedback(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 318, in _feedback
    params, opt_state = self._update_params(params, grads, opt_state,
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 428, in _update_params
    updates, opt_state = filter_null_grads(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 769, in filter_null_grads
    flat_opt_state = jax.tree_util.tree_map(
  File "/dcs/large/user/deep-thinking/clrs_env/lib64/python3.9/site-packages/clrs/_src/baselines.py", line 771, in <lambda>
    if not isinstance(x, _Array) else x, opt_state_skeleton, opt_state)
  File "/usr/lib64/python3.9/typing.py", line 720, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/usr/lib64/python3.9/typing.py", line 723, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

@hbq1
Copy link
Collaborator

hbq1 commented Feb 28, 2023

@mcleish7 it should be fixed in 2b37ff3

@mcleish7
Copy link
Contributor

mcleish7 commented Mar 1, 2023

@hbq1 @bibarzgoogle Thank you for your help, it is now working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants