Skip to content

Commit

Permalink
test: async function default value
Browse files Browse the repository at this point in the history
  • Loading branch information
YAGregor committed Oct 15, 2023
1 parent cb41996 commit c8f9fe5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/test_callable_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tests import testmodels
from tortoise.contrib import test


class TestCallableDefault(test.TestCase):
async def test_default_create(self):
model = await testmodels.CallableDefault.create()
self.assertEqual(model.callable_default, "callable_default")
self.assertEqual(model.async_default, "async_callable_default")

async def test_default_by_save(self):
saved_model = testmodels.CallableDefault()
await saved_model.save()
self.assertEqual(saved_model.callable_default, "callable_default")
self.assertEqual(saved_model.async_default, "async_callable_default")

async def test_async_default_change(self):
default_change = testmodels.CallableDefault()
default_change.async_default = "changed"
await default_change.save()
self.assertEqual(default_change.async_default, "changed")
14 changes: 14 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,3 +891,17 @@ class PydanticMeta:
alias_generator=camelize_var,
populate_by_name=True,
)


def callable_default() -> str:
return "callable_default"


async def async_callable_default() -> str:
return "async_callable_default"


class CallableDefault(Model):
id = fields.IntField(pk=True)
callable_default = fields.CharField(max_length=32, default=callable_default)
async_default = fields.CharField(max_length=32, default=async_callable_default)
2 changes: 1 addition & 1 deletion tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def register_listener(cls, signal: Signals, listener: Callable):
async def _set_async_default_field(self) -> None:
"""retrieve value from field's async default value"""
if hasattr(self, "_await_when_save"):
for k, v in self._await_when_save.items():
for k, v in self._await_when_save.copy().items():
setattr(self, k, await v())
self._await_when_save = {}

Expand Down

0 comments on commit c8f9fe5

Please sign in to comment.