diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f2014af..334b626 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -40,7 +40,7 @@ class User: import threading import types import warnings -from enum import EnumMeta +from enum import Enum, EnumMeta from functools import lru_cache, partial from typing import ( Any, @@ -648,7 +648,22 @@ def field_for_schema( if typing_inspect.is_literal_type(typ): arguments = typing_inspect.get_args(typ) - return marshmallow.fields.Raw( + + field_type = marshmallow.fields.Raw + + # If all fields are an enum of the same type, interpret our literal as + # an enum instead. + if ( + len(arguments) > 0 + and isinstance(arguments[0], Enum) + and all(type(arguments[0]) == type(arg) for arg in arguments) + ): + import marshmallow_enum + + metadata["enum"] = type(arguments[0]) + field_type = marshmallow_enum.EnumField + + return field_type( validate=( marshmallow.validate.Equal(arguments[0]) if len(arguments) == 1 diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index b56a4a4..1412285 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -19,6 +19,12 @@ ) +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + class TestFieldForSchema(unittest.TestCase): def assertFieldsEqual(self, a: fields.Field, b: fields.Field): self.assertEqual(a.__class__, b.__class__, "field class") @@ -90,11 +96,6 @@ def test_optional_str(self): def test_enum(self): import marshmallow_enum - class Color(Enum): - RED: 1 - GREEN: 2 - BLUE: 3 - self.assertFieldsEqual( field_for_schema(Color), marshmallow_enum.EnumField(enum=Color, required=True), @@ -112,6 +113,16 @@ def test_literal_multiple_types(self): fields.Raw(required=True, validate=validate.OneOf(("a", 1, 1.23, True))), ) + def test_literal_enum(self): + import marshmallow_enum + + self.assertFieldsEqual( + field_for_schema(Literal[Color.BLUE]), + marshmallow_enum.EnumField( + required=True, enum=Color, validate=validate.Equal(Color.BLUE) + ), + ) + def test_final(self): self.assertFieldsEqual( field_for_schema(Final[str]), fields.String(required=True)