Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type validation for foreign key and one to one model consistency #1792

Merged
merged 14 commits into from
Dec 20, 2024
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Changelog
Added
^^^^^
- Implement savepoints for transactions (#1816)
- Added type validation for foreign key fields to ensure type safety. Now raises `ValidationError` when assigning foreign key values with incorrect model types (#1792)

Fixed
^^^^^
Expand Down Expand Up @@ -1497,4 +1498,4 @@ Docs/examples:

await Tournament.filter(
events__name__in=['1', '3']
).order_by('-events__participants__name').distinct()
).order_by('-events__participants__name').distinct()
71 changes: 70 additions & 1 deletion tests/fields/test_fk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from tests import testmodels
from tortoise.contrib import test
from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError
from tortoise.exceptions import (
IntegrityError,
NoValuesFetched,
OperationalError,
ValidationError,
)
from tortoise.queryset import QuerySet


class TestForeignKeyField(test.TestCase):
def assertRaisesWrongTypeException(self, relation_name: str):
return self.assertRaisesRegex(
ValidationError, f"Invalid type for relationship field '{relation_name}'"
)

async def test_empty(self):
with self.assertRaises(IntegrityError):
await testmodels.MinRelation.create()
Expand Down Expand Up @@ -151,6 +161,11 @@ async def test_minimal__instantiated_create(self):
tour = await testmodels.Tournament.create(name="Team1")
await testmodels.MinRelation.create(tournament=tour)

async def test_minimal__instantiated_create_wrong_type(self):
author = await testmodels.Author.create(name="Author1")
with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.create(tournament=author)

async def test_minimal__instantiated_iterate(self):
tour = await testmodels.Tournament.create(name="Team1")
async for _ in tour.minrelations:
Expand Down Expand Up @@ -229,3 +244,57 @@ async def test_event__offset(self):
event2 = await testmodels.Event.create(name="Event2", tournament=tour)
event3 = await testmodels.Event.create(name="Event3", tournament=tour)
self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3])

async def test_fk_correct_type_assignment(self):
tour1 = await testmodels.Tournament.create(name="Team1")
tour2 = await testmodels.Tournament.create(name="Team2")
event = await testmodels.Event(name="Event1", tournament=tour1)

event.tournament = tour2
await event.save()
self.assertEqual(event.tournament_id, tour2.id)

async def test_fk_wrong_type_assignment(self):
tour = await testmodels.Tournament.create(name="Team1")
author = await testmodels.Author.create(name="Author")
rel = await testmodels.MinRelation.create(tournament=tour)

with self.assertRaisesWrongTypeException("tournament"):
rel.tournament = author

async def test_fk_none_assignment(self):
manager = await testmodels.Employee.create(name="Manager")
employee = await testmodels.Employee.create(name="Employee", manager=manager)

employee.manager = None
await employee.save()
self.assertIsNone(employee.manager)

async def test_fk_update_wrong_type(self):
tour = await testmodels.Tournament.create(name="Team1")
rel = await testmodels.MinRelation.create(tournament=tour)
author = await testmodels.Author.create(name="Author1")

with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.filter(id=rel.id).update(tournament=author)

async def test_fk_bulk_create_wrong_type(self):
author = await testmodels.Author.create(name="Author")
with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.bulk_create(
[testmodels.MinRelation(tournament=author) for _ in range(10)]
)

async def test_fk_bulk_update_wrong_type(self):
tour = await testmodels.Tournament.create(name="Team1")
await testmodels.MinRelation.bulk_create(
[testmodels.MinRelation(tournament=tour) for _ in range(1, 10)]
)
author = await testmodels.Author.create(name="Author")

with self.assertRaisesWrongTypeException("tournament"):
relations = await testmodels.MinRelation.all()
await testmodels.MinRelation.bulk_update(
[testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations],
fields=["tournament"],
)
25 changes: 25 additions & 0 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from tortoise.exceptions import (
ConfigurationError,
DoesNotExist,
FieldError,
IncompleteInstanceError,
IntegrityError,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ValidationError,
)
from tortoise.expressions import Expression
from tortoise.fields.base import Field
Expand Down Expand Up @@ -685,6 +687,8 @@ def __setattr__(self, key, value) -> None:
# set field value override async default function
if hasattr(self, "_await_when_save"):
self._await_when_save.pop(key, None)
if key in self._meta.fk_fields or key in self._meta.o2o_fields:
self._validate_relation_type(key, value)
super().__setattr__(key, value)

def _set_kwargs(self, kwargs: dict) -> Set[str]:
Expand Down Expand Up @@ -806,6 +810,27 @@ def _set_pk_val(self, value: Any) -> None:
Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc...
"""

@classmethod
def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None:
Abdeldjalil-H marked this conversation as resolved.
Show resolved Hide resolved
if value is None:
return

field = cls._meta.fields_map[field_key]
if not isinstance(field, (OneToOneFieldInstance, ForeignKeyFieldInstance)):
raise FieldError(
f"Field '{field_key}' must be a OneToOne or ForeignKey relation, "
f"got {type(field).__name__}"
)

expected_model = field.related_model
received_model = type(value)
if received_model is not expected_model:
raise ValidationError(
f"Invalid type for relationship field '{field_key}'. "
f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. "
"Make sure you're using the correct model class for this relationship."
)

@classmethod
async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL:
try:
Expand Down
1 change: 1 addition & 0 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ def _make_query(self) -> None:
if field_object.pk:
raise IntegrityError(f"Field {key} is PK and can not be updated")
if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)):
self.model._validate_relation_type(key, value)
Abdeldjalil-H marked this conversation as resolved.
Show resolved Hide resolved
fk_field: str = field_object.source_field # type: ignore
db_field = self.model._meta.fields_map[fk_field].source_field
value = executor.column_map[fk_field](
Expand Down