diff --git a/gradio-inference.py b/gradio-inference.py index b157fd8..cf820ab 100644 --- a/gradio-inference.py +++ b/gradio-inference.py @@ -80,7 +80,8 @@ def load_model(model_type: str = "efficientvit") -> nn.Module: logging.warning( f"Default model path '{model_path}' not found. Using untrained model." ) - # Set model to evaluation mode + # Move model to device and set to evaluation mode + model.to(device) model.eval() return model @@ -182,7 +183,6 @@ def predict_image(image: np.ndarray, model_type: str) -> tuple[dict, np.ndarray] # Load model model = load_model(model_type) - model.to(device) # Preprocess image logging.info("Preprocessing image...")