Skip to content

Commit

Permalink
Fix non config field check for inherited Configs (#603)
Browse files Browse the repository at this point in the history
Currently, the code below results in failure

```python
from axlearn.common.config import ConfigBase, config_class

@config_class
class A(ConfigBase):
    a: int = 1

@config_class
class B(A):
    a = 2

assert B().a == 2
```

The existing check for missing type annotations relies on attrs.field_dict
during __init__, which would not preserve the information for newly defined
attributes without typehints that were dropped during attrs.define. We are
adding an explicit typehint annotation check during the instantiation of
the config.
  • Loading branch information
soundway committed Jul 24, 2024
1 parent 1a262c1 commit 89c6f75
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 1 deletion.
20 changes: 20 additions & 0 deletions axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,26 @@ def config_class(cls: Type[T], **kwargs) -> Type[T]:
if not issubclass(cls, ConfigBase):
raise InvalidConfigClassError(f"A config class must be a subclass of ConfigBase: {cls}")

# We check that all attributes are properly type annotated. The danger of not doing this check
# is that the default values of any child class attributes without type annotations will be
# silently ignored, which could cause completely unexpected behaviors.
annotations = cls.__dict__.get("__annotations__", {})
for key, val in cls.__dict__.items():
if key.startswith("__") or key in annotations:
continue
if inspect.isfunction(val) and any(
f"{base_cls.__qualname__}.{key}" == val.__qualname__ for base_cls in inspect.getmro(cls)
):
# When the value is a function, we need to check if the key is part of the config or if
# method belongs to the class. To do so, we check if this function is defined within
# this class or any of its parent classes. A method defined in a class should have the
# joint of the class's qualname and the key as its qualname.
continue
raise NonConfigFieldError(
f"Non-config attribute is not supported: {cls.__qualname__}.{key}. "
"Please make sure all config attributes are annotated with typehints."
)

attr_cls = attr.define(maybe_cls=cls, **_config_class_kwargs(), **kwargs)
# Pytype seems to infer attr_cls as a callable.
return _wrap_config_attr_cls(attr_cls) # pytype: disable=wrong-arg-types
Expand Down
90 changes: 89 additions & 1 deletion axlearn/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,95 @@ def test_non_config_field(self):
class Config(ConfigBase):
foo = 10

_ = Config()
del Config

def test_non_config_field_with_inheritance(self):
"""Repeat test_non_config_field test but with inherited class.
We need to make sure that overriding attributes (including functions) without typehints
result in failures.
"""

@config_class
class PConfig(ConfigBase):
foo: int = 10

# Check overriding attribute without typehints result in failure.
with self.assertRaisesRegex(config.NonConfigFieldError, "foo"):

@config_class
class CConfig1(PConfig):
foo = 10

del CConfig1

# Check that no new __annotation__ doesn't result in failures.
@config_class
class CConfig2(PConfig):
pass

_ = CConfig2()

# Check that overriding with a normal function without typehints results in failure.
with self.assertRaisesRegex(config.NonConfigFieldError, "foo"):

def f():
pass

@config_class
class CConfig3(PConfig):
foo = f

del CConfig3

# Check that overriding with a fake class instance method without typehints raises an error.
with self.assertRaisesRegex(config.NonConfigFieldError, "foo"):

def fake_foo(self):
print(self)

@config_class
class CConfig4(PConfig):
foo = fake_foo

del CConfig4

# Check that callable classes with `self` are caught.
with self.assertRaisesRegex(config.NonConfigFieldError, "foo"):

@dataclasses.dataclass
class CallableClass:
def my_fn(self):
del self

@config_class
class CConfig5(PConfig):
foo = CallableClass()

del CConfig5

# Use lambda defined in the class to fake a class instance method.
with self.assertRaisesRegex(config.NonConfigFieldError, "foo"):

@config_class
class CConfig6(PConfig):
foo = lambda self: self

del CConfig6

# Check that overriding existing class instance methods are fine.
@config_class
class CConfig7(ConfigBase):
def set(self, **kwargs):
pass

_ = CConfig7()

# Check that attributes set after-the-fact are still caught.
with self.assertRaisesRegex(config.NonConfigFieldError, "other_field"):
CConfig7.other_field = 1

_ = CConfig7()

def test_definition(self):
@config_class
Expand Down

0 comments on commit 89c6f75

Please sign in to comment.