diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 106cde2..1e01f6d 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -578,13 +578,18 @@ def _field_for_generic_type( ) from . import union_field + union_metadata = { + k: v + for k, v in metadata.items() + if k not in ("allow_none", "dump_default", "load_default", "required") + } return union_field.Union( [ ( subtyp, field_for_schema( subtyp, - metadata={"required": True}, + metadata=union_metadata, base_schema=base_schema, typ_frame=typ_frame, ), diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 0e60f0b..2d6d7af 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -11,6 +11,7 @@ from typing_extensions import Final, Literal # type: ignore[assignment] from marshmallow import fields, Schema, validate +from marshmallow.warnings import RemovedInMarshmallow4Warning from marshmallow_dataclass import ( field_for_schema, @@ -132,6 +133,56 @@ class Color(Enum): marshmallow_enum.EnumField(enum=Color, required=True), ) + def test_union_enum(self): + class Fruit(Enum): + apple = "Apple" + banana = "Banana" + tomato = "Tomato" + + with self.assertWarns(RemovedInMarshmallow4Warning): + if hasattr(fields, "Enum"): + self.assertFieldsEqual( + field_for_schema(Union[Fruit, str], metadata={"by_value": True}), + union_field.Union( + [ + ( + Fruit, + fields.Enum(enum=Fruit, required=True, by_value=True), + ), + ( + str, + fields.String( + required=True, metadata={"by_value": True} + ), + ), + ], + required=True, + metadata={"by_value": True}, + ), + ) + else: + import marshmallow_enum + + self.assertFieldsEqual( + field_for_schema(Union[Fruit, str], metadata={"by_value": True}), + marshmallow_enum.EnumField( + [ + ( + Fruit, + fields.Enum(enum=Fruit, required=True, by_value=True), + ), + ( + str, + fields.String( + required=True, metadata={"by_value": True} + ), + ), + ], + required=True, + metadata={"by_value": True}, + ), + ) + def test_literal(self): self.assertFieldsEqual( field_for_schema(Literal["a"]), diff --git a/tests/test_union.py b/tests/test_union.py index 5f0eb58..aca4db7 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -1,8 +1,10 @@ from dataclasses import field +from enum import Enum import sys import unittest from typing import List, Optional, Union, Dict +from marshmallow.warnings import RemovedInMarshmallow4Warning import marshmallow from marshmallow_dataclass import dataclass @@ -196,3 +198,22 @@ class PEP604IntOrStr: data_in = {"value": 42} self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_union_enum(self): + class Fruit(Enum): + apple = "Apple" + banana = "Banana" + tomato = "Tomato" + + @dataclass + class Dclass: + value: Union[Fruit, dict] = field(metadata={"by_value": True}) + + with self.assertWarns(RemovedInMarshmallow4Warning): + schema = Dclass.Schema() + + data_in = {"value": "Apple"} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = {"value": {"fruit": "Orange"}} + self.assertEqual(schema.dump(schema.load(data_in)), data_in)