Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[Bug] Segment fault when using alpa to parallelize llama with jax 0.4.6 environment #970

Open
zigzagcai opened this issue Dec 6, 2023 · 2 comments

Comments

@zigzagcai
Copy link

zigzagcai commented Dec 6, 2023

Please describe the bug
When I try to use alpa to parallelize llama, it throws segment fault error.

Please describe the expected behavior
Alpa is expected to compile llama and work normally.

System information and environment

  • OS Platform and Distribution: Ubuntu 20.04 docker
  • Python version: 3.10.13
  • CUDA version: 11.8
  • NCCL version: 2.16.2
  • cupy version: cupy-cuda11x==12.2.0
  • GPU model and memory: NVIDIA A800 80GB
  • Alpa version: 1.1.0.dev0, built from source (alpa dev branch)
  • TensorFlow version: 2.11.0
  • JAX version: 0.4.6
  • jaxlibversion: 0.4.6, built from source (alpa dev branch and tensorflow-alpa dev branch)
  • Ray version:
>>> print(ray.__version__)
2.8.1
>>> print(ray.__commit__)
82a8df138fe7fcc5c42536ebf26e8c3665704fee

To Reproduce
Steps to reproduce the behavior:
LLaMa model used: https://github.com/young-geng/EasyLM/tree/main/EasyLM/models/llama

  1. ray start --head --system-config='{"object_spilling_threshold":0.99}'
  2. cd examples/llama_finetune
  3. bash run_llama.sh

Screenshots

/root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
2023-12-06 09:43:23.960016: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /fs-computility/llm/zigzagcai/gcc_10.2.0/lib64:/fs-computility/llm/zigzagcai/cuda118/lib64:/fs-computility/llm/zigzagcai/nccl/build/lib:/usr/local/nccl-rdma-sharp-plugins/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-12-06 09:43:23.960104: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /fs-computility/llm/zigzagcai/gcc_10.2.0/lib64:/fs-computility/llm/zigzagcai/cuda118/lib64:/fs-computility/llm/zigzagcai/nccl/build/lib:/usr/local/nccl-rdma-sharp-plugins/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-12-06 09:43:23.960114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-12-06 09:43:25,065 INFO worker.py:1489 -- Connecting to existing Ray cluster at address: 172.28.32.161:6379...
2023-12-06 09:43:25,075 INFO worker.py:1673 -- Connected to Ray cluster.
/root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
12/06/2023 09:45:35 - INFO - __main__ - Training/evaluation parameters TrainingArguments(output_dir='./output', overwrite_output_dir=True, do_train=True, do_eval=False, per_device_train_batch_size=32, per_device_eval_batch_size=16, num_micro_batches=32, operator_parallel=2, pipeline_parallel=2, use_remat=True, learning_rate=0.0005, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, adafactor=False, num_train_epochs=3.0, warmup_ratio=0.03, logging_steps=1, save_steps=3000, eval_steps=1000, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None)
Model config LLaMAConfig {
  "attn_pdrop": 0.0,
  "bos_token_id": 0,
  "embd_pdrop": 0.0,
  "eos_token_id": 1,
  "fcm_max_ratio": 0.0,
  "fcm_min_ratio": 0.0,
  "gradient_checkpointing": "nothing_saveable",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_sequence_length": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "resid_pdrop": 0.0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 32000
}

loading file tokenizer.model from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/tokenizer.model
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/special_tokens_map.json
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/tokenizer_config.json
12/06/2023 09:45:45 - INFO - jax._src.xla_bridge - Remote TPU is not linked into jax; skipping remote TPU.
12/06/2023 09:45:45 - INFO - jax._src.xla_bridge - Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
12/06/2023 09:45:46 - INFO - jax._src.xla_bridge - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
12/06/2023 09:45:46 - INFO - jax._src.xla_bridge - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
12/06/2023 09:45:46 - INFO - jax._src.xla_bridge - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "transformers_version": "4.28.1"
}

Model weights are not initialized as `_do_init` is set to `False`. Make sure to call `FlaxLLaMAForCausalLM.init_weights` manually to initialize the weights.
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/config.json
Model config LlamaConfig {
  "_name_or_path": "huggyllama/llama-7b",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "max_sequence_length": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 32000
}

loading weights file model.safetensors from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/model.safetensors.index.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.1"
}

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.53it/s]
All model checkpoint weights were used when initializing LlamaForCausalLM.

All the weights of LlamaForCausalLM were initialized from the model checkpoint at huggyllama/llama-7b.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/generation_config.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.1"
}

Loading data...
#train 44425, #eval 907
Formatting inputs...Skip in lazy mode
Formatting inputs...Skip in lazy mode
12/06/2023 09:47:09 - INFO - __main__ - ***** Build dataset *****
12/06/2023 09:48:09 - INFO - __main__ - ***** Running training *****
12/06/2023 09:48:09 - INFO - __main__ -   Num examples = 44425
12/06/2023 09:48:09 - INFO - __main__ -   Num Epochs = 3
12/06/2023 09:48:09 - INFO - __main__ -   Batch size per device (w. accumulation) = 32
12/06/2023 09:48:09 - INFO - __main__ -   Global train batch size (w. parallel & distributed) = 256
12/06/2023 09:48:09 - INFO - __main__ -   Total optimization steps = 519
Initial compilation. This might take some minutes...
Epoch ... :   0%|                                                                                                                                                                                                                                                       | 0/3 [00:00<?, ?it/s(pid=945590) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.                  | 0/173 [00:00<?, ?it/s]
(pid=945590)   jax.tree_util.register_keypaths(data_clz, keypaths)
(pid=945590) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
(pid=945590)   jax.tree_util.register_keypaths(data_clz, keypaths)
(CompileWorker pid=945590) 2023-12-06 09:48:32.950735: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:
(CompileWorker pid=945590)
(CompileWorker pid=945590)   dynamic-slice.321 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).
(CompileWorker pid=945590)
(CompileWorker pid=945590) This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
(CompileWorker pid=945590)
(CompileWorker pid=945590) If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
(CompileWorker pid=945590) 2023-12-06 09:48:34.623874: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2.673245469s
(CompileWorker pid=945590) Constant folding an instruction is taking > 1s:
(CompileWorker pid=945590)
(CompileWorker pid=945590)   dynamic-slice.321 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).
(CompileWorker pid=945590)
(CompileWorker pid=945590) This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
(CompileWorker pid=945590)
(CompileWorker pid=945590) If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
(pid=945589) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead. [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(pid=945589)   jax.tree_util.register_keypaths(data_clz, keypaths) [repeated 4x across cluster]
(CompileWorker pid=945590) 2023-12-06 09:48:36.646609: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:
(CompileWorker pid=945590)
(CompileWorker pid=945590)   dynamic-slice.324 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).
(CompileWorker pid=945590)
(CompileWorker pid=945590) This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
(CompileWorker pid=945590)
(CompileWorker pid=945590) If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
(CompileWorker pid=945590) 2023-12-06 09:48:37.395438: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2.748873783s
(CompileWorker pid=945590) Constant folding an instruction is taking > 2s:
(CompileWorker pid=945589) 2023-12-06 09:48:37.060684: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:
(CompileWorker pid=945589)  [repeated 9x across cluster]
(CompileWorker pid=945589)   dynamic-slice.327 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above). [repeated 3x across cluster]
(CompileWorker pid=945589) This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time. [repeated 3x across cluster]
(CompileWorker pid=945589) If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. [repeated 3x across cluster]
(CompileWorker pid=945589) 2023-12-06 09:48:40.733518: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:
(CompileWorker pid=945589) 2023-12-06 09:48:38.710697: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2.650110556s
(CompileWorker pid=945589) Constant folding an instruction is taking > 1s:
(CompileWorker pid=945589) 2023-12-06 09:48:41.451697: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2.718232627s
(CompileWorker pid=945589) Constant folding an instruction is taking > 2s:
(pid=947265) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
(pid=947265)   jax.tree_util.register_keypaths(data_clz, keypaths)
(pid=947265) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
(pid=947265)   jax.tree_util.register_keypaths(data_clz, keypaths)
(CompileWorker pid=945589)  [repeated 6x across cluster]
(CompileWorker pid=945589)   dynamic-slice.330 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above). [repeated 2x across cluster]
(CompileWorker pid=945589) This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time. [repeated 2x across cluster]
(CompileWorker pid=945589) If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. [repeated 2x across cluster]
(MeshHostWorker pid=947267) *** SIGSEGV received at time=1701856315 on cpu 66 ***
(MeshHostWorker pid=947267) PC: @     0x7f54dc73d71e  (unknown)  xla::HloInstruction::IsRoot()
(MeshHostWorker pid=947267)     @     0x7f8408612420  (unknown)  (unknown)
(MeshHostWorker pid=947267)     @     0x7f54dc6acb67         80  xla::HloLiveRange::Run()
(pid=947267) /root/miniconda3/envs/test_126/lib/python3.10/site-packages/flax/struct.py:136: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead. [repeated 10x across cluster]
(pid=947267)   jax.tree_util.register_keypaths(data_clz, keypaths) [repeated 10x across cluster]
(MeshHostWorker pid=947267)     @     0x7f54dc678183        928  xla::BufferAssigner::CreateAssignment()
(MeshHostWorker pid=947267)     @     0x7f54dc679f26        384  xla::BufferAssigner::Run()
(MeshHostWorker pid=947267)     @     0x7f54d900c6e7       1424  xla::gpu::CompileModuleToLlvmIrImpl()
(MeshHostWorker pid=947267)     @     0x7f54d900e9c7       2272  xla::gpu::GpuCompiler::RunBackend()
(MeshHostWorker pid=947267)     @     0x7f54db5fed57        320  xla::LLVMCompiler::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d9ed1e62        704  xla::Service::BuildExecutables()
(MeshHostWorker pid=947267)     @     0x7f54d9eca419       1184  xla::LocalService::CompileExecutables()
(MeshHostWorker pid=947267)     @     0x7f54d9ec65ab       3392  xla::LocalClient::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d97c02bc        800  xla::PjRtStreamExecutorClient::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d97d01e3       1456  xla::PjRtStreamExecutorClient::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d8f26fbb       2128  xla::ifrt::PjRtLoadedExecutable::Create()
(MeshHostWorker pid=947267)     @     0x7f54d8f2126e       1168  xla::ifrt::PjRtCompiler::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d8e9cfdd       1344  xla::PyClient::Compile()
(MeshHostWorker pid=947267)     @     0x7f54d8489fba       2464  pybind11::detail::argument_loader<>::call_impl<>()
(MeshHostWorker pid=947267)     @     0x7f54d848a3d2        240  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
(MeshHostWorker pid=947267)     @     0x7f54d8452274        720  pybind11::cpp_function::dispatcher()
(MeshHostWorker pid=947267)     @           0x4fc697  (unknown)  cfunction_call
(MeshHostWorker pid=947267)     @ ... and at least 1 more frames
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361: *** SIGSEGV received at time=1701856316 on cpu 66 ***
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361: PC: @     0x7f54dc73d71e  (unknown)  xla::HloInstruction::IsRoot()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f8408612420  (unknown)  (unknown)
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54dc6acb67         80  xla::HloLiveRange::Run()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54dc678183        928  xla::BufferAssigner::CreateAssignment()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54dc679f26        384  xla::BufferAssigner::Run()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d900c6e7       1424  xla::gpu::CompileModuleToLlvmIrImpl()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d900e9c7       2272  xla::gpu::GpuCompiler::RunBackend()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54db5fed57        320  xla::LLVMCompiler::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d9ed1e62        704  xla::Service::BuildExecutables()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d9eca419       1184  xla::LocalService::CompileExecutables()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d9ec65ab       3392  xla::LocalClient::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d97c02bc        800  xla::PjRtStreamExecutorClient::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d97d01e3       1456  xla::PjRtStreamExecutorClient::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d8f26fbb       2128  xla::ifrt::PjRtLoadedExecutable::Create()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d8f2126e       1168  xla::ifrt::PjRtCompiler::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d8e9cfdd       1344  xla::PyClient::Compile()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d8489fba       2464  pybind11::detail::argument_loader<>::call_impl<>()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d848a3d2        240  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @     0x7f54d8452274        720  pybind11::cpp_function::dispatcher()
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @           0x4fc697  (unknown)  cfunction_call
(MeshHostWorker pid=947267) [2023-12-06 09:51:56,194 E 947267 947267] logging.cc:361:     @ ... and at least 1 more frames
(MeshHostWorker pid=947267) Fatal Python error: Segmentation fault
(MeshHostWorker pid=947267)
(MeshHostWorker pid=947267) Stack (most recent call first):
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/shard_parallel/auto_sharding.py", line 459 in run_backend_compilation
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/mesh_executable.py", line 440 in __init__
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/mesh_executable.py", line 1019 in __init__
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 277 in put_executable
(MeshHostWorker pid=947267)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 467 in _resume_span
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 472 in __init__
(MeshHostWorker pid=947267)   File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 277 in put_executable
(MeshHostWorker pid=947267)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 467 in _resume_span
(MeshHostWorker pid=947267)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/function_manager.py", line 726 in actor_method_executor
(MeshHostWorker pid=947267)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/worker.py", line 797 in main_loop
(MeshHostWorker pid=947267)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/workers/default_worker.py", line 282 in <module>
(MeshHostWorker pid=947267)
(MeshHostWorker pid=947267) Extension modules: msgpack._cmsgpack, google.protobuf.pyext._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, ray._raylet, charset_normalizer.md, jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, scipy._lib._ccallback_c, numpy.linalg.lapack_lite, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.linalg._flinalg, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, cupy_backends.cuda.api._runtime_enum, cupy_backends.cuda.api.runtime, cupy_backends.cuda.stream, cupy_backends.cuda.libs.cublas, cupy_backends.cuda.libs.cusolver, cupy_backends.cuda._softlink, cupy_backends.cuda.libs.cusparse, cupy._util, cupy.cuda.device, fastrlock.rlock, cupy.cuda.memory_hook, cupy.cuda.graph, cupy.cuda.stream, cupy_backends.cuda.api._driver_enum, cupy_backends.cuda.api.driver, cupy.cuda.memory, cupy._core.internal, cupy._core._carray, cupy.cuda.texture, cupy.cuda.function, cupy_backends.cuda.libs.nvrtc, cupy.cuda.jitify, cupy.cuda.pinned_memory, cupy_backends.cuda.libs.curand, cupy_backends.cuda.libs.profiler, cupy.cuda.common, cupy.cuda.cub, cupy_backends.cuda.libs.nvtx, cupy.cuda.thrust, cupy._core._dtype, cupy._core._scalar, cupy._core._accelerator, cupy._core._memory_range, cupy._core._fusion_thread_local, cupy._core._kernel, cupy._core._routines_manipulation, cupy._core._routines_binary, cupy._core._optimize_config, cupy._core._cub_reduction, cupy._core._reduction, cupy._core._routines_math, cupy._core._routines_indexing, cupy._core._routines_linalg, cupy._core._routines_logic, cupy._core._routines_sorting, cupy._core._routines_statistics, cupy._core.dlpack, cupy._core.flags, cupy._core.core, cupy._core._fusion_variable, cupy._core._fusion_trace, cupy._core._fusion_kernel, cupy._core.new_fusion, cupy._core.fusion, cupy._core.raw, cupyx.cusolver, cupy.cuda.cufft, cupy.fft._cache, cupy.fft._callback, cupy.random._generator_api, cupy.random._bit_generator, cupy.lib._polynomial, cupy_backends.cuda.libs.nccl, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, pydantic.typing, pydantic.errors, pydantic.version, pydantic.utils, pydantic.class_validators, pydantic.config, pydantic.color, pydantic.datetime_parse, pydantic.validators, pydantic.networks, pydantic.types, pydantic.json, pydantic.error_wrappers, pydantic.fields, pydantic.parse, pydantic.schema, pydantic.main, pydantic.dataclasses, pydantic.annotated_types, pydantic.decorator, pydantic.env_settings, pydantic.tools, pydantic, pyarrow.lib, pyarrow._hdfsio, pyarrow._json (total: 220)
2023-12-06 09:51:56,673 WARNING worker.py:2074 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff5464d9520c63a70fc0121fca02000000 Worker ID: 4e0d5f9a17a15aad8a02ae6086dd6c61b1c535e4560a9b1ca0cd2dd6 Node ID: 862ba43f8d259b801fb37e72059afad3321e132866273935ac8b5743 Worker IP address: 172.28.32.161 Worker port: 10123 Worker PID: 947267 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
2023-12-06 09:51:57,219 WARNING worker.py:2074 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff19feb963c8ce433c78a3374602000000 Worker ID: 9c5223e56f7cf8d34752e503b847d7f2f88c98f7738289ef7df9ae9b Node ID: 862ba43f8d259b801fb37e72059afad3321e132866273935ac8b5743 Worker IP address: 172.28.32.161 Worker port: 10124 Worker PID: 947266 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Epoch ... :   0%|                                                                                                                                                                                                                                                       | 0/3 [04:31<?, ?it/s]
Traceback (most recent call last):
  File "/fs-computility/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 886, in <module>
    main()
  File "/fs-computility/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 760, in main
    executable.sync()
  File "/fs-computility/llm/zigzagcai/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 411, in sync
    self.mesh_group.sync_workers()
  File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2058, in sync_workers
    ray.get([w.sync.remote() for w in all_workers])
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/worker.py", line 2565, in get
    raise value
ray.exceptions.RayActorError: The actor died unexpectedly before finishing this task.
        class_name: MeshHostWorker
        actor_id: 19feb963c8ce433c78a3374602000000
        pid: 947266
        namespace: alpa_default_space
        ip: 172.28.32.161
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(MeshHostWorker pid=947266) *** SIGSEGV received at time=1701856316 on cpu 94 ***
(MeshHostWorker pid=947266) PC: @     0x7ee2d473d71e  (unknown)  xla::HloInstruction::IsRoot()
(MeshHostWorker pid=947266)     @     0x7f11fdc1e420  (unknown)  (unknown)
(MeshHostWorker pid=947266)     @     0x7ee2d46acb67         80  xla::HloLiveRange::Run()
(MeshHostWorker pid=947266)     @     0x7ee2d4678183        928  xla::BufferAssigner::CreateAssignment()
(MeshHostWorker pid=947266)     @     0x7ee2d4679f26        384  xla::BufferAssigner::Run()
(MeshHostWorker pid=947266)     @     0x7ee2d100c6e7       1424  xla::gpu::CompileModuleToLlvmIrImpl()
(MeshHostWorker pid=947266)     @     0x7ee2d100e9c7       2272  xla::gpu::GpuCompiler::RunBackend()
(MeshHostWorker pid=947266)     @     0x7ee2d35fed57        320  xla::LLVMCompiler::Compile()
(MeshHostWorker pid=947266)     @     0x7ee2d1ed1e62        704  xla::Service::BuildExecutables()
(MeshHostWorker pid=947266)     @     0x7ee2d1eca419       1184  xla::LocalService::CompileExecutables()
(MeshHostWorker pid=947266)     @     0x7ee2d1ec65ab       3392  xla::LocalClient::Compile()
(MeshHostWorker pid=947266)     @     0x7ee2d17d01e3       1456  xla::PjRtStreamExecutorClient::Compile() [repeated 2x across cluster]
(MeshHostWorker pid=947266)     @     0x7ee2d0f26fbb       2128  xla::ifrt::PjRtLoadedExecutable::Create()
(MeshHostWorker pid=947266)     @     0x7ee2d0f2126e       1168  xla::ifrt::PjRtCompiler::Compile()
(MeshHostWorker pid=947266)     @     0x7ee2d0e9cfdd       1344  xla::PyClient::Compile()
(MeshHostWorker pid=947266)     @     0x7ee2d0452274        704  pybind11::cpp_function::dispatcher() [repeated 3x across cluster]
(MeshHostWorker pid=947266)     @           0x4fc697  (unknown)  cfunction_call
(MeshHostWorker pid=947266)     @ ... and at least 1 more frames
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361: *** SIGSEGV received at time=1701856316 on cpu 94 ***
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361: PC: @     0x7ee2d473d71e  (unknown)  xla::HloInstruction::IsRoot()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7f11fdc1e420  (unknown)  (unknown)
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d46acb67         80  xla::HloLiveRange::Run()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d4678183        928  xla::BufferAssigner::CreateAssignment()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d4679f26        384  xla::BufferAssigner::Run()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d100c6e7       1424  xla::gpu::CompileModuleToLlvmIrImpl()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d100e9c7       2272  xla::gpu::GpuCompiler::RunBackend()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d35fed57        320  xla::LLVMCompiler::Compile()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d1ed1e62        704  xla::Service::BuildExecutables()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d1eca419       1184  xla::LocalService::CompileExecutables()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d1ec65ab       3392  xla::LocalClient::Compile()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d17d01e3       1456  xla::PjRtStreamExecutorClient::Compile() [repeated 2x across cluster]
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d0f26fbb       2128  xla::ifrt::PjRtLoadedExecutable::Create()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d0f2126e       1168  xla::ifrt::PjRtCompiler::Compile()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d0e9cfdd       1344  xla::PyClient::Compile()
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @     0x7ee2d0452274        704  pybind11::cpp_function::dispatcher() [repeated 3x across cluster]
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @           0x4fc697  (unknown)  cfunction_call
(MeshHostWorker pid=947266) [2023-12-06 09:51:56,617 E 947266 947266] logging.cc:361:     @ ... and at least 1 more frames
(MeshHostWorker pid=947266) Fatal Python error: Segmentation fault
(MeshHostWorker pid=947266)  [repeated 2x across cluster]
(MeshHostWorker pid=947266) Stack (most recent call first):
(MeshHostWorker pid=947266)   File "/fs-computility/llm/zigzagcai/alpa/alpa/shard_parallel/auto_sharding.py", line 459 in run_backend_compilation
(MeshHostWorker pid=947266)   File "/fs-computility/llm/zigzagcai/alpa/alpa/mesh_executable.py", line 1019 in __init__ [repeated 2x across cluster]
(MeshHostWorker pid=947266)   File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 277 in put_executable [repeated 2x across cluster]
(MeshHostWorker pid=947266)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 467 in _resume_span [repeated 2x across cluster]
(MeshHostWorker pid=947266)   File "/fs-computility/llm/zigzagcai/alpa/alpa/pipeline_parallel/pipeshard_executable.py", line 472 in __init__
(MeshHostWorker pid=947266)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/function_manager.py", line 726 in actor_method_executor
(MeshHostWorker pid=947266)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/worker.py", line 797 in main_loop
(MeshHostWorker pid=947266)   File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/workers/default_worker.py", line 282 in <module>
(MeshHostWorker pid=947266) Extension modules: msgpack._cmsgpack, google.protobuf.pyext._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, ray._raylet, charset_normalizer.md, jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, scipy._lib._ccallback_c, numpy.linalg.lapack_lite, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.linalg._flinalg, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, cupy_backends.cuda.api._runtime_enum, cupy_backends.cuda.api.runtime, cupy_backends.cuda.stream, cupy_backends.cuda.libs.cublas, cupy_backends.cuda.libs.cusolver, cupy_backends.cuda._softlink, cupy_backends.cuda.libs.cusparse, cupy._util, cupy.cuda.device, fastrlock.rlock, cupy.cuda.memory_hook, cupy.cuda.graph, cupy.cuda.stream, cupy_backends.cuda.api._driver_enum, cupy_backends.cuda.api.driver, cupy.cuda.memory, cupy._core.internal, cupy._core._carray, cupy.cuda.texture, cupy.cuda.function, cupy_backends.cuda.libs.nvrtc, cupy.cuda.jitify, cupy.cuda.pinned_memory, cupy_backends.cuda.libs.curand, cupy_backends.cuda.libs.profiler, cupy.cuda.common, cupy.cuda.cub, cupy_backends.cuda.libs.nvtx, cupy.cuda.thrust, cupy._core._dtype, cupy._core._scalar, cupy._core._accelerator, cupy._core._memory_range, cupy._core._fusion_thread_local, cupy._core._kernel, cupy._core._routines_manipulation, cupy._core._routines_binary, cupy._core._optimize_config, cupy._core._cub_reduction, cupy._core._reduction, cupy._core._routines_math, cupy._core._routines_indexing, cupy._core._routines_linalg, cupy._core._routines_logic, cupy._core._routines_sorting, cupy._core._routines_statistics, cupy._core.dlpack, cupy._core.flags, cupy._core.core, cupy._core._fusion_variable, cupy._core._fusion_trace, cupy._core._fusion_kernel, cupy._core.new_fusion, cupy._core.fusion, cupy._core.raw, cupyx.cusolver, cupy.cuda.cufft, cupy.fft._cache, cupy.fft._callback, cupy.random._generator_api, cupy.random._bit_generator, cupy.lib._polynomial, cupy_backends.cuda.libs.nccl, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, pydantic.typing, pydantic.errors, pydantic.version, pydantic.utils, pydantic.class_validators, pydantic.config, pydantic.color, pydantic.datetime_parse, pydantic.validators, pydantic.networks, pydantic.types, pydantic.json, pydantic.error_wrappers, pydantic.fields, pydantic.parse, pydantic.schema, pydantic.main, pydantic.dataclasses, pydantic.annotated_types, pydantic.decorator, pydantic.env_settings, pydantic.tools, pydantic, pyarrow.lib, pyarrow._hdfsio, pyarrow._json (total: 220)
Exception ignored in: <function RemoteArrayRef.__del__ at 0x7f9891183d00>
Traceback (most recent call last):
  File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 1488, in __del__
  File "/fs-computility/llm/zigzagcai/alpa/alpa/device_mesh.py", line 1271, in delete_remote_buffers
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/actor.py", line 165, in remote
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 23, in auto_init_wrapper
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 16, in auto_init_ray
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/_private/worker.py", line 1436, in init
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/job_config.py", line 70, in __init__
  File "/root/miniconda3/envs/test_126/lib/python3.10/site-packages/ray/job_config.py", line 137, in set_default_actor_lifetime
ImportError: sys.meta_path is None, Python is likely shutting down

Code snippet to reproduce the problem
cd examples/llama_finetune && bash run_llama.sh

Additional information
unittest such as python3 -m alpa.test_install can work normally, but using alpa to democratize llama model will throw segment fault.

@zigzagcai zigzagcai changed the title Segment fault when running llama with alpa [Bug] Segment fault when running llama with alpa Dec 6, 2023
@zigzagcai zigzagcai changed the title [Bug] Segment fault when running llama with alpa [Bug] Segment fault when using alpa to parallelize llama with jax 0.4.6 environment Dec 6, 2023
@zigzagcai
Copy link
Author

zigzagcai commented Dec 7, 2023

We can use Alpa to parallelize LLaMa with jax/jaxlib==0.3.22, but it fails with jax/jaxlib==0.4.6.

The reason why we want to switch to jax/jaxlib==0.4.6 is because we want to use the new features pf jax.
We also found jax/jaxlib==0.4.6 failed with built-int workload example/gpt2

@Lssyes
Copy link

Lssyes commented Dec 19, 2023

A lot of changes have been made in version 0.4 and version 0.3 of jax

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants