From f4048bc2c6dd469688ab2683002714d0fd83470e Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Fri, 29 Nov 2024 11:01:06 +0100 Subject: [PATCH 01/14] add type validation for fk and o2o --- tortoise/models.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tortoise/models.py b/tortoise/models.py index 7581ff060..73175fe97 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -30,6 +30,7 @@ from tortoise.exceptions import ( ConfigurationError, DoesNotExist, + FieldError, IncompleteInstanceError, IntegrityError, ObjectDoesNotExistError, @@ -699,6 +700,14 @@ def _set_kwargs(self, kwargs: dict) -> Set[str]: raise OperationalError( f"You should first call .save() on {value} before referring to it" ) + if type(value) is not meta.fields_map[key].related_model: + expected_model = meta.fields_map[key].related_model.__name__ + received_model = type(value).__name__ + raise FieldError( + f"Invalid type for foreign key '{key}'. " + f"Expected model type '{expected_model}', but got '{received_model}'. " + f"Make sure you're using the correct model class for this relationship." + ) setattr(self, key, value) passed_fields.add(meta.fields_map[key].source_field) elif key in meta.fields_db_projection: From 1e98c111296ee6cee38120672433bf02e2e33a8d Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Fri, 29 Nov 2024 11:14:52 +0100 Subject: [PATCH 02/14] handle None case --- tortoise/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/models.py b/tortoise/models.py index 73175fe97..102a5cf2a 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -700,7 +700,7 @@ def _set_kwargs(self, kwargs: dict) -> Set[str]: raise OperationalError( f"You should first call .save() on {value} before referring to it" ) - if type(value) is not meta.fields_map[key].related_model: + if value and type(value) is not meta.fields_map[key].related_model: expected_model = meta.fields_map[key].related_model.__name__ received_model = type(value).__name__ raise FieldError( From 9a7afdb028e245d89b6585a18b9214e758e95630 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Fri, 29 Nov 2024 11:21:42 +0100 Subject: [PATCH 03/14] add changelog --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3b40aecbb..9ceaf6c46 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ Changelog Added ^^^^^ - Implement savepoints for transactions (#1816) +- Added type validation for foreign key fields to ensure type safety. Now raises ``FieldError`` when assigning foreign key values with incorrect model types (#1792) Fixed ^^^^^ From 115b4c6463ed29d2685af2dee002b627b3025275 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sat, 30 Nov 2024 07:47:59 +0100 Subject: [PATCH 04/14] move validation logic to __setattr__ --- tortoise/models.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tortoise/models.py b/tortoise/models.py index 102a5cf2a..3e0b4b3dd 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -686,6 +686,15 @@ def __setattr__(self, key, value) -> None: # set field value override async default function if hasattr(self, "_await_when_save"): self._await_when_save.pop(key, None) + if value is not None and key in (self._meta.fk_fields | self._meta.o2o_fields): + expected_model = self._meta.fields_map[key].related_model + received_model = type(value) + if received_model is not expected_model: + raise FieldError( + f"Invalid type for relationship field '{key}'. " + f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. " + "Make sure you're using the correct model class for this relationship." + ) super().__setattr__(key, value) def _set_kwargs(self, kwargs: dict) -> Set[str]: @@ -700,14 +709,6 @@ def _set_kwargs(self, kwargs: dict) -> Set[str]: raise OperationalError( f"You should first call .save() on {value} before referring to it" ) - if value and type(value) is not meta.fields_map[key].related_model: - expected_model = meta.fields_map[key].related_model.__name__ - received_model = type(value).__name__ - raise FieldError( - f"Invalid type for foreign key '{key}'. " - f"Expected model type '{expected_model}', but got '{received_model}'. " - f"Make sure you're using the correct model class for this relationship." - ) setattr(self, key, value) passed_fields.add(meta.fields_map[key].source_field) elif key in meta.fields_db_projection: From 027ca7b3ff26252ffb676ae560f4c78144c55c18 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sat, 30 Nov 2024 10:34:28 +0100 Subject: [PATCH 05/14] add fk init/assign tests --- tests/fields/test_fk.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index 299e459d9..b0771ecaa 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -1,6 +1,11 @@ from tests import testmodels from tortoise.contrib import test -from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError +from tortoise.exceptions import ( + FieldError, + IntegrityError, + NoValuesFetched, + OperationalError, +) from tortoise.queryset import QuerySet @@ -151,6 +156,11 @@ async def test_minimal__instantiated_create(self): tour = await testmodels.Tournament.create(name="Team1") await testmodels.MinRelation.create(tournament=tour) + async def test_minimal__instantiated_create_wrong_type(self): + author = await testmodels.Author.create(name="Author1") + with self.assertRaises(FieldError): + await testmodels.MinRelation.create(tournament=author) + async def test_minimal__instantiated_iterate(self): tour = await testmodels.Tournament.create(name="Team1") async for _ in tour.minrelations: @@ -229,3 +239,28 @@ async def test_event__offset(self): event2 = await testmodels.Event.create(name="Event2", tournament=tour) event3 = await testmodels.Event.create(name="Event3", tournament=tour) self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3]) + + async def test_fk_correct_type_assignment(self): + tour1 = await testmodels.Tournament.create(name="Team1") + tour2 = await testmodels.Tournament.create(name="Team2") + event = await testmodels.Event(name="Event1", tournament=tour1) + + event.tournament = tour2 + await event.save() + self.assertEqual(event.tournament_id, tour2.id) + + async def test_fk_wrong_type_assignment(self): + tour = await testmodels.Tournament.create(name="Team1") + author = await testmodels.Author.create(name="Author") + rel = await testmodels.MinRelation.create(tournament=tour) + + with self.assertRaises(FieldError): + rel.tournament = author + + async def test_fk_none_assignment(self): + manager = await testmodels.Employee.create(name="Manager") + employee = await testmodels.Employee.create(name="Employee", manager=manager) + + employee.manager = None + await employee.save() + self.assertIsNone(employee.manager) From 836f6a975bbbfa501482760dca160aa3b81f117f Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 08:16:38 +0100 Subject: [PATCH 06/14] add update validation --- tests/fields/test_fk.py | 8 ++++++++ tortoise/models.py | 20 ++++++++++++-------- tortoise/queryset.py | 1 + 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index b0771ecaa..f3c0f59a1 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -264,3 +264,11 @@ async def test_fk_none_assignment(self): employee.manager = None await employee.save() self.assertIsNone(employee.manager) + + async def test_fk_update_wrong_type(self): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + author = await testmodels.Author.create(name="Author1") + + with self.assertRaises(FieldError): + await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) diff --git a/tortoise/models.py b/tortoise/models.py index 3e0b4b3dd..b49b43158 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -687,14 +687,7 @@ def __setattr__(self, key, value) -> None: if hasattr(self, "_await_when_save"): self._await_when_save.pop(key, None) if value is not None and key in (self._meta.fk_fields | self._meta.o2o_fields): - expected_model = self._meta.fields_map[key].related_model - received_model = type(value) - if received_model is not expected_model: - raise FieldError( - f"Invalid type for relationship field '{key}'. " - f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. " - "Make sure you're using the correct model class for this relationship." - ) + self._validate_relation_type(key, value) super().__setattr__(key, value) def _set_kwargs(self, kwargs: dict) -> Set[str]: @@ -816,6 +809,17 @@ def _set_pk_val(self, value: Any) -> None: Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc... """ + @classmethod + def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None: + expected_model = cls._meta.fields_map[field_key].related_model + received_model = type(value) + if received_model is not expected_model: + raise FieldError( + f"Invalid type for relationship field '{field_key}'. " + f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. " + "Make sure you're using the correct model class for this relationship." + ) + @classmethod async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL: try: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 0911421e7..636221797 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1204,6 +1204,7 @@ def _make_query(self) -> None: if field_object.pk: raise IntegrityError(f"Field {key} is PK and can not be updated") if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)): + self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field value = executor.column_map[fk_field]( From 1d9b93eb4c1a3823fc043cdca4673e147629ac83 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 13:52:44 +0100 Subject: [PATCH 07/14] update changelog --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9ceaf6c46..174ba3873 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1498,4 +1498,4 @@ Docs/examples: await Tournament.filter( events__name__in=['1', '3'] - ).order_by('-events__participants__name').distinct() + ).order_by('-events__participants__name').distinct() \ No newline at end of file From c2ecd6cee0b4fb69d299c8481aff30a2a5505177 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 14:07:40 +0100 Subject: [PATCH 08/14] add bulk create/update tests --- tests/fields/test_fk.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index f3c0f59a1..c713ae21e 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -272,3 +272,22 @@ async def test_fk_update_wrong_type(self): with self.assertRaises(FieldError): await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) + + async def test_fk_bulk_create_wrong_type(self): + author = await testmodels.Author.create(name="Author") + with self.assertRaises(FieldError): + await testmodels.MinRelation.bulk_create( + [testmodels.MinRelation(tournament=author) for _ in range(10)] + ) + + async def test_fk_bulk_update_wrong_type(self): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.MinRelation.bulk_create( + [testmodels.MinRelation(id=rel_id, tournament=tour) for rel_id in range(1, 10)] + ) + author = await testmodels.Author.create(name="Author") + + with self.assertRaises(FieldError): + await testmodels.MinRelation.bulk_update( + [testmodels.MinRelation(id=rel_id, tournament=author) for rel_id in range(1, 10)] + ) From 8735c1a7a208b7fceb8f1625d2e5264d8ec27cd0 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 14:14:58 +0100 Subject: [PATCH 09/14] use RaisesRegex --- tests/fields/test_fk.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index c713ae21e..38098bd01 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -10,6 +10,11 @@ class TestForeignKeyField(test.TestCase): + def assertRaisesWrongTypeException(self, relation_name: str): + return self.assertRaisesRegex( + FieldError, f"Invalid type for relationship field '{relation_name}'" + ) + async def test_empty(self): with self.assertRaises(IntegrityError): await testmodels.MinRelation.create() @@ -158,7 +163,7 @@ async def test_minimal__instantiated_create(self): async def test_minimal__instantiated_create_wrong_type(self): author = await testmodels.Author.create(name="Author1") - with self.assertRaises(FieldError): + with self.assertRaisesWrongTypeException("tournament"): await testmodels.MinRelation.create(tournament=author) async def test_minimal__instantiated_iterate(self): @@ -254,7 +259,7 @@ async def test_fk_wrong_type_assignment(self): author = await testmodels.Author.create(name="Author") rel = await testmodels.MinRelation.create(tournament=tour) - with self.assertRaises(FieldError): + with self.assertRaisesWrongTypeException("tournament"): rel.tournament = author async def test_fk_none_assignment(self): @@ -270,12 +275,12 @@ async def test_fk_update_wrong_type(self): rel = await testmodels.MinRelation.create(tournament=tour) author = await testmodels.Author.create(name="Author1") - with self.assertRaises(FieldError): + with self.assertRaisesWrongTypeException("tournament"): await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) async def test_fk_bulk_create_wrong_type(self): author = await testmodels.Author.create(name="Author") - with self.assertRaises(FieldError): + with self.assertRaisesWrongTypeException("tournament"): await testmodels.MinRelation.bulk_create( [testmodels.MinRelation(tournament=author) for _ in range(10)] ) @@ -287,7 +292,7 @@ async def test_fk_bulk_update_wrong_type(self): ) author = await testmodels.Author.create(name="Author") - with self.assertRaises(FieldError): + with self.assertRaisesWrongTypeException("tournament"): await testmodels.MinRelation.bulk_update( [testmodels.MinRelation(id=rel_id, tournament=author) for rel_id in range(1, 10)] ) From 1503caaad4b0a4feaa6c2e288682c53361a10a7c Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 14:16:58 +0100 Subject: [PATCH 10/14] make None a valid value --- tortoise/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tortoise/models.py b/tortoise/models.py index b49b43158..0f339fc00 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -686,7 +686,7 @@ def __setattr__(self, key, value) -> None: # set field value override async default function if hasattr(self, "_await_when_save"): self._await_when_save.pop(key, None) - if value is not None and key in (self._meta.fk_fields | self._meta.o2o_fields): + if key in self._meta.fk_fields or key in self._meta.o2o_fields: self._validate_relation_type(key, value) super().__setattr__(key, value) @@ -811,6 +811,9 @@ def _set_pk_val(self, value: Any) -> None: @classmethod def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None: + if value is None: + return + expected_model = cls._meta.fields_map[field_key].related_model received_model = type(value) if received_model is not expected_model: From e3f4f266fbdaaca9df1040494790643edddbd28b Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 14:27:28 +0100 Subject: [PATCH 11/14] use ValidationError instead of FieldError --- CHANGELOG.rst | 2 +- tests/fields/test_fk.py | 4 ++-- tortoise/models.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 174ba3873..4fa3bf246 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,7 +14,7 @@ Changelog Added ^^^^^ - Implement savepoints for transactions (#1816) -- Added type validation for foreign key fields to ensure type safety. Now raises ``FieldError`` when assigning foreign key values with incorrect model types (#1792) +- Added type validation for foreign key fields to ensure type safety. Now raises `ValidationError` when assigning foreign key values with incorrect model types (#1792) Fixed ^^^^^ diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index 38098bd01..585a6988a 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -1,10 +1,10 @@ from tests import testmodels from tortoise.contrib import test from tortoise.exceptions import ( - FieldError, IntegrityError, NoValuesFetched, OperationalError, + ValidationError, ) from tortoise.queryset import QuerySet @@ -12,7 +12,7 @@ class TestForeignKeyField(test.TestCase): def assertRaisesWrongTypeException(self, relation_name: str): return self.assertRaisesRegex( - FieldError, f"Invalid type for relationship field '{relation_name}'" + ValidationError, f"Invalid type for relationship field '{relation_name}'" ) async def test_empty(self): diff --git a/tortoise/models.py b/tortoise/models.py index 0f339fc00..630e6e01b 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -30,12 +30,12 @@ from tortoise.exceptions import ( ConfigurationError, DoesNotExist, - FieldError, IncompleteInstanceError, IntegrityError, ObjectDoesNotExistError, OperationalError, ParamsError, + ValidationError, ) from tortoise.expressions import Expression from tortoise.fields.base import Field @@ -817,7 +817,7 @@ def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> No expected_model = cls._meta.fields_map[field_key].related_model received_model = type(value) if received_model is not expected_model: - raise FieldError( + raise ValidationError( f"Invalid type for relationship field '{field_key}'. " f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. " "Make sure you're using the correct model class for this relationship." From 6e3995636fc361171ce805b1481fbe1e7fc00055 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Sun, 1 Dec 2024 14:31:02 +0100 Subject: [PATCH 12/14] handle wrong type passed to validate relation --- tortoise/models.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tortoise/models.py b/tortoise/models.py index 630e6e01b..779ef1033 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -30,6 +30,7 @@ from tortoise.exceptions import ( ConfigurationError, DoesNotExist, + FieldError, IncompleteInstanceError, IntegrityError, ObjectDoesNotExistError, @@ -814,7 +815,14 @@ def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> No if value is None: return - expected_model = cls._meta.fields_map[field_key].related_model + field = cls._meta.fields_map[field_key] + if not isinstance(field, (OneToOneFieldInstance, ForeignKeyFieldInstance)): + raise FieldError( + f"Field '{field_key}' must be a OneToOne or ForeignKey relation, " + f"got {type(field).__name__}" + ) + + expected_model = field.related_model received_model = type(value) if received_model is not expected_model: raise ValidationError( From 9e667a9a4980c67df08449664b5235f40c858769 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Thu, 19 Dec 2024 12:17:31 +0100 Subject: [PATCH 13/14] remove id from test --- tests/fields/test_fk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index 585a6988a..a01651d44 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -288,11 +288,12 @@ async def test_fk_bulk_create_wrong_type(self): async def test_fk_bulk_update_wrong_type(self): tour = await testmodels.Tournament.create(name="Team1") await testmodels.MinRelation.bulk_create( - [testmodels.MinRelation(id=rel_id, tournament=tour) for rel_id in range(1, 10)] + [testmodels.MinRelation(tournament=tour) for _ in range(1, 10)] ) author = await testmodels.Author.create(name="Author") with self.assertRaisesWrongTypeException("tournament"): + relations = await testmodels.MinRelation.all() await testmodels.MinRelation.bulk_update( - [testmodels.MinRelation(id=rel_id, tournament=author) for rel_id in range(1, 10)] + [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations] ) From 65c1cf80ebe3aa23db450a7f8b4d074da1dd9f66 Mon Sep 17 00:00:00 2001 From: Abdeldjalil Hezouat Date: Thu, 19 Dec 2024 13:27:16 +0100 Subject: [PATCH 14/14] add fields to bulk update test --- tests/fields/test_fk.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index a01651d44..ff69cbbd2 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -295,5 +295,6 @@ async def test_fk_bulk_update_wrong_type(self): with self.assertRaisesWrongTypeException("tournament"): relations = await testmodels.MinRelation.all() await testmodels.MinRelation.bulk_update( - [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations] + [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations], + fields=["tournament"], )