diff --git a/.isort.cfg b/.isort.cfg index 0e960e62..cb9e18d5 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party = cloudpickle,cv2,einops,numpy,omegaconf,setuptools,timm,torch,torchvision,yaml +known_third_party = cloudpickle,cv2,einops,numpy,omegaconf,pytest,setuptools,timm,torch,torchvision,yaml multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/tests/root_cfg.py b/tests/root_cfg.py index 139d5553..43b83505 100644 --- a/tests/root_cfg.py +++ b/tests/root_cfg.py @@ -9,5 +9,4 @@ # modification above won't affect future imports from .dir1.dir1_b import dir1b_dict, dir1b_str - lazyobj = L(count)(x=dir1a_str, y=dir1b_str) diff --git a/tests/test_config.py b/tests/test_config.py index b50fcb1c..1022d906 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,15 +1,22 @@ -import pytest -import torch - -from vformer.config import LazyCall, instantiate, get_config -from vformer.models import PVTSegmentation, SwinTransformer, VanillaViT, ViViTModel2 import os import tempfile from itertools import count -from vformer.config import LazyConfig, LazyCall as L +import pytest +import torch from omegaconf import DictConfig +from vformer.config import LazyCall +from vformer.config import LazyCall as L +from vformer.config import LazyConfig, instantiate +from vformer.models import ( + PVTSegmentation, + SwinTransformer, + VanillaViT, + ViViTModel2, + classification, +) + def test_lazy(): # classification models @@ -55,23 +62,30 @@ def test_lazy(): assert vivit(rand_vdo_tensor).shape == (32, 10) +def test_raise_errors(): + a = "strings" + with pytest.raises(TypeError): + obj = LazyConfig(a) # only callable objects -root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") def test_load(): + root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") cfg = LazyConfig.load(root_filename) assert cfg.dir1a_dict.a == "modified" - assert cfg.dir1b_dict.a == 1 + assert cfg.dir1b_dict.a == 1 assert cfg.lazyobj.x == "base_a_1" cfg.lazyobj.x = "new_x" # reload cfg = LazyConfig.load(root_filename) - assert cfg.lazyobj.x == "base_a_1" + assert cfg.lazyobj.x == "base_a_1" + def test_save_load(): + root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") + cfg = LazyConfig.load(root_filename) with tempfile.TemporaryDirectory(prefix="vformer") as d: fname = os.path.join(d, "test_config.yaml") @@ -85,28 +99,42 @@ def test_save_load(): # the rest are equal assert cfg == cfg2 + def test_failed_save(): cfg = DictConfig({"x": lambda: 3}, flags={"allow_objects": True}) with tempfile.TemporaryDirectory(prefix="vformer") as d: fname = os.path.join(d, "test_config.yaml") LazyConfig.save(cfg, fname) assert os.path.exists(fname) == True - assert os.path.exists(fname + ".pkl") == True + assert os.path.exists(fname + ".pkl") == True + def test_overrides(): + root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") + cfg = LazyConfig.load(root_filename) LazyConfig.apply_overrides(cfg, ["lazyobj.x=123", 'dir1b_dict.a="123"']) assert cfg.dir1b_dict.a == "123" assert cfg.lazyobj.x == 123 + def test_invalid_overrides(): + root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") + cfg = LazyConfig.load(root_filename) with pytest.raises(KeyError): LazyConfig.apply_overrides(cfg, ["lazyobj.x.xxx=123"]) + def test_to_py(): + root_filename = os.path.join(os.path.dirname(__file__), "root_cfg.py") + cfg = LazyConfig.load(root_filename) - cfg.lazyobj.x = {"a": 1, "b": 2, "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]})} + cfg.lazyobj.x = { + "a": 1, + "b": 2, + "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}), + } cfg.list = ["a", 1, "b", 3.2] py_str = LazyConfig.to_py(cfg) expected = """cfg.dir1a_dict.a = "modified" @@ -114,13 +142,18 @@ def test_to_py(): cfg.dir1b_dict.a = 1 cfg.dir1b_dict.b = 2 cfg.lazyobj = itertools.count( -x={ - "a": 1, - "b": 2, - "c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}), -}, -y="base_a_1_from_b", + x={ + "a": 1, + "b": 2, + "c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}), + }, + y="base_a_1_from_b", ) cfg.list = ["a", 1, "b", 3.2] """ assert py_str == expected + + root_filename = os.path.join(os.path.dirname(__file__), "testing.yaml") + cfg = LazyConfig.load(root_filename) + obj = LazyConfig.to_py(cfg) + print(obj) diff --git a/tests/test_models.py b/tests/test_models.py index 13854f5a..a858f39d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,14 +11,14 @@ def test_VanillaViT(): - model = MODEL_REGISTRY.get("trial")( + model = MODEL_REGISTRY.get("VanillaViT")( img_size=256, patch_size=32, n_classes=10, in_channels=3 ) out = model(img_3channels_256) assert out.shape == (2, 10) del model - model = MODEL_REGISTRY.get("trial")( + model = MODEL_REGISTRY.get("VanillaViT")( img_size=256, patch_size=32, n_classes=10, diff --git a/tests/testing.yaml b/tests/testing.yaml new file mode 100644 index 00000000..d4fd8332 --- /dev/null +++ b/tests/testing.yaml @@ -0,0 +1,2 @@ +x : 3 +y : 4 diff --git a/vformer/config/config_utils.py b/vformer/config/config_utils.py index 56a2c959..fa1a0913 100644 --- a/vformer/config/config_utils.py +++ b/vformer/config/config_utils.py @@ -5,6 +5,7 @@ import uuid from collections import abc from typing import Any + from omegaconf import DictConfig, ListConfig diff --git a/vformer/config/lazy.py b/vformer/config/lazy.py index cbebdc32..8a85e8c4 100644 --- a/vformer/config/lazy.py +++ b/vformer/config/lazy.py @@ -151,7 +151,9 @@ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): has_keys = keys is not None filename = filename.replace("/./", "/") # redundant if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: - raise ValueError(f"Config file {filename} has to be a python file.") + raise ValueError( + f"Config file {filename} is not supported, supported file types are : [`.py`, `.yaml`]." + ) if filename.endswith(".py"): _validate_py_syntax(filename) @@ -169,8 +171,14 @@ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): exec(compile(content, filename, "exec"), module_namespace) ret = module_namespace + elif filename.endswith(".yaml"): + + with open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + else: - raise NotImplementedError("Only python files supported for now. ") + raise NotImplementedError("Only python and yaml files supported for now. ") if has_keys: if isinstance(keys, str): @@ -364,8 +372,7 @@ def get_config_file(config_path): Returns: str: the real path to the config file. """ - cfg_file = open( os.path.join("vformer","configs", config_path) - ) + cfg_file = open(os.path.join("vformer", "configs", config_path)) if not os.path.exists(cfg_file): raise RuntimeError("{} not available in configs!".format(config_path)) return cfg_file diff --git a/vformer/models/dense/dpt.py b/vformer/models/dense/dpt.py index 71fb6ad5..32365c84 100644 --- a/vformer/models/dense/dpt.py +++ b/vformer/models/dense/dpt.py @@ -114,7 +114,7 @@ def __init__( 384, 768, ) - self.model = MODEL_REGISTRY.get("trial")( + self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=768, @@ -131,7 +131,7 @@ def __init__( elif backbone == "vitl16": scratch_in_features = (256, 512, 1024, 1024) - self.model = MODEL_REGISTRY.get("trial")( + self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=1024, @@ -148,7 +148,7 @@ def __init__( elif backbone == "vit_tiny": scratch_in_features = (48, 96, 144, 192) - self.model = MODEL_REGISTRY.get("trial")( + self.model = MODEL_REGISTRY.get("VanillaViT")( img_size=img_size, patch_size=16, embedding_dim=192,