diff --git a/pyproject.toml b/pyproject.toml index 4b729791e..4e09fe997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,16 @@ requires-python = ">=3.10,<3.13" dynamic = ["version"] dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp"] +[project.optional-dependencies] +dev = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" + [project.urls] Homepage = "https://github.com/ROCm/ATOM" Repository = "https://github.com/ROCm/ATOM" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..0538f32cb --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# ATOM Unit Tests + diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..91c745de1 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,53 @@ +"""Unit tests for config classes.""" + +import pytest +from atom.config import CompilationConfig, CompilationLevel, CUDAGraphMode + + +class TestCUDAGraphMode: + + def test_mode_values(self): + assert CUDAGraphMode.NONE.value == 0 + assert CUDAGraphMode.PIECEWISE.value == 1 + assert CUDAGraphMode.FULL.value == 2 + assert CUDAGraphMode.FULL_DECODE_ONLY.value == (2, 0) + assert CUDAGraphMode.FULL_AND_PIECEWISE.value == (2, 1) + + def test_separate_routine(self): + assert CUDAGraphMode.NONE.separate_routine() is False + assert CUDAGraphMode.FULL_DECODE_ONLY.separate_routine() is True + + def test_decode_and_mixed_mode(self): + assert CUDAGraphMode.FULL_DECODE_ONLY.decode_mode() == CUDAGraphMode.FULL + assert CUDAGraphMode.FULL_DECODE_ONLY.mixed_mode() == CUDAGraphMode.NONE + assert CUDAGraphMode.FULL_AND_PIECEWISE.mixed_mode() == CUDAGraphMode.PIECEWISE + + def test_has_full_cudagraphs(self): + assert CUDAGraphMode.NONE.has_full_cudagraphs() is False + assert CUDAGraphMode.FULL.has_full_cudagraphs() is True + + def test_requires_piecewise_compilation(self): + assert CUDAGraphMode.PIECEWISE.requires_piecewise_compilation() is True + assert CUDAGraphMode.FULL.requires_piecewise_compilation() is False + + +class TestCompilationConfig: + + def test_default_values(self): + config = CompilationConfig() + assert config.level == 0 + assert config.use_cudagraph is True + assert config.cuda_graph_sizes == [512] + + def test_invalid_level_raises_error(self): + with pytest.raises(ValueError, match="level must in 0-3"): + CompilationConfig(level=5) + + def test_compute_hash_consistency(self): + config = CompilationConfig(level=1) + assert config.compute_hash() == config.compute_hash() + + def test_set_splitting_ops_for_v1(self): + config = CompilationConfig(level=CompilationLevel.PIECEWISE) + config.set_splitting_ops_for_v1() + assert "aiter.unified_attention_with_output" in config.splitting_ops diff --git a/tests/test_sampling_params.py b/tests/test_sampling_params.py new file mode 100644 index 000000000..081edca7b --- /dev/null +++ b/tests/test_sampling_params.py @@ -0,0 +1,31 @@ +"""Unit tests for SamplingParams class.""" + +import pytest +from atom.sampling_params import SamplingParams + + +class TestSamplingParams: + + def test_default_values(self): + params = SamplingParams() + assert params.temperature == 1.0 + assert params.max_tokens == 64 + assert params.ignore_eos is False + assert params.stop_strings is None + + def test_custom_values(self): + params = SamplingParams( + temperature=0.5, + max_tokens=256, + ignore_eos=True, + stop_strings=["<|end|>", "STOP"] + ) + assert params.temperature == 0.5 + assert params.max_tokens == 256 + assert params.ignore_eos is True + assert params.stop_strings == ["<|end|>", "STOP"] + + def test_equality(self): + params1 = SamplingParams(temperature=0.8, max_tokens=100) + params2 = SamplingParams(temperature=0.8, max_tokens=100) + assert params1 == params2 diff --git a/tests/test_simple.py b/tests/test_simple.py new file mode 100644 index 000000000..114cbcb23 --- /dev/null +++ b/tests/test_simple.py @@ -0,0 +1,8 @@ +"""Simple sanity test.""" + + +def test_import(): + """Test that core modules can be imported.""" + from atom.sampling_params import SamplingParams + assert SamplingParams is not None +