From e7e7e9dacf5e9013c9d1ed16e7233232369147b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 17 Apr 2024 16:55:05 +0200 Subject: [PATCH] Fix activation lookup with Python 3.12.3 (#375) We used the metaclass `EnumMeta`/`EnumType` to override reporting of missing enum values (to give the full set of supported activations). However, in Python 3.12.3, the default value of the `name` parameter of `EnumType.__call__` method was changed from `None` to `_not_given`: https://github.com/python/cpython/commit/d771729679d39904768f60b3352e02f5f491966c Even though this is a public API (which now uses a private default value), it seems too risky to continue using it. So in this change, we implement `Enum.__mising__` instead for the improved error reporting. --- curated_transformers/layers/activations.py | 51 +++++----------------- 1 file changed, 10 insertions(+), 41 deletions(-) diff --git a/curated_transformers/layers/activations.py b/curated_transformers/layers/activations.py index d3503133..4867a77e 100644 --- a/curated_transformers/layers/activations.py +++ b/curated_transformers/layers/activations.py @@ -1,5 +1,5 @@ import math -from enum import Enum, EnumMeta +from enum import Enum from typing import Type import torch @@ -7,46 +7,7 @@ from torch.nn import Module -class _ActivationMeta(EnumMeta): - """ - ``Enum`` metaclass to override the class ``__call__`` method with a more - fine-grained exception for unknown activation functions. - """ - - def __call__( - cls, - value, - names=None, - *, - module=None, - qualname=None, - type=None, - start=1, - ): - # Wrap superclass __call__ to give a nicer error message when - # an unknown activation is used. - if names is None: - try: - return EnumMeta.__call__( - cls, - value, - names, - module=module, - qualname=qualname, - type=type, - start=start, - ) - except ValueError: - supported_activations = ", ".join(sorted(v.value for v in cls)) - raise ValueError( - f"Invalid activation function `{value}`. " - f"Supported functions: {supported_activations}" - ) - else: - return EnumMeta.__call__(cls, value, names, module, qualname, type, start) - - -class Activation(Enum, metaclass=_ActivationMeta): +class Activation(Enum): """ Activation functions. @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta): #: Sigmoid Linear Unit (`Hendrycks et al., 2016`_). SiLU = "silu" + @classmethod + def _missing_(cls, value): + supported_activations = ", ".join(sorted(v.value for v in cls)) + raise ValueError( + f"Invalid activation function `{value}`. " + f"Supported functions: {supported_activations}" + ) + @property def module(self) -> Type[torch.nn.Module]: """