Skip to content

Commit 0deedfd

Browse files
committed
support dict as Schema
1 parent 5ac7296 commit 0deedfd

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

flask_smorest/arguments.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps
55
import http
66

7+
import marshmallow as ma
78
from webargs.flaskparser import FlaskParser
89

910
from .utils import deepupdate
@@ -28,8 +29,8 @@ def arguments(
2829
):
2930
"""Decorator specifying the schema used to deserialize parameters
3031
31-
:param type|Schema schema: Marshmallow ``Schema`` class or instance
32-
used to deserialize and validate the argument.
32+
:param type|Schema|dict schema: Marshmallow ``Schema`` class or instance
33+
or dict used to deserialize and validate the argument.
3334
:param str location: Location of the argument.
3435
:param str content_type: Content type of the argument.
3536
Should only be used in conjunction with ``json``, ``form`` or
@@ -56,6 +57,8 @@ def arguments(
5657
5758
See :doc:`Arguments <arguments>`.
5859
"""
60+
if isinstance(schema, dict):
61+
schema = ma.Schema.from_dict(schema)
5962
# At this stage, put schema instance in doc dictionary. Il will be
6063
# replaced later on by $ref or json.
6164
parameters = {

tests/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema):
7878
error_id = ma.fields.Str()
7979
text = ma.fields.Str()
8080

81-
return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))(
82-
DocSchema, QueryArgsSchema, ClientErrorSchema
81+
DictSchema = {
82+
"item_id": ma.fields.Int(dump_only=True),
83+
"field": ma.fields.Int(attribute="db_field"),
84+
}
85+
86+
return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema"))(
87+
DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema
8388
)

tests/test_blueprint.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,47 @@ def func(document, query_args):
307307
"query_args": {"arg1": "test"},
308308
}
309309

310+
@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
311+
def test_blueprint_dict_arguments(self, app, schemas, openapi_version):
312+
app.config["OPENAPI_VERSION"] = openapi_version
313+
api = Api(app)
314+
blp = Blueprint("test", __name__, url_prefix="/test")
315+
client = app.test_client()
316+
317+
@blp.route("/", methods=("POST",))
318+
@blp.arguments(schemas.DictSchema)
319+
def func(document):
320+
return {"document": document}
321+
322+
api.register_blueprint(blp)
323+
spec = api.spec.to_dict()
324+
325+
# Check parameters are documented
326+
if openapi_version == "2.0":
327+
parameters = spec["paths"]["/test/"]["post"]["parameters"]
328+
assert len(parameters) == 1
329+
assert parameters[0]["in"] == "body"
330+
assert "schema" in parameters[0]
331+
else:
332+
assert (
333+
"schema"
334+
in spec["paths"]["/test/"]["post"]["requestBody"]["content"][
335+
"application/json"
336+
]
337+
)
338+
339+
# Check parameters are passed as arguments to view function
340+
item_data = {"field": 12}
341+
response = client.post(
342+
"/test/",
343+
data=json.dumps(item_data),
344+
content_type="application/json",
345+
)
346+
assert response.status_code == 200
347+
assert response.json == {
348+
"document": {"db_field": 12},
349+
}
350+
310351
@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
311352
def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version):
312353
app.config["OPENAPI_VERSION"] = openapi_version

0 commit comments

Comments
 (0)