diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fe0d7d9c2..412bdb1a2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,12 @@ Changelog 0.21 ==== +0.21.3 +------ +Fixed +^^^^^ +- Fix `bulk_update` when using source_field for pk (#1633) + 0.21.2 ------ Added diff --git a/pyproject.toml b/pyproject.toml index d6ee24b3c..1308a1108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tortoise-orm" -version = "0.21.2" +version = "0.21.3" description = "Easy async ORM for python, built with relations in mind" authors = ["Andrey Bondar ", "Nickolas Grigoriadis ", "long2ice "] license = "Apache-2.0" diff --git a/tests/test_update.py b/tests/test_update.py index c425aef14..2c8dc4e4b 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -15,6 +15,7 @@ JSONFields, Service, SmallIntFields, + SourceFieldPk, Tournament, UUIDFields, ) @@ -68,6 +69,18 @@ async def test_bulk_update_pk_uuid(self): self.assertEqual((await UUIDFields.get(pk=objs[0].pk)).data, objs[0].data) self.assertEqual((await UUIDFields.get(pk=objs[1].pk)).data, objs[1].data) + async def test_bulk_renamed_pk_source_field(self): + objs = [ + await SourceFieldPk.create(name="Model 1"), + await SourceFieldPk.create(name="Model 2"), + ] + objs[0].name = "Model 3" + objs[1].name = "Model 4" + rows_affected = await SourceFieldPk.bulk_update(objs, fields=["name"]) + self.assertEqual(rows_affected, 2) + self.assertEqual((await SourceFieldPk.get(pk=objs[0].pk)).name, objs[0].name) + self.assertEqual((await SourceFieldPk.get(pk=objs[1].pk)).name, objs[1].name) + async def test_bulk_update_json_value(self): objs = [ await JSONFields.create(data={}), diff --git a/tests/testmodels.py b/tests/testmodels.py index f4f78d00a..6a293baf2 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -731,6 +731,11 @@ class Meta: ordering = ["-one"] +class SourceFieldPk(Model): + id = fields.IntField(primary_key=True, source_field="counter") + name = fields.CharField(max_length=255) + + class DefaultOrderedInvalid(Model): one = fields.TextField() second = fields.IntField() diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 80f024756..00df3bba1 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1799,7 +1799,8 @@ def _make_query(self) -> None: ) executor = self._db.executor_class(model=self.model, db=self._db) pk_attr = self.model._meta.pk_attr - pk = Field(pk_attr) + source_pk_attr = self.model._meta.fields_map["id"].source_field or pk_attr + pk = Field(source_pk_attr) for objects_item in chunk(self.objects, self.batch_size): query = copy(self.query) for field in self.fields: