hello,ashawkey ,Do you know how to solve this problem?thank you so much!
Run training:
UnfilteredStackTrace Traceback (most recent call last)
in
18 experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,
---> 19 yield_results=True)):
20
23 frames
UnfilteredStackTrace: TypeError: Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f02455d0650>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
in run_train(self, experiment_dir, work_unit_dir, rng, yield_results)
125 template = config.get('query_template', '{query}')
126 query = template.format(query=config.query)
--> 127 z_clip = encode_text(tokenize_fn(query))
128 del encode_text, tokenize_fn # Clean up text encoder.
129
TypeError: Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f02455d0650>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
hello,ashawkey ,Do you know how to solve this problem?thank you so much!
Run training:
UnfilteredStackTrace Traceback (most recent call last)
in
18 experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,
---> 19 yield_results=True)):
20
23 frames
UnfilteredStackTrace: TypeError: Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f02455d0650>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
in run_train(self, experiment_dir, work_unit_dir, rng, yield_results)
125 template = config.get('query_template', '{query}')
126 query = template.format(query=config.query)
--> 127 z_clip = encode_text(tokenize_fn(query))
128 del encode_text, tokenize_fn # Clean up text encoder.
129
TypeError: Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f02455d0650>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.