Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ requires-python = ">=3.10,<3.13"
dynamic = ["version"]
dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp"]

[project.optional-dependencies]
dev = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"

[project.urls]
Homepage = "https://github.com/ROCm/ATOM"
Repository = "https://github.com/ROCm/ATOM"
Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ATOM Unit Tests

53 changes: 53 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Unit tests for config classes."""

import pytest
from atom.config import CompilationConfig, CompilationLevel, CUDAGraphMode


class TestCUDAGraphMode:

def test_mode_values(self):
assert CUDAGraphMode.NONE.value == 0
assert CUDAGraphMode.PIECEWISE.value == 1
assert CUDAGraphMode.FULL.value == 2
assert CUDAGraphMode.FULL_DECODE_ONLY.value == (2, 0)
assert CUDAGraphMode.FULL_AND_PIECEWISE.value == (2, 1)

def test_separate_routine(self):
assert CUDAGraphMode.NONE.separate_routine() is False
assert CUDAGraphMode.FULL_DECODE_ONLY.separate_routine() is True

def test_decode_and_mixed_mode(self):
assert CUDAGraphMode.FULL_DECODE_ONLY.decode_mode() == CUDAGraphMode.FULL
assert CUDAGraphMode.FULL_DECODE_ONLY.mixed_mode() == CUDAGraphMode.NONE
assert CUDAGraphMode.FULL_AND_PIECEWISE.mixed_mode() == CUDAGraphMode.PIECEWISE

def test_has_full_cudagraphs(self):
assert CUDAGraphMode.NONE.has_full_cudagraphs() is False
assert CUDAGraphMode.FULL.has_full_cudagraphs() is True

def test_requires_piecewise_compilation(self):
assert CUDAGraphMode.PIECEWISE.requires_piecewise_compilation() is True
assert CUDAGraphMode.FULL.requires_piecewise_compilation() is False


class TestCompilationConfig:

def test_default_values(self):
config = CompilationConfig()
assert config.level == 0
assert config.use_cudagraph is True
assert config.cuda_graph_sizes == [512]

def test_invalid_level_raises_error(self):
with pytest.raises(ValueError, match="level must in 0-3"):
CompilationConfig(level=5)

def test_compute_hash_consistency(self):
config = CompilationConfig(level=1)
assert config.compute_hash() == config.compute_hash()

def test_set_splitting_ops_for_v1(self):
config = CompilationConfig(level=CompilationLevel.PIECEWISE)
config.set_splitting_ops_for_v1()
assert "aiter.unified_attention_with_output" in config.splitting_ops
31 changes: 31 additions & 0 deletions tests/test_sampling_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Unit tests for SamplingParams class."""

import pytest
from atom.sampling_params import SamplingParams


class TestSamplingParams:

def test_default_values(self):
params = SamplingParams()
assert params.temperature == 1.0
assert params.max_tokens == 64
assert params.ignore_eos is False
assert params.stop_strings is None

def test_custom_values(self):
params = SamplingParams(
temperature=0.5,
max_tokens=256,
ignore_eos=True,
stop_strings=["<|end|>", "STOP"]
)
assert params.temperature == 0.5
assert params.max_tokens == 256
assert params.ignore_eos is True
assert params.stop_strings == ["<|end|>", "STOP"]

def test_equality(self):
params1 = SamplingParams(temperature=0.8, max_tokens=100)
params2 = SamplingParams(temperature=0.8, max_tokens=100)
assert params1 == params2
8 changes: 8 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Simple sanity test."""


def test_import():
"""Test that core modules can be imported."""
from atom.sampling_params import SamplingParams
assert SamplingParams is not None