From b69cd802b339bb82cc1590092dcfc68ce3633e2f Mon Sep 17 00:00:00 2001 From: abhash-er Date: Fri, 7 Feb 2025 14:22:12 +0100 Subject: [PATCH] test(test_benchmarks): fix tests for nb201 --- tests/test_benchmarks.py | 48 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 3ba3354e..d3f3cd9a 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -126,55 +126,57 @@ def test_nb201_benchmark(self) -> None: from confopt.benchmark import NB201Benchmark api = NB201Benchmark() + query_result = api.query(nb201_genotype) # check cifar 10 - query_result = api.query(nb201_genotype, dataset="cifar10") + dataset = "cifar10" train_result = 99.78 test_result = 92.32 - assert query_result["benchmark/train_top1"] == train_result - assert query_result["benchmark/test_top1"] == test_result + assert query_result[f"benchmark/{dataset}/train_top1"] == train_result + assert query_result[f"benchmark/{dataset}/test_top1"] == test_result # check cifar100 - query_result = api.query(nb201_genotype, dataset="cifar100") + dataset = "cifar100" train_result = 91.19 valid_result = 67.7 test_result = 67.94 - assert query_result["benchmark/train_top1"] == train_result - assert query_result["benchmark/valid_top1"] == valid_result - assert query_result["benchmark/test_top1"] == test_result + assert query_result[f"benchmark/{dataset}/train_top1"] == train_result + assert query_result[f"benchmark/{dataset}/valid_top1"] == valid_result + assert query_result[f"benchmark/{dataset}/test_top1"] == test_result # check imagenet - query_result = api.query(nb201_genotype, dataset="imagenet16") + dataset="imagenet16" train_result = 46.84 valid_result = 41.0 test_result = 41.47 - assert query_result["benchmark/train_top1"] == train_result - assert query_result["benchmark/valid_top1"] == valid_result - assert query_result["benchmark/test_top1"] == test_result + assert query_result[f"benchmark/{dataset}/train_top1"] == train_result + assert query_result[f"benchmark/{dataset}/valid_top1"] == valid_result + assert query_result[f"benchmark/{dataset}/test_top1"] == test_result @pytest.mark.benchmark() # type: ignore def test_nb201_benchmark_fail(self) -> None: from confopt.benchmark import NB201Benchmark api = NB201Benchmark() + query_result = api.query(nb201_genotype_fail) # check cifar 10 - query_result = api.query(nb201_genotype_fail, dataset="cifar10") - assert query_result["benchmark/train_top1"] == 10.0 - assert query_result["benchmark/valid_top1"] == 0.0 - assert query_result["benchmark/test_top1"] == 10.0 + dataset = "cifar10" + assert query_result[f"benchmark/{dataset}/train_top1"] == 10.0 + assert query_result[f"benchmark/{dataset}/valid_top1"] == 0.0 + assert query_result[f"benchmark/{dataset}/test_top1"] == 10.0 # check cifar100 - query_result = api.query(nb201_genotype_fail, dataset="cifar100") - assert query_result["benchmark/train_top1"] == 1.0 - assert query_result["benchmark/valid_top1"] == 1.0 - assert query_result["benchmark/test_top1"] == 1.0 + dataset="cifar100" + assert query_result[f"benchmark/{dataset}/train_top1"] == 1.0 + assert query_result[f"benchmark/{dataset}/valid_top1"] == 1.0 + assert query_result[f"benchmark/{dataset}/test_top1"] == 1.0 # check imagenet - query_result = api.query(nb201_genotype_fail, dataset="imagenet16") - assert query_result["benchmark/train_top1"] == 0.86 - assert query_result["benchmark/valid_top1"] == 0.83 - assert query_result["benchmark/test_top1"] == 0.83 + dataset="imagenet16" + assert query_result[f"benchmark/{dataset}/train_top1"] == 0.86 + assert query_result[f"benchmark/{dataset}/valid_top1"] == 0.83 + assert query_result[f"benchmark/{dataset}/test_top1"] == 0.83 @pytest.mark.benchmark() # type: ignore def test_nb301_benchmark(self) -> None: