You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
Please describe the bug
IndexError: InlinedVector::at(size_type) const failed bounds check
System information and environment
OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
Python version:3.8.10
CUDA version:11.3
NCCL version:2.9
cupy version:11.3
GPU model and memory:2*A100(80G)
Alpa version:0.2.3
TensorFlow version:2.8.0
JAX version:0.3.22
To Reproduce
Steps to reproduce the behavior:
1.Training an llama model implemented by flax produces the following error
2. See error
2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379...
2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster.
Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s]
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 121, in call
self._decode_args_and_get_executable(*args))
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
File "/home/mpi/.local/lib/python3.8/site-packages/jax/linear_util.py", line 309, in memoized_fun
ans = call(fun, *args)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 223, in _compile_parallel_executable
return method.compile_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/parallel_method.py", line 108, in compile_executable
return compile_shard_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 78, in compile_shard_executable
return shard_parallel_internal(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 139, in shard_parallel_internal
hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0],
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
jax._src.traceback_util.UnfilteredStackTrace: IndexError: InlinedVector::at(size_type) const failed bounds check
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
IndexError: InlinedVector::at(size_type) const failed bounds check
Screenshots
Code snippet to reproduce the problem
@alpa.parallelize(batch_argnums=(1,2,3,4))
def train_step(state, seq, seq_mask, labels, labels_mask):
for epoch in range(n_epochs):
with tqdm(dataloader) as tepoch:
tepoch.set_description(f"Training/epoch {epoch}")
for batch in tepoch:
seq, seq_mask, labels, labels_mask = batch
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
if name == 'main':
main()
Additional information
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Please describe the bug
IndexError:
InlinedVector::at(size_type) const
failed bounds checkSystem information and environment
To Reproduce
Steps to reproduce the behavior:
1.Training an llama model implemented by flax produces the following error
2. See error
2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379...
2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster.
Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s]
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 121, in call
self._decode_args_and_get_executable(*args))
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
File "/home/mpi/.local/lib/python3.8/site-packages/jax/linear_util.py", line 309, in memoized_fun
ans = call(fun, *args)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 223, in _compile_parallel_executable
return method.compile_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/parallel_method.py", line 108, in compile_executable
return compile_shard_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 78, in compile_shard_executable
return shard_parallel_internal(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 139, in shard_parallel_internal
hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0],
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
jax._src.traceback_util.UnfilteredStackTrace: IndexError:
InlinedVector::at(size_type) const
failed bounds checkThe stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
IndexError:
InlinedVector::at(size_type) const
failed bounds checkScreenshots
Code snippet to reproduce the problem
@alpa.parallelize(batch_argnums=(1,2,3,4))
def train_step(state, seq, seq_mask, labels, labels_mask):
def train_forward(params):
# seq, seq_mask, labels, labels_mask = data_batch
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seq).shape[-1]),seq.shape)
outputs = state.apply_fn(
params,
seq,
seq_mask,
position_ids,
deterministic = False,
return_dict = False,
)
logits = outputs[0]
loss = cross_entropy_loss(logits, labels, mask=labels_mask)
return loss
dynamic_scale = state.dynamic_scale
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(train_forward)
dynamic_scale, is_fin, loss, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
if dynamic_scale:
new_state = new_state.replace(
opt_state=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.opt_state, state.opt_state),
params=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.params, state.params),
master_copy=jax.tree_map(
functools.partial(jnp.where, is_fin),
new_state.master_copy, state.master_copy),
dynamic_scale=dynamic_scale)
return new_state, loss
def main() -> None:
global llama_model
alpa.init(cluster="ray")
lr = 0.001
batch_size = 1
max_len = 640
n_epochs = 7
load_pretrained_model = False
ckpt_dir="./JAX_model/7B"
prepare dataset
tokenizer = LLaMATokenizer("./JAX_model/tokenizer.model")
dataset = GSMDataset(split='train')
collate_fn = partial(gsm_collate_fn_train, tokenizer=tokenizer, max_len=max_len)
dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)
set config
if load_pretrained_model:
with open(Path(ckpt_dir)/"params.json", "r") as f:
config_params = json.loads(f.read())
config_params.update({'vocab_size': len(tokenizer), 'max_seq_len':max_len})
llama_config = LLaMAConfig(**config_params)
else:
llama_config = LLaMAConfig(num_hidden_layers=4)
llama_model = LLaMAForCausalLMModule(llama_config)
init model
input_ids = jnp.ones((batch_size,max_len))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),input_ids.shape)
params = llama_model.init(input_ids, attention_mask, position_ids, return_dict=False, init_cache=False)
if load_pretrained_model:
param = restore(Path(ckpt_dir)/"consolidated.nra", replace_keys=False)
params['param'] = param
n_steps = math.ceil(len(dataloader))
schedule = warmup_cosine_decay_schedule(
init_value=0.,
peak_value=lr,
warmup_steps=n_steps,
decay_steps=n_steps + 1,
end_value=lr,
)
optimizer = adamw(learning_rate=schedule)
use_master_copy = True
dynamic_scale = DynamicScale()
alpa.global_config.flax_always_use_fp16_embedding = True
state = TrainState.create(apply_fn=llama_model.run, params=params, tx=optimizer,dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)
for epoch in range(n_epochs):
with tqdm(dataloader) as tepoch:
tepoch.set_description(f"Training/epoch {epoch}")
for batch in tepoch:
seq, seq_mask, labels, labels_mask = batch
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
if name == 'main':
main()
Additional information
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
The text was updated successfully, but these errors were encountered: