Skip to content

Some of the step tasks have been OOM Killed. #189

Open
@shubhamgp47

Description

@shubhamgp47

I am facing "oom_kill event in StepId=866679.batch. Some of the step tasks have been OOM Killed." while using avg_confidence strategy for my multilabel dataset with around 38000 images of size 224. I use torch Dataloader with batch size 8 to load the data. Here's a snippet of the code covering Active Learning loop -

n_queries = 14
for i in range(n_queries):
if i == 0:
n_instances = 8
else:
power += 0.25
n_instances = batch(int(np.ceil(np.power(10, power))), batch_size)
total_samples += n_instances
n_instances_list.append(total_samples)

print(f"\nQuery {i + 1}: Requesting {n_instances} samples.")
print(f"Number of samples in pool before query: {X_pool.shape[0]}")



with torch.device("cpu"):
    query_idx, _ = learner.query(X_pool, n_instances=n_instances) 
    query_idx = np.unique(query_idx)
    query_idx = np.array(query_idx).flatten() 

# Extract the samples based on the query indices
X_query = X_pool[query_idx]
y_query = y_pool[query_idx]
filenames_query = [filenames_pool[idx] for idx in query_idx]

print("Shape of X_query after indexing:", X_query.shape)

if X_query.ndim != 4:
    raise ValueError(f"Unexpected number of dimensions in X_query: {X_query.ndim}")
if X_query.shape[1:] != (224, 224, 3):
    raise ValueError(f"Unexpected shape in X_query dimensions: {X_query.shape}")

X_cumulative = np.vstack((X_cumulative, X_query))
y_cumulative = np.vstack((y_cumulative, y_query))
filenames_cumulative.extend(filenames_query)

save_checkpoint(i + 1, X_cumulative, y_cumulative, filenames_cumulative, save_dir)

learner.teach(X=X_cumulative, y=y_cumulative)

y_pred = learner.predict(X_test_np)
accuracy = accuracy_score(y_test_np, y_pred)
f1 = f1_score(y_test_np, y_pred, average='macro')
acc_test_data.append(accuracy)
f1_test_data.append(f1)

print(f"Accuracy after query {i + 1}: {accuracy}")
print(f"F1 Score after query {i + 1}: {f1}")


# Early stopping check
if f1 > best_f1_score:
    best_f1_score = f1
    wait = 0  # reset the wait counter
else:
    wait += 1  # increment the wait counter
    if wait >= patience:
        print("Stopping early due to no improvement in F1 score.")
        break

# Remove queried instances from the pool
X_pool = np.delete(X_pool, query_idx, axis=0)
y_pool = np.delete(y_pool, query_idx, axis=0)
filenames_pool = [filename for idx, filename in enumerate(filenames_pool) if idx not in query_idx]
print(f"Number of samples in pool after query: {X_pool.shape[0]}")

This code runs well till 11 iterations but in the 12th iteration I get the OOM kill error.

I am using A100 GPU with 40GB RAM which should be sufficient for this loop. Could you please help me identify what could be going wrong which leads to excessive memory requirement. Is there a bottleneck in my code that I should address? Could it be the case that for every iterarion the data is held in the main memory and can it be freed somehow without breaking the code and distorting the results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions