Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ The interface will automatically open in your default web browser, typically at
| `--source` | Input source (predict only) | **Required** | `image.jpg` |
| `--im-size` | Input image size | 640 | Any positive integer |
| `--batch-size` | Batch size | 16 | Powers of 2 recommended |
| `--device` | Compute device | `cuda` | `cuda`, `cpu`, `mps` |
| `--device` | Compute device | `cuda` | `cuda`, `cpu` |
| `--workers` | Data loading workers | 4 | 0-16 recommended |
| `--output-dir` | Output directory | Auto-generated | Any valid path |

Expand Down
25 changes: 16 additions & 9 deletions focoos/infer/infer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
from pathlib import Path
from time import perf_counter
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

import numpy as np
import supervision as sv
Expand All @@ -41,7 +41,7 @@
)
from focoos.processor.processor_manager import ProcessorManager
from focoos.utils.logger import get_logger
from focoos.utils.system import get_cpu_name, get_device_name
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type
from focoos.utils.vision import (
annotate_frame,
image_loader,
Expand All @@ -55,6 +55,7 @@ def __init__(
self,
model_dir: Union[str, Path],
runtime_type: Optional[RuntimeType] = None,
device: Literal["cuda", "cpu", "auto"] = "auto",
):
"""
Initialize a LocalModel instance.
Expand Down Expand Up @@ -90,7 +91,12 @@ def __init__(
# Determine runtime type and model format
runtime_type = runtime_type or FOCOOS_CONFIG.runtime_type
extension = ModelExtension.from_runtime_type(runtime_type)

if device == "auto":
self.device = get_device_type()
elif runtime_type == RuntimeType.ONNX_CPU:
self.device = "cpu"
else:
self.device = device
# Set model directory and path
self.model_dir: Union[str, Path] = model_dir
self.model_path = os.path.join(model_dir, f"model.{extension.value}")
Expand All @@ -111,7 +117,7 @@ def __init__(
model_config = ConfigManager.from_dict(self.model_info.model_family, self.model_info.config)
self.processor = ProcessorManager.get_processor(
self.model_info.model_family, model_config, self.model_info.im_size
)
).eval()
except Exception as e:
logger.error(f"Error creating model config: {e}")
raise e
Expand All @@ -123,10 +129,11 @@ def __init__(

# Load runtime for inference
self.runtime: BaseRuntime = load_runtime(
runtime_type,
str(self.model_path),
self.model_info,
FOCOOS_CONFIG.warmup_iter,
runtime_type=runtime_type,
model_path=str(self.model_path),
model_info=self.model_info,
warmup_iter=FOCOOS_CONFIG.warmup_iter,
device=self.device,
)

def _read_model_info(self) -> ModelInfo:
Expand Down Expand Up @@ -175,7 +182,7 @@ def infer(
t0 = perf_counter()
im = image_loader(image)
t1 = perf_counter()
tensors, _ = self.processor.preprocess(inputs=im, device="cuda")
tensors, _ = self.processor.preprocess(inputs=im, device=self.device)
# logger.debug(f"Input image size: {im.shape}")
t2 = perf_counter()

Expand Down
5 changes: 4 additions & 1 deletion focoos/infer/runtimes/load_runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from focoos.infer.runtimes.base import BaseRuntime
from focoos.ports import ModelInfo, OnnxRuntimeOpts, RuntimeType, TorchscriptRuntimeOpts
from focoos.utils.logger import get_logger
Expand Down Expand Up @@ -25,6 +27,7 @@ def load_runtime(
model_path: str,
model_info: ModelInfo,
warmup_iter: int = 50,
device: Literal["cuda", "cpu", "auto"] = "auto",
) -> BaseRuntime:
"""
Creates and returns a runtime instance based on the specified runtime type.
Expand Down Expand Up @@ -57,7 +60,7 @@ def load_runtime(
from focoos.infer.runtimes.torchscript import TorchscriptRuntime

opts = TorchscriptRuntimeOpts(warmup_iter=warmup_iter)
return TorchscriptRuntime(model_path=model_path, opts=opts, model_info=model_info)
return TorchscriptRuntime(model_path=model_path, opts=opts, model_info=model_info, device=device)
else:
if not ORT_AVAILABLE:
logger.error(
Expand Down
12 changes: 8 additions & 4 deletions focoos/infer/runtimes/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from time import perf_counter
from typing import Tuple, Union
from typing import Literal, Tuple, Union

import numpy as np
import torch

from focoos.infer.runtimes.base import BaseRuntime
from focoos.ports import LatencyMetrics, ModelInfo, Task, TorchscriptRuntimeOpts
from focoos.utils.logger import get_logger
from focoos.utils.system import get_cpu_name, get_device_name
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type

logger = get_logger("TorchscriptRuntime")

Expand All @@ -32,8 +32,12 @@ def __init__(
model_path: str,
opts: TorchscriptRuntimeOpts,
model_info: ModelInfo,
device: Literal["cuda", "cpu", "auto"] = "auto",
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "auto":
self.device = torch.device(get_device_type())
else:
self.device = torch.device(device)
logger.info(f"🔧 Device: {self.device}")
self.opts = opts
self.model_info = model_info
Expand All @@ -49,7 +53,7 @@ def __init__(
)
logger.info(f"⏱️ Warming up model {self.model_info.name} on {self.device}, size: {size}x{size}..")
with torch.no_grad():
np_image = torch.rand(1, 3, size, size, device=self.device)
np_image = torch.rand(1, 3, size, size).to(self.device)
for _ in range(self.opts.warmup_iter):
self.model(np_image)

Expand Down
13 changes: 8 additions & 5 deletions focoos/models/focoos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from focoos.utils.distributed.dist import launch
from focoos.utils.env import TORCH_VERSION
from focoos.utils.logger import get_logger
from focoos.utils.system import get_cpu_name, get_device_name, get_focoos_version, get_system_info
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type, get_focoos_version, get_system_info
from focoos.utils.vision import annotate_frame, image_loader

logger = get_logger("FocoosModel")
Expand Down Expand Up @@ -393,7 +393,7 @@ def export(
runtime_type: RuntimeType = RuntimeType.TORCHSCRIPT_32,
onnx_opset: int = 17,
out_dir: Optional[str] = None,
device: Literal["cuda", "cpu"] = "cuda",
device: Literal["cuda", "cpu", "auto"] = "auto",
overwrite: bool = True,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
) -> InferModel:
Expand All @@ -416,9 +416,12 @@ def export(
Raises:
ValueError: If unsupported PyTorch version or export format.
"""
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
logger.warning("CUDA is not available. Using CPU for export.")
if device == "auto":
device = get_device_type() # type: ignore
else:
device = device

logger.info(f"🔧 Export Device: {device}")
if out_dir is None:
out_dir = os.path.join(MODELS_DIR, self.model_info.ref or self.model_info.name)

Expand Down
11 changes: 10 additions & 1 deletion focoos/utils/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import time
import zipfile
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

import torch

from focoos.ports import GPUInfo
from focoos.utils.distributed import comm
Expand Down Expand Up @@ -413,3 +415,10 @@ def get_device_name() -> str:
else:
cpu_name = get_cpu_name()
return cpu_name if cpu_name is not None else "CPU"


def get_device_type() -> Literal["cuda", "cpu"]:
if torch.cuda.is_available():
return "cuda"
else:
return "cpu"
18 changes: 13 additions & 5 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,19 @@ def test_load_runtime(mocker: MockerFixture, tmp_path, runtime_type, expected_op

# assertions
assert runtime is not None
mock_runtime_class.assert_called_once_with(
model_path,
expected_opts,
mock_model_metadata,
)
if runtime_type == RuntimeType.TORCHSCRIPT_32:
mock_runtime_class.assert_called_once_with(
model_path=model_path,
opts=expected_opts,
model_info=mock_model_metadata,
device="auto",
)
else:
mock_runtime_class.assert_called_once_with(
model_path,
expected_opts,
mock_model_metadata,
)


def test_load_unavailable_runtime(mocker: MockerFixture):
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading