Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

Commit

Permalink
Support Marshmallow 3.0.0 that just got released 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
ramnes committed Aug 21, 2019
1 parent 13575ca commit 1690fda
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 83 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
language: python
python:
- "3.4"
- "3.5"
- "3.6"
- "pypy3"
Expand Down
35 changes: 19 additions & 16 deletions flask_stupe/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@

class Color(marshmallow.fields.String):
default_error_messages = {
"type": "Invalid type.",
"invalid": "Not a valid color."
}

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
try:
value = value.lower()
except AttributeError:
self.fail("type")
raise self.make_error("type")
if not hexcolor_re.match(value):
self.fail("invalid")
raise self.make_error("invalid")
return value

class Cron(marshmallow.fields.String):
Expand All @@ -35,19 +36,19 @@ class Cron(marshmallow.fields.String):
limits = (("minute", 59), ("hour", 23), ("dom", 31), ("month", 12),
("dow", 7))

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
fields = value.split()
if len(fields) != 5:
self.fail("invalid")
raise self.make_error("invalid")
for field, (name, limit) in zip(fields, self.limits):
try:
if field != "*":
if int(field) > limit or int(field) < 0:
raise Exception
except ValueError:
self.fail("invalid")
raise self.make_error("invalid")
except Exception:
self.fail(name)
raise self.make_error(name)
return value

currencies = ("ADF", "ADP", "AED", "AFA", "AFN", "ALL", "AMD", "ANG",
Expand Down Expand Up @@ -83,16 +84,17 @@ def _deserialize(self, value, attr, data):

class Currency(marshmallow.fields.String):
default_error_messages = {
"type": "Invalid type.",
"invalid": "Not a valid currency."
}

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
try:
value = value.upper()
except AttributeError:
self.fail("type")
raise self.make_error("type")
if value not in currencies:
self.fail("invalid")
raise self.make_error("invalid")
return value

class IP(marshmallow.fields.String):
Expand All @@ -101,11 +103,11 @@ class IP(marshmallow.fields.String):
}
ip_type = staticmethod(ipaddress.ip_address)

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
try:
self.ip_type(value)
except ValueError:
self.fail("invalid")
raise self.make_error("invalid")
return value

class IPv4(IP):
Expand Down Expand Up @@ -152,23 +154,24 @@ def _deserialize(self, value, *args, **kwargs):
return field._deserialize(value, *args, **kwargs)
except marshmallow.exceptions.ValidationError:
pass
self.fail("invalid")
raise self.make_error("invalid")

__all__.extend(["Color", "Currency", "IP", "IPv4", "IPv6", "OneOf"])


if bson and marshmallow:
class ObjectId(marshmallow.fields.String):
default_error_messages = {
"type": "Invalid type.",
"invalid": "Not a valid ObjectId."
}

def _deserialize(self, value, attr, data):
def _deserialize(self, value, attr, data, **kwargs):
try:
return bson.ObjectId(value)
except TypeError:
self.fail("type")
raise self.make_error("type")
except bson.objectid.InvalidId:
self.fail("invalid")
raise self.make_error("invalid")

__all__.append("ObjectId")
18 changes: 5 additions & 13 deletions flask_stupe/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,11 @@


if marshmallow:
if marshmallow.__version__.startswith('3'): # pragma: no cover
def _load_schema(schema, json):
try:
return schema.load(json)
except marshmallow.exceptions.ValidationError as e:
abort(400, e.messages)

else:
def _load_schema(schema, json):
results = schema.load(json)
if results.errors:
abort(400, results.errors)
return results.data
def _load_schema(schema, json):
try:
return schema.load(json)
except marshmallow.exceptions.ValidationError as e:
abort(400, e.messages)

def schema_required(schema):
"""Validate body of the request against the schema.
Expand Down
125 changes: 72 additions & 53 deletions tests/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import bson
import pytest
from marshmallow import Schema
from marshmallow import Schema, ValidationError
from marshmallow.fields import Integer, String

from flask_stupe.fields import (IP, Color, Cron, Currency, IPv4, IPv6,
Expand All @@ -13,25 +13,28 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"IP": "127.0.0.1"})
assert result.data["IP"] == "127.0.0.1"
assert result["IP"] == "127.0.0.1"

result = schema.load({"IP": "127.0.0"})
assert result.errors["IP"] == ["Not a valid IPv4 or IPv6 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "127.0.0"})
assert error.value.messages["IP"] == ["Not a valid IPv4 or IPv6 address."]

result = schema.load({"IP": "256.256.256.256"})
assert result.errors["IP"] == ["Not a valid IPv4 or IPv6 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "256.256.256.256"})
assert error.value.messages["IP"] == ["Not a valid IPv4 or IPv6 address."]

result = schema.load({"IP": "2001:0db8:0000:0000:0000:ff00:0042:8329"})
assert result.data["IP"] == "2001:0db8:0000:0000:0000:ff00:0042:8329"
assert result["IP"] == "2001:0db8:0000:0000:0000:ff00:0042:8329"

result = schema.load({"IP": "2001:db8:0:0:0:ff00:42:8329"})
assert result.data["IP"] == "2001:db8:0:0:0:ff00:42:8329"
assert result["IP"] == "2001:db8:0:0:0:ff00:42:8329"

result = schema.load({"IP": "2001:db8::ff00:42:8329"})
assert result.data["IP"] == "2001:db8::ff00:42:8329"
assert result["IP"] == "2001:db8::ff00:42:8329"

result = schema.load({"IP": "2001:gb8::ff00:42:8329"})
assert result.errors["IP"] == ["Not a valid IPv4 or IPv6 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "2001:gb8::ff00:42:8329"})
assert error.value.messages["IP"] == ["Not a valid IPv4 or IPv6 address."]


def test_ipv4():
Expand All @@ -40,16 +43,19 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"IP": "127.0.0.1"})
assert result.data["IP"] == "127.0.0.1"
result["IP"] == "127.0.0.1"

result = schema.load({"IP": "127.0.0"})
assert result.errors["IP"] == ["Not a valid IPv4 address."]
with pytest.raises(ValidationError) as error:
result = schema.load({"IP": "127.0.0"})
assert error.value.messages["IP"] == ["Not a valid IPv4 address."]

result = schema.load({"IP": "256.256.256.256"})
assert result.errors["IP"] == ["Not a valid IPv4 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "256.256.256.256"})
assert error.value.messages["IP"] == ["Not a valid IPv4 address."]

result = schema.load({"IP": "2001:db8::ff00:42:8329"})
assert result.errors["IP"] == ["Not a valid IPv4 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "2001:db8::ff00:42:8329"})
assert error.value.messages["IP"] == ["Not a valid IPv4 address."]


def test_ipv6():
Expand All @@ -58,19 +64,21 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"IP": "2001:0db8:0000:0000:0000:ff00:0042:8329"})
assert result.data["IP"] == "2001:0db8:0000:0000:0000:ff00:0042:8329"
assert result["IP"] == "2001:0db8:0000:0000:0000:ff00:0042:8329"

result = schema.load({"IP": "2001:db8:0:0:0:ff00:42:8329"})
assert result.data["IP"] == "2001:db8:0:0:0:ff00:42:8329"
assert result["IP"] == "2001:db8:0:0:0:ff00:42:8329"

result = schema.load({"IP": "2001:db8::ff00:42:8329"})
assert result.data["IP"] == "2001:db8::ff00:42:8329"
assert result["IP"] == "2001:db8::ff00:42:8329"

result = schema.load({"IP": "2001:gb8::ff00:42:8329"})
assert result.errors["IP"] == ["Not a valid IPv6 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "2001:gb8::ff00:42:8329"})
assert error.value.messages["IP"] == ["Not a valid IPv6 address."]

result = schema.load({"IP": "255.255.255.255"})
assert result.errors["IP"] == ["Not a valid IPv6 address."]
with pytest.raises(ValidationError) as error:
schema.load({"IP": "255.255.255.255"})
assert error.value.messages["IP"] == ["Not a valid IPv6 address."]


def test_color():
Expand All @@ -79,13 +87,15 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"color": "#ec068d"})
assert result.data["color"] == "#ec068d"
assert result["color"] == "#ec068d"

result = schema.load({"color": "test"})
assert result.errors["color"] == ["Not a valid color."]
with pytest.raises(ValidationError) as error:
schema.load({"color": "test"})
assert error.value.messages["color"] == ["Not a valid color."]

result = schema.load({"color": ["test", "test"]})
assert result.errors["color"] == ["Invalid input type."]
with pytest.raises(ValidationError) as error:
schema.load({"color": ["test", "test"]})
assert error.value.messages["color"] == ["Invalid type."]


def test_cron(app):
Expand All @@ -94,16 +104,20 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"schedule": "* * 4 * *"})
assert result.data["schedule"] == "* * 4 * *"
assert result["schedule"] == "* * 4 * *"

result = schema.load({"schedule": "* * 1 * * *"})
assert result.errors["schedule"] == ["Not a valid cron expression."]
with pytest.raises(ValidationError) as error:
schema.load({"schedule": "* * 1 * * *"})
assert error.value.messages["schedule"] == ["Not a valid cron expression."]

result = schema.load({"schedule": "60 * * * *"})
assert result.errors["schedule"] == ["The minutes field is invalid."]
with pytest.raises(ValidationError) as error:
schema.load({"schedule": "60 * * * *"})
assert error.value.messages["schedule"] == ["The minutes field is "
"invalid."]

result = schema.load({"schedule": "a * * * *"})
assert result.errors["schedule"] == ["Not a valid cron expression."]
with pytest.raises(ValidationError) as error:
schema.load({"schedule": "a * * * *"})
assert error.value.messages["schedule"] == ["Not a valid cron expression."]


def test_currency():
Expand All @@ -112,13 +126,15 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"currency": "EUR"})
assert result.data["currency"] == "EUR"
assert result["currency"] == "EUR"

result = schema.load({"currency": "1MD"})
assert result.errors["currency"] == ["Not a valid currency."]
with pytest.raises(ValidationError) as error:
schema.load({"currency": "1MD"})
assert error.value.messages["currency"] == ["Not a valid currency."]

result = schema.load({"currency": ["ILS", "EUR"]})
assert result.errors["currency"] == ["Invalid input type."]
with pytest.raises(ValidationError) as error:
schema.load({"currency": ["ILS", "EUR"]})
assert error.value.messages["currency"] == ["Invalid type."]


def test_oneof(app):
Expand All @@ -127,14 +143,15 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"value_type": 42})
assert result.data["value_type"] == 42
assert result["value_type"] == 42

result = schema.load({"value_type": "test"})
assert result.data["value_type"] == "test"
assert result["value_type"] == "test"

result = schema.load({"value_type": ["42", 42]})
assert result.errors["value_type"] == [("Object type doesn't match any "
"valid type")]
with pytest.raises(ValidationError) as error:
schema.load({"value_type": ["42", 42]})
assert error.value.messages["value_type"] == [("Object type doesn't match "
"any valid type")]

with pytest.raises(ValueError) as error:
class TestSchema2(Schema):
Expand All @@ -161,7 +178,7 @@ class TestSchema6(Schema):

schema = TestSchema6()
result = schema.load({"value_type": 42})
assert result.data["value_type"] == 42
assert result["value_type"] == 42


def test_object_id():
Expand All @@ -172,10 +189,12 @@ class TestSchema(Schema):

schema = TestSchema()
result = schema.load({"id": test_id})
assert result.data["id"] == test_id
assert result["id"] == test_id

result = schema.load({"id": "fail"})
assert result.errors["id"] == ["Not a valid ObjectId."]
with pytest.raises(ValidationError) as error:
schema.load({"id": "fail"})
assert error.value.messages["id"] == ["Not a valid ObjectId."]

result = schema.load({"id": 42})
assert result.errors["id"] == ["Invalid input type."]
with pytest.raises(ValidationError) as error:
schema.load({"id": 42})
assert error.value.messages["id"] == ["Invalid type."]

0 comments on commit 1690fda

Please sign in to comment.