Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch-cuda backend fails to recover after assertion failed #20920

Open
jobs-git opened this issue Feb 18, 2025 · 2 comments
Open

Torch-cuda backend fails to recover after assertion failed #20920

jobs-git opened this issue Feb 18, 2025 · 2 comments

Comments

@jobs-git
Copy link

jobs-git commented Feb 18, 2025

Reproduction step:

  1. Use one Juypter cell for code 1 below
  2. Run the code using GPU CUDA
  3. In another cell use the same code but change the loss to 'mse'
  4. Run the code using GPU CUDA

Expectation:
Keras runs the second code without restarting python

Actual:
Cannot run the second code

System: python-3.10.16 Keras 3.8 Torch 2.3.1 Cuda 12.4

I am building a GUI component that users can build custom architecture, while doing some random test I found that:

The following code (Code 1) causes assertion failed with torch backend. Tensorflow on the other hand completes this gracefully. But what is more troublesome is that the torch backend also fails to recover until python restart which is detrimental for interactive environments such as ipykernel and python-based IDEs.

An equivalent code (Code 2) implemented in Torch has a tendency to fail as well, but Torch is able to recover so you can re-run it without restarting the whole python processes.

Code 1:

import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras.models import Sequential
from keras.layers import Input, Dense
import numpy as np

x_values = np.array([1, 2, 3, 4], dtype=np.float32)
y_values = np.array([0.50 + i * 0.50 for i in x_values], dtype=np.float32)

model = Sequential()

model.add(Input(shape=(1,)))

model.add(Dense(1, activation='relu'))

model.compile(optimizer='sgd', loss='binary_crossentropy')

model.fit(x_values, y_values, epochs=10, batch_size=1)

loss = model.evaluate(x_values, y_values)
print(f"Loss: {loss}")

Code 2:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os

x_values = np.array([1, 2, 3, 4], dtype=np.float32).reshape(-1, 1)
y_values = np.array([0.50 + i * 0.50 for i in x_values], dtype=np.float32).reshape(-1, 1)

x_train = torch.tensor(x_values)
y_train = torch.tensor(y_values)

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(1, 1)  
    
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return x

model = SimpleModel()

criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

max_y = torch.max(y_train)
if max_y > 0:
    y_train = y_train / max_y
else:
    raise ValueError("Maximum value of y_train is zero, cannot normalize")

epochs = 10
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}")

model.eval()
with torch.no_grad():
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    print(f"Final Loss: {loss.item()}")

Error output:

{
	"name": "RuntimeError",
	"message": "CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 15
     11 model = Sequential()
     13 model.add(Input(shape=(1,)))
---> 15 model.add(Dense(1, activation='relu'))
     17 model.compile(optimizer='sgd', loss='binary_crossentropy')
     19 model.fit(x_values, y_values, epochs=10, batch_size=1)

File ~/python3.10/site-packages/keras/src/models/sequential.py:122, in Sequential.add(self, layer, rebuild)
    120 self._layers.append(layer)
    121 if rebuild:
--> 122     self._maybe_rebuild()
    123 else:
    124     self.built = False

File ~/python3.10/site-packages/keras/src/models/sequential.py:141, in Sequential._maybe_rebuild(self)
    139 if isinstance(self._layers[0], InputLayer) and len(self._layers) > 1:
    140     input_shape = self._layers[0].batch_shape
--> 141     self.build(input_shape)
    142 elif hasattr(self._layers[0], \"input_shape\") and len(self._layers) > 1:
    143     # We can build the Sequential model if the first layer has the
    144     # `input_shape` property. This is most commonly found in Functional
    145     # model.
    146     input_shape = self._layers[0].input_shape

File ~/python3.10/site-packages/keras/src/layers/layer.py:228, in Layer.__new__.<locals>.build_wrapper(*args, **kwargs)
    226 with obj._open_name_scope():
    227     obj._path = current_path()
--> 228     original_build_method(*args, **kwargs)
    229 # Record build config.
    230 signature = inspect.signature(original_build_method)

File ~/python3.10/site-packages/keras/src/models/sequential.py:187, in Sequential.build(self, input_shape)
    185 for layer in self._layers[1:]:
    186     try:
--> 187         x = layer(x)
    188     except NotImplementedError:
    189         # Can happen if shape inference is not implemented.
    190         # TODO: consider reverting inbound nodes on layers processed.
    191         return

File ~/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/python3.10/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    449 prior = set_eval_frame(callback)
    450 try:
--> 451     return fn(*args, **kwargs)
    452 finally:
    453     set_eval_frame(prior)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
"
}
@abheesht17
Copy link
Collaborator

The reason for the error is that Torch expects targets to lie in the range [0, 1] (which makes sense because we are using binary cross entropy, so target should lie between 0 and 1). What is your use-case exactly, i.e., why are you looking to set targets > 1.?

As for the Python restarting issue, I am unable to replicate it on Colab. Here is the notebook: https://colab.research.google.com/gist/abheesht17/d360499d826b87ba12449b362aee398d/keras-issue-20920.ipynb.

@abheesht17 abheesht17 added stat:awaiting response from contributor and removed keras-team-review-pending Pending review by a Keras team member. labels Feb 20, 2025
@jobs-git
Copy link
Author

jobs-git commented Feb 21, 2025

It should be tried on CUDA, then use another CUDA calculation right after. I added a reproduction step.

@jobs-git jobs-git changed the title Torch backend fails to recover after assertion failed Torch-cuda backend fails to recover after assertion failed Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants