diff --git a/curated_transformers/layers/activations.py b/curated_transformers/layers/activations.py index d3503133..4867a77e 100644 --- a/curated_transformers/layers/activations.py +++ b/curated_transformers/layers/activations.py @@ -1,5 +1,5 @@ import math -from enum import Enum, EnumMeta +from enum import Enum from typing import Type import torch @@ -7,46 +7,7 @@ from torch.nn import Module -class _ActivationMeta(EnumMeta): - """ - ``Enum`` metaclass to override the class ``__call__`` method with a more - fine-grained exception for unknown activation functions. - """ - - def __call__( - cls, - value, - names=None, - *, - module=None, - qualname=None, - type=None, - start=1, - ): - # Wrap superclass __call__ to give a nicer error message when - # an unknown activation is used. - if names is None: - try: - return EnumMeta.__call__( - cls, - value, - names, - module=module, - qualname=qualname, - type=type, - start=start, - ) - except ValueError: - supported_activations = ", ".join(sorted(v.value for v in cls)) - raise ValueError( - f"Invalid activation function `{value}`. " - f"Supported functions: {supported_activations}" - ) - else: - return EnumMeta.__call__(cls, value, names, module, qualname, type, start) - - -class Activation(Enum, metaclass=_ActivationMeta): +class Activation(Enum): """ Activation functions. @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta): #: Sigmoid Linear Unit (`Hendrycks et al., 2016`_). SiLU = "silu" + @classmethod + def _missing_(cls, value): + supported_activations = ", ".join(sorted(v.value for v in cls)) + raise ValueError( + f"Invalid activation function `{value}`. " + f"Supported functions: {supported_activations}" + ) + @property def module(self) -> Type[torch.nn.Module]: """